├── Data.zip ├── README.md ├── dee ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_task.cpython-36.pyc │ ├── dee_helper.cpython-36.pyc │ ├── dee_metric.cpython-36.pyc │ ├── dee_model.cpython-36.pyc │ ├── dee_task.cpython-36.pyc │ ├── event_type.cpython-36.pyc │ ├── ner_model.cpython-36.pyc │ ├── ner_task.cpython-36.pyc │ ├── transformer.cpython-36.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-37.pyc ├── base_task.py ├── dee_helper.py ├── dee_metric.py ├── dee_model.py ├── dee_task.py ├── event_type.py ├── ner_model.py ├── ner_task.py ├── transformer.py └── utils.py ├── figs ├── model.png └── result.png ├── run_dee_task.py ├── run_eval.sh ├── run_train.sh └── train_multi.sh /Data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/Data.zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Document-level Event Extraction via Heterogeneous Graph-based Interaction Model with a Tracker 2 | 3 | Source code for ACL-IJCNLP 2021 Long paper: [Document-level Event Extraction via Heterogeneous Graph-based Interaction Model with a Tracker](https://arxiv.org/abs/2105.14924). 4 | 5 | Our code is based on [Doc2EDAG](https://github.com/dolphin-zs/Doc2EDAG). 6 | 7 | 8 | ## 0. Introduction 9 | 10 | > Document-level event extraction aims to extract events within a document. Different from sentence-level event extraction, the arguments of an event record may scatter across sentences, which requires a comprehensive understanding of the cross-sentence context. Besides, a document may express several correlated events simultaneously, and recognizing the interdependency among them is fundamental to successful extraction. To tackle the aforementioned two challenges, We propose a novel heterogeneous Graph-based Interaction Model with a Tracker (GIT). A graph-based interaction network is introduced to capture the global context for the scattered event arguments across sentences with different heterogeneous edges. We also decode event records with a Tracker module, which tracks the extracted event records, so that the interdependency among events is taken into consideration. Our approach delivers better results over the state-of-the-art methods, especially in cross-sentence events and multiple events scenarios. 11 | 12 | 13 | + Architecture 14 | ![model overview](figs/model.png) 15 | 16 | + Overall Results 17 | 18 |
19 | 20 | ## 1. Package Description 21 | ``` 22 | GIT/ 23 | ├─ dee/ 24 | ├── __init__.py 25 | ├── base_task.py 26 | ├── dee_task.py 27 | ├── ner_task.py 28 | ├── dee_helper.py: data features constrcution and evaluation utils 29 | ├── dee_metric.py: data evaluation utils 30 | ├── config.py: process command arguments 31 | ├── dee_model.py: GIT model 32 | ├── ner_model.py 33 | ├── transformer.py: transformer module 34 | ├── utils.py: utils 35 | ├─ run_dee_task.py: the main entry 36 | ├─ train_multi.sh 37 | ├─ run_train.sh: script for training (including evaluation) 38 | ├─ run_eval.sh: script for evaluation 39 | ├─ Exps/: experiment outputs 40 | ├─ Data.zip 41 | ├─ Data: unzip Data.zip 42 | ├─ LICENSE 43 | ├─ README.md 44 | ``` 45 | 46 | ## 2. Environments 47 | 48 | - python (3.6.9) 49 | - cuda (11.1) 50 | - Ubuntu-18.0.4 (5.4.0-73-generic) 51 | 52 | ## 3. Dependencies 53 | 54 | - numpy (1.19.5) 55 | - torch (1.8.1+cu111) 56 | - pytorch-pretrained-bert (0.4.0) 57 | - dgl-cu111 (0.6.1) 58 | - tensorboardX (2.2) 59 | 60 | PS: The environments and dependencies listed here is different from what we use in our paper, so the results may be a bit different. 61 | 62 | ## 4. Preparation 63 | 64 | - Unzip Data.zip and you can get an Data folder, where the training/dev/test data locate. 65 | 66 | ## 5. Training 67 | 68 | ```bash 69 | >> bash run_train.sh 70 | ``` 71 | 72 | ## 6. Evaluation 73 | 74 | ```bash 75 | >> bash run_eval.sh 76 | ``` 77 | 78 | (The evaluation is also conducted after the training) 79 | 80 | ## 7. License 81 | 82 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 83 | 84 | ## 8. Citation 85 | 86 | If you use this work or code, please kindly cite the following paper: 87 | 88 | ```bib 89 | @inproceedings{xu-etal-2021-git, 90 | title = "Document-level Event Extraction via Heterogeneous Graph-based Interaction Model with a Tracker", 91 | author = "Runxin Xu and 92 | Tianyu Liu and 93 | Lei Li and 94 | Baobao Chang", 95 | booktitle = "The Joint Conference of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (ACL-IJCNLP 2021)", 96 | year = "2021", 97 | publisher = "Association for Computational Linguistics", 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /dee/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dee/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dee/__pycache__/base_task.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/base_task.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/dee_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/dee_helper.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/dee_metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/dee_metric.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/dee_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/dee_model.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/dee_task.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/dee_task.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/event_type.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/event_type.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/ner_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/ner_model.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/ner_task.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/ner_task.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /dee/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /dee/base_task.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import logging 4 | import random 5 | import os 6 | import json 7 | import sys 8 | import numpy as np 9 | from datetime import datetime 10 | import torch 11 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 12 | from torch.utils.data.distributed import DistributedSampler 13 | import torch.distributed as dist 14 | import torch.nn.parallel as para 15 | from pytorch_pretrained_bert.optimization import BertAdam 16 | from tqdm import trange, tqdm 17 | from tensorboardX import SummaryWriter 18 | 19 | from .utils import default_dump_pkl, default_dump_json 20 | 21 | PY2 = sys.version_info[0] == 2 22 | PY3 = sys.version_info[0] == 3 23 | if PY2: 24 | import collections 25 | container_abcs = collections 26 | elif PY3: 27 | import collections.abc 28 | container_abcs = collections.abc 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | class TaskSetting(object): 33 | """Base task setting that can be initialized with a dictionary""" 34 | base_key_attrs = ['data_dir', 'model_dir', 'output_dir'] 35 | base_attr_default_pairs = [ 36 | ('bert_model', 'bert-base-chinese'), 37 | ('train_file_name', 'train.json'), 38 | ('dev_file_name', 'dev.json'), 39 | ('test_file_name', 'test.json'), 40 | ('max_seq_len', 128), 41 | ('train_batch_size', 32), 42 | ('eval_batch_size', 256), 43 | ('learning_rate', 1e-4), 44 | ('num_train_epochs', 3.0), 45 | ('warmup_proportion', 0.1), 46 | ('no_cuda', False), 47 | ('local_rank', -1), 48 | ('seed', 99), 49 | ('gradient_accumulation_steps', 1), 50 | ('optimize_on_cpu', False), 51 | ('fp16', False), 52 | ('loss_scale', 128), 53 | ('cpt_file_name', 'task.cpt'), 54 | ('summary_dir_name', '/root/summary'), 55 | ] 56 | 57 | def __init__(self, key_attrs, attr_default_pairs, **kwargs): 58 | for key_attr in TaskSetting.base_key_attrs: 59 | setattr(self, key_attr, kwargs[key_attr]) 60 | 61 | for attr, val in TaskSetting.base_attr_default_pairs: 62 | setattr(self, attr, val) 63 | 64 | for key_attr in key_attrs: 65 | setattr(self, key_attr, kwargs[key_attr]) 66 | 67 | for attr, val in attr_default_pairs: 68 | if attr in kwargs: 69 | setattr(self, attr, kwargs[attr]) 70 | else: 71 | setattr(self, attr, val) 72 | 73 | def update_by_dict(self, config_dict): 74 | for key, val in config_dict.items(): 75 | setattr(self, key, val) 76 | 77 | def dump_to(self, dir_path, file_name='task_setting.json'): 78 | dump_fp = os.path.join(dir_path, file_name) 79 | default_dump_json(self.__dict__, dump_fp) 80 | 81 | def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False): 82 | """ 83 | Utility function for optimize_on_cpu and 16-bits training. 84 | Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model 85 | """ 86 | is_nan = False 87 | for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model): 88 | if name_opti != name_model: 89 | logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) 90 | raise ValueError 91 | if param_model.grad is not None: 92 | if test_nan and torch.isnan(param_model.grad).sum() > 0: 93 | is_nan = True 94 | if param_opti.grad is None: 95 | param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) 96 | param_opti.grad.data.copy_(param_model.grad.data) 97 | else: 98 | param_opti.grad = None 99 | return is_nan 100 | 101 | 102 | def copy_optimizer_params_to_model(named_params_model, named_params_optimizer): 103 | """ 104 | Utility function for optimize_on_cpu and 16-bits training. 105 | Copy the parameters optimized on CPU/RAM back to the model on GPU 106 | """ 107 | for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model): 108 | if name_opti != name_model: 109 | logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) 110 | raise ValueError 111 | param_model.data.copy_(param_opti.data) 112 | 113 | class BasePytorchTask(object): 114 | """Basic task to support deep learning models on Pytorch""" 115 | 116 | def __init__(self, setting, only_master_logging=False): 117 | self.setting = setting 118 | self.logger = logging.getLogger(self.__class__.__name__) 119 | self.only_master_logging = only_master_logging 120 | 121 | if self.in_distributed_mode() and not dist.is_initialized(): 122 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 123 | dist.init_process_group(backend='nccl') 124 | # dist.init_process_group(backend='gloo') # 3 times slower than nccl for gpu training 125 | torch.cuda.set_device(self.setting.local_rank) 126 | self.logging('World Size {} Rank {}, Local Rank {}, Device Num {}, Device {}'.format( 127 | dist.get_world_size(), dist.get_rank(), self.setting.local_rank, 128 | torch.cuda.device_count(), torch.cuda.current_device() 129 | )) 130 | dist.barrier() 131 | 132 | self._check_setting_validity() 133 | self._init_device() 134 | self.reset_random_seed() 135 | self.summary_writer = None 136 | 137 | # ==> task-specific initialization 138 | # The following functions should be called specifically in inherited classes 139 | 140 | self.custom_collate_fn = None 141 | self.train_examples = None 142 | self.train_features = None 143 | self.train_dataset = None 144 | self.dev_examples = None 145 | self.dev_features = None 146 | self.dev_dataset = None 147 | self.test_examples = None 148 | self.test_features = None 149 | self.test_dataset = None 150 | # self._load_data() 151 | 152 | self.model = None 153 | # self._decorate_model() 154 | 155 | self.optimizer = None 156 | self.num_train_steps = None 157 | self.model_named_parameters = None 158 | # self._init_bert_optimizer() 159 | # (option) self.resume_checkpoint() 160 | 161 | def logging(self, msg, level=logging.INFO): 162 | if self.in_distributed_mode(): 163 | msg = 'Rank {} {}'.format(dist.get_rank(), msg) 164 | if self.only_master_logging: 165 | if self.is_master_node(): 166 | self.logger.log(level, msg) 167 | else: 168 | self.logger.log(level, msg) 169 | 170 | def _check_setting_validity(self): 171 | self.logging('='*20 + 'Check Setting Validity' + '='*20) 172 | self.logging('Setting: {}'.format( 173 | json.dumps(self.setting.__dict__, ensure_ascii=False, indent=2) 174 | )) 175 | 176 | # check valid grad accumulate step 177 | if self.setting.gradient_accumulation_steps < 1: 178 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 179 | self.setting.gradient_accumulation_steps)) 180 | # reset train batch size 181 | self.setting.train_batch_size = int(self.setting.train_batch_size 182 | / self.setting.gradient_accumulation_steps) 183 | 184 | # check output dir 185 | if os.path.exists(self.setting.output_dir) and os.listdir(self.setting.output_dir): 186 | self.logging("Output directory ({}) already exists and is not empty.".format(self.setting.output_dir), 187 | level=logging.WARNING) 188 | os.makedirs(self.setting.output_dir, exist_ok=True) 189 | 190 | # check model dir 191 | if os.path.exists(self.setting.model_dir) and os.listdir(self.setting.model_dir): 192 | self.logging("Model directory ({}) already exists and is not empty.".format(self.setting.model_dir), 193 | level=logging.WARNING) 194 | os.makedirs(self.setting.model_dir, exist_ok=True) 195 | 196 | def _init_device(self): 197 | self.logging('='*20 + 'Init Device' + '='*20) 198 | 199 | # set device 200 | if self.setting.local_rank == -1 or self.setting.no_cuda: 201 | self.device = torch.device("cuda" if torch.cuda.is_available() and not self.setting.no_cuda else "cpu") 202 | self.n_gpu = torch.cuda.device_count() 203 | else: 204 | self.device = torch.device("cuda", self.setting.local_rank) 205 | self.n_gpu = 1 206 | if self.setting.fp16: 207 | self.logging("16-bits training currently not supported in distributed training") 208 | self.setting.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496) 209 | self.logging("device {} n_gpu {} distributed training {}".format( 210 | self.device, self.n_gpu,self.in_distributed_mode() 211 | )) 212 | 213 | def reset_random_seed(self, seed=None): 214 | if seed is None: 215 | seed = self.setting.seed 216 | self.logging('='*20 + 'Reset Random Seed to {}'.format(seed) + '='*20) 217 | 218 | # set random seeds 219 | random.seed(seed) 220 | np.random.seed(seed) 221 | torch.manual_seed(seed) 222 | if self.n_gpu > 0: 223 | torch.cuda.manual_seed_all(seed) 224 | 225 | def is_master_node(self): 226 | if self.in_distributed_mode(): 227 | if dist.get_rank() == 0: 228 | return True 229 | else: 230 | return False 231 | else: 232 | return True 233 | 234 | def in_distributed_mode(self): 235 | return self.setting.local_rank >= 0 236 | 237 | def _init_summary_writer(self): 238 | if self.is_master_node(): 239 | self.logging('Init Summary Writer') 240 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 241 | sum_dir = '{}-{}'.format(self.setting.summary_dir_name, current_time) 242 | self.summary_writer = SummaryWriter(sum_dir) 243 | self.logging('Writing summary into {}'.format(sum_dir)) 244 | 245 | if self.in_distributed_mode(): 246 | # TODO: maybe this can be removed 247 | dist.barrier() 248 | 249 | def load_example_feature_dataset(self, load_example_func, convert_to_feature_func, convert_to_dataset_func, 250 | file_name=None, file_path=None): 251 | if file_name is None and file_path is None: 252 | raise Exception('Either file name or file path should be provided') 253 | 254 | if file_path is None: 255 | file_path = os.path.join(self.setting.data_dir, file_name) 256 | 257 | if os.path.exists(file_path): 258 | self.logging('Load example feature dataset from {}'.format(file_path)) 259 | examples = load_example_func(file_path) 260 | features = convert_to_feature_func(examples) 261 | dataset = convert_to_dataset_func(features) 262 | else: 263 | self.logging('Warning: file does not exists, {}'.format(file_path)) 264 | examples = None 265 | features = None 266 | dataset = None 267 | 268 | return examples, features, dataset 269 | 270 | def _load_data(self, load_example_func, convert_to_feature_func, convert_to_dataset_func, 271 | load_train=True, load_dev=True, load_test=True): 272 | self.logging('='*20 + 'Load Task Data' + '='*20) 273 | # prepare data 274 | if load_train: 275 | self.logging('Load train portion') 276 | self.train_examples, self.train_features, self.train_dataset = self.load_example_feature_dataset( 277 | load_example_func, convert_to_feature_func, convert_to_dataset_func, 278 | file_name=self.setting.train_file_name 279 | ) 280 | else: 281 | self.logging('Do not load train portion') 282 | 283 | if load_dev: 284 | self.logging('Load dev portion') 285 | self.dev_examples, self.dev_features, self.dev_dataset = self.load_example_feature_dataset( 286 | load_example_func, convert_to_feature_func, convert_to_dataset_func, 287 | file_name=self.setting.dev_file_name 288 | ) 289 | else: 290 | self.logging('Do not load dev portion') 291 | 292 | if load_test: 293 | self.logging('Load test portion') 294 | self.test_examples, self.test_features, self.test_dataset = self.load_example_feature_dataset( 295 | load_example_func, convert_to_feature_func, convert_to_dataset_func, 296 | file_name=self.setting.test_file_name 297 | ) 298 | else: 299 | self.logging('Do not load test portion') 300 | 301 | def reload_data(self, load_example_func, convert_to_feature_func, convert_to_dataset_func, 302 | data_type='return', file_name=None, file_path=None): 303 | """Subclass should inherit this function to omit function arguments""" 304 | if data_type.lower() == 'train': 305 | self.train_examples, self.train_features, self.train_dataset = \ 306 | self.load_example_feature_dataset( 307 | load_example_func, convert_to_feature_func, convert_to_dataset_func, 308 | file_name=file_name, file_path=file_path 309 | ) 310 | elif data_type.lower() == 'dev': 311 | self.dev_examples, self.dev_features, self.dev_dataset = \ 312 | self.load_example_feature_dataset( 313 | load_example_func, convert_to_feature_func, convert_to_dataset_func, 314 | file_name=file_name, file_path=file_path 315 | ) 316 | elif data_type.lower() == 'test': 317 | self.test_examples, self.test_features, self.test_dataset = \ 318 | self.load_example_feature_dataset( 319 | load_example_func, convert_to_feature_func, convert_to_dataset_func, 320 | file_name=file_name, file_path=file_path 321 | ) 322 | elif data_type.lower() == 'return': 323 | examples, features, dataset = self.load_example_feature_dataset( 324 | load_example_func, convert_to_feature_func, convert_to_dataset_func, 325 | file_name=file_name, file_path=file_path, 326 | ) 327 | 328 | return examples, features, dataset 329 | else: 330 | raise Exception('Unexpected data type {}'.format(data_type)) 331 | 332 | def _decorate_model(self, parallel_decorate=True): 333 | self.logging('='*20 + 'Decorate Model' + '='*20) 334 | 335 | if self.setting.fp16: 336 | self.model.half() 337 | 338 | self.model.to(self.device) 339 | self.logging('Set model device to {}'.format(str(self.device))) 340 | 341 | if parallel_decorate: 342 | if self.in_distributed_mode(): 343 | self.model = para.DistributedDataParallel(self.model, 344 | device_ids=[self.setting.local_rank], 345 | output_device=self.setting.local_rank) 346 | self.logging('Wrap distributed data parallel') 347 | # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper') 348 | elif self.n_gpu > 1: 349 | self.model = para.DataParallel(self.model) 350 | self.logging('Wrap data parallel') 351 | else: 352 | self.logging('Do not wrap parallel layers') 353 | 354 | def _init_bert_optimizer(self): 355 | self.logging('='*20 + 'Init Bert Optimizer' + '='*20) 356 | self.optimizer, self.num_train_steps, self.model_named_parameters = \ 357 | self.reset_bert_optimizer() 358 | 359 | def reset_bert_optimizer(self): 360 | # Prepare optimizer 361 | if self.setting.fp16: 362 | model_named_parameters = [(n, param.clone().detach().to('cpu').float().requires_grad_()) 363 | for n, param in self.model.named_parameters()] 364 | elif self.setting.optimize_on_cpu: 365 | model_named_parameters = [(n, param.clone().detach().to('cpu').requires_grad_()) 366 | for n, param in self.model.named_parameters()] 367 | else: 368 | model_named_parameters = list(self.model.named_parameters()) 369 | 370 | no_decay = ['bias', 'gamma', 'beta'] 371 | optimizer_grouped_parameters = [ 372 | { 373 | 'params': [p for n, p in model_named_parameters if n not in no_decay], 374 | 'weight_decay_rate': 0.01 375 | }, 376 | { 377 | 'params': [p for n, p in model_named_parameters if n in no_decay], 378 | 'weight_decay_rate': 0.0 379 | } 380 | ] 381 | 382 | num_train_steps = int(len(self.train_examples) 383 | / self.setting.train_batch_size 384 | / self.setting.gradient_accumulation_steps 385 | * self.setting.num_train_epochs) 386 | 387 | optimizer = BertAdam(optimizer_grouped_parameters, 388 | lr=self.setting.learning_rate, 389 | warmup=self.setting.warmup_proportion, 390 | t_total=num_train_steps) 391 | 392 | return optimizer, num_train_steps, model_named_parameters 393 | 394 | def prepare_data_loader(self, dataset, batch_size, rand_flag=True): 395 | # prepare data loader 396 | if rand_flag: 397 | data_sampler = RandomSampler(dataset) 398 | else: 399 | data_sampler = SequentialSampler(dataset) 400 | 401 | if self.custom_collate_fn is None: 402 | dataloader = DataLoader(dataset, 403 | batch_size=batch_size, 404 | sampler=data_sampler) 405 | else: 406 | dataloader = DataLoader(dataset, 407 | batch_size=batch_size, 408 | sampler=data_sampler, 409 | collate_fn=self.custom_collate_fn) 410 | 411 | return dataloader 412 | 413 | def prepare_dist_data_loader(self, dataset, batch_size, epoch=0): 414 | # prepare distributed data loader 415 | data_sampler = DistributedSampler(dataset) 416 | data_sampler.set_epoch(epoch) 417 | 418 | if self.custom_collate_fn is None: 419 | dataloader = DataLoader(dataset, 420 | batch_size=batch_size, 421 | sampler=data_sampler) 422 | else: 423 | dataloader = DataLoader(dataset, 424 | batch_size=batch_size, 425 | sampler=data_sampler, 426 | collate_fn=self.custom_collate_fn) 427 | return dataloader 428 | 429 | def get_current_train_batch_size(self): 430 | if self.in_distributed_mode(): 431 | train_batch_size = max(self.setting.train_batch_size // dist.get_world_size(), 1) 432 | else: 433 | train_batch_size = self.setting.train_batch_size 434 | 435 | return train_batch_size 436 | 437 | def set_batch_to_device(self, batch): 438 | # move mini-batch data to the proper device 439 | if isinstance(batch, torch.Tensor): 440 | batch = batch.to(self.device) 441 | 442 | return batch 443 | elif isinstance(batch, dict): 444 | for key, value in batch.items(): 445 | if isinstance(value, torch.Tensor): 446 | batch[key] = value.to(self.device) 447 | elif isinstance(value, dict) or isinstance(value, container_abcs.Sequence): 448 | batch[key] = self.set_batch_to_device(value) 449 | 450 | return batch 451 | elif isinstance(batch, container_abcs.Sequence): 452 | # batch = [ 453 | # t.to(self.device) if isinstance(t, torch.Tensor) else t for t in batch 454 | # ] 455 | new_batch = [] 456 | for value in batch: 457 | if isinstance(value, torch.Tensor): 458 | new_batch.append(value.to(self.device)) 459 | elif isinstance(value, dict) or isinstance(value, container_abcs.Sequence): 460 | new_batch.append(self.set_batch_to_device(value)) 461 | else: 462 | new_batch.append(value) 463 | 464 | return new_batch 465 | else: 466 | raise Exception('Unsupported batch type {}'.format(type(batch))) 467 | 468 | def base_train(self, get_loss_func, kwargs_dict1={}, 469 | epoch_eval_func=None, kwargs_dict2={}, base_epoch_idx=0): 470 | assert self.model is not None 471 | 472 | if self.num_train_steps is None: 473 | self.num_train_steps = round( 474 | self.setting.num_train_epochs * len(self.train_examples) / self.setting.train_batch_size 475 | ) 476 | 477 | train_batch_size = self.get_current_train_batch_size() 478 | 479 | self.logging('='*20 + 'Start Base Training' + '='*20) 480 | self.logging("\tTotal examples Num = {}".format(len(self.train_examples))) 481 | self.logging("\tBatch size = {}".format(self.setting.train_batch_size)) 482 | self.logging("\tNum steps = {}".format(self.num_train_steps)) 483 | if self.in_distributed_mode(): 484 | self.logging("\tWorker Batch Size = {}".format(train_batch_size)) 485 | self._init_summary_writer() 486 | 487 | # prepare data loader 488 | train_dataloader = self.prepare_data_loader( 489 | self.train_dataset, self.setting.train_batch_size, rand_flag=True 490 | ) 491 | 492 | # enter train mode 493 | global_step = 0 494 | self.model.train() 495 | 496 | self.logging('Reach the epoch beginning') 497 | for epoch_idx in trange(base_epoch_idx, int(self.setting.num_train_epochs), desc="Epoch"): 498 | iter_desc = 'Iteration' 499 | self.model.train() 500 | if self.in_distributed_mode(): 501 | train_dataloader = self.prepare_dist_data_loader( 502 | self.train_dataset, train_batch_size, epoch=epoch_idx 503 | ) 504 | iter_desc = 'Rank {} {}'.format(dist.get_rank(), iter_desc) 505 | 506 | tr_loss = 0 507 | nb_tr_examples, nb_tr_steps = 0, 0 508 | 509 | if self.only_master_logging: 510 | if self.is_master_node(): 511 | step_batch_iter = enumerate(tqdm(train_dataloader, desc=iter_desc)) 512 | else: 513 | step_batch_iter = enumerate(train_dataloader) 514 | else: 515 | step_batch_iter = enumerate(tqdm(train_dataloader, desc=iter_desc)) 516 | 517 | for step, batch in step_batch_iter: 518 | batch = self.set_batch_to_device(batch) 519 | 520 | # forward 521 | loss = get_loss_func(self, batch, **kwargs_dict1) 522 | 523 | if self.n_gpu > 1: 524 | loss = loss.mean() # mean() to average on multi-gpu. 525 | if self.setting.fp16 and self.setting.loss_scale != 1.0: 526 | # rescale loss for fp16 training 527 | # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html 528 | loss = loss * self.setting.loss_scale 529 | if self.setting.gradient_accumulation_steps > 1: 530 | loss = loss / self.setting.gradient_accumulation_steps 531 | 532 | # backward 533 | loss.backward() 534 | 535 | loss_scalar = loss.item() 536 | tr_loss += loss_scalar 537 | if self.is_master_node(): 538 | self.summary_writer.add_scalar('Loss', loss_scalar, global_step=global_step) 539 | nb_tr_examples += self.setting.train_batch_size # may not be very accurate due to incomplete batch 540 | nb_tr_steps += 1 541 | if (step + 1) % self.setting.gradient_accumulation_steps == 0: 542 | if self.setting.fp16 or self.setting.optimize_on_cpu: 543 | if self.setting.fp16 and self.setting.loss_scale != 1.0: 544 | # scale down gradients for fp16 training 545 | for param in self.model.parameters(): 546 | param.grad.data = param.grad.data / self.setting.loss_scale 547 | is_nan = set_optimizer_params_grad( 548 | self.model_named_parameters, self.model.named_parameters(), test_nan=True 549 | ) 550 | if is_nan: 551 | self.logging("FP16 TRAINING: Nan in gradients, reducing loss scaling") 552 | self.setting.loss_scale = self.setting.loss_scale / 2 553 | self.model.zero_grad() 554 | continue 555 | self.optimizer.step() 556 | copy_optimizer_params_to_model( 557 | self.model.named_parameters(), self.model_named_parameters 558 | ) 559 | else: 560 | self.optimizer.step() 561 | 562 | self.model.zero_grad() 563 | global_step += 1 564 | 565 | if epoch_eval_func is not None: 566 | epoch_eval_func(self, epoch_idx + 1, **kwargs_dict2) 567 | 568 | def base_eval(self, eval_dataset, get_info_on_batch, reduce_info_type='mean', dump_pkl_path=None, **func_kwargs): 569 | self.logging('='*20 + 'Start Base Evaluation' + '='*20) 570 | self.logging("\tNum examples = {}".format(len(eval_dataset))) 571 | self.logging("\tBatch size = {}".format(self.setting.eval_batch_size)) 572 | self.logging("\tReduce type = {}".format(reduce_info_type)) 573 | 574 | # prepare data loader 575 | eval_dataloader = self.prepare_data_loader( 576 | eval_dataset, self.setting.eval_batch_size, rand_flag=False 577 | ) 578 | 579 | # enter eval mode 580 | total_info = [] 581 | if self.model is not None: 582 | self.model.eval() 583 | 584 | iter_desc = 'Iteration' 585 | if self.in_distributed_mode(): 586 | iter_desc = 'Rank {} {}'.format(dist.get_rank(), iter_desc) 587 | 588 | for step, batch in enumerate(tqdm(eval_dataloader, desc=iter_desc)): 589 | batch = self.set_batch_to_device(batch) 590 | 591 | with torch.no_grad(): 592 | # this func must run batch_info = model(batch_input) 593 | # and metrics is an instance of torch.Tensor with Size([batch_size, ...]) 594 | # to fit the DataParallel and DistributedParallel functionality 595 | batch_info = get_info_on_batch(self, batch, **func_kwargs) 596 | # append metrics from this batch to event_info 597 | if isinstance(batch_info, torch.Tensor): 598 | total_info.append( 599 | batch_info.to(torch.device('cpu')) # collect results in cpu memory 600 | ) 601 | else: 602 | # batch_info is a list of some info on each example 603 | total_info.extend(batch_info) 604 | 605 | if isinstance(total_info[0], torch.Tensor): 606 | # transform event_info to torch.Tensor 607 | total_info = torch.cat(total_info, dim=0) 608 | 609 | # [batch_size, ...] -> [...] 610 | if reduce_info_type.lower() == 'sum': 611 | reduced_info = total_info.sum(dim=0) 612 | elif reduce_info_type.lower() == 'mean': 613 | reduced_info = total_info.mean(dim=0) 614 | elif reduce_info_type.lower() == 'none': 615 | reduced_info = total_info 616 | else: 617 | raise Exception('Unsupported reduce metric type {}'.format(reduce_info_type)) 618 | 619 | if dump_pkl_path is not None: 620 | default_dump_pkl(reduced_info, dump_pkl_path) 621 | 622 | return reduced_info 623 | 624 | def save_checkpoint(self, cpt_file_name=None, epoch=None): 625 | self.logging('='*20 + 'Dump Checkpoint' + '='*20) 626 | if cpt_file_name is None: 627 | cpt_file_name = self.setting.cpt_file_name 628 | cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name) 629 | self.logging('Dump checkpoint into {}'.format(cpt_file_path)) 630 | 631 | store_dict = { 632 | 'setting': self.setting.__dict__, 633 | } 634 | 635 | if self.model: 636 | if isinstance(self.model, para.DataParallel) or \ 637 | isinstance(self.model, para.DistributedDataParallel): 638 | model_state = self.model.module.state_dict() 639 | else: 640 | model_state = self.model.state_dict() 641 | store_dict['model_state'] = model_state 642 | else: 643 | self.logging('No model state is dumped', level=logging.WARNING) 644 | 645 | if self.optimizer: 646 | store_dict['optimizer_state'] = self.optimizer.state_dict() 647 | else: 648 | self.logging('No optimizer state is dumped', level=logging.WARNING) 649 | 650 | if epoch: 651 | store_dict['epoch'] = epoch 652 | 653 | torch.save(store_dict, cpt_file_path) 654 | 655 | def resume_checkpoint(self, cpt_file_path=None, cpt_file_name=None, 656 | resume_model=True, resume_optimizer=False, strict=False): 657 | self.logging('='*20 + 'Resume Checkpoint' + '='*20) 658 | # decide cpt_file_path to resume 659 | if cpt_file_path is None: # use provided path with highest priority 660 | if cpt_file_name is None: # no path and no name will resort to the default cpt name 661 | cpt_file_name = self.setting.cpt_file_name 662 | cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name) 663 | elif cpt_file_name is not None: # error when path and name are both provided 664 | raise Exception('Confused about path {} or file name {} to resume'.format( 665 | cpt_file_path, cpt_file_name 666 | )) 667 | 668 | if os.path.exists(cpt_file_path): 669 | self.logging('Resume checkpoint from {}'.format(cpt_file_path)) 670 | elif strict: 671 | raise Exception('Checkpoint does not exist, {}'.format(cpt_file_path)) 672 | else: 673 | self.logging('Checkpoint does not exist, {}'.format(cpt_file_path), level=logging.WARNING) 674 | return 675 | 676 | if torch.cuda.device_count() == 0: 677 | store_dict = torch.load(cpt_file_path, map_location='cpu') 678 | else: 679 | store_dict = torch.load(cpt_file_path, map_location=self.device) 680 | 681 | self.logging('Setting: {}'.format( 682 | json.dumps(store_dict['setting'], ensure_ascii=False, indent=2) 683 | )) 684 | 685 | if resume_model: 686 | if self.model and 'model_state' in store_dict: 687 | if isinstance(self.model, para.DataParallel) or \ 688 | isinstance(self.model, para.DistributedDataParallel): 689 | self.model.module.load_state_dict(store_dict['model_state']) 690 | else: 691 | self.model.load_state_dict(store_dict['model_state']) 692 | self.logging('Resume model successfully') 693 | elif strict: 694 | raise Exception('Resume model failed, dict.keys = {}'.format(store_dict.keys())) 695 | else: 696 | self.logging('Do not resume model') 697 | 698 | if resume_optimizer: 699 | if self.optimizer and 'optimizer_state' in store_dict: 700 | self.optimizer.load_state_dict(store_dict['optimizer_state']) 701 | self.logging('Resume optimizer successfully') 702 | elif strict: 703 | raise Exception('Resume optimizer failed, dict.keys = {}'.format(store_dict.keys())) 704 | else: 705 | self.logging('Do not resume optimizer') 706 | 707 | 708 | def average_gradients(model): 709 | """ Gradient averaging. """ 710 | size = float(dist.get_world_size()) 711 | for name, param in model.named_parameters(): 712 | try: 713 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) 714 | param.grad.data /= size 715 | except Exception as e: 716 | logger.error('Error when all_reduce parameter {}, size={}, grad_type={}, error message {}'.format( 717 | name, param.size(), param.grad.data.dtype, repr(e) 718 | )) 719 | -------------------------------------------------------------------------------- /dee/dee_helper.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import logging 4 | import os 5 | import re 6 | from collections import defaultdict, Counter 7 | import numpy as np 8 | import torch 9 | 10 | from .dee_metric import measure_event_table_filling 11 | from .event_type import event_type2event_class, BaseEvent, event_type_fields_list, common_fields 12 | from .ner_task import NERExample, NERFeatureConverter 13 | from .utils import default_load_json, default_dump_json, default_dump_pkl, default_load_pkl 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class DEEExample(object): 20 | def __init__(self, annguid, detail_align_dict, only_inference=False): 21 | self.guid = annguid 22 | # [sent_text, ...] 23 | self.sentences = detail_align_dict['sentences'] 24 | self.num_sentences = len(self.sentences) 25 | 26 | if only_inference: 27 | # set empty entity/event information 28 | self.only_inference = True 29 | self.ann_valid_mspans = [] 30 | self.ann_mspan2dranges = {} 31 | self.ann_mspan2guess_field = {} 32 | self.recguid_eventname_eventdict_list = [] 33 | self.num_events = 0 34 | self.sent_idx2srange_mspan_mtype_tuples = {} 35 | self.event_type2event_objs = {} 36 | else: 37 | # set event information accordingly 38 | self.only_inference = False 39 | 40 | # [span_text, ...] 41 | self.ann_valid_mspans = detail_align_dict['ann_valid_mspans'] 42 | # span_text -> [drange_tuple, ...] 43 | self.ann_mspan2dranges = detail_align_dict['ann_mspan2dranges'] 44 | # span_text -> guessed_field_name 45 | self.ann_mspan2guess_field = detail_align_dict['ann_mspan2guess_field'] 46 | # [(recguid, event_name, event_dict), ...] 47 | self.recguid_eventname_eventdict_list = detail_align_dict['recguid_eventname_eventdict_list'] 48 | self.num_events = len(self.recguid_eventname_eventdict_list) 49 | 50 | # for create ner examples 51 | # sentence_index -> [(sent_match_range, match_span, match_type), ...] 52 | self.sent_idx2srange_mspan_mtype_tuples = {} 53 | for sent_idx in range(self.num_sentences): 54 | self.sent_idx2srange_mspan_mtype_tuples[sent_idx] = [] 55 | 56 | for mspan in self.ann_valid_mspans: 57 | for drange in self.ann_mspan2dranges[mspan]: 58 | sent_idx, char_s, char_e = drange 59 | sent_mrange = (char_s, char_e) 60 | 61 | sent_text = self.sentences[sent_idx] 62 | if sent_text[char_s: char_e] != mspan: 63 | raise Exception('GUID: {} span range is not correct, span={}, range={}, sent={}'.format( 64 | annguid, mspan, str(sent_mrange), sent_text 65 | )) 66 | 67 | guess_field = self.ann_mspan2guess_field[mspan] 68 | 69 | self.sent_idx2srange_mspan_mtype_tuples[sent_idx].append( 70 | (sent_mrange, mspan, guess_field) 71 | ) 72 | 73 | # for create event objects 74 | # the length of event_objs should >= 1 75 | self.event_type2event_objs = {} 76 | for mrecguid, event_name, event_dict in self.recguid_eventname_eventdict_list: 77 | event_class = event_type2event_class[event_name] 78 | event_obj = event_class() 79 | assert isinstance(event_obj, BaseEvent) 80 | event_obj.update_by_dict(event_dict, recguid=mrecguid) 81 | 82 | if event_obj.name in self.event_type2event_objs: 83 | self.event_type2event_objs[event_obj.name].append(event_obj) 84 | else: 85 | self.event_type2event_objs[event_name] = [event_obj] 86 | 87 | def __repr__(self): 88 | dee_str = 'DEEExample (\n' 89 | dee_str += ' guid: {},\n'.format(repr(self.guid)) 90 | 91 | if not self.only_inference: 92 | dee_str += ' span info: (\n' 93 | for span_idx, span in enumerate(self.ann_valid_mspans): 94 | gfield = self.ann_mspan2guess_field[span] 95 | dranges = self.ann_mspan2dranges[span] 96 | dee_str += ' {:2} {:20} {:30} {}\n'.format(span_idx, span, gfield, str(dranges)) 97 | dee_str += ' ),\n' 98 | 99 | dee_str += ' event info: (\n' 100 | event_str_list = repr(self.event_type2event_objs).split('\n') 101 | for event_str in event_str_list: 102 | dee_str += ' {}\n'.format(event_str) 103 | dee_str += ' ),\n' 104 | 105 | dee_str += ' sentences: (\n' 106 | for sent_idx, sent in enumerate(self.sentences): 107 | dee_str += ' {:2} {}\n'.format(sent_idx, sent) 108 | dee_str += ' ),\n' 109 | 110 | dee_str += ')\n' 111 | 112 | return dee_str 113 | 114 | @staticmethod 115 | def get_event_type_fields_pairs(): 116 | return list(event_type_fields_list) 117 | 118 | @staticmethod 119 | def get_entity_label_list(): 120 | visit_set = set() 121 | entity_label_list = [NERExample.basic_entity_label] 122 | 123 | for field in common_fields: 124 | if field not in visit_set: 125 | visit_set.add(field) 126 | entity_label_list.extend(['B-' + field, 'I-' + field]) 127 | 128 | for event_name, fields in event_type_fields_list: 129 | for field in fields: 130 | if field not in visit_set: 131 | visit_set.add(field) 132 | entity_label_list.extend(['B-' + field, 'I-' + field]) 133 | 134 | return entity_label_list 135 | 136 | class DEEExampleLoader(object): 137 | def __init__(self, rearrange_sent_flag, max_sent_len): 138 | self.rearrange_sent_flag = rearrange_sent_flag 139 | self.max_sent_len = max_sent_len 140 | 141 | def rearrange_sent_info(self, detail_align_info): 142 | if 'ann_valid_dranges' not in detail_align_info: 143 | detail_align_info['ann_valid_dranges'] = [] 144 | if 'ann_mspan2dranges' not in detail_align_info: 145 | detail_align_info['ann_mspan2dranges'] = {} 146 | 147 | detail_align_info = dict(detail_align_info) 148 | split_rgx = re.compile('[,::;;))]') 149 | 150 | raw_sents = detail_align_info['sentences'] 151 | doc_text = ''.join(raw_sents) 152 | raw_dranges = detail_align_info['ann_valid_dranges'] 153 | raw_sid2span_char_set = defaultdict(lambda: set()) 154 | for raw_sid, char_s, char_e in raw_dranges: 155 | span_char_set = raw_sid2span_char_set[raw_sid] 156 | span_char_set.update(range(char_s, char_e)) 157 | 158 | # try to split long sentences into short ones by comma, colon, semi-colon, bracket 不能把mention切开! 159 | short_sents = [] 160 | for raw_sid, sent in enumerate(raw_sents): 161 | span_char_set = raw_sid2span_char_set[raw_sid] 162 | if len(sent) > self.max_sent_len: 163 | cur_char_s = 0 164 | for mobj in split_rgx.finditer(sent): 165 | m_char_s, m_char_e = mobj.span() 166 | if m_char_s in span_char_set: 167 | continue 168 | short_sents.append(sent[cur_char_s:m_char_e]) 169 | cur_char_s = m_char_e 170 | short_sents.append(sent[cur_char_s:]) 171 | else: 172 | short_sents.append(sent) 173 | 174 | # merge adjacent short sentences to compact ones that match max_sent_len 175 | comp_sents = [''] 176 | for sent in short_sents: 177 | prev_sent = comp_sents[-1] 178 | if len(prev_sent + sent) <= self.max_sent_len: 179 | comp_sents[-1] = prev_sent + sent 180 | else: 181 | comp_sents.append(sent) 182 | 183 | # get global sentence character base indexes 184 | raw_char_bases = [0] 185 | for sent in raw_sents: 186 | raw_char_bases.append(raw_char_bases[-1] + len(sent)) 187 | comp_char_bases = [0] 188 | for sent in comp_sents: 189 | comp_char_bases.append(comp_char_bases[-1] + len(sent)) 190 | 191 | assert raw_char_bases[-1] == comp_char_bases[-1] == len(doc_text) 192 | 193 | # calculate compact doc ranges 194 | raw_dranges.sort() 195 | raw_drange2comp_drange = {} 196 | prev_comp_sid = 0 197 | for raw_drange in raw_dranges: 198 | raw_drange = tuple(raw_drange) # important when json dump change tuple to list 199 | raw_sid, raw_char_s, raw_char_e = raw_drange 200 | raw_char_base = raw_char_bases[raw_sid] 201 | doc_char_s = raw_char_base + raw_char_s 202 | doc_char_e = raw_char_base + raw_char_e 203 | assert doc_char_s >= comp_char_bases[prev_comp_sid] 204 | 205 | cur_comp_sid = prev_comp_sid 206 | for cur_comp_sid in range(prev_comp_sid, len(comp_sents)): 207 | if doc_char_e <= comp_char_bases[cur_comp_sid+1]: 208 | prev_comp_sid = cur_comp_sid 209 | break 210 | comp_char_base = comp_char_bases[cur_comp_sid] 211 | assert comp_char_base <= doc_char_s < doc_char_e <= comp_char_bases[cur_comp_sid+1] 212 | comp_char_s = doc_char_s - comp_char_base 213 | comp_char_e = doc_char_e - comp_char_base 214 | comp_drange = (cur_comp_sid, comp_char_s, comp_char_e) 215 | 216 | raw_drange2comp_drange[raw_drange] = comp_drange 217 | assert raw_sents[raw_drange[0]][raw_drange[1]:raw_drange[2]] == \ 218 | comp_sents[comp_drange[0]][comp_drange[1]:comp_drange[2]] 219 | 220 | # update detailed align info with rearranged sentences 221 | detail_align_info['sentences'] = comp_sents 222 | detail_align_info['ann_valid_dranges'] = [ 223 | raw_drange2comp_drange[tuple(raw_drange)] for raw_drange in detail_align_info['ann_valid_dranges'] 224 | ] 225 | ann_mspan2comp_dranges = {} 226 | for ann_mspan, mspan_raw_dranges in detail_align_info['ann_mspan2dranges'].items(): 227 | comp_dranges = [ 228 | raw_drange2comp_drange[tuple(raw_drange)] for raw_drange in mspan_raw_dranges 229 | ] 230 | ann_mspan2comp_dranges[ann_mspan] = comp_dranges 231 | detail_align_info['ann_mspan2dranges'] = ann_mspan2comp_dranges 232 | 233 | return detail_align_info 234 | 235 | def convert_dict_to_example(self, annguid, detail_align_info, only_inference=False): 236 | if self.rearrange_sent_flag: 237 | detail_align_info = self.rearrange_sent_info(detail_align_info) 238 | dee_example = DEEExample(annguid, detail_align_info, only_inference=only_inference) 239 | 240 | return dee_example 241 | 242 | def __call__(self, dataset_json_path): 243 | total_dee_examples = [] 244 | annguid_aligninfo_list = default_load_json(dataset_json_path) 245 | for annguid, detail_align_info in annguid_aligninfo_list: 246 | # if self.rearrange_sent_flag: 247 | # detail_align_info = self.rearrange_sent_info(detail_align_info) 248 | # dee_example = DEEExample(annguid, detail_align_info) 249 | dee_example = self.convert_dict_to_example(annguid, detail_align_info) 250 | total_dee_examples.append(dee_example) 251 | 252 | return total_dee_examples 253 | 254 | class DEEFeature(object): 255 | def __init__(self, guid, ex_idx, doc_token_id_mat, doc_token_mask_mat, doc_token_label_mat, 256 | span_token_ids_list, span_dranges_list, event_type_labels, event_arg_idxs_objs_list, 257 | valid_sent_num=None): 258 | self.guid = guid 259 | self.ex_idx = ex_idx # example row index, used for backtracking 260 | self.valid_sent_num = valid_sent_num 261 | 262 | # directly set tensor for dee feature to save memory 263 | # self.doc_token_id_mat = doc_token_id_mat 264 | # self.doc_token_mask_mat = doc_token_mask_mat 265 | # self.doc_token_label_mat = doc_token_label_mat 266 | self.doc_token_ids = torch.tensor(doc_token_id_mat, dtype=torch.long) 267 | self.doc_token_masks = torch.tensor(doc_token_mask_mat, dtype=torch.uint8) # uint8 for mask 268 | self.doc_token_labels = torch.tensor(doc_token_label_mat, dtype=torch.long) 269 | 270 | # sorted by the first drange tuple 271 | # [(token_id, ...), ...] 272 | # span_idx -> span_token_id tuple , list of token ids of span 273 | self.span_token_ids_list = span_token_ids_list 274 | # [[(sent_idx, char_s, char_e), ...], ...] 275 | # span_idx -> [drange tuple, ...] 276 | # self.span_dranges_list[i] contains all the mention spans of self.span_token_ids_list[i] entity 277 | self.span_dranges_list = span_dranges_list 278 | 279 | # [event_type_label, ...] 280 | # length = the total number of events to be considered 281 | # event_type_label \in {0, 1}, 0: no 1: yes 282 | self.event_type_labels = event_type_labels # length=5 1: has this event type 0: does not have this event type 283 | # event_type is denoted by the index of event_type_labels 284 | # event_type_idx -> event_obj_idx -> event_arg_idx -> span_idx 285 | # if no event objects, event_type_idx -> None 286 | self.event_arg_idxs_objs_list = event_arg_idxs_objs_list 287 | 288 | # event_type_idx -> event_field_idx -> pre_path -> {span_idx, ...} 289 | # pre_path is tuple of span_idx 290 | self.event_idx2field_idx2pre_path2cur_span_idx_set = self.build_dag_info(self.event_arg_idxs_objs_list) 291 | 292 | # event_type_idx -> key_sent_idx_set, used for key-event sentence detection 293 | self.event_idx2key_sent_idx_set, self.doc_sent_labels = self.build_key_event_sent_info() 294 | 295 | def generate_dag_info_for(self, pred_span_token_tup_list, return_miss=False): 296 | token_tup2pred_span_idx = { 297 | token_tup: pred_span_idx for pred_span_idx, token_tup in enumerate(pred_span_token_tup_list) 298 | } 299 | gold_span_idx2pred_span_idx = {} 300 | # pred_span_idx2gold_span_idx = {} 301 | missed_span_idx_list = [] # in terms of self 302 | missed_sent_idx_list = [] # in terms of self 303 | for gold_span_idx, token_tup in enumerate(self.span_token_ids_list): 304 | if token_tup in token_tup2pred_span_idx: 305 | pred_span_idx = token_tup2pred_span_idx[token_tup] 306 | gold_span_idx2pred_span_idx[gold_span_idx] = pred_span_idx 307 | # pred_span_idx2gold_span_idx[pred_span_idx] = gold_span_idx 308 | else: 309 | missed_span_idx_list.append(gold_span_idx) 310 | for gold_drange in self.span_dranges_list[gold_span_idx]: 311 | missed_sent_idx_list.append(gold_drange[0]) 312 | missed_sent_idx_list = list(set(missed_sent_idx_list)) 313 | 314 | pred_event_arg_idxs_objs_list = [] 315 | for event_arg_idxs_objs in self.event_arg_idxs_objs_list: 316 | if event_arg_idxs_objs is None: 317 | pred_event_arg_idxs_objs_list.append(None) 318 | else: 319 | pred_event_arg_idxs_objs = [] 320 | for event_arg_idxs in event_arg_idxs_objs: 321 | pred_event_arg_idxs = [] 322 | for gold_span_idx in event_arg_idxs: 323 | if gold_span_idx in gold_span_idx2pred_span_idx: 324 | pred_event_arg_idxs.append( 325 | gold_span_idx2pred_span_idx[gold_span_idx] 326 | ) 327 | else: 328 | pred_event_arg_idxs.append(None) 329 | 330 | pred_event_arg_idxs_objs.append(tuple(pred_event_arg_idxs)) 331 | pred_event_arg_idxs_objs_list.append(pred_event_arg_idxs_objs) 332 | 333 | # event_idx -> field_idx -> pre_path -> cur_span_idx_set 334 | pred_dag_info = self.build_dag_info(pred_event_arg_idxs_objs_list) 335 | 336 | if return_miss: 337 | return pred_dag_info, missed_span_idx_list, missed_sent_idx_list 338 | else: 339 | return pred_dag_info 340 | 341 | def get_event_args_objs_list(self): 342 | event_args_objs_list = [] 343 | for event_arg_idxs_objs in self.event_arg_idxs_objs_list: 344 | if event_arg_idxs_objs is None: 345 | event_args_objs_list.append(None) 346 | else: 347 | event_args_objs = [] 348 | for event_arg_idxs in event_arg_idxs_objs: 349 | event_args = [] 350 | for arg_idx in event_arg_idxs: 351 | if arg_idx is None: 352 | token_tup = None 353 | else: 354 | token_tup = self.span_token_ids_list[arg_idx] 355 | event_args.append(token_tup) 356 | event_args_objs.append(event_args) 357 | event_args_objs_list.append(event_args_objs) 358 | 359 | return event_args_objs_list 360 | 361 | def build_key_event_sent_info(self): 362 | assert len(self.event_type_labels) == len(self.event_arg_idxs_objs_list) 363 | # event_idx -> key_event_sent_index_set 364 | event_idx2key_sent_idx_set = [set() for _ in self.event_type_labels] 365 | for key_sent_idx_set, event_label, event_arg_idxs_objs in zip( 366 | event_idx2key_sent_idx_set, self.event_type_labels, self.event_arg_idxs_objs_list 367 | ): 368 | if event_label == 0: 369 | assert event_arg_idxs_objs is None 370 | else: 371 | for event_arg_idxs_obj in event_arg_idxs_objs: 372 | sent_idx_cands = [] 373 | for span_idx in event_arg_idxs_obj: 374 | if span_idx is None: 375 | continue 376 | span_dranges = self.span_dranges_list[span_idx] 377 | for sent_idx, _, _ in span_dranges: 378 | sent_idx_cands.append(sent_idx) 379 | if len(sent_idx_cands) == 0: 380 | raise Exception('Event {} has no valid spans'.format(str(event_arg_idxs_obj))) 381 | sent_idx_cnter = Counter(sent_idx_cands) 382 | key_sent_idx = sent_idx_cnter.most_common()[0][0] 383 | key_sent_idx_set.add(key_sent_idx) 384 | 385 | doc_sent_labels = [] # 1: key event sentence, 0: otherwise 386 | for sent_idx in range(self.valid_sent_num): # masked sents will be truncated at the model part 387 | sent_labels = [] 388 | for key_sent_idx_set in event_idx2key_sent_idx_set: # this mapping is a list 389 | if sent_idx in key_sent_idx_set: 390 | sent_labels.append(1) 391 | else: 392 | sent_labels.append(0) 393 | doc_sent_labels.append(sent_labels) 394 | 395 | return event_idx2key_sent_idx_set, doc_sent_labels 396 | 397 | @staticmethod 398 | def build_dag_info(event_arg_idxs_objs_list): 399 | # event_idx -> field_idx -> pre_path -> {span_idx, ...} 400 | # pre_path is tuple of span_idx 401 | event_idx2field_idx2pre_path2cur_span_idx_set = [] 402 | for event_idx, event_arg_idxs_list in enumerate(event_arg_idxs_objs_list): 403 | if event_arg_idxs_list is None: 404 | event_idx2field_idx2pre_path2cur_span_idx_set.append(None) 405 | else: 406 | num_fields = len(event_arg_idxs_list[0]) 407 | # field_idx -> pre_path -> {span_idx, ...} 408 | field_idx2pre_path2cur_span_idx_set = [] 409 | for field_idx in range(num_fields): 410 | pre_path2cur_span_idx_set = {} 411 | for event_arg_idxs in event_arg_idxs_list: 412 | pre_path = event_arg_idxs[:field_idx] 413 | span_idx = event_arg_idxs[field_idx] 414 | if pre_path not in pre_path2cur_span_idx_set: 415 | pre_path2cur_span_idx_set[pre_path] = set() 416 | pre_path2cur_span_idx_set[pre_path].add(span_idx) 417 | field_idx2pre_path2cur_span_idx_set.append(pre_path2cur_span_idx_set) 418 | event_idx2field_idx2pre_path2cur_span_idx_set.append(field_idx2pre_path2cur_span_idx_set) 419 | 420 | return event_idx2field_idx2pre_path2cur_span_idx_set 421 | 422 | def is_multi_event(self): 423 | event_cnt = 0 424 | for event_objs in self.event_arg_idxs_objs_list: 425 | if event_objs is not None: 426 | event_cnt += len(event_objs) 427 | if event_cnt > 1: 428 | return True 429 | 430 | return False 431 | 432 | class DEEFeatureConverter(object): 433 | def __init__(self, entity_label_list, event_type_fields_pairs, 434 | max_sent_len, max_sent_num, tokenizer, 435 | ner_fea_converter=None, include_cls=True, include_sep=True): 436 | self.entity_label_list = entity_label_list 437 | self.event_type_fields_pairs = event_type_fields_pairs 438 | self.max_sent_len = max_sent_len 439 | self.max_sent_num = max_sent_num 440 | self.tokenizer = tokenizer 441 | self.truncate_doc_count = 0 # track how many docs have been truncated due to max_sent_num 442 | self.truncate_span_count = 0 # track how may spans have been truncated 443 | 444 | # label not in entity_label_list will be default 'O' 445 | # sent_len > max_sent_len will be truncated, and increase ner_fea_converter.truncate_freq 446 | if ner_fea_converter is None: 447 | self.ner_fea_converter = NERFeatureConverter(entity_label_list, self.max_sent_len, tokenizer, 448 | include_cls=include_cls, include_sep=include_sep) 449 | else: 450 | self.ner_fea_converter = ner_fea_converter 451 | 452 | self.include_cls = include_cls 453 | self.include_sep = include_sep 454 | 455 | # prepare entity_label -> entity_index mapping 456 | self.entity_label2index = {} 457 | for entity_idx, entity_label in enumerate(self.entity_label_list): 458 | self.entity_label2index[entity_label] = entity_idx 459 | 460 | # prepare event_type -> event_index and event_index -> event_fields mapping 461 | self.event_type2index = {} 462 | self.event_type_list = [] 463 | self.event_fields_list = [] 464 | for event_idx, (event_type, event_fields) in enumerate(self.event_type_fields_pairs): 465 | self.event_type2index[event_type] = event_idx 466 | self.event_type_list.append(event_type) 467 | self.event_fields_list.append(event_fields) 468 | 469 | def convert_example_to_feature(self, ex_idx, dee_example, log_flag=False): 470 | annguid = dee_example.guid 471 | assert isinstance(dee_example, DEEExample) 472 | 473 | # 1. prepare doc token-level feature 474 | 475 | # Size(num_sent_num, num_sent_len) 476 | doc_token_id_mat = [] # [[token_idx, ...], ...] 477 | doc_token_mask_mat = [] # [[token_mask, ...], ...] 478 | doc_token_label_mat = [] # [[token_label_id, ...], ...] 479 | 480 | for sent_idx, sent_text in enumerate(dee_example.sentences): 481 | if sent_idx >= self.max_sent_num: 482 | # truncate doc whose number of sentences is longer than self.max_sent_num 483 | self.truncate_doc_count += 1 484 | break 485 | 486 | if sent_idx in dee_example.sent_idx2srange_mspan_mtype_tuples: 487 | srange_mspan_mtype_tuples = dee_example.sent_idx2srange_mspan_mtype_tuples[sent_idx] 488 | else: 489 | srange_mspan_mtype_tuples = [] 490 | 491 | # srange_mspan_mtype_tuples in this sentence (span-position,span-text,span-type) 492 | 493 | ner_example = NERExample( 494 | '{}-{}'.format(annguid, sent_idx), sent_text, srange_mspan_mtype_tuples 495 | ) 496 | # sentence truncated count will be recorded incrementally 497 | ner_feature = self.ner_fea_converter.convert_example_to_feature(ner_example, log_flag=log_flag) 498 | 499 | doc_token_id_mat.append(ner_feature.input_ids) 500 | doc_token_mask_mat.append(ner_feature.input_masks) 501 | doc_token_label_mat.append(ner_feature.label_ids) 502 | 503 | # already pad to max_len=128 504 | 505 | assert len(doc_token_id_mat) == len(doc_token_mask_mat) == len(doc_token_label_mat) <= self.max_sent_num 506 | valid_sent_num = len(doc_token_id_mat) 507 | 508 | # 2. prepare span feature 509 | # spans are sorted by the first drange 510 | span_token_ids_list = [] 511 | span_dranges_list = [] 512 | mspan2span_idx = {} 513 | for mspan in dee_example.ann_valid_mspans: 514 | if mspan in mspan2span_idx: 515 | continue 516 | 517 | raw_dranges = dee_example.ann_mspan2dranges[mspan] 518 | char_base_s = 1 if self.include_cls else 0 519 | char_max_end = self.max_sent_len - 1 if self.include_sep else self.max_sent_len 520 | span_dranges = [] 521 | for sent_idx, char_s, char_e in raw_dranges: 522 | if char_base_s + char_e <= char_max_end and sent_idx < self.max_sent_num: 523 | span_dranges.append((sent_idx, char_base_s + char_s, char_base_s + char_e)) 524 | else: 525 | self.truncate_span_count += 1 526 | if len(span_dranges) == 0: 527 | # span does not have any valid location in truncated sequences 528 | continue 529 | 530 | span_tokens = self.tokenizer.char_tokenize(mspan) 531 | span_token_ids = tuple(self.tokenizer.convert_tokens_to_ids(span_tokens)) 532 | 533 | mspan2span_idx[mspan] = len(span_token_ids_list) 534 | span_token_ids_list.append(span_token_ids) 535 | span_dranges_list.append(span_dranges) 536 | assert len(span_token_ids_list) == len(span_dranges_list) == len(mspan2span_idx) 537 | 538 | if len(span_token_ids_list) == 0 and not dee_example.only_inference: 539 | logger.warning('Neglect example {}'.format(ex_idx)) 540 | return None 541 | 542 | # 3. prepare doc-level event feature 543 | # event_type_labels: event_type_index -> event_type_exist_sign (1: exist, 0: no) 544 | # event_arg_idxs_objs_list: event_type_index -> event_obj_index -> event_arg_index -> arg_span_token_ids 545 | 546 | event_type_labels = [] # event_type_idx -> event_type_exist_sign (1 or 0) 547 | event_arg_idxs_objs_list = [] # event_type_idx -> event_obj_idx -> event_arg_idx -> span_idx 548 | for event_idx, event_type in enumerate(self.event_type_list): 549 | event_fields = self.event_fields_list[event_idx] 550 | 551 | if event_type not in dee_example.event_type2event_objs: 552 | event_type_labels.append(0) 553 | event_arg_idxs_objs_list.append(None) 554 | else: 555 | event_objs = dee_example.event_type2event_objs[event_type] 556 | 557 | event_arg_idxs_objs = [] 558 | for event_obj in event_objs: 559 | assert isinstance(event_obj, BaseEvent) 560 | 561 | event_arg_idxs = [] 562 | any_valid_flag = False 563 | for field in event_fields: 564 | arg_span = event_obj.field2content[field] 565 | 566 | if arg_span is None or arg_span not in mspan2span_idx: 567 | # arg_span can be none or valid span is truncated 568 | arg_span_idx = None 569 | else: 570 | # when constructing data files, 571 | # must ensure event arg span is covered by the total span collections 572 | arg_span_idx = mspan2span_idx[arg_span] 573 | any_valid_flag = True 574 | 575 | event_arg_idxs.append(arg_span_idx) 576 | 577 | if any_valid_flag: 578 | event_arg_idxs_objs.append(tuple(event_arg_idxs)) 579 | 580 | if event_arg_idxs_objs: 581 | event_type_labels.append(1) 582 | event_arg_idxs_objs_list.append(event_arg_idxs_objs) 583 | else: 584 | event_type_labels.append(0) 585 | event_arg_idxs_objs_list.append(None) 586 | 587 | dee_feature = DEEFeature( 588 | annguid, ex_idx, doc_token_id_mat, doc_token_mask_mat, doc_token_label_mat, 589 | span_token_ids_list, span_dranges_list, event_type_labels, event_arg_idxs_objs_list, 590 | valid_sent_num=valid_sent_num 591 | ) 592 | return dee_feature 593 | 594 | def __call__(self, dee_examples, log_example_num=0): 595 | """Convert examples to features suitable for document-level event extraction""" 596 | dee_features = [] 597 | self.truncate_doc_count = 0 598 | self.truncate_span_count = 0 599 | self.ner_fea_converter.truncate_count = 0 600 | 601 | remove_ex_cnt = 0 602 | for ex_idx, dee_example in enumerate(dee_examples): 603 | if ex_idx < log_example_num: 604 | dee_feature = self.convert_example_to_feature(ex_idx-remove_ex_cnt, dee_example, log_flag=True) 605 | else: 606 | dee_feature = self.convert_example_to_feature(ex_idx-remove_ex_cnt, dee_example, log_flag=False) 607 | 608 | if dee_feature is None: 609 | remove_ex_cnt += 1 610 | continue 611 | dee_features.append(dee_feature) 612 | 613 | logger.info('{} documents, ignore {} examples, truncate {} docs, {} sents, {} spans'.format( 614 | len(dee_examples), remove_ex_cnt, 615 | self.truncate_doc_count, self.ner_fea_converter.truncate_count, self.truncate_span_count 616 | )) 617 | 618 | return dee_features 619 | 620 | def convert_dee_features_to_dataset(dee_features): 621 | # just view a list of doc_fea as the dataset, that only requires __len__, __getitem__ 622 | assert len(dee_features) > 0 and isinstance(dee_features[0], DEEFeature) 623 | 624 | return dee_features 625 | 626 | def prepare_doc_batch_dict(doc_fea_list): 627 | doc_batch_keys = ['ex_idx', 'doc_token_ids', 'doc_token_masks', 'doc_token_labels', 'valid_sent_num'] 628 | doc_batch_dict = {} 629 | for key in doc_batch_keys: 630 | doc_batch_dict[key] = [getattr(doc_fea, key) for doc_fea in doc_fea_list] 631 | 632 | return doc_batch_dict 633 | 634 | def measure_dee_prediction(event_type_fields_pairs, features, event_decode_results, 635 | dump_json_path=None, writer=None, epoch=None): 636 | pred_record_mat_list = [] 637 | gold_record_mat_list = [] 638 | for term in event_decode_results: 639 | ex_idx, pred_event_type_labels, pred_record_mat, doc_span_info = term[:4] 640 | pred_record_mat = [ 641 | [ 642 | [ 643 | tuple(arg_tup) if arg_tup is not None else None 644 | for arg_tup in pred_record 645 | ] for pred_record in pred_records 646 | ] if pred_records is not None else None 647 | for pred_records in pred_record_mat 648 | ] 649 | doc_fea = features[ex_idx] 650 | assert isinstance(doc_fea, DEEFeature) 651 | gold_record_mat = [ 652 | [ 653 | [ 654 | tuple(doc_fea.span_token_ids_list[arg_idx]) if arg_idx is not None else None 655 | for arg_idx in event_arg_idxs 656 | ] for event_arg_idxs in event_arg_idxs_objs 657 | ] if event_arg_idxs_objs is not None else None 658 | for event_arg_idxs_objs in doc_fea.event_arg_idxs_objs_list 659 | ] 660 | 661 | pred_record_mat_list.append(pred_record_mat) 662 | gold_record_mat_list.append(gold_record_mat) 663 | 664 | g_eval_res = measure_event_table_filling( 665 | pred_record_mat_list, gold_record_mat_list, event_type_fields_pairs, dict_return=True 666 | ) 667 | 668 | if writer is not None and dump_json_path is not None: 669 | if 'dev' in dump_json_path: 670 | prefix = 'Dev-Pred-' if 'pred' in dump_json_path else 'Dev-Gold-' 671 | else: 672 | prefix = 'Test-Pred-' if 'pred' in dump_json_path else 'Test-Gold-' 673 | writer.add_scalar(prefix+'MicroF1', g_eval_res[-1]['MicroF1'], global_step=epoch) 674 | writer.add_scalar(prefix+'MacroF1', g_eval_res[-1]['MacroF1'], global_step=epoch) 675 | writer.add_scalar(prefix+'MicroPrecision', g_eval_res[-1]['MicroPrecision'], global_step=epoch) 676 | writer.add_scalar(prefix+'MicroRecall', g_eval_res[-1]['MicroRecall'], global_step=epoch) 677 | 678 | event_triggering_tp = [0 for _ in range(5)] 679 | event_triggering_fp = [0 for _ in range(5)] 680 | event_triggering_fn = [0 for _ in range(5)] 681 | for term in event_decode_results: 682 | ex_idx, pred_event_type_labels, pred_record_mat, doc_span_info = term[:4] 683 | event_triggering_golden = features[ex_idx].event_type_labels 684 | for et_idx, et in enumerate(pred_event_type_labels): 685 | if pred_event_type_labels[et_idx] == 1: 686 | if event_triggering_golden[et_idx] == 1: 687 | event_triggering_tp[et_idx] += 1 688 | else: 689 | event_triggering_fp[et_idx] += 1 690 | else: 691 | if event_triggering_golden[et_idx] == 1: 692 | event_triggering_fn[et_idx] += 1 693 | 694 | for eidx in range(5): 695 | if event_triggering_tp[eidx]+event_triggering_fp[eidx] != 0: 696 | event_p = event_triggering_tp[eidx] / (event_triggering_tp[eidx]+event_triggering_fp[eidx]) 697 | else: 698 | event_p = 0 699 | if event_triggering_tp[eidx]+event_triggering_fn[eidx] != 0: 700 | event_r = event_triggering_tp[eidx] / (event_triggering_tp[eidx]+event_triggering_fn[eidx]) 701 | else: 702 | event_r = 0 703 | if event_p != 0 and event_r != 0: 704 | event_f1 = 2 * event_p * event_r / (event_p + event_r) 705 | else: 706 | event_f1 = 0 707 | g_eval_res[-1]['event_{}_p'.format(eidx+1)] = event_p 708 | g_eval_res[-1]['event_{}_r'.format(eidx+1)] = event_r 709 | g_eval_res[-1]['event_{}_f1'.format(eidx+1)] = event_f1 710 | 711 | if dump_json_path is not None: 712 | default_dump_json(g_eval_res, dump_json_path) 713 | 714 | return g_eval_res 715 | 716 | def aggregate_task_eval_info(eval_dir_path, target_file_pre='dee_eval', target_file_suffix='.json', 717 | dump_name='total_task_eval.pkl', dump_flag=False): 718 | """Enumerate the evaluation directory to collect all dumped evaluation results""" 719 | logger.info('Aggregate task evaluation info from {}'.format(eval_dir_path)) 720 | data_span_type2model_str2epoch_res_list = {} 721 | for fn in os.listdir(eval_dir_path): 722 | fn_splits = fn.split('.') 723 | if fn.startswith(target_file_pre) and fn.endswith(target_file_suffix) and len(fn_splits) == 6: 724 | _, data_type, span_type, model_str, epoch, _ = fn_splits 725 | 726 | data_span_type = (data_type, span_type) 727 | if data_span_type not in data_span_type2model_str2epoch_res_list: 728 | data_span_type2model_str2epoch_res_list[data_span_type] = {} 729 | model_str2epoch_res_list = data_span_type2model_str2epoch_res_list[data_span_type] 730 | 731 | if model_str not in model_str2epoch_res_list: 732 | model_str2epoch_res_list[model_str] = [] 733 | epoch_res_list = model_str2epoch_res_list[model_str] 734 | 735 | epoch = int(epoch) 736 | fp = os.path.join(eval_dir_path, fn) 737 | eval_res = default_load_json(fp) 738 | 739 | epoch_res_list.append((epoch, eval_res)) 740 | 741 | for data_span_type, model_str2epoch_res_list in data_span_type2model_str2epoch_res_list.items(): 742 | for model_str, epoch_res_list in model_str2epoch_res_list.items(): 743 | epoch_res_list.sort(key=lambda x: x[0]) 744 | 745 | if dump_flag: 746 | dump_fp = os.path.join(eval_dir_path, dump_name) 747 | logger.info('Dumping {} into {}'.format(dump_name, eval_dir_path)) 748 | default_dump_pkl(data_span_type2model_str2epoch_res_list, dump_fp) 749 | 750 | return data_span_type2model_str2epoch_res_list 751 | 752 | def print_total_eval_info(data_span_type2model_str2epoch_res_list, 753 | metric_type='micro', 754 | span_type='pred_span', 755 | model_str='GIT', 756 | target_set='test'): 757 | """Print the final performance by selecting the best epoch on dev set and emitting performance on test set""" 758 | dev_type = 'dev' 759 | test_type = 'test' 760 | avg_type2prf1_keys = { 761 | 'macro': ('MacroPrecision', 'MacroRecall', 'MacroF1'), 762 | 'micro': ('MicroPrecision', 'MicroRecall', 'MicroF1'), 763 | } 764 | 765 | name_key = 'EventType' 766 | p_key, r_key, f_key = avg_type2prf1_keys[metric_type] 767 | 768 | def get_avg_event_score(epoch_res): 769 | eval_res = epoch_res[1] 770 | avg_event_score = eval_res[-1][f_key] 771 | 772 | return avg_event_score 773 | 774 | dev_model_str2epoch_res_list = data_span_type2model_str2epoch_res_list[(dev_type, span_type)] 775 | test_model_str2epoch_res_list = data_span_type2model_str2epoch_res_list[(test_type, span_type)] 776 | 777 | has_header = False 778 | mstr_bepoch_list = [] 779 | print('=' * 15, 'Final Performance (%) (avg_type={})'.format(metric_type), '=' * 15) 780 | 781 | if model_str not in dev_model_str2epoch_res_list or model_str not in test_model_str2epoch_res_list: 782 | pass 783 | else: 784 | # get the best epoch on dev set 785 | dev_epoch_res_list = dev_model_str2epoch_res_list[model_str] 786 | best_dev_epoch, best_dev_res = max(dev_epoch_res_list, key=get_avg_event_score) 787 | 788 | test_epoch_res_list = test_model_str2epoch_res_list[model_str] 789 | best_test_epoch = None 790 | best_test_res = None 791 | for test_epoch, test_res in test_epoch_res_list: 792 | if test_epoch == best_dev_epoch: 793 | best_test_epoch = test_epoch 794 | best_test_res = test_res 795 | assert best_test_epoch is not None 796 | mstr_bepoch_list.append((model_str, best_test_epoch)) 797 | 798 | if target_set == 'test': 799 | target_eval_res = best_test_res 800 | else: 801 | target_eval_res = best_dev_res 802 | 803 | align_temp = '{:20}' 804 | head_str = align_temp.format('ModelType') 805 | eval_str = align_temp.format(model_str) 806 | head_temp = ' \t {}' 807 | eval_temp = ' \t & {:.1f} & {:.1f} & {:.1f}' 808 | ps = [] 809 | rs = [] 810 | fs = [] 811 | for tgt_event_res in target_eval_res[:-1]: 812 | head_str += align_temp.format(head_temp.format(tgt_event_res[0][name_key])) 813 | p, r, f1 = (100 * tgt_event_res[0][key] for key in [p_key, r_key, f_key]) 814 | eval_str += align_temp.format(eval_temp.format(p, r, f1)) 815 | ps.append(p) 816 | rs.append(r) 817 | fs.append(f1) 818 | 819 | head_str += align_temp.format(head_temp.format('Average')) 820 | ap, ar, af1 = (x for x in [np.mean(ps), np.mean(rs), np.mean(fs)]) 821 | eval_str += align_temp.format(eval_temp.format(ap, ar, af1)) 822 | 823 | head_str += align_temp.format(head_temp.format('Total ({})'.format(metric_type))) 824 | g_avg_res = target_eval_res[-1] 825 | ap, ar, af1 = (100 * g_avg_res[key] for key in [p_key, r_key, f_key]) 826 | eval_str += align_temp.format(eval_temp.format(ap, ar, af1)) 827 | 828 | if not has_header: 829 | print(head_str) 830 | has_header = True 831 | print(eval_str) 832 | 833 | print(mstr_bepoch_list) 834 | return mstr_bepoch_list 835 | 836 | # evaluation dump file name template 837 | # dee_eval.[DataType].[SpanType].[ModelStr].[Epoch].(pkl|json) 838 | decode_dump_template = 'dee_eval.{}.{}.{}.{}.pkl' 839 | eval_dump_template = 'dee_eval.{}.{}.{}.{}.json' 840 | 841 | def resume_decode_results(base_dir, data_type, span_type, model_str, epoch): 842 | decode_fn = decode_dump_template.format(data_type, span_type, model_str, epoch) 843 | decode_fp = os.path.join(base_dir, decode_fn) 844 | logger.info('Resume decoded results from {}'.format(decode_fp)) 845 | decode_results = default_load_pkl(decode_fp) 846 | 847 | return decode_results 848 | 849 | def resume_eval_results(base_dir, data_type, span_type, model_str, epoch): 850 | eval_fn = eval_dump_template.format(data_type, span_type, model_str, epoch) 851 | eval_fp = os.path.join(base_dir, eval_fn) 852 | logger.info('Resume eval results from {}'.format(eval_fp)) 853 | eval_results = default_load_json(eval_fp) 854 | 855 | return eval_results 856 | 857 | def print_single_vs_multi_performance(mstr_bepoch_list, base_dir, features, 858 | metric_type='micro', data_type='test', span_type='pred_span'): 859 | model_str2decode_results = {} 860 | for model_str, best_epoch in mstr_bepoch_list: 861 | model_str2decode_results[model_str] = resume_decode_results( 862 | base_dir, data_type, span_type, model_str, best_epoch 863 | ) 864 | 865 | single_eid_set = set([doc_fea.ex_idx for doc_fea in features if not doc_fea.is_multi_event()]) 866 | multi_eid_set = set([doc_fea.ex_idx for doc_fea in features if doc_fea.is_multi_event()]) 867 | event_type_fields_pairs = DEEExample.get_event_type_fields_pairs() 868 | event_type_list = [x for x, y in event_type_fields_pairs] 869 | 870 | name_key = 'EventType' 871 | avg_type2f1_key = { 872 | 'micro': 'MicroF1', 873 | 'macro': 'MacroF1', 874 | } 875 | f1_key = avg_type2f1_key[metric_type] 876 | 877 | model_str2etype_sf1_mf1_list = {} 878 | for model_str, _ in mstr_bepoch_list: 879 | total_decode_results = model_str2decode_results[model_str] 880 | 881 | single_decode_results = [dec_res for dec_res in total_decode_results if dec_res[0] in single_eid_set] 882 | assert len(single_decode_results) == len(single_eid_set) 883 | single_eval_res = measure_dee_prediction( 884 | event_type_fields_pairs, features, single_decode_results 885 | ) 886 | 887 | multi_decode_results = [dec_res for dec_res in total_decode_results if dec_res[0] in multi_eid_set] 888 | assert len(multi_decode_results) == len(multi_eid_set) 889 | multi_eval_res = measure_dee_prediction( 890 | event_type_fields_pairs, features, multi_decode_results 891 | ) 892 | 893 | etype_sf1_mf1_list = [] 894 | for event_idx, (se_res, me_res) in enumerate(zip(single_eval_res[:-1], multi_eval_res[:-1])): 895 | assert se_res[0][name_key] == me_res[0][name_key] == event_type_list[event_idx] 896 | event_type = event_type_list[event_idx] 897 | single_f1 = se_res[0][f1_key] 898 | multi_f1 = me_res[0][f1_key] 899 | 900 | etype_sf1_mf1_list.append((event_type, single_f1, multi_f1)) 901 | g_avg_se_res = single_eval_res[-1] 902 | g_avg_me_res = multi_eval_res[-1] 903 | etype_sf1_mf1_list.append( 904 | ('Total ({})'.format(metric_type), g_avg_se_res[f1_key], g_avg_me_res[f1_key]) 905 | ) 906 | model_str2etype_sf1_mf1_list[model_str] = etype_sf1_mf1_list 907 | 908 | print('=' * 15, 'Single vs. Multi (%) (avg_type={})'.format(metric_type), '=' * 15) 909 | align_temp = '{:20}' 910 | head_str = align_temp.format('ModelType') 911 | head_temp = ' \t {}' 912 | eval_temp = ' \t & {:.1f} & {:.1f} ' 913 | for event_type in event_type_list: 914 | head_str += align_temp.format(head_temp.format(event_type)) 915 | head_str += align_temp.format(head_temp.format('Total ({})'.format(metric_type))) 916 | head_str += align_temp.format(head_temp.format('Average')) 917 | print(head_str) 918 | 919 | for model_str, _ in mstr_bepoch_list: 920 | eval_str = align_temp.format(model_str) 921 | sf1s = [] 922 | mf1s = [] 923 | for _, single_f1, multi_f1 in model_str2etype_sf1_mf1_list[model_str]: 924 | eval_str += align_temp.format(eval_temp.format(single_f1*100, multi_f1*100)) 925 | sf1s.append(single_f1) 926 | mf1s.append(multi_f1) 927 | avg_sf1 = np.mean(sf1s[:-1]) 928 | avg_mf1 = np.mean(mf1s[:-1]) 929 | eval_str += align_temp.format(eval_temp.format(avg_sf1*100, avg_mf1*100)) 930 | print(eval_str) -------------------------------------------------------------------------------- /dee/dee_metric.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import numpy as np 4 | 5 | 6 | def agg_event_role_tpfpfn_stats(pred_records, gold_records, role_num): 7 | """ 8 | Aggregate TP,FP,FN statistics for a single event prediction of one instance. 9 | A pred_records should be formated as 10 | [(Record Index) 11 | ((Role Index) 12 | argument 1, ... 13 | ), ... 14 | ], where argument 1 should support the '=' operation and the empty argument is None. 15 | """ 16 | role_tpfpfn_stats = [[0] * 3 for _ in range(role_num)] 17 | 18 | if gold_records is None: 19 | if pred_records is not None: # FP 20 | for pred_record in pred_records: 21 | assert len(pred_record) == role_num 22 | for role_idx, arg_tup in enumerate(pred_record): 23 | if arg_tup is not None: 24 | role_tpfpfn_stats[role_idx][1] += 1 25 | else: # ignore TN 26 | pass 27 | else: 28 | if pred_records is None: # FN 29 | for gold_record in gold_records: 30 | assert len(gold_record) == role_num 31 | for role_idx, arg_tup in enumerate(gold_record): 32 | if arg_tup is not None: 33 | role_tpfpfn_stats[role_idx][2] += 1 34 | else: # True Positive at the event level 35 | # sort predicted event records by the non-empty count 36 | # to remove the impact of the record order on evaluation 37 | pred_records = sorted(pred_records, 38 | key=lambda x: sum(1 for a in x if a is not None), 39 | reverse=True) 40 | gold_records = list(gold_records) 41 | 42 | while len(pred_records) > 0 and len(gold_records) > 0: 43 | pred_record = pred_records[0] 44 | assert len(pred_record) == role_num 45 | 46 | # pick the most similar gold record 47 | _tmp_key = lambda gr: sum([1 for pa, ga in zip(pred_record, gr) if pa == ga]) 48 | best_gr_idx = gold_records.index(max(gold_records, key=_tmp_key)) 49 | gold_record = gold_records[best_gr_idx] 50 | 51 | for role_idx, (pred_arg, gold_arg) in enumerate(zip(pred_record, gold_record)): 52 | if gold_arg is None: 53 | if pred_arg is not None: # FP at the role level 54 | role_tpfpfn_stats[role_idx][1] += 1 55 | else: # ignore TN 56 | pass 57 | else: 58 | if pred_arg is None: # FN 59 | role_tpfpfn_stats[role_idx][2] += 1 60 | else: 61 | if pred_arg == gold_arg: # TP 62 | role_tpfpfn_stats[role_idx][0] += 1 63 | else: 64 | role_tpfpfn_stats[role_idx][1] += 1 65 | role_tpfpfn_stats[role_idx][2] += 1 66 | 67 | del pred_records[0] 68 | del gold_records[best_gr_idx] 69 | 70 | # remaining FP 71 | for pred_record in pred_records: 72 | assert len(pred_record) == role_num 73 | for role_idx, arg_tup in enumerate(pred_record): 74 | if arg_tup is not None: 75 | role_tpfpfn_stats[role_idx][1] += 1 76 | # remaining FN 77 | for gold_record in gold_records: 78 | assert len(gold_record) == role_num 79 | for role_idx, arg_tup in enumerate(gold_record): 80 | if arg_tup is not None: 81 | role_tpfpfn_stats[role_idx][2] += 1 82 | 83 | return role_tpfpfn_stats 84 | 85 | def agg_event_level_tpfpfn_stats(pred_records, gold_records, role_num): 86 | """ 87 | Get event-level TP,FP,FN 88 | """ 89 | # add role-level statistics as the event-level ones 90 | role_tpfpfn_stats = agg_event_role_tpfpfn_stats( 91 | pred_records, gold_records, role_num 92 | ) 93 | 94 | return list(np.sum(role_tpfpfn_stats, axis=0)) 95 | 96 | def agg_ins_event_role_tpfpfn_stats(pred_record_mat, gold_record_mat, event_role_num_list): 97 | """ 98 | Aggregate TP,FP,FN statistics for a single instance. 99 | A record_mat should be formated as 100 | [(Event Index) 101 | [(Record Index) 102 | ((Role Index) 103 | argument 1, ... 104 | ), ... 105 | ], ... 106 | ], where argument 1 should support the '=' operation and the empty argument is None. 107 | """ 108 | assert len(pred_record_mat) == len(gold_record_mat) 109 | # tpfpfn_stat: TP, FP, FN 110 | event_role_tpfpfn_stats = [] 111 | for event_idx, (pred_records, gold_records) in enumerate(zip(pred_record_mat, gold_record_mat)): 112 | role_num = event_role_num_list[event_idx] 113 | role_tpfpfn_stats = agg_event_role_tpfpfn_stats(pred_records, gold_records, role_num) 114 | event_role_tpfpfn_stats.append(role_tpfpfn_stats) 115 | 116 | return event_role_tpfpfn_stats 117 | 118 | def agg_ins_event_level_tpfpfn_stats(pred_record_mat, gold_record_mat, event_role_num_list): 119 | assert len(pred_record_mat) == len(gold_record_mat) 120 | # tpfpfn_stat: TP, FP, FN 121 | event_tpfpfn_stats = [] 122 | for event_idx, (pred_records, gold_records, role_num) in enumerate(zip( 123 | pred_record_mat, gold_record_mat, event_role_num_list)): 124 | event_tpfpfn = agg_event_level_tpfpfn_stats(pred_records, gold_records, role_num) 125 | event_tpfpfn_stats.append(event_tpfpfn) 126 | 127 | return event_tpfpfn_stats 128 | 129 | def get_prec_recall_f1(tp, fp, fn): 130 | a = tp + fp 131 | prec = tp / a if a > 0 else 0 132 | b = tp + fn 133 | rec = tp / b if b > 0 else 0 134 | if prec > 0 and rec > 0: 135 | f1 = 2.0 / (1 / prec + 1 / rec) 136 | else: 137 | f1 = 0 138 | return prec, rec, f1 139 | 140 | def measure_event_table_filling(pred_record_mat_list, gold_record_mat_list, event_type_roles_list, avg_type='micro', 141 | dict_return=False): 142 | """ 143 | The record_mat_list is formated as 144 | [(Document Index) 145 | [(Event Index) 146 | [(Record Index) 147 | ((Role Index) 148 | argument 1, ... 149 | ), ... 150 | ], ... 151 | ], ... 152 | ] 153 | The argument type should support the '==' operation. 154 | Empty arguments and records are set as None. 155 | """ 156 | event_role_num_list = [len(roles) for _, roles in event_type_roles_list] 157 | # to store total statistics of TP, FP, FN 158 | total_event_role_stats = [ 159 | [ 160 | [0]*3 for _ in range(role_num) 161 | ] for event_idx, role_num in enumerate(event_role_num_list) 162 | ] 163 | 164 | assert len(pred_record_mat_list) == len(gold_record_mat_list) 165 | for pred_record_mat, gold_record_mat in zip(pred_record_mat_list, gold_record_mat_list): 166 | event_role_tpfpfn_stats = agg_ins_event_role_tpfpfn_stats( 167 | pred_record_mat, gold_record_mat, event_role_num_list 168 | ) 169 | for event_idx, role_num in enumerate(event_role_num_list): 170 | for role_idx in range(role_num): 171 | for sid in range(3): 172 | total_event_role_stats[event_idx][role_idx][sid] += \ 173 | event_role_tpfpfn_stats[event_idx][role_idx][sid] 174 | 175 | per_role_metric = [] 176 | per_event_metric = [] 177 | 178 | num_events = len(event_role_num_list) 179 | g_tpfpfn_stat = [0] * 3 180 | g_prf1_stat = [0] * 3 181 | event_role_eval_dicts = [] 182 | for event_idx, role_num in enumerate(event_role_num_list): 183 | event_tpfpfn = [0] * 3 # tp, fp, fn 184 | event_prf1_stat = [0] * 3 185 | per_role_metric.append([]) 186 | role_eval_dicts = [] 187 | for role_idx in range(role_num): 188 | role_tpfpfn_stat = total_event_role_stats[event_idx][role_idx][:3] 189 | role_prf1_stat = get_prec_recall_f1(*role_tpfpfn_stat) 190 | per_role_metric[event_idx].append(role_prf1_stat) 191 | for mid in range(3): 192 | event_tpfpfn[mid] += role_tpfpfn_stat[mid] 193 | event_prf1_stat[mid] += role_prf1_stat[mid] 194 | 195 | role_eval_dict = { 196 | 'RoleType': event_type_roles_list[event_idx][1][role_idx], 197 | 'Precision': role_prf1_stat[0], 198 | 'Recall': role_prf1_stat[1], 199 | 'F1': role_prf1_stat[2], 200 | 'TP': role_tpfpfn_stat[0], 201 | 'FP': role_tpfpfn_stat[1], 202 | 'FN': role_tpfpfn_stat[2] 203 | } 204 | role_eval_dicts.append(role_eval_dict) 205 | 206 | for mid in range(3): 207 | event_prf1_stat[mid] /= role_num 208 | g_tpfpfn_stat[mid] += event_tpfpfn[mid] 209 | g_prf1_stat[mid] += event_prf1_stat[mid] 210 | 211 | micro_event_prf1 = get_prec_recall_f1(*event_tpfpfn) 212 | macro_event_prf1 = tuple(event_prf1_stat) 213 | if avg_type.lower() == 'micro': 214 | event_prf1_stat = micro_event_prf1 215 | elif avg_type.lower() == 'macro': 216 | event_prf1_stat = macro_event_prf1 217 | else: 218 | raise Exception('Unsupported average type {}'.format(avg_type)) 219 | 220 | per_event_metric.append(event_prf1_stat) 221 | 222 | event_eval_dict = { 223 | 'EventType': event_type_roles_list[event_idx][0], 224 | 'MacroPrecision': macro_event_prf1[0], 225 | 'MacroRecall': macro_event_prf1[1], 226 | 'MacroF1': macro_event_prf1[2], 227 | 'MicroPrecision': micro_event_prf1[0], 228 | 'MicroRecall': micro_event_prf1[1], 229 | 'MicroF1': micro_event_prf1[2], 230 | 'TP': event_tpfpfn[0], 231 | 'FP': event_tpfpfn[1], 232 | 'FN': event_tpfpfn[2], 233 | } 234 | event_role_eval_dicts.append((event_eval_dict, role_eval_dicts)) 235 | 236 | micro_g_prf1 = get_prec_recall_f1(*g_tpfpfn_stat) 237 | macro_g_prf1 = tuple(s / num_events for s in g_prf1_stat) 238 | if avg_type.lower() == 'micro': 239 | g_metric = micro_g_prf1 240 | else: 241 | g_metric = macro_g_prf1 242 | 243 | g_eval_dict = { 244 | 'MacroPrecision': macro_g_prf1[0], 245 | 'MacroRecall': macro_g_prf1[1], 246 | 'MacroF1': macro_g_prf1[2], 247 | 'MicroPrecision': micro_g_prf1[0], 248 | 'MicroRecall': micro_g_prf1[1], 249 | 'MicroF1': micro_g_prf1[2], 250 | 'TP': g_tpfpfn_stat[0], 251 | 'FP': g_tpfpfn_stat[1], 252 | 'FN': g_tpfpfn_stat[2], 253 | } 254 | event_role_eval_dicts.append(g_eval_dict) 255 | 256 | if not dict_return: 257 | return g_metric, per_event_metric, per_role_metric 258 | else: 259 | return event_role_eval_dicts 260 | -------------------------------------------------------------------------------- /dee/dee_task.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import logging 4 | import os 5 | import torch.optim as optim 6 | import torch.distributed as dist 7 | from itertools import product 8 | 9 | from .dee_helper import logger, DEEExample, DEEExampleLoader, DEEFeatureConverter, \ 10 | convert_dee_features_to_dataset, prepare_doc_batch_dict, measure_dee_prediction, \ 11 | decode_dump_template, eval_dump_template 12 | from .utils import BERTChineseCharacterTokenizer, default_dump_json, default_load_pkl 13 | from .ner_model import BertForBasicNER 14 | from .base_task import TaskSetting, BasePytorchTask 15 | from .event_type import event_type_fields_list 16 | from .dee_model import GITModel 17 | 18 | 19 | class DEETaskSetting(TaskSetting): 20 | base_key_attrs = TaskSetting.base_key_attrs 21 | base_attr_default_pairs = [ 22 | ('train_file_name', 'train.json'), 23 | ('dev_file_name', 'dev.json'), 24 | ('test_file_name', 'test.json'), 25 | ('summary_dir_name', './tmp/Summary'), 26 | ('max_sent_len', 128), 27 | ('max_sent_num', 64), 28 | ('train_batch_size', 64), 29 | ('gradient_accumulation_steps', 8), 30 | ('eval_batch_size', 2), 31 | ('learning_rate', 1e-4), 32 | ('num_train_epochs', 100), 33 | ('no_cuda', False), 34 | ('local_rank', -1), 35 | ('seed', 99), 36 | ('optimize_on_cpu', False), 37 | ('fp16', False), 38 | ('bert_model', 'bert-base-chinese'), # use which pretrained bert model 39 | ('only_master_logging', True), # whether to print logs from multiple processes 40 | ('resume_latest_cpt', True), # whether to resume latest checkpoints when training for fault tolerance 41 | ('cpt_file_name', 'GIT'), # decide the identity of checkpoints, evaluation results, etc. 42 | ('rearrange_sent', False), # whether to rearrange sentences 43 | ('use_crf_layer', True), # whether to use CRF Layer 44 | ('min_teacher_prob', 0.1), # the minimum prob to use gold spans 45 | ('schedule_epoch_start', 10), # from which epoch the scheduled sampling starts 46 | ('schedule_epoch_length', 10), # the number of epochs to linearly transit to the min_teacher_prob 47 | ('loss_lambda', 0.05), # the proportion of ner loss 48 | ('loss_gamma', 1.0), # the scaling proportion of missed span sentence ner loss 49 | ('seq_reduce_type', 'MaxPooling'), # use 'MaxPooling', 'MeanPooling' or 'AWA' to reduce a tensor sequence 50 | # network parameters (follow Bert Base) 51 | ('hidden_size', 768), 52 | ('dropout', 0.1), 53 | ('ff_size', 1024), # feed-forward mid layer size 54 | ('num_tf_layers', 4), # transformer layer number 55 | # ablation study parameters, 56 | ('use_path_mem', True), # whether to use the memory module when expanding paths 57 | ('use_scheduled_sampling', True), # whether to use the scheduled sampling 58 | ('neg_field_loss_scaling', 3.0), # prefer FNs over FPs 59 | ('gcn_layer', 3), # prefer FNs over FPs 60 | ('ner_num_tf_layers', 8) 61 | ] 62 | 63 | def __init__(self, **kwargs): 64 | super(DEETaskSetting, self).__init__( 65 | self.base_key_attrs, self.base_attr_default_pairs, **kwargs 66 | ) 67 | 68 | 69 | class DEETask(BasePytorchTask): 70 | """Doc-level Event Extraction Task""" 71 | 72 | def __init__(self, dee_setting, load_train=True, load_dev=True, load_test=True, 73 | parallel_decorate=True): 74 | super(DEETask, self).__init__(dee_setting, only_master_logging=dee_setting.only_master_logging) 75 | self.logger = logging.getLogger(self.__class__.__name__) 76 | self.logging('Initializing {}'.format(self.__class__.__name__)) 77 | 78 | self.tokenizer = BERTChineseCharacterTokenizer.from_pretrained(self.setting.bert_model) 79 | self.setting.vocab_size = len(self.tokenizer.vocab) 80 | 81 | # get entity and event label name 82 | self.entity_label_list = DEEExample.get_entity_label_list() 83 | self.event_type_fields_pairs = DEEExample.get_event_type_fields_pairs() # event -> list of entity types 84 | # build example loader 85 | self.example_loader_func = DEEExampleLoader(self.setting.rearrange_sent, self.setting.max_sent_len) 86 | 87 | # build feature converter 88 | self.feature_converter_func = DEEFeatureConverter( 89 | self.entity_label_list, self.event_type_fields_pairs, 90 | self.setting.max_sent_len, self.setting.max_sent_num, self.tokenizer, 91 | include_cls=False, include_sep=False, 92 | ) 93 | 94 | # LOAD DATA 95 | # self.example_loader_func: raw data -> example 96 | # self.feature_converter_func: example -> feature 97 | self._load_data( 98 | self.example_loader_func, self.feature_converter_func, convert_dee_features_to_dataset, 99 | load_train=load_train, load_dev=load_dev, load_test=load_test, 100 | ) 101 | # customized mini-batch producer 102 | self.custom_collate_fn = prepare_doc_batch_dict 103 | 104 | self.setting.num_entity_labels = len(self.entity_label_list) 105 | 106 | ner_model = None 107 | 108 | self.model = GITModel( 109 | self.setting, self.event_type_fields_pairs, ner_model=ner_model, 110 | ) 111 | 112 | self._decorate_model(parallel_decorate=parallel_decorate) 113 | 114 | # prepare optimizer 115 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.setting.learning_rate) 116 | 117 | # # resume option 118 | # if resume_model or resume_optimizer: 119 | # self.resume_checkpoint(resume_model=resume_model, resume_optimizer=resume_optimizer) 120 | 121 | self.min_teacher_prob = None 122 | self.teacher_norm = None 123 | self.teacher_cnt = None 124 | self.teacher_base = None 125 | self.reset_teacher_prob() 126 | 127 | self.logging('Successfully initialize {}'.format(self.__class__.__name__)) 128 | 129 | def reset_teacher_prob(self): 130 | self.min_teacher_prob = self.setting.min_teacher_prob 131 | if self.train_dataset is None: 132 | # avoid crashing when not loading training data 133 | num_step_per_epoch = 500 134 | else: 135 | num_step_per_epoch = int(len(self.train_dataset) / self.setting.train_batch_size) 136 | self.teacher_norm = num_step_per_epoch * self.setting.schedule_epoch_length 137 | self.teacher_base = num_step_per_epoch * self.setting.schedule_epoch_start 138 | self.teacher_cnt = 0 139 | 140 | def get_teacher_prob(self, batch_inc_flag=True): 141 | if self.teacher_cnt < self.teacher_base: 142 | prob = 1 143 | else: 144 | prob = max( 145 | self.min_teacher_prob, (self.teacher_norm - self.teacher_cnt + self.teacher_base) / self.teacher_norm 146 | ) 147 | 148 | if batch_inc_flag: 149 | self.teacher_cnt += 1 150 | 151 | return prob 152 | 153 | def get_event_idx2entity_idx2field_idx(self): 154 | entity_idx2entity_type = {} 155 | for entity_idx, entity_label in enumerate(self.entity_label_list): 156 | if entity_label == 'O': 157 | entity_type = entity_label 158 | else: 159 | entity_type = entity_label[2:] 160 | 161 | entity_idx2entity_type[entity_idx] = entity_type 162 | 163 | event_idx2entity_idx2field_idx = {} 164 | for event_idx, (event_name, field_types) in enumerate(self.event_type_fields_pairs): 165 | field_type2field_idx = {} 166 | for field_idx, field_type in enumerate(field_types): 167 | field_type2field_idx[field_type] = field_idx 168 | 169 | entity_idx2field_idx = {} 170 | for entity_idx, entity_type in entity_idx2entity_type.items(): 171 | if entity_type in field_type2field_idx: 172 | entity_idx2field_idx[entity_idx] = field_type2field_idx[entity_type] 173 | else: 174 | entity_idx2field_idx[entity_idx] = None 175 | 176 | event_idx2entity_idx2field_idx[event_idx] = entity_idx2field_idx 177 | 178 | return event_idx2entity_idx2field_idx 179 | 180 | def get_loss_on_batch(self, doc_batch_dict, features=None): 181 | if features is None: 182 | features = self.train_features 183 | 184 | # teacher_prob = 1 185 | # if use_gold_span, gold spans will be used every time 186 | # else, teacher_prob will ensure the proportion of using gold spans 187 | if self.setting.use_scheduled_sampling: 188 | use_gold_span = False 189 | teacher_prob = self.get_teacher_prob() 190 | else: 191 | use_gold_span = True 192 | teacher_prob = 1 193 | 194 | try: 195 | loss = self.model( 196 | doc_batch_dict, features, use_gold_span=use_gold_span, train_flag=True, teacher_prob=teacher_prob 197 | ) 198 | except Exception as e: 199 | print('-'*30) 200 | print('Exception occurs when processing ' + 201 | ','.join([features[ex_idx].guid for ex_idx in doc_batch_dict['ex_idx']])) 202 | raise Exception('Cannot get the loss') 203 | 204 | return loss 205 | 206 | def get_event_decode_result_on_batch(self, doc_batch_dict, features=None, use_gold_span=False): 207 | if features is None: 208 | raise Exception('Features mush be provided') 209 | 210 | event_idx2entity_idx2field_idx = None 211 | batch_eval_results = self.model( 212 | doc_batch_dict, features, use_gold_span=use_gold_span, train_flag=False, 213 | event_idx2entity_idx2field_idx=event_idx2entity_idx2field_idx, 214 | ) 215 | 216 | return batch_eval_results 217 | 218 | def train(self, save_cpt_flag=True, resume_base_epoch=None): 219 | self.logging('=' * 20 + 'Start Training' + '=' * 20) 220 | self.reset_teacher_prob() 221 | 222 | # resume_base_epoch arguments have higher priority over settings 223 | if resume_base_epoch is None: 224 | # whether to resume latest cpt when restarting, very useful for preemptive scheduling clusters 225 | if self.setting.resume_latest_cpt: 226 | resume_base_epoch = self.get_latest_cpt_epoch() 227 | else: 228 | resume_base_epoch = 0 229 | 230 | # resume cpt if possible 231 | if resume_base_epoch > 0: 232 | self.logging('Training starts from epoch {}'.format(resume_base_epoch)) 233 | for _ in range(resume_base_epoch): 234 | self.get_teacher_prob() 235 | self.resume_cpt_at(resume_base_epoch, resume_model=True, resume_optimizer=True) 236 | else: 237 | self.logging('Training starts from scratch') 238 | 239 | self.base_train( 240 | DEETask.get_loss_on_batch, 241 | kwargs_dict1={}, 242 | epoch_eval_func=DEETask.resume_save_eval_at, 243 | kwargs_dict2={ 244 | 'save_cpt_flag': save_cpt_flag, 245 | 'resume_cpt_flag': False, 246 | }, 247 | base_epoch_idx=resume_base_epoch, 248 | ) 249 | 250 | def resume_save_eval_at(self, epoch, resume_cpt_flag=False, save_cpt_flag=True): 251 | if self.is_master_node(): 252 | print('\nPROGRESS: {:.2f}%\n'.format(epoch / self.setting.num_train_epochs * 100)) 253 | self.logging('Current teacher prob {}'.format(self.get_teacher_prob(batch_inc_flag=False))) 254 | 255 | if resume_cpt_flag: 256 | self.resume_cpt_at(epoch) 257 | 258 | if self.is_master_node() and save_cpt_flag: 259 | self.save_cpt_at(epoch) 260 | 261 | eval_tasks = product(['dev', 'test'], [False, True]) 262 | 263 | for task_idx, (data_type, gold_span_flag) in enumerate(eval_tasks): 264 | if self.in_distributed_mode() and task_idx % dist.get_world_size() != dist.get_rank(): 265 | continue 266 | 267 | if data_type == 'test': 268 | features = self.test_features 269 | dataset = self.test_dataset 270 | elif data_type == 'dev': 271 | features = self.dev_features 272 | dataset = self.dev_dataset 273 | else: 274 | raise Exception('Unsupported data type {}'.format(data_type)) 275 | 276 | if gold_span_flag: 277 | span_str = 'gold_span' 278 | else: 279 | span_str = 'pred_span' 280 | 281 | model_str = self.setting.cpt_file_name.replace('.', '~') 282 | 283 | decode_dump_name = decode_dump_template.format(data_type, span_str, model_str, epoch) 284 | eval_dump_name = eval_dump_template.format(data_type, span_str, model_str, epoch) 285 | self.eval(features, dataset, use_gold_span=gold_span_flag, 286 | dump_decode_pkl_name=decode_dump_name, dump_eval_json_name=eval_dump_name, epoch=epoch) 287 | 288 | def save_cpt_at(self, epoch): 289 | self.save_checkpoint(cpt_file_name='{}.cpt.{}'.format(self.setting.cpt_file_name, epoch), epoch=epoch) 290 | 291 | def resume_cpt_at(self, epoch, resume_model=True, resume_optimizer=False): 292 | self.resume_checkpoint(cpt_file_name='{}.cpt.{}'.format(self.setting.cpt_file_name, epoch), 293 | resume_model=resume_model, resume_optimizer=resume_optimizer) 294 | 295 | def get_latest_cpt_epoch(self): 296 | prev_epochs = [] 297 | for fn in os.listdir(self.setting.model_dir): 298 | if fn.startswith('{}.cpt'.format(self.setting.cpt_file_name)): 299 | try: 300 | epoch = int(fn.split('.')[-1]) 301 | prev_epochs.append(epoch) 302 | except Exception as e: 303 | continue 304 | prev_epochs.sort() 305 | 306 | if len(prev_epochs) > 0: 307 | latest_epoch = prev_epochs[-1] 308 | self.logging('Pick latest epoch {} from {}'.format(latest_epoch, str(prev_epochs))) 309 | else: 310 | latest_epoch = 0 311 | self.logging('No previous epoch checkpoints, just start from scratch') 312 | 313 | return latest_epoch 314 | 315 | def eval(self, features, dataset, use_gold_span=False, 316 | dump_decode_pkl_name=None, dump_eval_json_name=None, epoch=None): 317 | self.logging('=' * 20 + 'Start Evaluation' + '=' * 20) 318 | 319 | if dump_decode_pkl_name is not None: 320 | dump_decode_pkl_path = os.path.join(self.setting.output_dir, dump_decode_pkl_name) 321 | self.logging('Dumping decode results into {}'.format(dump_decode_pkl_name)) 322 | else: 323 | dump_decode_pkl_path = None 324 | 325 | total_event_decode_results = self.base_eval( 326 | dataset, DEETask.get_event_decode_result_on_batch, 327 | reduce_info_type='none', dump_pkl_path=dump_decode_pkl_path, 328 | features=features, use_gold_span=use_gold_span, 329 | ) 330 | 331 | self.logging('Measure DEE Prediction') 332 | 333 | if dump_eval_json_name is not None: 334 | dump_eval_json_path = os.path.join(self.setting.output_dir, dump_eval_json_name) 335 | self.logging('Dumping eval results into {}'.format(dump_eval_json_name)) 336 | else: 337 | dump_eval_json_path = None 338 | 339 | total_eval_res = measure_dee_prediction( 340 | self.event_type_fields_pairs, features, total_event_decode_results, 341 | dump_json_path=dump_eval_json_path, writer=self.summary_writer, epoch=epoch 342 | ) 343 | 344 | return total_event_decode_results, total_eval_res 345 | 346 | def reevaluate_dee_prediction(self, target_file_pre='dee_eval', target_file_suffix='.pkl', 347 | dump_flag=False): 348 | """Enumerate the evaluation directory to collect all dumped evaluation results""" 349 | eval_dir_path = self.setting.output_dir 350 | logger.info('Re-evaluate dee predictions from {}'.format(eval_dir_path)) 351 | data_span_type2model_str2epoch_res_list = {} 352 | for fn in os.listdir(eval_dir_path): 353 | fn_splits = fn.split('.') 354 | if fn.startswith(target_file_pre) and fn.endswith(target_file_suffix) and len(fn_splits) == 6: 355 | _, data_type, span_type, model_str, epoch, _ = fn_splits 356 | 357 | data_span_type = (data_type, span_type) 358 | if data_span_type not in data_span_type2model_str2epoch_res_list: 359 | data_span_type2model_str2epoch_res_list[data_span_type] = {} 360 | model_str2epoch_res_list = data_span_type2model_str2epoch_res_list[data_span_type] 361 | 362 | if model_str not in model_str2epoch_res_list: 363 | model_str2epoch_res_list[model_str] = [] 364 | epoch_res_list = model_str2epoch_res_list[model_str] 365 | 366 | if data_type == 'dev': 367 | features = self.dev_features 368 | elif data_type == 'test': 369 | features = self.test_features 370 | else: 371 | raise Exception('Unsupported data type {}'.format(data_type)) 372 | 373 | epoch = int(epoch) 374 | fp = os.path.join(eval_dir_path, fn) 375 | self.logging('Re-evaluating {}'.format(fp)) 376 | event_decode_results = default_load_pkl(fp) 377 | total_eval_res = measure_dee_prediction( 378 | event_type_fields_list, features, event_decode_results 379 | ) 380 | 381 | if dump_flag: 382 | fp = fp.rstrip('.pkl') + '.json' 383 | self.logging('Dumping {}'.format(fp)) 384 | default_dump_json(total_eval_res, fp) 385 | 386 | epoch_res_list.append((epoch, total_eval_res)) 387 | 388 | for data_span_type, model_str2epoch_res_list in data_span_type2model_str2epoch_res_list.items(): 389 | for model_str, epoch_res_list in model_str2epoch_res_list.items(): 390 | epoch_res_list.sort(key=lambda x: x[0]) 391 | 392 | return data_span_type2model_str2epoch_res_list 393 | 394 | 395 | -------------------------------------------------------------------------------- /dee/event_type.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | 4 | class BaseEvent(object): 5 | def __init__(self, fields, event_name='Event', key_fields=(), recguid=None): 6 | self.recguid = recguid 7 | self.name = event_name 8 | self.fields = list(fields) 9 | self.field2content = {f: None for f in fields} 10 | self.nonempty_count = 0 11 | self.nonempty_ratio = self.nonempty_count / len(self.fields) 12 | 13 | self.key_fields = set(key_fields) 14 | for key_field in self.key_fields: 15 | assert key_field in self.field2content 16 | 17 | def __repr__(self): 18 | event_str = "\n{}[\n".format(self.name) 19 | event_str += " {}={}\n".format("recguid", self.recguid) 20 | event_str += " {}={}\n".format("nonempty_count", self.nonempty_count) 21 | event_str += " {}={:.3f}\n".format("nonempty_ratio", self.nonempty_ratio) 22 | event_str += "] (\n" 23 | for field in self.fields: 24 | if field in self.key_fields: 25 | key_str = " (key)" 26 | else: 27 | key_str = "" 28 | event_str += " " + field + "=" + str(self.field2content[field]) + ", {}\n".format(key_str) 29 | event_str += ")\n" 30 | return event_str 31 | 32 | def update_by_dict(self, field2text, recguid=None): 33 | self.nonempty_count = 0 34 | self.recguid = recguid 35 | 36 | for field in self.fields: 37 | if field in field2text and field2text[field] is not None: 38 | self.nonempty_count += 1 39 | self.field2content[field] = field2text[field] 40 | else: 41 | self.field2content[field] = None 42 | 43 | self.nonempty_ratio = self.nonempty_count / len(self.fields) 44 | 45 | def field_to_dict(self): 46 | return dict(self.field2content) 47 | 48 | def set_key_fields(self, key_fields): 49 | self.key_fields = set(key_fields) 50 | 51 | def is_key_complete(self): 52 | for key_field in self.key_fields: 53 | if self.field2content[key_field] is None: 54 | return False 55 | 56 | return True 57 | 58 | def is_good_candidate(self): 59 | raise NotImplementedError() 60 | 61 | def get_argument_tuple(self): 62 | args_tuple = tuple(self.field2content[field] for field in self.fields) 63 | return args_tuple 64 | 65 | class EquityFreezeEvent(BaseEvent): 66 | NAME = 'EquityFreeze' 67 | FIELDS = [ 68 | 'EquityHolder', 69 | 'FrozeShares', 70 | 'LegalInstitution', 71 | 'TotalHoldingShares', 72 | 'TotalHoldingRatio', 73 | 'StartDate', 74 | 'EndDate', 75 | 'UnfrozeDate', 76 | ] 77 | 78 | def __init__(self, recguid=None): 79 | super().__init__( 80 | EquityFreezeEvent.FIELDS, event_name=EquityFreezeEvent.NAME, recguid=recguid 81 | ) 82 | self.set_key_fields([ 83 | 'EquityHolder', 84 | 'FrozeShares', 85 | 'LegalInstitution', 86 | ]) 87 | 88 | def is_good_candidate(self, min_match_count=5): 89 | key_flag = self.is_key_complete() 90 | if key_flag: 91 | if self.nonempty_count >= min_match_count: 92 | return True 93 | return False 94 | 95 | class EquityRepurchaseEvent(BaseEvent): 96 | NAME = 'EquityRepurchase' 97 | FIELDS = [ 98 | 'CompanyName', 99 | 'HighestTradingPrice', 100 | 'LowestTradingPrice', 101 | 'RepurchasedShares', 102 | 'ClosingDate', 103 | 'RepurchaseAmount', 104 | ] 105 | 106 | def __init__(self, recguid=None): 107 | super().__init__( 108 | EquityRepurchaseEvent.FIELDS, event_name=EquityRepurchaseEvent.NAME, recguid=recguid 109 | ) 110 | self.set_key_fields([ 111 | 'CompanyName', 112 | ]) 113 | 114 | def is_good_candidate(self, min_match_count=4): 115 | key_flag = self.is_key_complete() 116 | if key_flag: 117 | if self.nonempty_count >= min_match_count: 118 | return True 119 | return False 120 | 121 | class EquityUnderweightEvent(BaseEvent): 122 | NAME = 'EquityUnderweight' 123 | FIELDS = [ 124 | 'EquityHolder', 125 | 'TradedShares', 126 | 'StartDate', 127 | 'EndDate', 128 | 'LaterHoldingShares', 129 | 'AveragePrice', 130 | ] 131 | 132 | def __init__(self, recguid=None): 133 | super().__init__( 134 | EquityUnderweightEvent.FIELDS, event_name=EquityUnderweightEvent.NAME, recguid=recguid 135 | ) 136 | self.set_key_fields([ 137 | 'EquityHolder', 138 | 'TradedShares', 139 | ]) 140 | 141 | def is_good_candidate(self, min_match_count=4): 142 | key_flag = self.is_key_complete() 143 | if key_flag: 144 | if self.nonempty_count >= min_match_count: 145 | return True 146 | return False 147 | 148 | class EquityOverweightEvent(BaseEvent): 149 | NAME = 'EquityOverweight' 150 | FIELDS = [ 151 | 'EquityHolder', 152 | 'TradedShares', 153 | 'StartDate', 154 | 'EndDate', 155 | 'LaterHoldingShares', 156 | 'AveragePrice', 157 | ] 158 | 159 | def __init__(self, recguid=None): 160 | super().__init__( 161 | EquityOverweightEvent.FIELDS, event_name=EquityOverweightEvent.NAME, recguid=recguid 162 | ) 163 | self.set_key_fields([ 164 | 'EquityHolder', 165 | 'TradedShares', 166 | ]) 167 | 168 | def is_good_candidate(self, min_match_count=4): 169 | key_flag = self.is_key_complete() 170 | if key_flag: 171 | if self.nonempty_count >= min_match_count: 172 | return True 173 | return False 174 | 175 | class EquityPledgeEvent(BaseEvent): 176 | NAME = 'EquityPledge' 177 | FIELDS = [ 178 | 'Pledger', 179 | 'PledgedShares', 180 | 'Pledgee', 181 | 'TotalHoldingShares', 182 | 'TotalHoldingRatio', 183 | 'TotalPledgedShares', 184 | 'StartDate', 185 | 'EndDate', 186 | 'ReleasedDate', 187 | ] 188 | 189 | def __init__(self, recguid=None): 190 | # super(EquityPledgeEvent, self).__init__( 191 | super().__init__( 192 | EquityPledgeEvent.FIELDS, event_name=EquityPledgeEvent.NAME, recguid=recguid 193 | ) 194 | self.set_key_fields([ 195 | 'Pledger', 196 | 'PledgedShares', 197 | 'Pledgee', 198 | ]) 199 | 200 | def is_good_candidate(self, min_match_count=5): 201 | key_flag = self.is_key_complete() 202 | if key_flag: 203 | if self.nonempty_count >= min_match_count: 204 | return True 205 | return False 206 | 207 | common_fields = ['StockCode', 'StockAbbr', 'CompanyName'] 208 | 209 | event_type2event_class = { 210 | EquityFreezeEvent.NAME: EquityFreezeEvent, 211 | EquityRepurchaseEvent.NAME: EquityRepurchaseEvent, 212 | EquityUnderweightEvent.NAME: EquityUnderweightEvent, 213 | EquityOverweightEvent.NAME: EquityOverweightEvent, 214 | EquityPledgeEvent.NAME: EquityPledgeEvent, 215 | } 216 | 217 | event_type_fields_list = [ 218 | (EquityFreezeEvent.NAME, EquityFreezeEvent.FIELDS), 219 | (EquityRepurchaseEvent.NAME, EquityRepurchaseEvent.FIELDS), 220 | (EquityUnderweightEvent.NAME, EquityUnderweightEvent.FIELDS), 221 | (EquityOverweightEvent.NAME, EquityOverweightEvent.FIELDS), 222 | (EquityPledgeEvent.NAME, EquityPledgeEvent.FIELDS), 223 | ] 224 | -------------------------------------------------------------------------------- /dee/ner_model.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | from pytorch_pretrained_bert.modeling import PreTrainedBertModel, BertModel 9 | 10 | from . import transformer 11 | 12 | 13 | class BertForBasicNER(PreTrainedBertModel): 14 | """BERT model for basic NER functionality. 15 | This module is composed of the BERT model with a linear layer on top of 16 | the output sequences. 17 | 18 | Params: 19 | `config`: a BertConfig class instance with the configuration to build a new model. 20 | `num_entity_labels`: the number of entity classes for the classifier. 21 | 22 | Inputs: 23 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 24 | with the word token indices in the vocabulary. 25 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 26 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 27 | a `sentence B` token (see BERT paper for more details). 28 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 29 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 30 | input sequence length in the current batch. It's the mask that we typically use for attention when 31 | a batch has varying length sentences. 32 | `label_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 33 | with label indices selected in [0, ..., num_labels-1]. 34 | 35 | Outputs: 36 | if `labels` is not `None`: 37 | Outputs the CrossEntropy classification loss of the output with the labels. 38 | if `labels` is `None`: 39 | Outputs the classification logits sequence. 40 | """ 41 | 42 | def __init__(self, config, num_entity_labels): 43 | super(BertForBasicNER, self).__init__(config) 44 | self.bert = BertModel(config) 45 | 46 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 47 | self.classifier = nn.Linear(config.hidden_size, num_entity_labels) 48 | self.apply(self.init_bert_weights) 49 | 50 | self.num_entity_labels = num_entity_labels 51 | 52 | def old_forward(self, input_ids, input_masks, 53 | token_type_ids=None, label_ids=None, 54 | eval_flag=False, eval_for_metric=True): 55 | """Assume input size [batch_size, seq_len]""" 56 | if input_masks.dtype != torch.uint8: 57 | input_masks = input_masks == 1 58 | 59 | enc_seq_out, _ = self.bert(input_ids, 60 | token_type_ids=token_type_ids, 61 | attention_mask=input_masks, 62 | output_all_encoded_layers=False) 63 | # [batch_size, seq_len, hidden_size] 64 | enc_seq_out = self.dropout(enc_seq_out) 65 | # [batch_size, seq_len, num_entity_labels] 66 | seq_logits = self.classifier(enc_seq_out) 67 | 68 | if eval_flag: # if for evaluation purpose 69 | if label_ids is None: 70 | raise Exception('Cannot do evaluation without label info') 71 | else: 72 | if eval_for_metric: 73 | batch_metrics = produce_ner_batch_metrics(seq_logits, label_ids, input_masks) 74 | return batch_metrics 75 | else: 76 | seq_logp = F.log_softmax(seq_logits, dim=-1) 77 | seq_pred = seq_logp.argmax(dim=-1, keepdim=True) # [batch_size, seq_len, 1] 78 | seq_gold = label_ids.unsqueeze(-1) # [batch_size, seq_len, 1] 79 | seq_mask = input_masks.unsqueeze(-1).long() # [batch_size, seq_len, 1] 80 | seq_pred_gold_mask = torch.cat([seq_pred, seq_gold, seq_mask], dim=-1) # [batch_size, seq_len, 3] 81 | return seq_pred_gold_mask 82 | elif label_ids is not None: # if has label_ids, calculate the loss 83 | # [num_valid_token, num_entity_labels] 84 | batch_logits = seq_logits[input_masks, :] 85 | # [num_valid_token], lid \in {0,..., num_entity_labels-1} 86 | batch_labels = label_ids[input_masks] 87 | loss = F.cross_entropy(batch_logits, batch_labels) 88 | return loss, enc_seq_out 89 | else: # just reture seq_pred_logps 90 | return F.log_softmax(seq_logits, dim=-1), enc_seq_out 91 | 92 | def forward(self, input_ids, input_masks, 93 | label_ids=None, train_flag=True, decode_flag=True): 94 | """Assume input size [batch_size, seq_len]""" 95 | if input_masks.dtype != torch.uint8: 96 | input_masks = input_masks == 1 97 | 98 | batch_seq_enc, _ = self.bert(input_ids, 99 | attention_mask=input_masks, 100 | output_all_encoded_layers=False) 101 | # [batch_size, seq_len, hidden_size] 102 | batch_seq_enc = self.dropout(batch_seq_enc) 103 | # [batch_size, seq_len, num_entity_labels] 104 | batch_seq_logits = self.classifier(batch_seq_enc) 105 | 106 | batch_seq_logp = F.log_softmax(batch_seq_logits, dim=-1) 107 | 108 | if train_flag: 109 | batch_logp = batch_seq_logp.view(-1, batch_seq_logp.size(-1)) 110 | batch_label = label_ids.view(-1) 111 | # ner_loss = F.nll_loss(batch_logp, batch_label, reduction='sum') 112 | ner_loss = F.nll_loss(batch_logp, batch_label, reduction='none') 113 | ner_loss = ner_loss.view(label_ids.size()).sum(dim=-1) # [batch_size] 114 | else: 115 | ner_loss = None 116 | 117 | if decode_flag: 118 | batch_seq_preds = batch_seq_logp.argmax(dim=-1) 119 | else: 120 | batch_seq_preds = None 121 | 122 | return batch_seq_enc, ner_loss, batch_seq_preds 123 | 124 | class NERModel(nn.Module): 125 | def __init__(self, config): 126 | super(NERModel, self).__init__() 127 | 128 | self.config = config 129 | # Word Embedding, Word Local Position Embedding 130 | self.token_embedding = NERTokenEmbedding( 131 | config.vocab_size, config.hidden_size, 132 | max_sent_len=config.max_sent_len, dropout=config.dropout 133 | ) 134 | # Multi-layer Transformer Layers to Incorporate Contextual Information 135 | self.token_encoder = transformer.make_transformer_encoder( 136 | config.ner_num_tf_layers, config.hidden_size, ff_size=config.ff_size, dropout=config.dropout 137 | ) 138 | if self.config.use_crf_layer: 139 | self.crf_layer = CRFLayer(config.hidden_size, self.config.num_entity_labels) 140 | else: 141 | # Token Label Classification 142 | self.classifier = nn.Linear(config.hidden_size, self.config.num_entity_labels) 143 | 144 | def forward(self, input_ids, input_masks, 145 | label_ids=None, train_flag=True, decode_flag=True): 146 | """Assume input size [batch_size, seq_len]""" 147 | if input_masks.dtype != torch.uint8: 148 | input_masks = input_masks == 1 149 | if train_flag: 150 | assert label_ids is not None 151 | 152 | # get contextual info 153 | input_emb = self.token_embedding(input_ids) 154 | input_masks = input_masks.unsqueeze(-2) # to fit for the transformer code 155 | batch_seq_enc = self.token_encoder(input_emb, input_masks) 156 | 157 | if self.config.use_crf_layer: 158 | ner_loss, batch_seq_preds = self.crf_layer( 159 | batch_seq_enc, seq_token_label=label_ids, batch_first=True, 160 | train_flag=train_flag, decode_flag=decode_flag 161 | ) 162 | else: 163 | # [batch_size, seq_len, num_entity_labels] 164 | batch_seq_logits = self.classifier(batch_seq_enc) 165 | batch_seq_logp = F.log_softmax(batch_seq_logits, dim=-1) 166 | 167 | if train_flag: 168 | batch_logp = batch_seq_logp.view(-1, batch_seq_logp.size(-1)) 169 | batch_label = label_ids.view(-1) 170 | # ner_loss = F.nll_loss(batch_logp, batch_label, reduction='sum') 171 | ner_loss = F.nll_loss(batch_logp, batch_label, reduction='none') 172 | ner_loss = ner_loss.view(label_ids.size()).sum(dim=-1) # [batch_size] 173 | else: 174 | ner_loss = None 175 | 176 | if decode_flag: 177 | batch_seq_preds = batch_seq_logp.argmax(dim=-1) 178 | else: 179 | batch_seq_preds = None 180 | 181 | return batch_seq_enc, ner_loss, batch_seq_preds 182 | 183 | class NERTokenEmbedding(nn.Module): 184 | """Add token position information""" 185 | def __init__(self, vocab_size, hidden_size, max_sent_len=256, dropout=0.1): 186 | super(NERTokenEmbedding, self).__init__() 187 | 188 | self.token_embedding = nn.Embedding(vocab_size, hidden_size) 189 | self.pos_embedding = nn.Embedding(max_sent_len, hidden_size) 190 | 191 | self.layer_norm = transformer.LayerNorm(hidden_size) 192 | self.dropout = nn.Dropout(dropout) 193 | 194 | def forward(self, batch_token_ids): 195 | batch_size, sent_len = batch_token_ids.size() 196 | device = batch_token_ids.device 197 | 198 | batch_pos_ids = torch.arange( 199 | sent_len, dtype=torch.long, device=device, requires_grad=False 200 | ) 201 | batch_pos_ids = batch_pos_ids.unsqueeze(0).expand_as(batch_token_ids) 202 | 203 | batch_token_emb = self.token_embedding(batch_token_ids) 204 | batch_pos_emb = self.pos_embedding(batch_pos_ids) 205 | 206 | batch_token_emb = batch_token_emb + batch_pos_emb 207 | 208 | batch_token_out = self.layer_norm(batch_token_emb) 209 | batch_token_out = self.dropout(batch_token_out) 210 | 211 | return batch_token_out 212 | 213 | class CRFLayer(nn.Module): 214 | NEG_LOGIT = -100000. 215 | """ 216 | Conditional Random Field Layer 217 | Reference: 218 | https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html#sphx-glr-beginner-nlp-advanced-tutorial-py 219 | The original example codes operate on one sequence, while this version operates on one batch 220 | """ 221 | 222 | def __init__(self, hidden_size, num_entity_labels): 223 | super(CRFLayer, self).__init__() 224 | 225 | self.tag_size = num_entity_labels + 2 # add start tag and end tag 226 | self.start_tag = self.tag_size - 2 227 | self.end_tag = self.tag_size - 1 228 | 229 | # Map token-level hidden state into tag scores 230 | self.hidden2tag = nn.Linear(hidden_size, self.tag_size) 231 | # Transition Matrix 232 | # [i, j] denotes transitioning from j to i 233 | self.trans_mat = nn.Parameter(torch.randn(self.tag_size, self.tag_size)) 234 | self.reset_trans_mat() 235 | 236 | def reset_trans_mat(self): 237 | nn.init.kaiming_uniform_(self.trans_mat, a=math.sqrt(5)) # copy from Linear init 238 | # set parameters that will not be updated during training, but is important 239 | self.trans_mat.data[self.start_tag, :] = self.NEG_LOGIT 240 | self.trans_mat.data[:, self.end_tag] = self.NEG_LOGIT 241 | 242 | def get_log_parition(self, seq_emit_score): 243 | """ 244 | Calculate the log of the partition function 245 | :param seq_emit_score: [seq_len, batch_size, tag_size] 246 | :return: Tensor with Size([batch_size]) 247 | """ 248 | seq_len, batch_size, tag_size = seq_emit_score.size() 249 | # dynamic programming table to store previously summarized tag logits 250 | dp_table = seq_emit_score.new_full( 251 | (batch_size, tag_size), self.NEG_LOGIT, requires_grad=False 252 | ) 253 | dp_table[:, self.start_tag] = 0. 254 | 255 | batch_trans_mat = self.trans_mat.unsqueeze(0).expand(batch_size, tag_size, tag_size) 256 | 257 | for token_idx in range(seq_len): 258 | prev_logit = dp_table.unsqueeze(1) # [batch_size, 1, tag_size] 259 | batch_emit_score = seq_emit_score[token_idx].unsqueeze(-1) # [batch_size, tag_size, 1] 260 | cur_logit = batch_trans_mat + batch_emit_score + prev_logit # [batch_size, tag_size, tag_size] 261 | dp_table = log_sum_exp(cur_logit) # [batch_size, tag_size] 262 | batch_logit = dp_table + self.trans_mat[self.end_tag, :].unsqueeze(0) 263 | log_partition = log_sum_exp(batch_logit) # [batch_size] 264 | 265 | return log_partition 266 | 267 | def get_gold_score(self, seq_emit_score, seq_token_label): 268 | """ 269 | Calculate the score of the given sequence label 270 | :param seq_emit_score: [seq_len, batch_size, tag_size] 271 | :param seq_token_label: [seq_len, batch_size] 272 | :return: Tensor with Size([batch_size]) 273 | """ 274 | seq_len, batch_size, tag_size = seq_emit_score.size() 275 | 276 | end_token_label = seq_token_label.new_full( 277 | (1, batch_size), self.end_tag, requires_grad=False 278 | ) 279 | seq_cur_label = torch.cat( 280 | [seq_token_label, end_token_label], dim=0 281 | ).unsqueeze(-1).unsqueeze(-1).expand(seq_len+1, batch_size, 1, tag_size) 282 | 283 | start_token_label = seq_token_label.new_full( 284 | (1, batch_size), self.start_tag, requires_grad=False 285 | ) 286 | seq_prev_label = torch.cat( 287 | [start_token_label, seq_token_label], dim=0 288 | ).unsqueeze(-1).unsqueeze(-1) # [seq_len+1, batch_size, 1, 1] 289 | 290 | seq_trans_score = self.trans_mat.unsqueeze(0).unsqueeze(0).expand(seq_len+1, batch_size, tag_size, tag_size) 291 | # gather according to token label at the current token 292 | gold_trans_score = torch.gather(seq_trans_score, 2, seq_cur_label) # [seq_len+1, batch_size, 1, tag_size] 293 | # gather according to token label at the previous token 294 | gold_trans_score = torch.gather(gold_trans_score, 3, seq_prev_label) # [seq_len+1, batch_size, 1, 1] 295 | batch_trans_score = gold_trans_score.sum(dim=0).squeeze(-1).squeeze(-1) # [batch_size] 296 | 297 | gold_emit_score = torch.gather(seq_emit_score, 2, seq_token_label.unsqueeze(-1)) # [seq_len, batch_size, 1] 298 | batch_emit_score = gold_emit_score.sum(dim=0).squeeze(-1) # [batch_size] 299 | 300 | gold_score = batch_trans_score + batch_emit_score # [batch_size] 301 | 302 | return gold_score 303 | 304 | def viterbi_decode(self, seq_emit_score): 305 | """ 306 | Use viterbi decoding to get prediction 307 | :param seq_emit_score: [seq_len, batch_size, tag_size] 308 | :return: 309 | batch_best_path: [batch_size, seq_len], the best tag for each token 310 | batch_best_score: [batch_size], the corresponding score for each path 311 | """ 312 | seq_len, batch_size, tag_size = seq_emit_score.size() 313 | 314 | dp_table = seq_emit_score.new_full((batch_size, tag_size), self.NEG_LOGIT, requires_grad=False) 315 | dp_table[:, self.start_tag] = 0 316 | backpointers = [] 317 | 318 | for token_idx in range(seq_len): 319 | last_tag_score = dp_table.unsqueeze(-2) # [batch_size, 1, tag_size] 320 | batch_trans_mat = self.trans_mat.unsqueeze(0).expand(batch_size, tag_size, tag_size) 321 | cur_emit_score = seq_emit_score[token_idx].unsqueeze(-1) # [batch_size, tag_size, 1] 322 | cur_trans_score = batch_trans_mat + last_tag_score + cur_emit_score # [batch_size, tag_size, tag_size] 323 | dp_table, cur_tag_bp = cur_trans_score.max(dim=-1) # [batch_size, tag_size] 324 | backpointers.append(cur_tag_bp) 325 | # transition to the end tag 326 | last_trans_arr = self.trans_mat[self.end_tag].unsqueeze(0).expand(batch_size, tag_size) 327 | dp_table = dp_table + last_trans_arr 328 | 329 | # get the best path score and the best tag of the last token 330 | batch_best_score, best_tag = dp_table.max(dim=-1) # [batch_size] 331 | best_tag = best_tag.unsqueeze(-1) # [batch_size, 1] 332 | best_tag_list = [best_tag] 333 | # reversely traverse back pointers to recover the best path 334 | for last_tag_bp in reversed(backpointers): 335 | # best_tag Size([batch_size, 1]) records the current tag that can own the highest score 336 | # last_tag_bp Size([batch_size, tag_size]) records the last best tag that the current tag is based on 337 | best_tag = torch.gather(last_tag_bp, 1, best_tag) # [batch_size, 1] 338 | best_tag_list.append(best_tag) 339 | batch_start = best_tag_list.pop() 340 | assert (batch_start == self.start_tag).sum().item() == batch_size 341 | best_tag_list.reverse() 342 | batch_best_path = torch.cat(best_tag_list, dim=-1) # [batch_size, seq_len] 343 | 344 | return batch_best_path, batch_best_score 345 | 346 | def forward(self, seq_token_emb, seq_token_label=None, batch_first=False, 347 | train_flag=True, decode_flag=True): 348 | """ 349 | Get loss and prediction with CRF support. 350 | :param seq_token_emb: assume size [seq_len, batch_size, hidden_size] if not batch_first 351 | :param seq_token_label: assume size [seq_len, batch_size] if not batch_first 352 | :param batch_first: Flag to denote the meaning of the first dimension 353 | :param train_flag: whether to calculate the loss 354 | :param decode_flag: whether to decode the path based on current parameters 355 | :return: 356 | nll_loss: negative log-likelihood loss 357 | seq_token_pred: seqeunce predictions 358 | """ 359 | if batch_first: 360 | # CRF assumes the input size of [seq_len, batch_size, hidden_size] 361 | seq_token_emb = seq_token_emb.transpose(0, 1).contiguous() 362 | if seq_token_label is not None: 363 | seq_token_label = seq_token_label.transpose(0, 1).contiguous() 364 | 365 | seq_emit_score = self.hidden2tag(seq_token_emb) # [seq_len, batch_size, tag_size] 366 | if train_flag: 367 | gold_score = self.get_gold_score(seq_emit_score, seq_token_label) # [batch_size] 368 | log_partition = self.get_log_parition(seq_emit_score) # [batch_size] 369 | nll_loss = log_partition - gold_score 370 | else: 371 | nll_loss = None 372 | 373 | if decode_flag: 374 | # Use viterbi decoding to get the current prediction 375 | # no matter what batch_first is, return size is [batch_size, seq_len] 376 | batch_best_path, batch_best_score = self.viterbi_decode(seq_emit_score) 377 | else: 378 | batch_best_path = None 379 | 380 | return nll_loss, batch_best_path 381 | 382 | # Compute log sum exp in a numerically stable way 383 | def log_sum_exp(batch_logit): 384 | """ 385 | Caculate the log-sum-exp operation for the last dimension. 386 | :param batch_logit: Size([*, logit_size]), * should at least be 1 387 | :return: Size([*]) 388 | """ 389 | batch_max, _ = batch_logit.max(dim=-1) 390 | batch_broadcast = batch_max.unsqueeze(-1) 391 | return batch_max + \ 392 | torch.log(torch.sum(torch.exp(batch_logit - batch_broadcast), dim=-1)) 393 | 394 | def produce_ner_batch_metrics(seq_logits, gold_labels, masks): 395 | # seq_logits: [batch_size, seq_len, num_entity_labels] 396 | # gold_labels: [batch_size, seq_len] 397 | # masks: [batch_size, seq_len] 398 | batch_size, seq_len, num_entities = seq_logits.size() 399 | 400 | # [batch_size, seq_len, num_entity_labels] 401 | seq_logp = F.log_softmax(seq_logits, dim=-1) 402 | # [batch_size, seq_len] 403 | pred_labels = seq_logp.argmax(dim=-1) 404 | # [batch_size*seq_len, num_entity_labels] 405 | token_logp = seq_logp.view(-1, num_entities) 406 | # [batch_size*seq_len] 407 | token_labels = gold_labels.view(-1) 408 | # [batch_size, seq_len] 409 | seq_token_loss = F.nll_loss(token_logp, token_labels, reduction='none').view(batch_size, seq_len) 410 | 411 | batch_metrics = [] 412 | for bid in range(batch_size): 413 | ex_loss = seq_token_loss[bid, masks[bid]].mean().item() 414 | ex_acc = (pred_labels[bid, masks[bid]] == gold_labels[bid, masks[bid]]).float().mean().item() 415 | ex_pred_lids = pred_labels[bid, masks[bid]].tolist() 416 | ex_gold_lids = gold_labels[bid, masks[bid]].tolist() 417 | ner_tp_set, ner_fp_set, ner_fn_set = judge_ner_prediction(ex_pred_lids, ex_gold_lids) 418 | batch_metrics.append([ex_loss, ex_acc, len(ner_tp_set), len(ner_fp_set), len(ner_fn_set)]) 419 | 420 | return torch.tensor(batch_metrics, dtype=torch.float, device=seq_logits.device) 421 | 422 | def judge_ner_prediction(pred_label_ids, gold_label_ids): 423 | """Very strong assumption on label_id, 0: others, odd: ner_start, even: ner_mid""" 424 | if isinstance(pred_label_ids, torch.Tensor): 425 | pred_label_ids = pred_label_ids.tolist() 426 | if isinstance(gold_label_ids, torch.Tensor): 427 | gold_label_ids = gold_label_ids.tolist() 428 | # element: (ner_start_index, ner_end_index, ner_type_id) 429 | pred_ner_set = set() 430 | gold_ner_set = set() 431 | 432 | pred_ner_sid = None 433 | for idx, ner in enumerate(pred_label_ids): 434 | if pred_ner_sid is None: 435 | if ner % 2 == 1: 436 | pred_ner_sid = idx 437 | continue 438 | else: 439 | prev_ner = pred_label_ids[pred_ner_sid] 440 | if ner == 0: 441 | pred_ner_set.add((pred_ner_sid, idx, prev_ner)) 442 | pred_ner_sid = None 443 | continue 444 | elif ner == prev_ner + 1: # same entity 445 | continue 446 | elif ner % 2 == 1: 447 | pred_ner_set.add((pred_ner_sid, idx, prev_ner)) 448 | pred_ner_sid = idx 449 | continue 450 | else: # ignore invalid subsequence ners 451 | pred_ner_set.add((pred_ner_sid, idx, prev_ner)) 452 | pred_ner_sid = None 453 | pass 454 | if pred_ner_sid is not None: 455 | prev_ner = pred_label_ids[pred_ner_sid] 456 | pred_ner_set.add((pred_ner_sid, len(pred_label_ids), prev_ner)) 457 | 458 | gold_ner_sid = None 459 | for idx, ner in enumerate(gold_label_ids): 460 | if gold_ner_sid is None: 461 | if ner % 2 == 1: 462 | gold_ner_sid = idx 463 | continue 464 | else: 465 | prev_ner = gold_label_ids[gold_ner_sid] 466 | if ner == 0: 467 | gold_ner_set.add((gold_ner_sid, idx, prev_ner)) 468 | gold_ner_sid = None 469 | continue 470 | elif ner == prev_ner + 1: # same entity 471 | continue 472 | elif ner % 2 == 1: 473 | gold_ner_set.add((gold_ner_sid, idx, prev_ner)) 474 | gold_ner_sid = idx 475 | continue 476 | else: # ignore invalid subsequence ners 477 | gold_ner_set.add((gold_ner_sid, idx, prev_ner)) 478 | gold_ner_sid = None 479 | pass 480 | if gold_ner_sid is not None: 481 | prev_ner = gold_label_ids[gold_ner_sid] 482 | gold_ner_set.add((gold_ner_sid, len(gold_label_ids), prev_ner)) 483 | 484 | ner_tp_set = pred_ner_set.intersection(gold_ner_set) 485 | ner_fp_set = pred_ner_set - gold_ner_set 486 | ner_fn_set = gold_ner_set - pred_ner_set 487 | 488 | return ner_tp_set, ner_fp_set, ner_fn_set 489 | -------------------------------------------------------------------------------- /dee/ner_task.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import torch 4 | import logging 5 | import os 6 | import json 7 | from torch.utils.data import TensorDataset 8 | from collections import defaultdict 9 | 10 | from .utils import default_load_json, default_dump_json, EPS, BERTChineseCharacterTokenizer 11 | from .event_type import common_fields, event_type_fields_list 12 | from .ner_model import BertForBasicNER, judge_ner_prediction 13 | from .base_task import TaskSetting, BasePytorchTask 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | class NERExample(object): 18 | basic_entity_label = 'O' 19 | 20 | def __init__(self, guid, text, entity_range_span_types): 21 | self.guid = guid 22 | self.text = text 23 | self.num_chars = len(text) 24 | self.entity_range_span_types = sorted(entity_range_span_types, key=lambda x: x[0]) 25 | 26 | def get_char_entity_labels(self): 27 | char_entity_labels = [] 28 | char_idx = 0 29 | ent_idx = 0 30 | while ent_idx < len(self.entity_range_span_types): 31 | (ent_cid_s, ent_cid_e), ent_span, ent_type = self.entity_range_span_types[ent_idx] 32 | assert ent_cid_s < ent_cid_e <= self.num_chars 33 | 34 | if ent_cid_s > char_idx: 35 | char_entity_labels.append(NERExample.basic_entity_label) 36 | char_idx += 1 37 | elif ent_cid_s == char_idx: 38 | # tmp_ent_labels = [ent_type] * (ent_cid_e - ent_cid_s) 39 | tmp_ent_labels = ['B-' + ent_type] + ['I-' + ent_type] * (ent_cid_e - ent_cid_s - 1) 40 | char_entity_labels.extend(tmp_ent_labels) 41 | char_idx = ent_cid_e 42 | ent_idx += 1 43 | else: 44 | logger.error('Example GUID {}'.format(self.guid)) 45 | logger.error('NER conflicts at char_idx {}, ent_cid_s {}'.format(char_idx, ent_cid_s)) 46 | logger.error(self.text[char_idx - 20:char_idx + 20]) 47 | logger.error(self.entity_range_span_types[ent_idx - 1:ent_idx + 1]) 48 | raise Exception('Unexpected logic error') 49 | 50 | char_entity_labels.extend([NERExample.basic_entity_label] * (self.num_chars - char_idx)) 51 | assert len(char_entity_labels) == self.num_chars 52 | 53 | return char_entity_labels 54 | 55 | @staticmethod 56 | def get_entity_label_list(): 57 | visit_set = set() 58 | entity_label_list = [NERExample.basic_entity_label] 59 | 60 | for field in common_fields: 61 | if field not in visit_set: 62 | visit_set.add(field) 63 | entity_label_list.extend(['B-' + field, 'I-' + field]) 64 | 65 | for event_name, fields in event_type_fields_list: 66 | for field in fields: 67 | if field not in visit_set: 68 | visit_set.add(field) 69 | entity_label_list.extend(['B-' + field, 'I-' + field]) 70 | 71 | return entity_label_list 72 | 73 | def __repr__(self): 74 | ex_str = 'NERExample(guid={}, text={}, entity_info={}'.format( 75 | self.guid, self.text, str(self.entity_range_span_types) 76 | ) 77 | return ex_str 78 | 79 | def load_ner_dataset(dataset_json_path): 80 | total_ner_examples = [] 81 | annguid2detail_align_info = default_load_json(dataset_json_path) 82 | for annguid, detail_align_info in annguid2detail_align_info.items(): 83 | sents = detail_align_info['sentences'] 84 | ann_valid_mspans = detail_align_info['ann_valid_mspans'] 85 | ann_valid_dranges = detail_align_info['ann_valid_dranges'] 86 | ann_mspan2guess_field = detail_align_info['ann_mspan2guess_field'] 87 | assert len(ann_valid_dranges) == len(ann_valid_mspans) 88 | 89 | sent_idx2mrange_mspan_mfield_tuples = {} 90 | for drange, mspan in zip(ann_valid_dranges, ann_valid_mspans): 91 | sent_idx, char_s, char_e = drange 92 | sent_mrange = (char_s, char_e) 93 | 94 | sent_text = sents[sent_idx] 95 | assert sent_text[char_s: char_e] == mspan 96 | 97 | guess_field = ann_mspan2guess_field[mspan] 98 | 99 | if sent_idx not in sent_idx2mrange_mspan_mfield_tuples: 100 | sent_idx2mrange_mspan_mfield_tuples[sent_idx] = [] 101 | sent_idx2mrange_mspan_mfield_tuples[sent_idx].append((sent_mrange, mspan, guess_field)) 102 | 103 | for sent_idx in range(len(sents)): 104 | sent_text = sents[sent_idx] 105 | if sent_idx in sent_idx2mrange_mspan_mfield_tuples: 106 | mrange_mspan_mfield_tuples = sent_idx2mrange_mspan_mfield_tuples[sent_idx] 107 | else: 108 | mrange_mspan_mfield_tuples = [] 109 | 110 | total_ner_examples.append( 111 | NERExample('{}-{}'.format(annguid, sent_idx), 112 | sent_text, 113 | mrange_mspan_mfield_tuples) 114 | ) 115 | 116 | return total_ner_examples 117 | 118 | class NERFeature(object): 119 | def __init__(self, input_ids, input_masks, segment_ids, label_ids, seq_len=None): 120 | self.input_ids = input_ids 121 | self.input_masks = input_masks 122 | self.segment_ids = segment_ids 123 | self.label_ids = label_ids 124 | self.seq_len = seq_len 125 | 126 | def __repr__(self): 127 | fea_strs = ['NERFeature(real_seq_len={}'.format(self.seq_len), ] 128 | info_template = ' {:5} {:9} {:5} {:7} {:7}' 129 | fea_strs.append(info_template.format( 130 | 'index', 'input_ids', 'masks', 'seg_ids', 'lbl_ids' 131 | )) 132 | max_print_len = 10 133 | idx = 0 134 | for tid, mask, segid, lid in zip( 135 | self.input_ids, self.input_masks, self.segment_ids, self.label_ids): 136 | fea_strs.append(info_template.format( 137 | idx, tid, mask, segid, lid 138 | )) 139 | idx += 1 140 | if idx >= max_print_len: 141 | break 142 | fea_strs.append(info_template.format( 143 | '...', '...', '...', '...', '...' 144 | )) 145 | fea_strs.append(')') 146 | 147 | fea_str = '\n'.join(fea_strs) 148 | return fea_str 149 | 150 | class NERFeatureConverter(object): 151 | def __init__(self, entity_label_list, max_seq_len, tokenizer, include_cls=True, include_sep=True): 152 | self.entity_label_list = entity_label_list 153 | self.max_seq_len = max_seq_len # used to normalize sequence length 154 | self.tokenizer = tokenizer 155 | self.entity_label2index = { # for entity label to label index mapping 156 | entity_label: idx for idx, entity_label in enumerate(self.entity_label_list) 157 | } 158 | 159 | self.include_cls = include_cls 160 | self.include_sep = include_sep 161 | 162 | # used to track how many examples have been truncated 163 | self.truncate_count = 0 164 | # used to track the maximum length of input sentences 165 | self.data_max_seq_len = -1 166 | 167 | def convert_example_to_feature(self, ner_example, log_flag=False): 168 | ex_tokens = self.tokenizer.char_tokenize(ner_example.text) 169 | ex_entity_labels = ner_example.get_char_entity_labels() 170 | 171 | assert len(ex_tokens) == len(ex_entity_labels) 172 | 173 | # get valid token sequence length 174 | valid_token_len = self.max_seq_len 175 | if self.include_cls: 176 | valid_token_len -= 1 177 | if self.include_sep: 178 | valid_token_len -= 1 179 | 180 | # truncate according to max_seq_len and record some statistics 181 | self.data_max_seq_len = max(self.data_max_seq_len, len(ex_tokens)) 182 | if len(ex_tokens) > valid_token_len: 183 | ex_tokens = ex_tokens[:valid_token_len] 184 | ex_entity_labels = ex_entity_labels[:valid_token_len] 185 | 186 | self.truncate_count += 1 187 | 188 | basic_label_index = self.entity_label2index[NERExample.basic_entity_label] 189 | 190 | # add bert-specific token 191 | if self.include_cls: 192 | fea_tokens = ['[CLS]'] 193 | fea_token_labels = [NERExample.basic_entity_label] 194 | fea_label_ids = [basic_label_index] 195 | else: 196 | fea_tokens = [] 197 | fea_token_labels = [] 198 | fea_label_ids = [] 199 | 200 | for token, ent_label in zip(ex_tokens, ex_entity_labels): 201 | fea_tokens.append(token) 202 | fea_token_labels.append(ent_label) 203 | 204 | if ent_label in self.entity_label2index: 205 | fea_label_ids.append(self.entity_label2index[ent_label]) 206 | else: 207 | fea_label_ids.append(basic_label_index) 208 | 209 | if self.include_sep: 210 | fea_tokens.append('[SEP]') 211 | fea_token_labels.append(NERExample.basic_entity_label) 212 | fea_label_ids.append(basic_label_index) 213 | 214 | assert len(fea_tokens) == len(fea_token_labels) == len(fea_label_ids) <= self.max_seq_len 215 | 216 | fea_input_ids = self.tokenizer.convert_tokens_to_ids(fea_tokens) 217 | fea_seq_len = len(fea_input_ids) 218 | fea_segment_ids = [0] * fea_seq_len 219 | fea_masks = [1] * fea_seq_len 220 | 221 | # feature is padded to max_seq_len, but fea_seq_len is the real length 222 | while len(fea_input_ids) < self.max_seq_len: 223 | fea_input_ids.append(0) 224 | fea_label_ids.append(0) 225 | fea_masks.append(0) 226 | fea_segment_ids.append(0) 227 | 228 | assert len(fea_input_ids) == len(fea_label_ids) == len(fea_masks) == len(fea_segment_ids) == self.max_seq_len 229 | 230 | if log_flag: 231 | logger.info("*** Example ***") 232 | logger.info("guid: %s" % ner_example.guid) 233 | info_template = '{:8} {:4} {:2} {:2} {:2} {}' 234 | logger.info(info_template.format( 235 | 'TokenId', 'Token', 'Mask', 'SegId', 'LabelId', 'Label' 236 | )) 237 | for tid, token, mask, segid, lid, label in zip( 238 | fea_input_ids, fea_tokens, fea_masks, 239 | fea_segment_ids, fea_label_ids, fea_token_labels): 240 | logger.info(info_template.format( 241 | tid, token, mask, segid, lid, label 242 | )) 243 | if len(fea_input_ids) > len(fea_tokens): 244 | sid = len(fea_tokens) 245 | logger.info(info_template.format( 246 | fea_input_ids[sid], '[PAD]', fea_masks[sid], fea_segment_ids[sid], fea_label_ids[sid], 'O') 247 | + ' x {}'.format(len(fea_input_ids) - len(fea_tokens))) 248 | 249 | return NERFeature(fea_input_ids, fea_masks, fea_segment_ids, fea_label_ids, seq_len=fea_seq_len) 250 | 251 | def __call__(self, ner_examples, log_example_num=0): 252 | """Convert examples to features suitable for ner models""" 253 | self.truncate_count = 0 254 | self.data_max_seq_len = -1 255 | ner_features = [] 256 | 257 | for ex_index, ner_example in enumerate(ner_examples): 258 | if ex_index < log_example_num: 259 | ner_feature = self.convert_example_to_feature(ner_example, log_flag=True) 260 | else: 261 | ner_feature = self.convert_example_to_feature(ner_example, log_flag=False) 262 | 263 | ner_features.append(ner_feature) 264 | 265 | logger.info('{} examples in total, {} truncated example, max_sent_len={}'.format( 266 | len(ner_examples), self.truncate_count, self.data_max_seq_len 267 | )) 268 | 269 | return ner_features 270 | 271 | def convert_ner_features_to_dataset(ner_features): 272 | all_input_ids = torch.tensor([f.input_ids for f in ner_features], dtype=torch.long) 273 | # very important to use the mask type of uint8 to support advanced indexing 274 | all_input_masks = torch.tensor([f.input_masks for f in ner_features], dtype=torch.uint8) 275 | all_segment_ids = torch.tensor([f.segment_ids for f in ner_features], dtype=torch.long) 276 | all_label_ids = torch.tensor([f.label_ids for f in ner_features], dtype=torch.long) 277 | all_seq_len = torch.tensor([f.seq_len for f in ner_features], dtype=torch.long) 278 | ner_tensor_dataset = TensorDataset(all_input_ids, all_input_masks, all_segment_ids, all_label_ids, all_seq_len) 279 | 280 | return ner_tensor_dataset 281 | 282 | class NERTaskSetting(TaskSetting): 283 | def __init__(self, **kwargs): 284 | ner_key_attrs = [] 285 | ner_attr_default_pairs = [ 286 | ('bert_model', 'bert-base-chinese'), 287 | ('train_file_name', 'train.json'), 288 | ('dev_file_name', 'dev.json'), 289 | ('test_file_name', 'test.json'), 290 | ('max_seq_len', 128), 291 | ('train_batch_size', 32), 292 | ('eval_batch_size', 256), 293 | ('learning_rate', 2e-5), 294 | ('num_train_epochs', 3.0), 295 | ('warmup_proportion', 0.1), 296 | ('no_cuda', False), 297 | ('local_rank', -1), 298 | ('seed', 99), 299 | ('gradient_accumulation_steps', 1), 300 | ('optimize_on_cpu', True), 301 | ('fp16', False), 302 | ('loss_scale', 128), 303 | ('cpt_file_name', 'ner_task.cpt'), 304 | ('summary_dir_name', '/tmp/summary'), 305 | ] 306 | super(NERTaskSetting, self).__init__(ner_key_attrs, ner_attr_default_pairs, **kwargs) 307 | 308 | class NERTask(BasePytorchTask): 309 | """Named Entity Recognition Task""" 310 | 311 | def __init__(self, setting, 312 | load_train=True, load_dev=True, load_test=True, 313 | build_model=True, parallel_decorate=True, 314 | resume_model=False, resume_optimizer=False): 315 | super(NERTask, self).__init__(setting) 316 | self.logger = logging.getLogger(self.__class__.__name__) 317 | self.logging('Initializing {}'.format(self.__class__.__name__)) 318 | 319 | # initialize entity label list 320 | self.entity_label_list = NERExample.get_entity_label_list() 321 | # initialize tokenizer 322 | self.tokenizer = BERTChineseCharacterTokenizer.from_pretrained(self.setting.bert_model) 323 | # initialize feature converter 324 | self.feature_converter_func = NERFeatureConverter( 325 | self.entity_label_list, self.setting.max_seq_len, self.tokenizer 326 | ) 327 | 328 | # load data 329 | self._load_data( 330 | load_ner_dataset, self.feature_converter_func, convert_ner_features_to_dataset, 331 | load_train=load_train, load_dev=load_dev, load_test=load_test 332 | ) 333 | 334 | # build model 335 | if build_model: 336 | self.model = BertForBasicNER.from_pretrained(self.setting.bert_model, len(self.entity_label_list)) 337 | self.setting.update_by_dict(self.model.config.__dict__) # BertConfig dictionary 338 | self._decorate_model(parallel_decorate=parallel_decorate) 339 | 340 | # prepare optimizer 341 | if build_model and load_train: 342 | self._init_bert_optimizer() 343 | 344 | # resume option 345 | if build_model and (resume_model or resume_optimizer): 346 | self.resume_checkpoint(resume_model=resume_model, resume_optimizer=resume_optimizer) 347 | 348 | self.logging('Successfully initialize {}'.format(self.__class__.__name__)) 349 | 350 | def reload_data(self, data_type='return', file_name=None, file_path=None, **kwargs): 351 | """ Either file_name or file_path needs to be provided, 352 | data_type: return (default), return (examples, features, dataset) 353 | train, override self.train_xxx 354 | dev, override self.dev_xxx 355 | test, override self.test_xxx 356 | """ 357 | return super(NERTask, self).reload_data( 358 | load_ner_dataset, self.feature_converter_func, convert_ner_features_to_dataset, 359 | data_type=data_type, file_name=file_name, file_path=file_path, 360 | ) 361 | 362 | def train(self): 363 | self.logging('='*20 + 'Start Training' + '='*20) 364 | self.base_train(get_ner_loss_on_batch) 365 | 366 | def eval(self, eval_dataset, eval_save_prefix='', pgm_return_flag=False): 367 | self.logging('='*20 + 'Start Evaluation' + '='*20) 368 | # 1. get total prediction info 369 | # pgm denotes (pred_label, gold_label, token_mask) 370 | # size = [num_examples, max_seq_len, 3] 371 | # value = [[(pred_label, gold_label, token_mask), ...], ...] 372 | total_seq_pgm = self.get_total_prediction(eval_dataset) 373 | num_examples, max_seq_len, _ = total_seq_pgm.size() 374 | 375 | # 2. collect per-entity-label tp, fp, fn counts 376 | ent_lid2tp_cnt = defaultdict(lambda: 0) 377 | ent_lid2fp_cnt = defaultdict(lambda: 0) 378 | ent_lid2fn_cnt = defaultdict(lambda: 0) 379 | for bid in range(num_examples): 380 | seq_pgm = total_seq_pgm[bid] # [max_seq_len, 3] 381 | seq_pred = seq_pgm[:, 0] # [max_seq_len] 382 | seq_gold = seq_pgm[:, 1] 383 | seq_mask = seq_pgm[:, 2] 384 | 385 | seq_pred_lid = seq_pred[seq_mask == 1] # [seq_len] 386 | seq_gold_lid = seq_gold[seq_mask == 1] 387 | ner_tp_set, ner_fp_set, ner_fn_set = judge_ner_prediction(seq_pred_lid, seq_gold_lid) 388 | for ent_lid2cnt, ex_ner_set in [ 389 | (ent_lid2tp_cnt, ner_tp_set), 390 | (ent_lid2fp_cnt, ner_fp_set), 391 | (ent_lid2fn_cnt, ner_fn_set) 392 | ]: 393 | for ent_idx_s, ent_idx_e, ent_lid in ex_ner_set: 394 | ent_lid2cnt[ent_lid] += 1 395 | 396 | # 3. calculate per-entity-label metrics and collect global counts 397 | ent_label_eval_infos = [] 398 | g_ner_tp_cnt = 0 399 | g_ner_fp_cnt = 0 400 | g_ner_fn_cnt = 0 401 | # Entity Label Id, 0 for others, odd for BEGIN-ENTITY, even for INSIDE-ENTITY 402 | # using odd is enough to represent the entity type 403 | for ent_lid in range(1, len(self.entity_label_list), 2): 404 | el_name = self.entity_label_list[ent_lid] 405 | el_tp_cnt, el_fp_cnt, el_fn_cnt = ent_lid2tp_cnt[ent_lid], ent_lid2fp_cnt[ent_lid], ent_lid2fn_cnt[ent_lid] 406 | 407 | el_pred_cnt = el_tp_cnt + el_fp_cnt 408 | el_gold_cnt = el_tp_cnt + el_fn_cnt 409 | el_prec = el_tp_cnt / el_pred_cnt if el_pred_cnt > 0 else 0 410 | el_recall = el_tp_cnt / el_gold_cnt if el_gold_cnt > 0 else 0 411 | el_f1 = 2 / (1 / el_prec + 1 / el_recall) if el_prec > EPS and el_recall > EPS else 0 412 | 413 | # per-entity-label evaluation info 414 | el_eval_info = { 415 | 'entity_label_indexes': (ent_lid, ent_lid + 1), 416 | 'entity_label': el_name[2:], # omit 'B-' prefix 417 | 'ner_tp_cnt': el_tp_cnt, 418 | 'ner_fp_cnt': el_fp_cnt, 419 | 'ner_fn_cnt': el_fn_cnt, 420 | 'ner_prec': el_prec, 421 | 'ner_recall': el_recall, 422 | 'ner_f1': el_f1, 423 | } 424 | ent_label_eval_infos.append(el_eval_info) 425 | 426 | # collect global count info 427 | g_ner_tp_cnt += el_tp_cnt 428 | g_ner_fp_cnt += el_fp_cnt 429 | g_ner_fn_cnt += el_fn_cnt 430 | 431 | # 4. summarize total evaluation info 432 | g_ner_pred_cnt = g_ner_tp_cnt + g_ner_fp_cnt 433 | g_ner_gold_cnt = g_ner_tp_cnt + g_ner_fn_cnt 434 | g_ner_prec = g_ner_tp_cnt / g_ner_pred_cnt if g_ner_pred_cnt > 0 else 0 435 | g_ner_recall = g_ner_tp_cnt / g_ner_gold_cnt if g_ner_gold_cnt > 0 else 0 436 | g_ner_f1 = 2 / (1 / g_ner_prec + 1 / g_ner_recall) if g_ner_prec > EPS and g_ner_recall > EPS else 0 437 | 438 | total_eval_info = { 439 | 'eval_name': eval_save_prefix, 440 | 'num_examples': num_examples, 441 | 'ner_tp_cnt': g_ner_tp_cnt, 442 | 'ner_fp_cnt': g_ner_fp_cnt, 443 | 'ner_fn_cnt': g_ner_fn_cnt, 444 | 'ner_prec': g_ner_prec, 445 | 'ner_recall': g_ner_recall, 446 | 'ner_f1': g_ner_f1, 447 | 'per_ent_label_eval': ent_label_eval_infos 448 | } 449 | 450 | self.logging('Evaluation Results\n{:.300s} ...'.format(json.dumps(total_eval_info, indent=4))) 451 | 452 | if eval_save_prefix: 453 | eval_res_fp = os.path.join(self.setting.output_dir, 454 | '{}.eval'.format(eval_save_prefix)) 455 | self.logging('Dump eval results into {}'.format(eval_res_fp)) 456 | default_dump_json(total_eval_info, eval_res_fp) 457 | 458 | if pgm_return_flag: 459 | return total_seq_pgm 460 | else: 461 | return total_eval_info 462 | 463 | def get_total_prediction(self, eval_dataset): 464 | self.logging('='*20 + 'Get Total Prediction' + '='*20) 465 | total_pred_gold_mask = self.base_eval( 466 | eval_dataset, get_ner_pred_on_batch, reduce_info_type='none' 467 | ) 468 | # torch.Tensor(dtype=torch.long, device='cpu') 469 | # size = [batch_size, seq_len, 3] 470 | # value = [[(pred_label, gold_label, token_mask), ...], ...] 471 | return total_pred_gold_mask 472 | 473 | def normalize_batch_seq_len(input_seq_lens, *batch_seq_tensors): 474 | batch_max_seq_len = input_seq_lens.max().item() 475 | normed_tensors = [] 476 | for batch_seq_tensor in batch_seq_tensors: 477 | if batch_seq_tensor.dim() == 2: 478 | normed_tensors.append(batch_seq_tensor[:, :batch_max_seq_len]) 479 | elif batch_seq_tensor.dim() == 1: 480 | normed_tensors.append(batch_seq_tensor) 481 | else: 482 | raise Exception('Unsupported batch_seq_tensor dimension {}'.format(batch_seq_tensor.dim())) 483 | 484 | return normed_tensors 485 | 486 | def prepare_ner_batch(batch, resize_len=True): 487 | # prepare batch 488 | input_ids, input_masks, segment_ids, label_ids, input_lens = batch 489 | if resize_len: 490 | input_ids, input_masks, segment_ids, label_ids = normalize_batch_seq_len( 491 | input_lens, input_ids, input_masks, segment_ids, label_ids 492 | ) 493 | 494 | return input_ids, input_masks, segment_ids, label_ids 495 | 496 | def get_ner_loss_on_batch(ner_task, batch): 497 | input_ids, input_masks, segment_ids, label_ids = prepare_ner_batch(batch, resize_len=True) 498 | loss, _ = ner_task.model(input_ids, input_masks, 499 | token_type_ids=segment_ids, 500 | label_ids=label_ids) 501 | 502 | return loss 503 | 504 | def get_ner_metrics_on_batch(ner_task, batch): 505 | input_ids, input_masks, segment_ids, label_ids = prepare_ner_batch(batch, resize_len=True) 506 | batch_metrics = ner_task.model(input_ids, input_masks, 507 | token_type_ids=segment_ids, 508 | label_ids=label_ids, 509 | eval_flag=True, 510 | eval_for_metric=True) 511 | 512 | return batch_metrics 513 | 514 | def get_ner_pred_on_batch(ner_task, batch): 515 | # important to set resize_len to False to maintain the same seq len between batches 516 | input_ids, input_masks, segment_ids, label_ids = prepare_ner_batch(batch, resize_len=False) 517 | batch_seq_pred_gold_mask = ner_task.model(input_ids, input_masks, 518 | token_type_ids=segment_ids, 519 | label_ids=label_ids, 520 | eval_flag=True, 521 | eval_for_metric=False) 522 | # size = [batch_size, max_seq_len, 3] 523 | # value = [[(pred_label, gold_label, token_mask), ...], ...] 524 | return batch_seq_pred_gold_mask 525 | -------------------------------------------------------------------------------- /dee/transformer.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import copy 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import math 9 | 10 | 11 | def clones(module, N): 12 | """Produce N identical layers.""" 13 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 14 | 15 | class EncoderDecoder(nn.Module): 16 | """ 17 | A standard Encoder-Decoder architecture. Base for this and many 18 | other models. 19 | """ 20 | 21 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 22 | super(EncoderDecoder, self).__init__() 23 | self.encoder = encoder 24 | self.decoder = decoder 25 | self.src_embed = src_embed 26 | self.tgt_embed = tgt_embed 27 | self.generator = generator 28 | 29 | def forward(self, src, tgt, src_mask, tgt_mask): 30 | """Take in and process masked src and target sequences.""" 31 | return self.decode(self.encode(src, src_mask), src_mask, 32 | tgt, tgt_mask) 33 | 34 | def encode(self, src, src_mask): 35 | return self.encoder(self.src_embed(src), src_mask) 36 | 37 | def decode(self, memory, src_mask, tgt, tgt_mask): 38 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 39 | 40 | class Generator(nn.Module): 41 | """Define standard linear + softmax generation step.""" 42 | def __init__(self, d_model, vocab): 43 | super(Generator, self).__init__() 44 | self.proj = nn.Linear(d_model, vocab) 45 | 46 | def forward(self, x): 47 | return F.log_softmax(self.proj(x), dim=-1) 48 | 49 | class LayerNorm(nn.Module): 50 | """Construct a layernorm module (See citation for details).""" 51 | 52 | def __init__(self, features, eps=1e-6): 53 | super(LayerNorm, self).__init__() 54 | # self.a_2 = nn.Parameter(torch.ones(features)) 55 | # self.b_2 = nn.Parameter(torch.zeros(features)) 56 | # fit for bert optimizer 57 | self.gamma = nn.Parameter(torch.ones(features)) 58 | self.beta = nn.Parameter(torch.zeros(features)) 59 | self.eps = eps 60 | 61 | def forward(self, x): 62 | mean = x.mean(-1, keepdim=True) 63 | std = x.std(-1, keepdim=True) 64 | # return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 65 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 66 | 67 | 68 | class Encoder(nn.Module): 69 | """"Core encoder is a stack of N layers""" 70 | 71 | def __init__(self, layer, N): 72 | super(Encoder, self).__init__() 73 | self.layers = clones(layer, N) 74 | self.norm = LayerNorm(layer.size) 75 | 76 | def forward(self, x, mask): 77 | """Pass the input (and mask) through each layer in turn.""" 78 | for layer in self.layers: 79 | x = layer(x, mask) 80 | return self.norm(x) 81 | 82 | class SublayerConnection(nn.Module): 83 | """ 84 | A residual connection followed by a layer norm. 85 | Note for code simplicity the norm is first as opposed to last. 86 | """ 87 | def __init__(self, size, dropout): 88 | super(SublayerConnection, self).__init__() 89 | self.norm = LayerNorm(size) 90 | self.dropout = nn.Dropout(dropout) 91 | 92 | def forward(self, x, sublayer): 93 | """Apply residual connection to any sublayer with the same size.""" 94 | return x + self.dropout(sublayer(self.norm(x))) 95 | 96 | class EncoderLayer(nn.Module): 97 | """Encoder is made up of self-attn and feed forward (defined below)""" 98 | 99 | def __init__(self, size, self_attn, feed_forward, dropout): 100 | super(EncoderLayer, self).__init__() 101 | self.self_attn = self_attn 102 | self.feed_forward = feed_forward 103 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 104 | self.size = size 105 | 106 | def forward(self, x, mask): 107 | """Follow Figure 1 (left) for connections.""" 108 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 109 | return self.sublayer[1](x, self.feed_forward) 110 | 111 | class Decoder(nn.Module): 112 | """Generic N layer decoder with masking.""" 113 | 114 | def __init__(self, layer, N): 115 | super(Decoder, self).__init__() 116 | self.layers = clones(layer, N) 117 | self.norm = LayerNorm(layer.size) 118 | 119 | def forward(self, x, memory, src_mask, tgt_mask): 120 | for layer in self.layers: 121 | x = layer(x, memory, src_mask, tgt_mask) 122 | return self.norm(x) 123 | 124 | class DecoderLayer(nn.Module): 125 | """Decoder is made of self-attn, src-attn, and feed forward (defined below)""" 126 | 127 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 128 | super(DecoderLayer, self).__init__() 129 | self.size = size 130 | self.self_attn = self_attn 131 | self.src_attn = src_attn 132 | self.feed_forward = feed_forward 133 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 134 | 135 | def forward(self, x, memory, src_mask, tgt_mask): 136 | """Follow Figure 1 (right) for connections.""" 137 | m = memory 138 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 139 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 140 | return self.sublayer[2](x, self.feed_forward) 141 | 142 | def subsequent_mask(size): 143 | """Mask out subsequent positions.""" 144 | attn_shape = (1, size, size) 145 | subseq_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 146 | return torch.from_numpy(subseq_mask) == 0 147 | 148 | def attention(query, key, value, mask=None, dropout=None): 149 | """Compute 'Scaled Dot Product Attention'""" 150 | d_k = query.size(-1) 151 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 152 | if mask is not None: 153 | scores = scores.masked_fill(mask == 0, -1e9) 154 | p_attn = F.softmax(scores, dim=-1) 155 | if dropout is not None: 156 | p_attn = dropout(p_attn) 157 | return torch.matmul(p_attn, value), p_attn 158 | 159 | class MultiHeadedAttention(nn.Module): 160 | def __init__(self, h, d_model, dropout=0.1): 161 | """Take in model size and number of heads.""" 162 | super(MultiHeadedAttention, self).__init__() 163 | assert d_model % h == 0 164 | # We assume d_v always equals d_k 165 | self.d_k = d_model // h 166 | self.h = h 167 | self.linears = clones(nn.Linear(d_model, d_model), 4) 168 | self.attn = None 169 | self.dropout = nn.Dropout(p=dropout) 170 | 171 | def forward(self, query, key, value, mask=None): 172 | """Implements Figure 2""" 173 | if mask is not None: 174 | # Same mask applied to all h heads. 175 | mask = mask.unsqueeze(1) 176 | nbatches = query.size(0) 177 | 178 | # 1) Do all the linear projections in batch from d_model => h x d_k 179 | query, key, value = [ 180 | l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 181 | for l, x in zip(self.linears, (query, key, value)) 182 | ] 183 | 184 | # 2) Apply attention on all the projected vectors in batch. 185 | x, self.attn = attention(query, key, value, mask=mask, 186 | dropout=self.dropout) 187 | 188 | # 3) "Concat" using a view and apply a final linear. 189 | x = x.transpose(1, 2).contiguous() \ 190 | .view(nbatches, -1, self.h * self.d_k) 191 | 192 | return self.linears[-1](x) 193 | 194 | class PositionwiseFeedForward(nn.Module): 195 | """Implements FFN equation.""" 196 | 197 | def __init__(self, d_model, d_ff, dropout=0.1): 198 | super(PositionwiseFeedForward, self).__init__() 199 | self.w_1 = nn.Linear(d_model, d_ff) 200 | self.w_2 = nn.Linear(d_ff, d_model) 201 | self.dropout = nn.Dropout(dropout) 202 | 203 | def forward(self, x): 204 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 205 | 206 | class Embeddings(nn.Module): 207 | def __init__(self, d_model, vocab): 208 | super(Embeddings, self).__init__() 209 | self.lut = nn.Embedding(vocab, d_model) 210 | self.d_model = d_model 211 | 212 | def forward(self, x): 213 | return self.lut(x) * math.sqrt(self.d_model) 214 | 215 | class PositionalEncoding(nn.Module): 216 | """Implement the PE function.""" 217 | 218 | def __init__(self, d_model, dropout, max_len=5000): 219 | super(PositionalEncoding, self).__init__() 220 | self.dropout = nn.Dropout(p=dropout) 221 | 222 | # Compute the positional encodings once in log space. 223 | pe = torch.zeros(max_len, d_model) 224 | position = torch.arange(0, max_len).unsqueeze(1) 225 | div_term = torch.exp(torch.arange(0, d_model, 2) * 226 | -(math.log(10000.0) / d_model)) 227 | pe[:, 0::2] = torch.sin(position * div_term) 228 | pe[:, 1::2] = torch.cos(position * div_term) 229 | pe = pe.unsqueeze(0) 230 | self.register_buffer('pe', pe) 231 | 232 | def forward(self, x): 233 | x = x + self.pe[:, :x.size(1)].to(device=x.device) 234 | return self.dropout(x) 235 | 236 | def make_model(src_vocab, tgt_vocab, num_layers=6, d_model=512, d_ff=2048, h=8, dropout=0.1): 237 | """Helper: Construct a model from hyperparameters.""" 238 | c = copy.deepcopy 239 | attn = MultiHeadedAttention(h, d_model) 240 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 241 | position = PositionalEncoding(d_model, dropout) 242 | model = EncoderDecoder( 243 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), num_layers), 244 | Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), num_layers), 245 | nn.Sequential(Embeddings(d_model, src_vocab), c(position)), 246 | nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), 247 | Generator(d_model, tgt_vocab)) 248 | 249 | # This was important from their code. 250 | # Initialize parameters with Glorot / fan_avg. 251 | for p in model.parameters(): 252 | if p.dim() > 1: 253 | nn.init.xavier_uniform(p) 254 | return model 255 | 256 | def make_transformer_encoder(num_layers, hidden_size, ff_size=2048, num_att_heads=8, dropout=0.1): 257 | dcopy = copy.deepcopy 258 | mh_att = MultiHeadedAttention(num_att_heads, hidden_size, dropout=dropout) 259 | pos_ff = PositionwiseFeedForward(hidden_size, ff_size, dropout=dropout) 260 | 261 | tranformer_encoder = Encoder( 262 | EncoderLayer(hidden_size, dcopy(mh_att), dcopy(pos_ff), dropout=dropout), 263 | num_layers 264 | ) 265 | 266 | return tranformer_encoder 267 | -------------------------------------------------------------------------------- /dee/utils.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import json 4 | import logging 5 | import pickle 6 | from pytorch_pretrained_bert import BertTokenizer 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | EPS = 1e-10 12 | 13 | def default_load_json(json_file_path, encoding='utf-8', **kwargs): 14 | with open(json_file_path, 'r', encoding=encoding) as fin: 15 | tmp_json = json.load(fin, **kwargs) 16 | return tmp_json 17 | 18 | def default_dump_json(obj, json_file_path, encoding='utf-8', ensure_ascii=False, indent=2, **kwargs): 19 | with open(json_file_path, 'w', encoding=encoding) as fout: 20 | json.dump(obj, fout, 21 | ensure_ascii=ensure_ascii, 22 | indent=indent, 23 | **kwargs) 24 | 25 | def default_load_pkl(pkl_file_path, **kwargs): 26 | with open(pkl_file_path, 'rb') as fin: 27 | obj = pickle.load(fin, **kwargs) 28 | 29 | return obj 30 | 31 | def default_dump_pkl(obj, pkl_file_path, **kwargs): 32 | with open(pkl_file_path, 'wb') as fout: 33 | pickle.dump(obj, fout, **kwargs) 34 | 35 | def set_basic_log_config(): 36 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 37 | datefmt='%Y-%m-%d %H:%M:%S', 38 | level=logging.INFO) 39 | 40 | class BERTChineseCharacterTokenizer(BertTokenizer): 41 | """Customized tokenizer for Chinese financial announcements""" 42 | 43 | def __init__(self, vocab_file, do_lower_case=True): 44 | super(BERTChineseCharacterTokenizer, self).__init__(vocab_file, do_lower_case) 45 | 46 | def char_tokenize(self, text, unk_token='[UNK]'): 47 | """perform pure character-based tokenization""" 48 | tokens = list(text) 49 | out_tokens = [] 50 | for token in tokens: 51 | if token in self.vocab: 52 | out_tokens.append(token) 53 | else: 54 | out_tokens.append(unk_token) 55 | 56 | return out_tokens 57 | 58 | def recursive_print_grad_fn(grad_fn, prefix='', depth=0, max_depth=50): 59 | if depth > max_depth: 60 | return 61 | print(prefix, depth, grad_fn.__class__.__name__) 62 | if hasattr(grad_fn, 'next_functions'): 63 | for nf in grad_fn.next_functions: 64 | ngfn = nf[0] 65 | recursive_print_grad_fn(ngfn, prefix=prefix + ' ', depth=depth+1, max_depth=max_depth) 66 | 67 | def strtobool(str_val): 68 | """Convert a string representation of truth to true (1) or false (0). 69 | 70 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values 71 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 72 | 'val' is anything else. 73 | """ 74 | str_val = str_val.lower() 75 | if str_val in ('y', 'yes', 't', 'true', 'on', '1'): 76 | return True 77 | elif str_val in ('n', 'no', 'f', 'false', 'off', '0'): 78 | return False 79 | else: 80 | raise ValueError("invalid truth value %r" % (str_val,)) 81 | -------------------------------------------------------------------------------- /figs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/figs/model.png -------------------------------------------------------------------------------- /figs/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/figs/result.png -------------------------------------------------------------------------------- /run_dee_task.py: -------------------------------------------------------------------------------- 1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG) 2 | 3 | import argparse 4 | import os 5 | import torch.distributed as dist 6 | 7 | from dee.utils import set_basic_log_config, strtobool 8 | from dee.dee_task import DEETask, DEETaskSetting 9 | from dee.dee_helper import aggregate_task_eval_info, print_total_eval_info, print_single_vs_multi_performance 10 | 11 | set_basic_log_config() 12 | 13 | 14 | def parse_args(in_args=None): 15 | arg_parser = argparse.ArgumentParser() 16 | arg_parser.add_argument('--task_name', type=str, required=True, 17 | help='Take Name') 18 | arg_parser.add_argument('--data_dir', type=str, default='./Data', 19 | help='Data directory') 20 | arg_parser.add_argument('--exp_dir', type=str, default='./Exps', 21 | help='Experiment directory') 22 | arg_parser.add_argument('--save_cpt_flag', type=strtobool, default=True, 23 | help='Whether to save cpt for each epoch') 24 | arg_parser.add_argument('--skip_train', type=strtobool, default=False, 25 | help='Whether to skip training') 26 | arg_parser.add_argument('--eval_model_names', type=str, default='GIT', 27 | help="Models to be evaluated") 28 | arg_parser.add_argument('--re_eval_flag', type=strtobool, default=False, 29 | help='Whether to re-evaluate previous predictions') 30 | 31 | # add task setting arguments 32 | for key, val in DEETaskSetting.base_attr_default_pairs: 33 | if isinstance(val, bool): 34 | arg_parser.add_argument('--' + key, type=strtobool, default=val) 35 | else: 36 | arg_parser.add_argument('--'+key, type=type(val), default=val) 37 | 38 | arg_info = arg_parser.parse_args(args=in_args) 39 | 40 | return arg_info 41 | 42 | 43 | if __name__ == '__main__': 44 | in_argv = parse_args() 45 | 46 | task_dir = os.path.join(in_argv.exp_dir, in_argv.task_name) 47 | if not os.path.exists(task_dir): 48 | os.makedirs(task_dir, exist_ok=True) 49 | 50 | in_argv.model_dir = os.path.join(task_dir, "Model") 51 | in_argv.output_dir = os.path.join(task_dir, "Output") 52 | 53 | # in_argv must contain 'data_dir', 'model_dir', 'output_dir' 54 | dee_setting = DEETaskSetting( 55 | **in_argv.__dict__ 56 | ) 57 | dee_setting.summary_dir_name = os.path.join(task_dir, "Summary") 58 | 59 | # build task 60 | dee_task = DEETask(dee_setting, load_train=not in_argv.skip_train) 61 | 62 | if not in_argv.skip_train: 63 | # dump hyper-parameter settings 64 | if dee_task.is_master_node(): 65 | fn = '{}.task_setting.json'.format(dee_setting.cpt_file_name) 66 | dee_setting.dump_to(task_dir, file_name=fn) 67 | 68 | dee_task.train(save_cpt_flag=in_argv.save_cpt_flag) 69 | else: 70 | dee_task.logging('Skip training') 71 | 72 | if dee_task.is_master_node(): 73 | if in_argv.re_eval_flag: 74 | data_span_type2model_str2epoch_res_list = dee_task.reevaluate_dee_prediction(dump_flag=True) 75 | else: 76 | data_span_type2model_str2epoch_res_list = aggregate_task_eval_info(in_argv.output_dir, dump_flag=True) 77 | data_type = 'test' 78 | span_type = 'pred_span' 79 | metric_type = 'micro' 80 | mstr_bepoch_list = print_total_eval_info( 81 | data_span_type2model_str2epoch_res_list, metric_type=metric_type, span_type=span_type, 82 | model_str=in_argv.eval_model_names, 83 | target_set=data_type 84 | ) 85 | print_single_vs_multi_performance( 86 | mstr_bepoch_list, in_argv.output_dir, dee_task.test_features, 87 | metric_type=metric_type, data_type=data_type, span_type=span_type 88 | ) 89 | 90 | # ensure every processes exit at the same time 91 | if dist.is_initialized(): 92 | dist.barrier() 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | GPU=0 4 | DATA_DIR=./Data 5 | EXP_DIR=./Exps 6 | COMMON_TASK_NAME=try 7 | EVAL_BS=2 8 | NUM_GPUS=1 9 | MODEL_STR=GIT 10 | 11 | echo "---> ${MODEL_STR} Run" 12 | CUDA_VISIBLE_DEVICES=${GPU} ./train_multi.sh ${NUM_GPUS} \ 13 | --data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} \ 14 | --eval_batch_size ${EVAL_BS} \ 15 | --cpt_file_name ${MODEL_STR} \ 16 | --skip_train True \ 17 | --re_eval_flag False 18 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | GPU=0,1,2,3,4,5,6,7 4 | DATA_DIR=./Data 5 | EXP_DIR=./Exps 6 | COMMON_TASK_NAME=try 7 | RESUME_TRAIN=True 8 | SAVE_CPT=True 9 | N_EPOCH=100 10 | TRAIN_BS=64 11 | EVAL_BS=2 12 | NUM_GPUS=8 13 | GRAD_ACC_STEP=8 14 | MODEL_STR=GIT 15 | 16 | echo "---> ${MODEL_STR} Run" 17 | CUDA_VISIBLE_DEVICES=${GPU} ./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \ 18 | --data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \ 19 | --train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \ 20 | --cpt_file_name ${MODEL_STR} 21 | -------------------------------------------------------------------------------- /train_multi.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | NUM_GPUS=$1 4 | shift 5 | 6 | python -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_dee_task.py $* 7 | --------------------------------------------------------------------------------