├── .gitignore ├── LICENSE ├── README.md ├── bre_eval.py ├── config.py ├── config ├── 4re_generative_model.json ├── bre_generative_model.json ├── generative_model.json └── ree_generative_model.json ├── constants.py ├── convert_grit.py ├── copy_bart.py ├── data.py ├── evaluate.py ├── model.py ├── raw_scripts ├── README.md ├── go_proc_doc.sh ├── go_proc_keys.sh ├── proc_keys.py ├── proc_texts.py ├── process_all_keys.py ├── process_all_keys.sh ├── process_scirex.py └── process_scirex.sh ├── ree_eval.py ├── requirements.txt ├── sagcopy.py ├── scirex_eval.py ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | *.pyc 3 | data/* 4 | trained_models/* 5 | bert/* 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kung-hsiang, Huang (Steeve) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/document-level-entity-based-extraction-as/role-filler-entity-extraction-on-muc-4)](https://paperswithcode.com/sota/role-filler-entity-extraction-on-muc-4?p=document-level-entity-based-extraction-as)
2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/document-level-entity-based-extraction-as/binary-relation-extraction-on-scirex)](https://paperswithcode.com/sota/binary-relation-extraction-on-scirex?p=document-level-entity-based-extraction-as)
3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/document-level-entity-based-extraction-as/4-ary-relation-extraction-on-scirex)](https://paperswithcode.com/sota/4-ary-relation-extraction-on-scirex?p=document-level-entity-based-extraction-as)
4 | # TempGen 5 | Source code for the EMNLP' 21 paper [Document-level Entity-based Extraction as Template Generation](https://arxiv.org/abs/2109.04901). 6 | 7 | ## Dependencies 8 | 9 | All the required packages are listed in `requirements.txt`. To install all the dependencies, run 10 | 11 | ``` 12 | conda create -n tg python=3.7 13 | conda activate tg 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | 18 | ## Data 19 | 20 | All data lies in directory `./data`. The processed REE output can be found at `data/muc34/proc_output/`. Files name with patterns `ree_*.json` refers to the train, dev, and test set data for role-filler entity extraction in our in-house representation. These files are converted from `grit_*.json`, which are the train, dev, and test copied from [GRIT's repo](https://github.com/xinyadu/grit_doc_event_entity/). The conversion script is `convert_grit.py`. An example of converting GRIT data into our in-house format is: 21 | 22 | ``` 23 | python convert_grit.py --input_path data/muc34/proc_output/grit_train.json --output_path data/muc34/proc_output/ree_train.json 24 | ``` 25 | 26 | As for SciREX, we downloaded the original dataset `data/scirex/release_data.tar.gz` from [the original SciREX repo](https://github.com/allenai/SciREX/tree/master/scirex_dataset). The extracted train, dev, and test files are located in `data/scirex/release_data`. These original data are transformed into our internal representations using `raw_scripts/process_scirex.sh` and stored in `data/scirex/proc_output`. The binary RE data does not have any post-fix, while the 4-ary RE data are post-fixxed with `_4ary`. 27 | 28 | ### Pre-processing 29 | 30 | We adpated some of the pre-processing code from [Du et al. 2021](https://arxiv.org/abs/2008.09249). To produce our training data, you need to navigate to `raw_script` and extract documents by running 31 | 32 | ``` 33 | bash go_proc_doc.sh 34 | ``` 35 | 36 | with __Python 2.7 !!__ . (Previous works use Python 2.7 for this step of pre-processing. Will upgrade this script later when I have time.) 37 | 38 | Then, use __Python 3.6__ or above to run the second pre-processing script for combining annotation and doucments. 39 | 40 | ``` 41 | bash process_all_keys.sh 42 | ``` 43 | 44 | Please refer to the `raw_script/READMD.md` for more details about the data format. 45 | 46 | 47 | ## Training 48 | 49 | Our formulation of document-level IE as template generation tasks allows the same model architecture applicable for role-filler entity extraction, binary relation extraction, and 4-ary relation extraction. Therefore, the same script `train.py` can be used for training models for all three tasks. The only difference in training models each task task is the config file. 50 | 51 | Role-filler entity extraction 52 | ``` 53 | python train.py -c config/ree_generative_model.json 54 | ``` 55 | Binary relation extraction 56 | ``` 57 | python train.py -c config/bre_generative_model.json 58 | ``` 59 | 4-ary relation extraction 60 | ``` 61 | python train.py -c config/4re_generative_model.json 62 | ``` 63 | 64 | The key difference between these two config files is the `task` field. Event template extraction has `task: ete`, while role-filler entity extraction has `task: ree`. To enable/ disable the Topk Copy mechanism, set `use_copy` to `true/ false`. 65 | 66 | 67 | 68 | 69 | ## Evaluation 70 | 71 | The evaluation scripts for MUC-4 REE and SciREX RE are `ree_eval.py` and `scirex_eval.py`, which are copied over from the [GRIT repo](https://github.com/xinyadu/grit_doc_event_entity/) and the [SciREX repo](https://github.com/allenai/SciREX). 72 | 73 | To run evaluation on trained models, execute the `evaluate.py` script as follows: 74 | ``` 75 | python evaluate.py --gpu 0 --checkpoint $PATH_TO_MODEL/best.mdl 76 | ``` 77 | passing `--gpu -1` can run evaluation on CPUs. 78 | 79 | The trained models can be downloaded from [here](https://drive.google.com/drive/folders/1D6-0mM7n3JeqXzspBtNWi6fQC4mJHdSb?usp=sharing) for reproduction purposes. 80 | 81 | The structure of this repo is based on OneIE (https://blender.cs.illinois.edu/software/oneie/) 82 | 83 | ## Citation 84 | ```bibtex 85 | @inproceedings{huang-etal-2021-tempgen, 86 | title = "Document-level Entity-based Extraction as Template Generation", 87 | author = "Huang, Kung-Hsiang and 88 | Tang, Sam and 89 | Peng, Nanyun", 90 | booktitle = "The 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 91 | year = "2021", 92 | address = "Online", 93 | publisher = "Association for Computational Linguistics", 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /bre_eval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied over from https://github.com/allenai/SciREX/blob/master/scirex/evaluation_scripts/scirex_relation_evaluate.py 3 | ''' 4 | from typing import Dict, List 5 | import os 6 | from itertools import combinations 7 | import pandas as pd 8 | from copy import deepcopy 9 | 10 | BASEPATH = os.getenv("RESULT_EXTRACTION_BASEPATH", ".") 11 | 12 | available_entity_types_sciERC = ["Material", "Metric", "Task", "Generic", "OtherScientificTerm", "Method"] 13 | map_available_entity_to_true = {"Material": "dataset", "Metric": "metric", "Task": "task", "Method": "model_name"} 14 | map_true_entity_to_available = {v: k for k, v in map_available_entity_to_true.items()} 15 | 16 | used_entities = list(map_available_entity_to_true.keys()) 17 | true_entities = list(map_available_entity_to_true.values()) 18 | 19 | def has_all_mentions(doc, relation): 20 | has_mentions = all(len(doc["coref"][x[1]]) > 0 for x in relation) 21 | return has_mentions 22 | 23 | def compute_mapping(predicted_relations: List[Dict[str, str]], 24 | gold_entities: Dict[str, List], 25 | doc_tokens: List[str]): 26 | ''' 27 | Each relation in predicted_relations is a dict with two elements (for binary relation). e.g. 28 | { 29 | 'Metric': 'accuracy', 30 | 'Task': 'Natural language inference', 31 | } 32 | ''' 33 | # make a copy so we don't alter the original data 34 | gold_entities = deepcopy(gold_entities) 35 | predicted_mentions = set([mention for relation in predicted_relations for mention in relation.values()]) 36 | 37 | # # Assign each mention to one gold entity. 38 | predicted_mention2gold_entity_name : Dict[str, str] = {} 39 | for predicted_mention in predicted_mentions: 40 | gold_entity_name_to_pop = None 41 | for gold_entity_name, gold_mention_spans in gold_entities.items(): 42 | gold_mentions = { ' '.join(doc_tokens[start_tok:end_tok]) for (start_tok, end_tok) in gold_mention_spans} 43 | if predicted_mention in gold_mentions: 44 | gold_entity_name_to_pop = gold_entity_name 45 | predicted_mention2gold_entity_name[predicted_mention] = gold_entity_name 46 | break 47 | # Make sure each gold entity is only assigned once. 48 | if gold_entity_name_to_pop is not None: 49 | gold_entities.pop(gold_entity_name_to_pop) 50 | 51 | else: 52 | print(f"Cannot find span for {predicted_mention}") 53 | 54 | 55 | return predicted_mention2gold_entity_name 56 | 57 | 58 | 59 | 60 | def bre_eval(predicted_relations, gold_data): 61 | 62 | all_metrics = [] 63 | 64 | for types in combinations(used_entities, 2): 65 | for doc in gold_data: 66 | relations = predicted_relations[doc["doc_id"]] 67 | 68 | mapping = compute_mapping(relations, doc['coref'], doc["words"]) 69 | 70 | for relation in relations: 71 | for entity_type, entity_name in relation.items(): 72 | relation[entity_type] = mapping.get(entity_name, entity_name) 73 | 74 | # each iteration only evaluate those of corresponding types 75 | relations = set([tuple((t, x[t]) for t in types) for x in relations if all(t in x.keys() for t in types)]) 76 | 77 | gold_relations = [tuple((t, x[t]) for t in types) for x in doc['n_ary_relations']] 78 | gold_relations = set([x for x in gold_relations if has_all_mentions(doc, x)]) 79 | 80 | matched = relations & gold_relations 81 | 82 | metrics = { 83 | "p": len(matched) / (len(relations) + 1e-7), 84 | "r": len(matched) / (len(gold_relations) + 1e-7), 85 | } 86 | metrics["f1"] = 2 * metrics["p"] * metrics["r"] / (metrics["p"] + metrics["r"] + 1e-7) 87 | 88 | if len(gold_relations) > 0: 89 | all_metrics.append(metrics) 90 | 91 | all_metrics = pd.DataFrame(all_metrics) 92 | print("Relation Metrics n=2") 93 | print(all_metrics.describe().loc['mean'][['p', 'r', 'f1']]) 94 | 95 | # take the mean value 96 | return all_metrics.describe().loc['mean'][['p', 'r', 'f1']].to_dict() -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from constants import * 5 | 6 | from transformers import AutoConfig 7 | 8 | class Config(object): 9 | def __init__(self, **kwargs): 10 | self.coref = kwargs.pop('coref', False) 11 | # bert 12 | self.bert_model_name = kwargs.pop('bert_model_name', 'bert-large-cased') 13 | self.bert_cache_dir = kwargs.pop('bert_cache_dir', None) 14 | self.extra_bert = kwargs.pop('extra_bert', -1) 15 | self.use_extra_bert = kwargs.pop('use_extra_bert', False) 16 | # model 17 | # self.multi_piece_strategy = kwargs.pop('multi_piece_strategy', 'first') 18 | self.bert_dropout = kwargs.pop('bert_dropout', .5) 19 | self.linear_dropout = kwargs.pop('linear_dropout', .4) 20 | self.linear_bias = kwargs.pop('linear_bias', True) 21 | self.linear_activation = kwargs.pop('linear_activation', 'relu') 22 | 23 | # decoding 24 | self.max_position_embeddings = kwargs.pop('max_position_embeddings', 2048) 25 | self.num_beams = kwargs.pop('num_beams', 4) 26 | self.decoding_method = kwargs.pop('decoding_method', "greedy") 27 | 28 | # files 29 | self.train_file = kwargs.pop('train_file', None) 30 | self.dev_file = kwargs.pop('dev_file', None) 31 | self.test_file = kwargs.pop('test_file', None) 32 | self.valid_pattern_path = kwargs.pop('valid_pattern_path', None) 33 | self.log_path = kwargs.pop('log_path', './log') 34 | self.output_path = kwargs.pop('output_path', './output') 35 | self.grit_dev_file = kwargs.pop('grit_dev_file', None) 36 | self.grit_test_file = kwargs.pop('grit_test_file', None) 37 | 38 | # training 39 | self.accumulate_step = kwargs.pop('accumulate_step', 1) 40 | self.batch_size = kwargs.pop('batch_size', 10) 41 | self.eval_batch_size = kwargs.pop('eval_batch_size', 5) 42 | self.max_epoch = kwargs.pop('max_epoch', 50) 43 | self.max_length = kwargs.pop('max_length', 128) 44 | self.learning_rate = kwargs.pop('learning_rate', 1e-3) 45 | self.bert_learning_rate = kwargs.pop('bert_learning_rate', 1e-5) 46 | self.weight_decay = kwargs.pop('weight_decay', 0.001) 47 | self.bert_weight_decay = kwargs.pop('bert_weight_decay', 0.00001) 48 | self.warmup_epoch = kwargs.pop('warmup_epoch', 5) 49 | self.grad_clipping = kwargs.pop('grad_clipping', 5.0) 50 | self.SOT_weights = kwargs.pop('SOT_weights', 100) 51 | self.permute_slots = kwargs.pop('permute_slots', False) 52 | 53 | self.task = kwargs.pop('task',EVENT_TEMPLATE_EXTRACTION) # task cannot be empty 54 | 55 | # others 56 | self.use_gpu = kwargs.pop('use_gpu', True) 57 | self.gpu_device = kwargs.pop('gpu_device', 0) 58 | self.seed = kwargs.pop('seed', 0) 59 | self.use_copy = kwargs.pop('use_copy', False) 60 | self.use_SAGCopy = kwargs.pop('use_SAGCopy', False) 61 | self.k = kwargs.pop('k', 12) 62 | 63 | 64 | 65 | @classmethod 66 | def from_dict(cls, dict_obj): 67 | """Creates a Config object from a dictionary. 68 | Args: 69 | dict_obj (Dict[str, Any]): a dict where keys are 70 | """ 71 | config = cls() 72 | for k, v in dict_obj.items(): 73 | setattr(config, k, v) 74 | return config 75 | 76 | @classmethod 77 | def from_json_file(cls, path): 78 | with open(path, 'r', encoding='utf-8') as r: 79 | return cls.from_dict(json.load(r)) 80 | 81 | def to_dict(self): 82 | output = copy.deepcopy(self.__dict__) 83 | return output 84 | 85 | def save_config(self, path): 86 | """Save a configuration object to a file. 87 | :param path (str): path to the output file or its parent directory. 88 | """ 89 | if os.path.isdir(path): 90 | path = os.path.join(path, 'config.json') 91 | print('Save config to {}'.format(path)) 92 | with open(path, 'w', encoding='utf-8') as w: 93 | w.write(json.dumps(self.to_dict(), indent=2, 94 | sort_keys=True)) 95 | @property 96 | def bert_config(self): 97 | 98 | 99 | return AutoConfig.from_pretrained(self.bert_model_name, 100 | cache_dir=self.bert_cache_dir, 101 | max_position_embeddings=self.max_position_embeddings) 102 | 103 | -------------------------------------------------------------------------------- /config/4re_generative_model.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "debug": true, 4 | "accumulate_step": 32, 5 | "batch_size": 1, 6 | "eval_batch_size": 4, 7 | 8 | "bert_cache_dir": "./bert", 9 | "bert_dropout": 0.0, 10 | "bert_learning_rate": 5e-05, 11 | "bert_model_name": "facebook/bart-base", 12 | "bert_weight_decay": 1e-05, 13 | 14 | 15 | "decoding_method": "beam_search", 16 | "train_file": "data/scirex/proc_output/train_4ary.json", 17 | "dev_file": "data/scirex/proc_output/dev_4ary.json", 18 | "test_file": "data/scirex/proc_output/test_4ary.json", 19 | 20 | 21 | "SOT_weights":25, 22 | "permute_slots":false, 23 | "max_position_embeddings":512, 24 | "grad_clipping": 5.0, 25 | "task":"4re", 26 | "scirex_dev_file":"data/scirex/release_data/dev.jsonl", 27 | "scirex_test_file":"data/scirex/release_data/test.jsonl", 28 | "log_path": "./log/4re_bart_gen", 29 | "k":10, 30 | "max_epoch": 50, 31 | "max_length": 1024, 32 | 33 | "num_beams":4, 34 | 35 | "output_path": "./output/4re_bart_gen", 36 | 37 | "use_copy":true, 38 | "use_SAGCopy":false, 39 | "use_extra_bert": true, 40 | "use_gpu": true, 41 | "gpu_device": 1, 42 | "warmup_epoch": 5, 43 | "weight_decay": 0.001 44 | } -------------------------------------------------------------------------------- /config/bre_generative_model.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "debug": true, 4 | "accumulate_step": 32, 5 | "batch_size": 1, 6 | "eval_batch_size": 4, 7 | 8 | "bert_cache_dir": "./bert", 9 | "bert_dropout": 0.0, 10 | "bert_learning_rate": 5e-05, 11 | "bert_model_name": "facebook/bart-base", 12 | "bert_weight_decay": 1e-05, 13 | 14 | 15 | "decoding_method": "beam_search", 16 | "train_file": "data/scirex/proc_output/train.json", 17 | "dev_file": "data/scirex/proc_output/dev.json", 18 | "test_file": "data/scirex/proc_output/test.json", 19 | 20 | 21 | "SOT_weights":25, 22 | "permute_slots":false, 23 | "max_position_embeddings":512, 24 | "grad_clipping": 5.0, 25 | "task":"bre", 26 | "scirex_dev_file":"data/scirex/release_data/dev.jsonl", 27 | "scirex_test_file":"data/scirex/release_data/test.jsonl", 28 | "log_path": "./log/bre_bart_gen", 29 | "k":10, 30 | "max_epoch": 150, 31 | "max_length": 1024, 32 | 33 | "num_beams":4, 34 | 35 | "output_path": "./output/bre_bart_gen", 36 | 37 | "use_copy":true, 38 | "use_SAGCopy":false, 39 | "use_extra_bert": true, 40 | "use_gpu": true, 41 | "gpu_device": 3, 42 | "warmup_epoch": 5, 43 | "weight_decay": 0.001 44 | } -------------------------------------------------------------------------------- /config/generative_model.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "debug": true, 4 | "accumulate_step": 8, 5 | "batch_size": 2, 6 | "eval_batch_size": 2, 7 | 8 | "bert_cache_dir": "./bert", 9 | "bert_dropout": 0.0, 10 | "bert_learning_rate": 1e-05, 11 | "bert_model_name": "facebook/bart-base", 12 | "bert_weight_decay": 1e-05, 13 | 14 | 15 | "decoding_method": "greedy", 16 | "train_file": "data/muc34/proc_output/train.json", 17 | "dev_file": "data/muc34/proc_output/dev.json", 18 | "test_file": "data/muc34/proc_output/test.json", 19 | "early_stop_patient": 15, 20 | "early_stop_use": false, 21 | 22 | 23 | "SOT_weights":5, 24 | "permute_slots":true, 25 | "max_position_embeddings":1024, 26 | "grad_clipping": 5.0, 27 | "task":"ete", 28 | 29 | 30 | "log_path": "./log/bart_gen", 31 | 32 | "max_epoch": 100, 33 | "max_length": 512, 34 | 35 | "num_beams":4, 36 | 37 | "output_path": "./output/bart_gen", 38 | 39 | 40 | "use_extra_bert": true, 41 | "use_gpu": true, 42 | "gpu_device": 2, 43 | "warmup_epoch": 5, 44 | "weight_decay": 0.001 45 | } -------------------------------------------------------------------------------- /config/ree_generative_model.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "debug": true, 4 | "accumulate_step": 16, 5 | "batch_size": 2, 6 | "eval_batch_size": 4, 7 | 8 | "bert_cache_dir": "./bert", 9 | "bert_dropout": 0.0, 10 | "bert_learning_rate": 5e-05, 11 | "bert_model_name": "facebook/bart-base", 12 | "bert_weight_decay": 1e-05, 13 | 14 | 15 | "decoding_method": "beam_search", 16 | "train_file": "data/muc34/proc_output/ree_train.json", 17 | "dev_file": "data/muc34/proc_output/ree_dev.json", 18 | "test_file": "data/muc34/proc_output/ree_test.json", 19 | 20 | 21 | "SOT_weights":5, 22 | "permute_slots":false, 23 | "max_position_embeddings":1024, 24 | "grad_clipping": 5.0, 25 | "task":"ree", 26 | "grit_dev_file":"data/muc34/proc_output/grit_dev.json", 27 | "grit_test_file":"data/muc34/proc_output/grit_test.json", 28 | "log_path": "./log/ree_bart_gen", 29 | "k":10, 30 | "max_epoch": 150, 31 | "max_length": 512, 32 | 33 | "num_beams":4, 34 | 35 | "output_path": "./output/ree_bart_gen", 36 | "use_copy":true, 37 | "use_SAGCopy":false, 38 | 39 | "use_extra_bert": true, 40 | "use_gpu": true, 41 | "gpu_device": 0, 42 | "warmup_epoch": 5, 43 | "weight_decay": 0.001 44 | } -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | START_OF_SLOT_NAME= '' 2 | END_OF_SLOT_NAME= '' 3 | START_OF_ENTITY= '' 4 | END_OF_ENTITY= '' 5 | START_OF_TEMPLATE='' 6 | END_OF_TEMPLATE= '' 7 | SPECIAL_TOKENS = [START_OF_SLOT_NAME, END_OF_SLOT_NAME, START_OF_ENTITY, END_OF_ENTITY, START_OF_TEMPLATE, END_OF_TEMPLATE] 8 | 9 | # these variables are for decoding 10 | SLOT_NAME_TAG=0 11 | ENTITY_TAG=1 12 | 13 | ROLE_FILLER_ENTITY_EXTRACTION='ree' 14 | EVENT_TEMPLATE_EXTRACTION='ete' 15 | BINARY_RELATION_EXTRACTION='bre' 16 | FOUR_ARY_RELATION_EXTRACTION='4re' 17 | 18 | PERP_IND='PerpInd' 19 | PERP_ORG='PerpOrg' 20 | TARGET='Target' 21 | VICTIM='Victim' 22 | WEAPON='Weapon' 23 | REE_ROLES = [PERP_IND, PERP_ORG, TARGET, VICTIM, WEAPON] 24 | -------------------------------------------------------------------------------- /convert_grit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import nltk 4 | # these are for splitting doctext to sentences 5 | nltk.download('punkt') 6 | sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') 7 | 8 | def process_entities(entities): 9 | 10 | ''' 11 | [ 12 | [ 13 | ['guerrillas', 37], 14 | ['guerrilla column', 349] 15 | ], 16 | [ 17 | ['apple', 45] 18 | ], 19 | [ 20 | ['banana', 60] 21 | ] 22 | ] 23 | -> [['guerrillas, guerrilla column'], ['apple'], ['banana']] 24 | ''' 25 | 26 | res = [] 27 | for entity in entities: 28 | 29 | # take only the string 30 | res.append([mention[0] for mention in entity]) 31 | 32 | return res 33 | 34 | def convert(doc, capitalize=False): 35 | ''' 36 | doc: a dictionary that has the following format: 37 | 38 | {'docid': 'TST1-MUC3-0001', 39 | 'doctext': 'the guatemala army denied today that guerrillas attacked the "santo tomas" presidential farm, located on the pacific side, where president cerezo has been staying since 2 february. a report published by the "cerigua" news agency -- mouthpiece of the guatemalan national revolutionary unity (urng) -- whose main offices are in mexico, says that a guerrilla column attacked the farm 2 days ago. however, armed forces spokesman colonel luis arturo isaacs said that the attack, which resulted in the death of a civilian who was passing by at the time of the skirmish, was not against the farm, and that president cerezo is safe and sound. he added that on 3 february president cerezo met with the diplomatic corps accredited in guatemala. the government also issued a communique describing the rebel report as "false and incorrect," and stressing that the president was never in danger. col isaacs said that the guerrillas attacked the "la eminencia" farm located near the "santo tomas" farm, where they burned the facilities and stole food. a military patrol clashed with a rebel column and inflicted three casualties, which were taken away by the guerrillas who fled to the mountains, isaacs noted. he also reported that guerrillas killed a peasant in the city of flores, in the northern el peten department, and burned a tank truck.', 40 | 'extracts': {'PerpInd': [[['guerrillas', 37], ['guerrilla column', 349]]], 41 | 'PerpOrg': [[['guatemalan national revolutionary unity', 253], 42 | ['urng', 294]]], 43 | 'Target': [[['"santo tomas" presidential farm', 61], 44 | ['presidential farm', 75]], 45 | [['farm', 88], ['"la eminencia" farm', 947]], 46 | [['facilities', 1026]], 47 | [['tank truck', 1341], ['truck', 1346]]], 48 | 'Victim': [[['cerezo', 139]]], 49 | 'Weapon': []}} 50 | 51 | capitalize: whether to capitalize doctext or not 52 | ''' 53 | 54 | res = { 55 | 'docid': doc['docid'], 56 | 'document': doc['doctext'], # the raw text document. 57 | 'annotation': [] # A list of templates. In role-filler entity extraction, we only have one template for each don't care about this. 58 | } 59 | 60 | if capitalize: 61 | # split doctext into sentences 62 | sentences = sent_tokenizer.tokenize(doc['doctext']) 63 | capitalized_doctext = ' '.join([sent.capitalize() for sent in sentences]) 64 | res['document'] = capitalized_doctext 65 | 66 | # process "\n\n" and "\n" https://github.com/xinyadu/grit_doc_event_entity/blob/master/data/muc/scripts/preprocess.py 67 | # paragraphs = doc['doctext'].split("\n\n") 68 | # paragraphs_no_n = [] 69 | # for para in paragraphs: 70 | # para = " ".join(para.split("\n")) 71 | # paragraphs_no_n.append(para) 72 | # doc_text_no_n = " ".join(paragraphs_no_n) 73 | 74 | # TODO: add "tags" in the document 75 | # res['document'] = doc_text_no_n 76 | 77 | annotation = doc['extracts'] 78 | for role, entities in annotation.items(): 79 | # make sure entities is not an empty list 80 | if entities: 81 | # make sure res['annotation'] has one dictionary 82 | if len(res['annotation']) == 0: 83 | res['annotation'].append({}) 84 | res['annotation'][0][role] = process_entities(entities) 85 | 86 | return res 87 | 88 | if __name__ == '__main__': 89 | 90 | p = argparse.ArgumentParser("Convert GRIT input data into ours format.") 91 | 92 | p.add_argument('--input_path', type=str, help="input file in GRIT format.") 93 | p.add_argument('--output_path',type=str, help="path to store the output json file.") 94 | p.add_argument('--capitalize',action="store_true", help="whether to capitalize the first char of each sentence") 95 | args = p.parse_args() 96 | 97 | with open(args.input_path, 'r') as f: 98 | grit_inputs = [json.loads(l) for l in f.readlines()] 99 | 100 | all_processed_doc = dict() 101 | 102 | # iterate thru and process all grit documents 103 | for grit_doc in grit_inputs: 104 | 105 | processed = convert(grit_doc, args.capitalize) 106 | doc_id = processed.pop('docid') 107 | all_processed_doc[doc_id] = processed 108 | 109 | with open(args.output_path, 'w') as f: 110 | f.write(json.dumps(all_processed_doc)) -------------------------------------------------------------------------------- /copy_bart.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied over from https://huggingface.co/transformers/_modules/transformers/models/bart/modeling_bart.html 3 | ''' 4 | import torch 5 | import torch.nn.functional as F 6 | from transformers import BartForConditionalGeneration, BartModel, BartConfig 7 | from transformers.modeling_outputs import Seq2SeqLMOutput 8 | class CopyBartForConditionalGeneration(BartForConditionalGeneration): 9 | 10 | def __init__(self, config: BartConfig): 11 | super().__init__(config) 12 | 13 | self.selected_heads = None 14 | 15 | def forward( 16 | self, 17 | input_ids=None, 18 | attention_mask=None, 19 | decoder_input_ids=None, 20 | decoder_attention_mask=None, 21 | head_mask=None, 22 | decoder_head_mask=None, 23 | encoder_outputs=None, 24 | past_key_values=None, 25 | inputs_embeds=None, 26 | decoder_inputs_embeds=None, 27 | labels=None, 28 | use_cache=None, 29 | output_attentions=None, 30 | output_hidden_states=None, 31 | return_dict=None, 32 | ): 33 | r""" 34 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 35 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., 36 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored 37 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. 38 | 39 | Returns: 40 | """ 41 | assert output_attentions or self.model.config.output_attentions, "output_attentions must be true" 42 | 43 | # original outputs 44 | outputs = self.model(input_ids, 45 | attention_mask=attention_mask, 46 | decoder_input_ids=decoder_input_ids, 47 | encoder_outputs=encoder_outputs, 48 | decoder_attention_mask=decoder_attention_mask, 49 | head_mask=head_mask, 50 | decoder_head_mask=decoder_head_mask, 51 | past_key_values=past_key_values, 52 | inputs_embeds=inputs_embeds, 53 | decoder_inputs_embeds=decoder_inputs_embeds, 54 | use_cache=use_cache, 55 | output_attentions=output_attentions, 56 | output_hidden_states=output_hidden_states, 57 | return_dict=return_dict,) 58 | 59 | if input_ids is None: 60 | input_ids = self._cache_input_ids 61 | 62 | # if self.selected_heads is None: 63 | # take the cross attention non-linear function 64 | cross_attention_non_linear = self.model.decoder.layers[-1].encoder_attn.out_proj.weight # (emb_dim, emb_dim) 65 | cross_attention_non_linear_sum = cross_attention_non_linear.view(self.config.decoder_attention_heads, -1).abs().sum(1) # (num_heads) 66 | _, selected_heads = torch.topk(cross_attention_non_linear_sum, k=self._k) 67 | self.selected_heads = selected_heads 68 | 69 | encoder_last_hidden_state = outputs.encoder_last_hidden_state # (batch, seq, hidden) 70 | decoder_last_hidden_state = outputs[0] #(batch, decoding_seq, hidden ) 71 | 72 | 73 | # compute lm logits based on attention 74 | last_cross_attentions = outputs.cross_attentions[-1] # (batch_size, num_heads, decoding_seq_length, encoding_seq_length). 75 | 76 | 77 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias #(batch_size, decoding_seq_length, emb_dim) 78 | 79 | 80 | 81 | cross_attentions_aggregate = last_cross_attentions[:,self.selected_heads,:,:].mean(dim=1) #(batch, decoding_seq_length, encoding_seq_length) 82 | 83 | 84 | dummy_input_ids = input_ids.unsqueeze(-1).expand(-1, -1, lm_logits.size(1)).transpose(1,2) # (batch, decoding_seq_length, encoding_seq_length) 85 | copy_logits = torch.zeros_like(lm_logits) # (batch, decoding_seq_length, emb_dim) 86 | copy_logits.scatter_add_(dim=2, index=dummy_input_ids, src=cross_attentions_aggregate) 87 | 88 | 89 | p_gen = torch.bmm(decoder_last_hidden_state, encoder_last_hidden_state.mean(dim=1).unsqueeze(dim=-1)) # (batch, decoding_seq, 1) 90 | p_gen = torch.sigmoid(p_gen) 91 | 92 | 93 | lm_logits = F.softmax(lm_logits, dim=-1) * p_gen + copy_logits * (1 - p_gen)#(batch_size, decoding_seq_length, emb_dim) 94 | 95 | 96 | 97 | 98 | masked_lm_loss = None 99 | if labels is not None: 100 | # compute loss mask and fill -100 with 0 101 | loss_mask = labels != -100 102 | labels.masked_fill_(~loss_mask, 0) 103 | # use negative log likelihood 104 | gold_probs = torch.gather(lm_logits, 2, labels.unsqueeze(2)).squeeze(2) 105 | eps = 1e-7 # for safe log 106 | masked_lm_loss = - torch.log(gold_probs + eps) * self._loss_weight[labels] 107 | masked_lm_loss = (masked_lm_loss * loss_mask).mean() 108 | 109 | 110 | 111 | if not return_dict: 112 | output = (lm_logits,) + outputs[1:] 113 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 114 | 115 | 116 | return Seq2SeqLMOutput( 117 | loss=masked_lm_loss, 118 | logits=lm_logits, 119 | past_key_values=outputs.past_key_values, 120 | decoder_hidden_states=outputs.decoder_hidden_states, 121 | decoder_attentions=outputs.decoder_attentions, 122 | cross_attentions=outputs.cross_attentions, 123 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 124 | encoder_hidden_states=outputs.encoder_hidden_states, 125 | encoder_attentions=outputs.encoder_attentions, 126 | ) 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from constants import * 3 | from collections import namedtuple 4 | from util import token2sub_tokens 5 | import json 6 | import torch 7 | 8 | instance_fields = [ 9 | 'doc_id', 'input_ids', 'attention_mask','decoder_input_chunks', 'input_tokens','document' 10 | ] 11 | 12 | batch_fields = [ 13 | 'doc_ids', 'input_ids', 'attention_masks','decoder_input_chunks', 'input_tokens','document' 14 | ] 15 | 16 | Instance = namedtuple('Instance', field_names=instance_fields, 17 | defaults=[None] * len(instance_fields)) 18 | Batch = namedtuple('Batch', field_names=batch_fields, 19 | defaults=[None] * len(batch_fields)) 20 | 21 | class IEDataset(Dataset): 22 | def __init__(self, path, max_length=128, gpu=False): 23 | """ 24 | :param path (str): path to the data file. 25 | :param max_length (int): max sentence length. 26 | :param gpu (bool): use GPU (default=False). 27 | :param ignore_title (bool): Ignore sentences that are titles (default=False). 28 | """ 29 | self.path = path 30 | self.data = [] 31 | self.max_length = max_length 32 | self.gpu = gpu 33 | 34 | self.load_data() 35 | 36 | def __len__(self): 37 | return len(self.data) 38 | 39 | def __getitem__(self, item): 40 | return self.data[item] 41 | 42 | 43 | def load_data(self): 44 | """Load data from file.""" 45 | overlength_num = title_num = 0 46 | with open(self.path, 'r', encoding='utf-8') as r: 47 | 48 | self.data = json.loads(r.read()) 49 | 50 | 51 | def create_decoder_input_chunks(self, templates, tokenizer): 52 | 53 | ''' 54 | `templates` is a list of dict. 55 | [ 56 | { 57 | MESSAGE-TEMPLATE': '1', 58 | 'INCIDENT-DATE': '28 AUG 89', 59 | ... 60 | }, 61 | { 62 | 'MESSAGE-TEMPLATE': '2', 63 | 'INCIDENT-DATE': '- 30 AUG 89', 64 | ... 65 | } 66 | ] 67 | Parse the templates and create a chunk of ids 68 | [tokenizer.eos_token_id, [[template_1_entity_1],[template_1_entity_2], ...],[[template_2_entit_1],[template_2_entity_2],...], tokenizer.sep_token_id ] 69 | ''' 70 | 71 | 72 | # Bart uses the eos_token_id as the starting token for decoder_input_ids generation. If past_key_values is used, optionally only the last decoder_input_ids have to be input (see past_key_values) 73 | res = [] 74 | for template in templates: 75 | current_template_chunk = [] 76 | for entity_key, entity_values in template.items(): 77 | 78 | 79 | if isinstance(entity_values, list): 80 | for entity_value in entity_values: 81 | 82 | # Add " " so that the token will be the same subtoken as the input document 83 | mentions = [[START_OF_ENTITY, " " + mention.strip(" ") +" ", END_OF_ENTITY] for mention in entity_value ] 84 | entity = [] 85 | 86 | # create a chunk for 87 | for mention in mentions: 88 | 89 | entity_tokens = [START_OF_SLOT_NAME, entity_key, END_OF_SLOT_NAME] + mention 90 | 91 | mention_chunk = [] 92 | for entity_token in entity_tokens: 93 | mention_chunk += token2sub_tokens(tokenizer, entity_token) 94 | entity.append(mention_chunk) 95 | current_template_chunk.append(entity) 96 | else: 97 | raise NotImplementedError 98 | 99 | res.append(current_template_chunk) 100 | 101 | 102 | return res 103 | 104 | 105 | def numberize(self, tokenizer, vocabs): 106 | """Numberize word pieces, labels, etcs. 107 | :param tokenizer: Bert tokenizer. 108 | :param vocabs (dict): a dict of vocabularies. 109 | """ 110 | 111 | 112 | data = [] 113 | for doc_id, content in self.data.items(): 114 | 115 | document = content['document'] 116 | annotation = content['annotation'] 117 | 118 | 119 | input_ids = tokenizer([document], max_length=self.max_length, truncation=True)['input_ids'][0] 120 | 121 | pad_num = self.max_length - len(input_ids) 122 | attn_mask = [1] * len(input_ids) + [0] * pad_num 123 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_num 124 | 125 | 126 | decoder_input_chunks = self.create_decoder_input_chunks(annotation, tokenizer) 127 | 128 | 129 | assert len(input_ids) == self.max_length, len(input_ids) 130 | 131 | input_tokens = tokenizer.decode(input_ids) 132 | # print("decoder_input_chunks", decoder_input_chunks) 133 | instance = Instance( 134 | doc_id=doc_id, 135 | input_ids=input_ids, 136 | attention_mask=attn_mask, 137 | decoder_input_chunks=decoder_input_chunks, 138 | input_tokens=input_tokens, 139 | document=document 140 | ) 141 | data.append(instance) 142 | self.data = data 143 | 144 | def collate_fn(self, batch): 145 | batch_input_ids = [] 146 | batch_attention_masks = [] 147 | batch_decoder_input_chunks = [] 148 | batch_input_tokens = [] 149 | batch_document = [] 150 | 151 | doc_ids = [inst.doc_id for inst in batch] 152 | 153 | for inst in batch: 154 | batch_input_ids.append(inst.input_ids) 155 | batch_attention_masks.append(inst.attention_mask) 156 | batch_decoder_input_chunks.append(inst.decoder_input_chunks) 157 | batch_input_tokens.append(inst.input_tokens) 158 | batch_document.append(inst.document) 159 | 160 | if self.gpu: 161 | batch_input_ids = torch.cuda.LongTensor(batch_input_ids) 162 | batch_attention_masks = torch.cuda.FloatTensor(batch_attention_masks) 163 | 164 | else: 165 | batch_input_ids = torch.LongTensor(batch_input_ids) 166 | batch_attention_masks = torch.FloatTensor(batch_attention_masks) 167 | 168 | return Batch( 169 | doc_ids=doc_ids, 170 | input_ids=batch_input_ids, 171 | attention_masks=batch_attention_masks, 172 | decoder_input_chunks=batch_decoder_input_chunks, 173 | input_tokens=batch_input_tokens, 174 | document=batch_document 175 | ) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from argparse import ArgumentParser 5 | 6 | import numpy as np 7 | import tqdm 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from transformers import AutoTokenizer 11 | 12 | from model import GenerativeModel 13 | from config import Config 14 | from data import IEDataset 15 | from constants import * 16 | from util import * 17 | import ree_eval 18 | import scirex_eval 19 | 20 | # configuration 21 | parser = ArgumentParser() 22 | parser.add_argument('--gpu', type=int, required=True) 23 | parser.add_argument('--checkpoint', type=str, required=True) 24 | args = parser.parse_args() 25 | 26 | use_gpu = args.gpu > -1 27 | checkpoint = torch.load(args.checkpoint, map_location=f'cuda:{args.gpu}' if use_gpu else 'cpu') 28 | config = Config.from_dict(checkpoint['config']) 29 | 30 | # set GPU device 31 | config.gpu_device = args.gpu 32 | config.use_gpu = use_gpu 33 | # fix random seed 34 | random.seed(config.seed) 35 | np.random.seed(config.seed) 36 | torch.manual_seed(config.seed) 37 | torch.backends.cudnn.enabled = False 38 | 39 | if use_gpu and config.gpu_device >= 0: 40 | torch.cuda.set_device(config.gpu_device) 41 | 42 | # datasets 43 | model_name = config.bert_model_name 44 | 45 | tokenizer = AutoTokenizer.from_pretrained(model_name, 46 | cache_dir=config.bert_cache_dir) 47 | tokenizer.add_tokens(SPECIAL_TOKENS) 48 | # special_tokens_dict = {'additional_special_tokens': SPECIAL_TOKENS} 49 | # num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 50 | 51 | 52 | # print('==============Prepare Training Set=================') 53 | # train_set = IEDataset(config.train_file, max_length=config.max_length, gpu=use_gpu) 54 | print('==============Prepare Dev Set=================') 55 | dev_set = IEDataset(config.dev_file, max_length=config.max_length, gpu=use_gpu) 56 | print('==============Prepare Test Set=================') 57 | test_set = IEDataset(config.test_file, max_length=config.max_length, gpu=use_gpu) 58 | vocabs = {} 59 | 60 | # print('==============Prepare Training Set=================') 61 | # train_set.numberize(tokenizer, vocabs) 62 | print('==============Prepare Dev Set=================') 63 | dev_set.numberize(tokenizer, vocabs) 64 | print('==============Prepare Test Set=================') 65 | test_set.numberize(tokenizer, vocabs) 66 | 67 | if config.task == ROLE_FILLER_ENTITY_EXTRACTION: 68 | grit_dev = read_grit_gold_file(config.grit_dev_file) 69 | grit_test = read_grit_gold_file(config.grit_test_file) 70 | elif config.task in {BINARY_RELATION_EXTRACTION, FOUR_ARY_RELATION_EXTRACTION}: 71 | scirex_dev = read_scirex_gold_file(config.scirex_dev_file) 72 | scirex_test = read_scirex_gold_file(config.scirex_test_file) 73 | 74 | 75 | dev_batch_num = len(dev_set) // config.eval_batch_size + \ 76 | (len(dev_set) % config.eval_batch_size != 0) 77 | test_batch_num = len(test_set) // config.eval_batch_size + \ 78 | (len(test_set) % config.eval_batch_size != 0) 79 | 80 | output_dir = '/'.join(args.checkpoint.split('/')[:-1]) 81 | dev_result_file = os.path.join(output_dir, 'dev.out.json') 82 | test_result_file = os.path.join(output_dir, 'test.out.json') 83 | # initialize the model 84 | 85 | model = GenerativeModel(config, vocabs) 86 | model.load_bert(model_name, cache_dir=config.bert_cache_dir, tokenizer=tokenizer) 87 | 88 | if not model_name.startswith('roberta'): 89 | model.bert.resize_token_embeddings(len(tokenizer)) 90 | 91 | model.load_state_dict(checkpoint['model'], strict=True) 92 | 93 | if use_gpu: 94 | model.cuda(device=config.gpu_device) 95 | epoch = 1000 96 | # dev set 97 | progress = tqdm.tqdm(total=dev_batch_num, ncols=75, 98 | desc='Dev {}'.format(epoch)) 99 | 100 | dev_gold_outputs, dev_pred_outputs, dev_input_tokens, dev_doc_ids, dev_documents = [], [], [], [], [] 101 | 102 | for batch in DataLoader(dev_set, batch_size=config.eval_batch_size, 103 | shuffle=False, collate_fn=dev_set.collate_fn): 104 | progress.update(1) 105 | outputs = model.predict(batch, tokenizer,epoch=epoch) 106 | decoder_inputs_outputs = generate_decoder_inputs_outputs(batch, tokenizer, model, use_gpu, config.max_position_embeddings, task=config.task) 107 | dev_pred_outputs.extend(outputs['decoded_ids'].tolist()) 108 | dev_gold_outputs.extend(decoder_inputs_outputs['decoder_labels'].tolist()) 109 | dev_input_tokens.extend(batch.input_tokens) 110 | dev_doc_ids.extend(batch.doc_ids) 111 | dev_documents.extend(batch.document) 112 | progress.close() 113 | 114 | dev_result = { 115 | 'pred_outputs': dev_pred_outputs, 116 | 'gold_outputs': dev_gold_outputs, 117 | 'input_tokens': dev_input_tokens, 118 | 'doc_ids': dev_doc_ids, 119 | 'documents': dev_documents 120 | } 121 | with open(dev_result_file ,'w') as f: 122 | f.write(json.dumps(dev_result)) 123 | 124 | 125 | if config.task == EVENT_TEMPLATE_EXTRACTION: 126 | dev_scores = 0 127 | elif config.task == ROLE_FILLER_ENTITY_EXTRACTION: 128 | ree_preds = construct_outputs_for_ceaf(dev_pred_outputs, dev_input_tokens, dev_doc_ids, tokenizer) 129 | dev_scores = ree_eval.ree_eval(ree_preds, grit_dev) 130 | elif config.task == BINARY_RELATION_EXTRACTION: 131 | bre_preds = construct_outputs_for_scirex(dev_pred_outputs, dev_documents, dev_doc_ids, tokenizer, task=BINARY_RELATION_EXTRACTION) 132 | dev_scores = scirex_eval.scirex_eval(bre_preds, scirex_dev, cardinality=2) 133 | elif config.task == FOUR_ARY_RELATION_EXTRACTION: 134 | bre_preds = construct_outputs_for_scirex(dev_pred_outputs, dev_documents, dev_doc_ids, tokenizer, task=FOUR_ARY_RELATION_EXTRACTION) 135 | dev_scores = scirex_eval.scirex_eval(bre_preds, scirex_dev, cardinality=4) 136 | else: 137 | raise NotImplementedError 138 | save_model = False 139 | 140 | 141 | 142 | 143 | # test set 144 | progress = tqdm.tqdm(total=test_batch_num, ncols=75, 145 | desc='Test {}'.format(epoch)) 146 | test_gold_outputs, test_pred_outputs, test_input_tokens, test_doc_ids, test_documents = [], [], [], [], [] 147 | test_loss = 0 148 | 149 | for batch in DataLoader(test_set, batch_size=config.eval_batch_size, shuffle=False, 150 | collate_fn=test_set.collate_fn): 151 | progress.update(1) 152 | outputs = model.predict(batch, tokenizer, epoch=epoch) 153 | decoder_inputs_outputs = generate_decoder_inputs_outputs(batch, tokenizer, model, use_gpu, config.max_position_embeddings, task=config.task) 154 | 155 | test_pred_outputs.extend(outputs['decoded_ids'].tolist()) 156 | test_gold_outputs.extend(decoder_inputs_outputs['decoder_labels'].tolist()) 157 | test_input_tokens.extend(batch.input_tokens) 158 | test_doc_ids.extend(batch.doc_ids) 159 | test_documents.extend(batch.document) 160 | progress.close() 161 | 162 | 163 | # currently use negative dev loss as validation criteria 164 | if config.task == EVENT_TEMPLATE_EXTRACTION: 165 | # TODO: call the official evaluator 166 | test_scores = 0 167 | elif config.task == ROLE_FILLER_ENTITY_EXTRACTION: 168 | ree_preds = construct_outputs_for_ceaf(test_pred_outputs, test_input_tokens, test_doc_ids, tokenizer) 169 | test_scores = ree_eval.ree_eval(ree_preds, grit_test) 170 | elif config.task == BINARY_RELATION_EXTRACTION: 171 | bre_preds = construct_outputs_for_scirex(test_pred_outputs, test_documents, test_doc_ids, tokenizer, task=BINARY_RELATION_EXTRACTION) 172 | test_scores = scirex_eval.scirex_eval(bre_preds, scirex_test, cardinality=2) 173 | elif config.task == FOUR_ARY_RELATION_EXTRACTION: 174 | bre_preds = construct_outputs_for_scirex(test_pred_outputs, test_documents, test_doc_ids, tokenizer, task=FOUR_ARY_RELATION_EXTRACTION) 175 | test_scores = scirex_eval.scirex_eval(bre_preds, scirex_test, cardinality=4) 176 | else: 177 | raise NotImplementedError 178 | 179 | test_result = { 180 | 'pred_outputs': test_pred_outputs, 181 | 'gold_outputs': test_gold_outputs, 182 | 'input_tokens': test_input_tokens, 183 | 'doc_ids': test_doc_ids, 184 | 'documents': test_documents 185 | } 186 | with open(test_result_file,'w') as f: 187 | f.write(json.dumps(test_result)) 188 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from transformers import AutoModelForSeq2SeqLM 5 | from transformers import BeamSearchScorer, LogitsProcessorList 6 | 7 | from copy_bart import CopyBartForConditionalGeneration 8 | from sagcopy import SAGCopyBartForConditionalGeneration 9 | from constants import * 10 | 11 | 12 | 13 | class GenerativeModel(nn.Module): 14 | def __init__(self, 15 | config, 16 | vocabs): 17 | super().__init__() 18 | 19 | 20 | # vocabularies 21 | self.vocabs = vocabs 22 | 23 | 24 | # BERT encoder 25 | bert_config = config.bert_config 26 | bert_config.output_hidden_states = True 27 | self.bert_dim = bert_config.hidden_size 28 | self.extra_bert = config.extra_bert 29 | self.use_extra_bert = config.use_extra_bert 30 | if self.use_extra_bert: 31 | self.bert_dim *= 2 32 | self.bert_config = bert_config 33 | self.bert_dropout = nn.Dropout(p=config.bert_dropout) 34 | self.max_position_embeddings = config.max_position_embeddings 35 | self.num_beams = config.num_beams 36 | self.decoding_method = config.decoding_method 37 | self.SOT_weights = config.SOT_weights 38 | self.max_length = config.max_length 39 | self.use_copy = config.use_copy 40 | self.use_SAGCopy = config.use_SAGCopy 41 | self._k = config.k 42 | # TODO: may need to tune weight for padding token 43 | 44 | # self.decoder_criteria = torch.nn.CrossEntropyLoss() 45 | 46 | def load_bert(self, name, cache_dir=None, tokenizer=None): 47 | """Load the pre-trained LM (used in training phrase) 48 | :param name (str): pre-trained LM name 49 | :param cache_dir (str): path to the LM cache directory 50 | """ 51 | print('Loading pre-trained LM {}'.format(name)) 52 | 53 | 54 | if self.use_copy: 55 | self.bert = CopyBartForConditionalGeneration.from_pretrained(name, cache_dir=cache_dir, output_attentions=True) 56 | self.bert._k = self._k 57 | elif self.use_SAGCopy: 58 | self.bert = SAGCopyBartForConditionalGeneration.from_pretrained(name, cache_dir=cache_dir, output_attentions=True, output_hidden_states=True) 59 | else: 60 | self.bert = AutoModelForSeq2SeqLM.from_pretrained(name, cache_dir=cache_dir) 61 | 62 | 63 | def forward(self, batch, decoder_input_ids=None, decoder_labels=None, decoder_masks=None, logger=None, tag=None, step=None, tokenizer=None): 64 | 65 | res = {} 66 | # increase weight for 67 | vocab_size = len(tokenizer) 68 | 69 | weight = torch.ones(vocab_size).to(batch.input_ids.device) 70 | self.bert._loss_weight = weight 71 | self.bert._vocab_size = vocab_size 72 | 73 | if self.use_copy or self.use_SAGCopy: 74 | bart_outputs = self.encode(batch, decoder_input_ids=decoder_input_ids, decoder_labels=decoder_labels) 75 | else: 76 | bart_outputs = self.encode(batch, decoder_input_ids=decoder_input_ids) 77 | 78 | # if labels provided, assign loss 79 | if decoder_labels is not None: 80 | 81 | if self.use_copy or self.use_SAGCopy: 82 | weight[tokenizer.convert_tokens_to_ids(START_OF_TEMPLATE)] = self.SOT_weights 83 | loss = bart_outputs.loss 84 | else: 85 | weight[tokenizer.convert_tokens_to_ids(START_OF_TEMPLATE)] = self.SOT_weights 86 | # weight[tokenizer.eos_token_id] = 0.05 87 | loss = torch.nn.functional.cross_entropy(input=bart_outputs.logits.view(-1, vocab_size), target=decoder_labels.view(-1), weight=weight) 88 | 89 | res['loss'] = loss 90 | 91 | return res 92 | 93 | def encode(self, batch, decoder_input_ids=None, decoder_labels=None, decoder_masks=None): 94 | ''' 95 | Encode the input documents 96 | ''' 97 | 98 | return self.bert(input_ids=batch.input_ids, 99 | attention_mask=batch.attention_masks, #1 for tokens that are not masked, 0 for tokens that are masked. 100 | decoder_input_ids=decoder_input_ids, # For translation and summarization training, decoder_input_ids should be provided. If no decoder_input_ids is provided, the model will create this tensor by shifting the input_ids to the right for denoising pre-training following the paper. 101 | labels=decoder_labels, 102 | # decoder_attention_mask=decoder_masks, #Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. 103 | return_dict=True, 104 | output_hidden_states=True, 105 | 106 | ) 107 | 108 | def beam_search(self, batch, num_beams, decoding_length, decoder_token_masks=None): 109 | ''' 110 | Adapted from https://huggingface.co/transformers/main_classes/model.html?highlight=beamsearchscorer 111 | Do stardard beam search 112 | ''' 113 | beam_scorer = BeamSearchScorer( 114 | batch_size=batch.input_ids.size(0), 115 | max_length=decoding_length, 116 | num_beams=num_beams, 117 | device=self.bert.device, 118 | ) 119 | 120 | logits_processor = LogitsProcessorList([]) 121 | 122 | # seems that this is required if our model is a encoder-decoder architecture. 123 | model_kwargs = { 124 | "encoder_outputs": self.bert.get_encoder()(batch.input_ids.repeat_interleave(num_beams, dim=0), batch.attention_masks.repeat_interleave(num_beams, dim=0), return_dict=True), 125 | } 126 | # huggingface beamsearch workaround 127 | self.bert._cache_input_ids = batch.input_ids 128 | 129 | # create token for start decoding. 130 | decoder_input_ids = torch.ones((num_beams * batch.input_ids.size(0), 1), device=self.bert.device, dtype=torch.long) 131 | decoder_input_ids = decoder_input_ids * self.bert.config.decoder_start_token_id 132 | 133 | decoded_ids = self.bert.beam_search(decoder_input_ids, beam_scorer, max_length=decoding_length, logits_processor=logits_processor, **model_kwargs) 134 | 135 | return decoded_ids 136 | 137 | 138 | def predict(self, batch, tokenizer, epoch=None): 139 | self.eval() 140 | 141 | 142 | 143 | with torch.no_grad(): 144 | 145 | decoding_length = self.max_position_embeddings-1 146 | # when epoch < 4, the model generates trash 147 | if epoch is not None and epoch < 10: 148 | decoding_length = 10 149 | 150 | 151 | # only those token present in the input document and the special tokens can be decoded. 152 | # (batch, num_tokens) 153 | decoder_token_masks = torch.zeros(batch.input_ids.size(0), len(tokenizer) ,device=batch.input_ids.device, dtype=torch.bool) 154 | 155 | for batch_idx, input_ids in enumerate(batch.input_ids): 156 | decoder_token_masks[batch_idx, input_ids] = 1 157 | 158 | # TODO: these can be cached in the __init__ function so we don't need to do it repeatedly. 159 | decoder_token_masks[:, tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)] = 1 160 | decoder_token_masks[:, tokenizer.eos_token_id] = 1 161 | decoder_token_masks[:, tokenizer.bos_token_id] = 1 162 | for role in REE_ROLES: 163 | decoder_token_masks[:, tokenizer.encode(role, add_special_tokens=False)] = 1 164 | 165 | if self.decoding_method == 'greedy': 166 | # Adapted part of the code from https://huggingface.co/blog/encoder-decoder 167 | decoded_ids = torch.LongTensor([[self.bert.config.decoder_start_token_id] * len(batch.input_ids)]).to(batch.input_ids.device).reshape(-1,1) 168 | 169 | # pass input_ids to encoder and to decoder and pass BOS token to decoder to retrieve first logit 170 | bart_outputs = self.bert(batch.input_ids, attention_mask=batch.attention_masks, decoder_input_ids=decoded_ids, use_cache=True, return_dict=True) 171 | 172 | # encode encoder input_ids once 173 | encoded_sequence = (bart_outputs.encoder_last_hidden_state,) 174 | 175 | 176 | # get next token id and append it to decoded list 177 | lm_logits = bart_outputs.logits 178 | next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1) 179 | decoded_ids = torch.cat([decoded_ids, next_decoder_input_ids], axis=-1) 180 | 181 | # use past_key_values to speed up decoding 182 | past_key_values = bart_outputs.past_key_values 183 | 184 | 185 | # only those token present in the input document and the special tokens can be decoded. 186 | 187 | for i in range(decoding_length): 188 | 189 | bart_outputs = self.bert(batch.input_ids, encoder_outputs=encoded_sequence, past_key_values=past_key_values, decoder_input_ids=next_decoder_input_ids, use_cache=True, return_dict=True) 190 | lm_logits = bart_outputs.logits 191 | 192 | # TODO: this is incorrect, will implement in the future if necessary 193 | # lm_logits[:,-1] = lm_logits[:,-1] * decoder_token_masks 194 | past_key_values = bart_outputs.past_key_values 195 | 196 | # sample last token with highest prob again 197 | next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1) 198 | # concat again 199 | decoded_ids = torch.cat([decoded_ids, next_decoder_input_ids], axis=-1) 200 | 201 | if torch.all(next_decoder_input_ids == tokenizer.eos_token_id): 202 | break 203 | # decoded_ids = self.bert.generate(input_ids=batch.input_ids, attention_mask=batch.attention_masks, max_length=100) 204 | elif self.decoding_method == "beam_search": 205 | decoded_ids = self.beam_search(batch, num_beams=4, decoding_length=decoding_length, decoder_token_masks=decoder_token_masks) 206 | else: 207 | raise NotImplementedError 208 | res = { 209 | 'decoded_ids':decoded_ids 210 | } 211 | 212 | 213 | self.train() 214 | return res 215 | 216 | -------------------------------------------------------------------------------- /raw_scripts/README.md: -------------------------------------------------------------------------------- 1 | # Proc scripts for the raw MUC 2 | 3 | The pre-processing is composed of two parts: 4 | 5 | 1. Process dcuments: This step loops over the corpus and generate a list of json file. Each json corresponds to a document. The script is adpated from https://github.com/xinyadu/grit_doc_event_entity/tree/master/data/muc/raw_files/raw_scripts. 6 | 7 | ``` 8 | bash go_proc_doc.sh 9 | ``` 10 | 11 | 12 | 2. Process anntation: This step gather all the annotation and the documents processed in the first step to create a json file for each split: `train.json`, `dev.json` and `test.json`. 13 | 14 | ``` 15 | bash process_all_keys.sh 16 | ``` 17 | 18 | The json file generated has the following format: 19 | 20 | ``` 21 | { 22 | "doc_id":{ 23 | "document": "...", 24 | "annotation": [ 25 | { 26 | MESSAGE-TEMPLATE': '1', 27 | 'INCIDENT-DATE': '28 AUG 89', 28 | ... 29 | }, 30 | { 31 | 'MESSAGE-TEMPLATE': '2', 32 | 'INCIDENT-DATE': '- 30 AUG 89', 33 | ... 34 | } 35 | ] 36 | 37 | }, 38 | ... 39 | } 40 | 41 | ``` 42 | 43 | __Note that you need Python 2.7 for the first step and Python 3.6 for the second!!__ Will fix this issue when I have time. 44 | 45 | -------------------------------------------------------------------------------- /raw_scripts/go_proc_doc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # mkdir -p proc_output 3 | set -eu 4 | data_dir=../data/muc34 5 | output_dir=${data_dir}/proc_output 6 | train_output_path=${output_dir}/doc_train 7 | dev_output_path=${output_dir}/doc_dev 8 | test_output_path=${output_dir}/doc_test 9 | 10 | if [ ! -d ${output_dir} ]; then 11 | mkdir ${output_dir} 12 | fi 13 | 14 | # train 15 | cat ../data/muc34/TASK/CORPORA/dev/dev-* | python proc_texts.py > ${train_output_path} 16 | 17 | # dev 18 | cat ../data/muc34/TASK/CORPORA/tst1/tst1-muc3 | python proc_texts.py > ${dev_output_path} 19 | cat ../data/muc34/TASK/CORPORA/tst2/tst2-muc4 | python proc_texts.py >> ${dev_output_path} 20 | 21 | # test 22 | cat ../data/muc34/TASK/CORPORA/tst3/tst3-muc4 | python proc_texts.py > ${test_output_path} 23 | cat ../data/muc34/TASK/CORPORA/tst4/tst4-muc4 | python proc_texts.py >> ${test_output_path} 24 | 25 | -------------------------------------------------------------------------------- /raw_scripts/go_proc_keys.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # mkdir -p proc_output 3 | 4 | set -eu 5 | 6 | # function compose_html() { 7 | # tag=$1 8 | # table=$2 9 | # err=$3 10 | # ( 11 | # echo "

$tag keys

" 12 | # if [[ $(cat $err | wc -c) -gt 0 ]]; then 13 | # echo "

Warnings during processing

" 14 | # echo "
"
15 | #     cat $err
16 | #     echo "
" 17 | # fi 18 | # echo "

Keys: Left=original, Right=processed (JSON format)

" 19 | # cat $table 20 | # ) > proc_output/keys_${tag}.html 21 | # } 22 | 23 | 24 | # cat data/TASK/CORPORA/dev/key-dev-* | python scripts/proc_keys.py --format sidebyside 1>out.html 2>err.log 25 | # compose_html dev out.html err.log 26 | 27 | # cat data/TASK/CORPORA/testsets/key-tst* | python scripts/proc_keys.py --format sidebyside 1>out.html 2>err.log 28 | # compose_html tst out.html err.log 29 | 30 | # train 31 | cat ../muc34/TASK/CORPORA/dev/key-dev-* | python proc_keys.py > ../proc_output/keys_train 32 | 33 | # dev 34 | cat ../muc34/TASK/CORPORA/tst1/key-tst* | python proc_keys.py > ../proc_output/keys_dev 35 | cat ../muc34/TASK/CORPORA/tst2/key-tst* | python proc_keys.py >> ../proc_output/keys_dev 36 | 37 | # test 38 | cat ../muc34/TASK/CORPORA/tst3/key-tst* | python proc_keys.py > ../proc_output/keys_test 39 | cat ../muc34/TASK/CORPORA/tst4/key-tst* | python proc_keys.py >> ../proc_output/keys_test 40 | 41 | -------------------------------------------------------------------------------- /raw_scripts/proc_keys.py: -------------------------------------------------------------------------------- 1 | """ 2 | Process the crazy-ass MUC keyfile format into a reasonable JSON thing 3 | cat data/TASK/CORPORA/dev/key-dev-0* | python proc_keys.py 4 | 5 | for unit tests (http://pytest.org/) 6 | py.test proc_keys.py 7 | """ 8 | import sys,re 9 | import json 10 | from pprint import pprint 11 | 12 | def cleankey(keystr): 13 | return re.sub(r'[^A-Z]+', '_', keystr).strip('_').lower() 14 | 15 | def clean_docid(value): 16 | return re.sub(r'\s*\(.*$','', value) 17 | 18 | ALL_KEYS = """ 19 | MESSAGE: ID 20 | MESSAGE: TEMPLATE 21 | INCIDENT: DATE 22 | INCIDENT: LOCATION 23 | INCIDENT: TYPE 24 | INCIDENT: STAGE OF EXECUTION 25 | INCIDENT: INSTRUMENT ID 26 | INCIDENT: INSTRUMENT TYPE 27 | PERP: INCIDENT CATEGORY 28 | PERP: INDIVIDUAL ID 29 | PERP: ORGANIZATION ID 30 | PERP: ORGANIZATION CONFIDENCE 31 | PHYS TGT: ID 32 | PHYS TGT: TYPE 33 | PHYS TGT: NUMBER 34 | PHYS TGT: FOREIGN NATION 35 | PHYS TGT: EFFECT OF INCIDENT 36 | PHYS TGT: TOTAL NUMBER 37 | HUM TGT: NAME 38 | HUM TGT: DESCRIPTION 39 | HUM TGT: TYPE 40 | HUM TGT: NUMBER 41 | HUM TGT: FOREIGN NATION 42 | HUM TGT: EFFECT OF INCIDENT 43 | HUM TGT: TOTAL NUMBER 44 | """.strip().split('\n') 45 | 46 | ALL_KEYS = set(cleankey(k) for k in ALL_KEYS) 47 | 48 | KEY_WHITELIST = """ 49 | perp_individual_id 50 | perp_organization_id 51 | phys_tgt_id 52 | hum_tgt_name 53 | hum_tgt_description 54 | incident_instrument_id 55 | """.split() 56 | 57 | KEY_WHITELIST = set(KEY_WHITELIST) 58 | 59 | assert KEY_WHITELIST <= ALL_KEYS 60 | 61 | cur_docid = None 62 | def warning(s): 63 | global cur_docid 64 | print>>sys.stderr, "WARNING docid=%s | %s" % (cur_docid, s) 65 | 66 | def yield_keyvals(chunk): 67 | """ 68 | Processes the raw MUC "key file" format. Parses one entry ("chunk"). 69 | Yields a sequence of (key,value) pairs. 70 | A single key can be repeated many times. 71 | This function cleans up key names, but passes the values through as-is. 72 | """ 73 | curkey = None 74 | for line in chunk.split('\n'): 75 | if line.startswith(';'): 76 | yield 'comment', line 77 | continue 78 | middle = 33 ## Different in dev vs test files... this is the minimum size to get all keys. 79 | keytext = line[:middle].strip() 80 | valtext = line[middle:].strip() 81 | if not keytext: 82 | ## it's a continuation 83 | assert curkey 84 | else: 85 | curkey = cleankey(keytext) 86 | assert curkey in ALL_KEYS 87 | 88 | yield curkey, valtext 89 | 90 | def parse_values(keyvals): 91 | """ 92 | Takes key,value pairs as input, where the values are unparsed. 93 | Filter down to the slots we want, and parse their values as well. 94 | """ 95 | for key,value in keyvals: 96 | if key=='message_id': 97 | yield key, clean_docid(value) 98 | continue 99 | if key=='message_template': 100 | if re.search(r'^\d+$', value): 101 | yield key, int(value) 102 | elif value == '*': 103 | yield key, value 104 | elif re.search(r'^\d+ \(OPTIONAL\)$', value): 105 | yield key, int(value.split()[0]) 106 | yield 'message_template_optional', True 107 | else: 108 | assert False, "bad message_template format" 109 | continue 110 | 111 | # if key in KEY_WHITELIST: 112 | if True: # accecpt all values for now 113 | if value == '*': 114 | continue 115 | 116 | if value == '-': 117 | yield key, None 118 | continue 119 | 120 | if '"' not in value: 121 | warning("apparent data error, missing quotes. adding back in. value was ||| %s" % value) 122 | value = '"' + value + '"' 123 | 124 | value = parse_one_value(value) 125 | yield key,value 126 | 127 | def parse_one_value(namestr): 128 | """ 129 | Returns a dictionary with 'type' either 130 | 'simple_strings' ==> has a field 'strings' 131 | 'colon_clause' ==> has two fields 'strings_lhs' and 'strings_rhs' 132 | Furthermore, has 'optional':true if this valueline is optional, which I think means the entity is optional. 133 | (There is only one example of a colon clause having optional=true; I suspect it's an annotation error.) 134 | """ 135 | 136 | global cur_docid 137 | # Fix bugs in the data 138 | if cur_docid == "DEV-MUC3-0604" and "BODYGUARD OF EL ESPECTADOR" in namestr: 139 | # DEV-MUC3-0604 (MDESC) 140 | # ? ("BODYGUARD OF EL ESPECTADOR'S CHIEF OF DISTRIBUTION IN MEDELLIN" / "BODYGUARD"): "PEDRO LUIS OSORIO" 141 | namestr = '''? "BODYGUARD OF EL ESPECTADOR'S CHIEF OF DISTRIBUTION IN MEDELLIN" / "BODYGUARD" / "PEDRO LUIS OSORIO"''' 142 | if namestr == 'MACHINEGUNS"': 143 | # DEV-MUC3-0217 144 | namestr = '"' + namestr 145 | 146 | d = {} 147 | match = re.search(r'\? *(.*)', namestr) 148 | if match: 149 | d['optional'] = True 150 | namestr = match.group(1) 151 | 152 | if ':' in namestr: 153 | 154 | print>>sys.stderr, namestr 155 | assert len(re.findall(':', namestr))==1 156 | lhs,rhs = re.split(r' *: *', namestr) 157 | lhs_value = parse_strings_possibly_with_alternations(lhs) 158 | rhs_value = parse_strings_possibly_with_alternations(rhs) 159 | d.update({'type':'colon_clause', 'strings_lhs': lhs_value, 'strings_rhs': rhs_value}) 160 | return d 161 | 162 | else: 163 | strings = parse_strings_possibly_with_alternations(namestr) 164 | d.update({'type':'simple_strings', 'strings': strings}) 165 | return d 166 | 167 | def parse_strings_possibly_with_alternations(namestr): 168 | namestr = namestr.strip() 169 | assert ':' not in namestr, namestr 170 | assert not namestr.startswith('?') 171 | parts = re.split(' */ *', namestr) 172 | parts = [ss.strip() for ss in parts] 173 | strings = [] 174 | for ss in parts: 175 | if ss == '-': 176 | # We should see this only inside a colon clause. There are a few of these, e.g. 177 | # 21. HUM TGT: NUMBER -: "ORLANDO LETELIER" 178 | strings.append(None) 179 | continue 180 | if not (ss[0]=='"' and ss[-1]=='"'): 181 | warning("WTF ||| " + ss) 182 | ss = ss[1:-1] 183 | ss = ss.decode('string_escape') # They seem to use C-style backslash escaping 184 | ss = ss.strip() 185 | strings.append(ss) 186 | return strings 187 | 188 | def test_parsestrings(): 189 | f = parse_strings_possibly_with_alternations 190 | s = '"CAR DEALERSHIP"' 191 | assert set(f(s)) == {"CAR DEALERSHIP"} 192 | s = '"TUPAC AMARU REVOLUTIONARY MOVEMENT" / "MRTA"' 193 | assert set(f(s)) == {"TUPAC AMARU REVOLUTIONARY MOVEMENT","MRTA"} 194 | 195 | def test_parse_one_value(): 196 | s = '"U.S. JOURNALIST": "BERNARDETTE PARDO"' 197 | d = parse_one_value(s) 198 | assert d['strings_lhs'] == ["U.S. JOURNALIST"] 199 | assert d['strings_rhs'] == ["BERNARDETTE PARDO"] 200 | 201 | def fancy_json_print(keyvals): 202 | lines = [ json.dumps(kv, sort_keys=True) for kv in keyvals ] 203 | s = "" 204 | s += "[\n " 205 | s += ",\n ".join(lines) 206 | s += "\n]" 207 | return s 208 | 209 | if __name__=='__main__': 210 | 211 | import argparse 212 | p = argparse.ArgumentParser() 213 | p.add_argument('--format', default='jsonpp', choices=['jsonpp','sidebyside']) 214 | args = p.parse_args() 215 | 216 | if args.format=='sidebyside': 217 | print """ 218 | 219 | """ 220 | print "" 221 | 222 | 223 | data = sys.stdin.read() 224 | lines = data.split('\n') 225 | lines = [L for L in lines if not re.search(r'^\s*;', L)] ## comments 226 | data = '\n'.join(lines) 227 | chunks = re.split(r'\n\n+|\n(?=0\. )', data) 228 | chunks = [c.strip() for c in chunks if c.strip()] 229 | 230 | for chunk in chunks: 231 | #print "==="; print chunk 232 | 233 | keyvals1 = list(yield_keyvals(chunk)) 234 | assert all(k in ALL_KEYS or k=='comment' for k,v in keyvals1) 235 | cur_docid = clean_docid(dict(keyvals1)['message_id']) 236 | #print "===", cur_docid 237 | #print "--- raw key/val pairs"; pprint(keyvals1); print 238 | 239 | keyvals2 = list(parse_values(keyvals1)) 240 | #print "--- parsed values"; pprint(keyvals2); print 241 | 242 | if args.format == 'jsonpp': 243 | print "%%%" 244 | print fancy_json_print(keyvals2) 245 | 246 | elif args.format == 'sidebyside': 247 | print "
{chunk}
{json}
".format(chunk=chunk, json=fancy_json_print(keyvals2)) 248 | 249 | else: assert False 250 | 251 | -------------------------------------------------------------------------------- /raw_scripts/proc_texts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split and clean up MUC article text files into JSON document objects 3 | """ 4 | import sys, re, json 5 | 6 | data = sys.stdin.read() 7 | matches = list(re.finditer(r'(DEV-\S+) *\(([^\)]*)\)', data)) 8 | has_source = bool(matches) 9 | if not matches: 10 | matches = list(re.finditer(r'(TST\d+-\S+)', data)) 11 | #print matches 12 | 13 | doc_infos = [] 14 | 15 | for match in matches: 16 | docid = match.group(1) 17 | d = {'docid':docid, 'char_start':match.end(), 'char_before':match.start()} 18 | if has_source: 19 | d['source'] = match.group(2) 20 | doc_infos.append(d) 21 | 22 | for i in range(len(doc_infos)-1): 23 | doc_infos[i]['char_end'] = doc_infos[i+1]['char_before'] 24 | doc_infos[-1]['char_end'] = len(data) 25 | 26 | #from pprint import pprint 27 | #pprint(doc_infos[:5]) 28 | 29 | for d in doc_infos: 30 | raw_text = data[d['char_start']:d['char_end']].strip() 31 | 32 | # issue: there are sometimes recursive (multiple?) datelines. we only get the first in that case. 33 | 34 | tag_re = r'\[[^\]]+\]' 35 | tags_re= '(?:%s\s+)+' % tag_re 36 | full_re = r'^(.*?)--\s+(%s)(.*)' % tags_re 37 | m = re.search(full_re, raw_text, re.DOTALL) 38 | if not m: 39 | print raw_text[:1000] 40 | assert False 41 | 42 | dateline = m.group(1).replace('\n',' ').strip() 43 | tags = m.group(2).replace('\n',' ') 44 | text = m.group(3) 45 | 46 | #print dateline 47 | #print tags 48 | #print text[:50].replace('\n',' ') 49 | #print raw_text[:500] 50 | 51 | assert tags.upper() == tags 52 | tags = re.findall(tag_re, tags) 53 | tags = [x.lstrip('[').rstrip(']').lower() for x in tags] 54 | 55 | d['dateline'] = dateline 56 | d['tags'] = tags 57 | 58 | text = text.strip() 59 | text = text.replace('[','(').replace(']',')') ## should be easier for WSJPTB parsers, right...? 60 | #print text 61 | 62 | d['text'] = text 63 | 64 | print json.dumps(d) 65 | 66 | 67 | -------------------------------------------------------------------------------- /raw_scripts/process_all_keys.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from glob import glob 4 | import re 5 | from collections import defaultdict 6 | 7 | def clean_docid(value): 8 | return re.sub(r'\s*\(.*$','', value) 9 | 10 | def cleankey(key): 11 | return re.sub(r'[^A-Z|\s]+', '', key).strip().replace(' ','-') 12 | 13 | ALL_KEYS = """ 14 | MESSAGE: ID 15 | MESSAGE: TEMPLATE 16 | INCIDENT: DATE 17 | INCIDENT: LOCATION 18 | INCIDENT: TYPE 19 | INCIDENT: STAGE OF EXECUTION 20 | INCIDENT: INSTRUMENT ID 21 | INCIDENT: INSTRUMENT TYPE 22 | PERP: INCIDENT CATEGORY 23 | PERP: INDIVIDUAL ID 24 | PERP: ORGANIZATION ID 25 | PERP: ORGANIZATION CONFIDENCE 26 | PHYS TGT: ID 27 | PHYS TGT: TYPE 28 | PHYS TGT: NUMBER 29 | PHYS TGT: FOREIGN NATION 30 | PHYS TGT: EFFECT OF INCIDENT 31 | PHYS TGT: TOTAL NUMBER 32 | HUM TGT: NAME 33 | HUM TGT: DESCRIPTION 34 | HUM TGT: TYPE 35 | HUM TGT: NUMBER 36 | HUM TGT: FOREIGN NATION 37 | HUM TGT: EFFECT OF INCIDENT 38 | HUM TGT: TOTAL NUMBER 39 | """.strip().split('\n') 40 | 41 | ALL_KEYS = set(cleankey(key) for key in ALL_KEYS) 42 | SINGLE_VALUE_KEYS = set(cleankey(key) for key in ['MESSAGE: ID','MESSAGE: TEMPLATE','INCIDENT: DATE','INCIDENT: LOCATION','INCIDENT: TYPE','INCIDENT: STAGE OF EXECUTION','PERP: INCIDENT CATEGORY','PHYS TGT: TOTAL NUMBER']) 43 | 44 | def gather_key_vals(chunk): 45 | """ 46 | Processes the raw MUC "key file" format. Parses one entry ("chunk"). 47 | Returns a dictionary of key-values. 48 | A single key can be repeated many times. 49 | This function cleans up key names, but passes the values through as-is. 50 | 51 | """ 52 | res = defaultdict(list) 53 | curkey = None 54 | for line in chunk.split('\n'): 55 | if line.startswith(';'): 56 | 57 | continue 58 | middle = 33 ## Different in dev vs test files... this is the minimum size to get all keys. 59 | keytext = line[:middle].strip() 60 | valtext = line[middle:].strip() 61 | if not keytext: 62 | ## it's a continuation 63 | assert curkey 64 | else: 65 | curkey = cleankey(keytext) 66 | assert curkey in ALL_KEYS, (curkey, line) 67 | 68 | # if it's message_id then clean value 69 | if curkey == cleankey('MESSAGE: ID'): 70 | valtext = clean_docid(valtext) 71 | 72 | 73 | # elif curkey == cleankey('MESSAGE: TEMPLATE') and '(OPTIONAL)' in valtext: 74 | # valtext = valtext.replace('(OPTIONAL)','').strip(' ') 75 | 76 | # do not append empty vals 77 | if valtext not in ['*','-']: 78 | if curkey in SINGLE_VALUE_KEYS: 79 | res[curkey] = valtext 80 | else: 81 | res[curkey].append([val.strip() for val in valtext.split('/')]) 82 | 83 | return res 84 | 85 | def combine(new_dict, all_dict): 86 | doc_id = new_dict[cleankey('MESSAGE: ID')] 87 | # remove id from message 88 | new_dict.pop(cleankey('MESSAGE: ID')) 89 | # nothing in the dictionary 90 | if len(new_dict) == 0: 91 | # put empty list 92 | all_dict[doc_id] = [] 93 | else: 94 | all_dict[doc_id].append(new_dict) 95 | return all_dict 96 | 97 | def parse_key_file(key_file): 98 | with open(key_file, 'r') as f: 99 | key_lines = f.readlines() 100 | 101 | key_lines = [L.strip('\n') for L in key_lines if not re.search(r'^\s*;', L)] ## comments 102 | 103 | data = '\n'.join(key_lines) 104 | 105 | # each chunk corresponds to a template annotation 106 | chunks = re.split(r'\n\n+|\n(?=0\. )', data) 107 | chunks = [c.strip() for c in chunks if c.strip()] 108 | 109 | # key by doc_id 110 | all_chunk_dict = defaultdict(list) 111 | 112 | 113 | for chunk in chunks: 114 | keyvals = gather_key_vals(chunk) 115 | all_chunk_dict = combine(keyvals, all_chunk_dict) 116 | 117 | return all_chunk_dict 118 | 119 | def add_documents(all_key_file_dict, input_corpus): 120 | ''' 121 | Attach document into each annotation. 122 | ''' 123 | 124 | res = {} 125 | 126 | with open(input_corpus, 'r') as f: 127 | documents = [json.loads(l) for l in f.readlines()] 128 | 129 | for doc_id, annotation in all_key_file_dict.items(): 130 | 131 | input_document = [ '\n'.join([document['dateline']]+ document['tags'] + [document['text']]) for document in documents if document['docid'] == doc_id][0] 132 | 133 | res[doc_id] = { 134 | 'document': input_document, 135 | 'annotation': annotation 136 | } 137 | 138 | return res 139 | 140 | if __name__ == '__main__': 141 | p = argparse.ArgumentParser() 142 | p.add_argument('--input_corpus', type=str, help="corpus containing MUC documents represented in a list of json") 143 | p.add_argument('--input_pattern', type=str, help="input patters that will be passed to glob to fetch file names.") 144 | p.add_argument('--output_path',type=str, help="path to store the output json file.") 145 | args = p.parse_args() 146 | 147 | key_files = glob(f"{args.input_pattern}*") 148 | all_key_file_dict = {} 149 | for key_file in key_files: 150 | one_key_file_dict = parse_key_file(key_file) 151 | all_key_file_dict.update(one_key_file_dict) 152 | # attach document 153 | all_key_file_dict = add_documents(all_key_file_dict, args.input_corpus) 154 | 155 | with open(args.output_path, 'w') as f: 156 | f.write(json.dumps(all_key_file_dict)) -------------------------------------------------------------------------------- /raw_scripts/process_all_keys.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # mkdir -p proc_output 3 | 4 | set -eu 5 | 6 | data_dir=../data/muc34 7 | output_dir=${data_dir}/proc_output 8 | train_output_path=${output_dir}/train.json 9 | dev_output_path=${output_dir}/dev.json 10 | test_output_path=${output_dir}/test.json 11 | 12 | if [ ! -d ${output_dir} ]; then 13 | mkdir ${output_dir} 14 | fi 15 | 16 | 17 | python process_all_keys.py --input_corpus ${output_dir}/doc_train --input_pattern ../data/muc34/TASK/CORPORA/dev/key-dev- --output_path ${train_output_path} 18 | python process_all_keys.py --input_corpus ${output_dir}/doc_dev --input_pattern ../data/muc34/TASK/CORPORA/tst[12]/key-tst --output_path ${dev_output_path} 19 | python process_all_keys.py --input_corpus ${output_dir}/doc_test --input_pattern ../data/muc34/TASK/CORPORA/tst[34]/key-tst --output_path ${test_output_path} 20 | -------------------------------------------------------------------------------- /raw_scripts/process_scirex.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | import argparse 3 | import json 4 | 5 | from itertools import combinations 6 | 7 | # copied over from https://github.com/allenai/SciREX/blob/master/scirex_utilities/entity_utils.py 8 | available_entity_types_sciERC = ["Material", "Metric", "Task", "Generic", "OtherScientificTerm", "Method"] 9 | map_available_entity_to_true = {"Material": "dataset", "Metric": "metric", "Task": "task", "Method": "model_name"} 10 | map_true_entity_to_available = {v: k for k, v in map_available_entity_to_true.items()} 11 | used_entities = list(map_available_entity_to_true.keys()) 12 | 13 | def generate_relations(doc: Dict[str, Any], cardinality) -> List[Dict]: 14 | ''' 15 | Break down 4-ary relations into binary relations. 16 | ''' 17 | def get_mentions(clusters, doc_tokens, entity_name): 18 | res = [] 19 | cluster = clusters[entity_name] 20 | # TODO: set does not preserve order so currently I use for loop to keep the order of mentions while removing duplicates. 21 | for start_tok, end_tok in cluster: 22 | mention = ' '.join(doc_tokens[start_tok:end_tok]) 23 | if mention not in res: 24 | res.append(mention) 25 | return res 26 | 27 | res = [] 28 | for types in combinations(used_entities, cardinality): 29 | relations = [tuple((t, x[t]) for t in types) for x in doc['n_ary_relations']] 30 | 31 | # make sure each entity has at least one cluster and make (entity_1, entity_2, relation) unique 32 | relations = set([x for x in relations if has_all_mentions(doc, x)]) 33 | 34 | 35 | for relation in relations: 36 | current_relation_dict = {} 37 | for entity in relation: 38 | entity_type, entity_name = entity 39 | entity_mentions = get_mentions(doc['coref'], doc['words'], entity_name) 40 | current_relation_dict[entity_type] = [entity_mentions] # we need to make it a list of list to comply with the convention in data.py 41 | 42 | res.append(current_relation_dict) 43 | return res 44 | 45 | 46 | def has_all_mentions(doc: Dict[str, Any], relation): 47 | ''' 48 | 49 | Make sure each entity has at least one mention. 50 | ''' 51 | has_mentions = all(len(doc["coref"][x[1]]) > 0 for x in relation) 52 | return has_mentions 53 | 54 | 55 | def tokens_to_string(tokens): 56 | return ' '.join(tokens) 57 | 58 | def process_document(doc: Dict[str, Any], cardinality: int) -> Dict[str, Any]: 59 | 60 | assert cardinality in {2, 4}, "Only support binary and 4-ary relations" 61 | 62 | relations = generate_relations(doc, cardinality) 63 | 64 | 65 | 66 | doctext = tokens_to_string(doc['words']) 67 | return { 68 | 'doc_id': doc['doc_id'], 69 | 'document': doctext, 70 | 'annotation': relations, 71 | 72 | } 73 | 74 | def process_file(input_path: str, output_path: str, cardinality:int): 75 | 76 | with open(input_path, 'r') as f: 77 | input_data = [json.loads(l) for l in f.readlines()] 78 | 79 | processed_docs = {} 80 | for input_doc in input_data: 81 | processed_data = process_document(input_doc, cardinality) 82 | doc_id = processed_data.pop('doc_id') 83 | 84 | processed_docs[doc_id] = processed_data 85 | 86 | # store in json format 87 | with open(output_path, 'w') as f: 88 | json.dump(processed_docs, f) 89 | 90 | 91 | if __name__ == '__main__': 92 | p = argparse.ArgumentParser() 93 | 94 | p.add_argument('--input_path', type=str) 95 | p.add_argument('--output_path',type=str) 96 | p.add_argument('--cardinality',type=int) 97 | args = p.parse_args() 98 | 99 | process_file(args.input_path, args.output_path, args.cardinality) -------------------------------------------------------------------------------- /raw_scripts/process_scirex.sh: -------------------------------------------------------------------------------- 1 | data_dir=../data/scirex 2 | raw_dir=${data_dir}/release_data 3 | output_dir=${data_dir}/proc_output 4 | train_output_path=${output_dir}/train.json 5 | dev_output_path=${output_dir}/dev.json 6 | test_output_path=${output_dir}/test.json 7 | 8 | if [ ! -d ${output_dir} ]; then 9 | mkdir ${output_dir} 10 | fi 11 | 12 | 13 | python process_scirex.py --input_path $raw_dir/train.jsonl --output_path $train_output_path --cardinality 2 14 | python process_scirex.py --input_path $raw_dir/dev.jsonl --output_path $dev_output_path --cardinality 2 15 | python process_scirex.py --input_path $raw_dir/test.jsonl --output_path $test_output_path --cardinality 2 16 | 17 | python process_scirex.py --input_path $raw_dir/train.jsonl --output_path ${output_dir}/train_4ary.json --cardinality 4 18 | python process_scirex.py --input_path $raw_dir/dev.jsonl --output_path ${output_dir}/dev_4ary.json --cardinality 4 19 | python process_scirex.py --input_path $raw_dir/test.jsonl --output_path ${output_dir}/test_4ary.json --cardinality 4 -------------------------------------------------------------------------------- /ree_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import string 4 | import json 5 | import argparse 6 | from scipy.optimize import linear_sum_assignment # https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.optimize.linear_sum_assignment.html 7 | from collections import OrderedDict 8 | tag2role = OrderedDict({'perp_individual_id': "PerpInd", 'perp_organization_id': "PerpOrg", 'phys_tgt_id': "Target", 'hum_tgt_name': "Victim", 'incident_instrument_id': "Weapon"}) 9 | 10 | 11 | 12 | def f1(p_num, p_den, r_num, r_den, beta=1): 13 | p = 0 if p_den == 0 else p_num / float(p_den) 14 | r = 0 if r_den == 0 else r_num / float(r_den) 15 | return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) 16 | 17 | 18 | def phi_strict(c1, c2): 19 | # similarity: if c2 (pred) is subset of c1 (gold) return 1 20 | for m in c2: 21 | if m not in c1: 22 | return 0 23 | return 1 24 | 25 | 26 | def phi_prop(c1, c2): 27 | # # similarity: len(overlap of c2 (pred) and c1 (gold)) / len(c2) 28 | return len([m for m in c1 if m in c2]) / len(c2) 29 | 30 | def ceaf(clusters, gold_clusters, phi_similarity): 31 | # !!! need to comment the next line, the conll-2012 eval ignore singletons 32 | # clusters = [c for c in clusters if len(c) != 1] 33 | scores = np.zeros((len(gold_clusters), len(clusters))) 34 | for i in range(len(gold_clusters)): 35 | for j in range(len(clusters)): 36 | scores[i, j] = phi_similarity(gold_clusters[i], clusters[j]) 37 | # matching = linear_assignment(-scores) # [deprecated] linear_assignment from sklearn 38 | # similarity = sum(scores[matching[:, 0], matching[:, 1]]) 39 | row_ind, col_ind = linear_sum_assignment(-scores) 40 | similarity = sum(scores[row_ind, col_ind]) 41 | return similarity, len(clusters), similarity, len(gold_clusters) 42 | 43 | 44 | def eval_ceaf_base(preds, golds, phi_similarity, docids=[]): 45 | result = OrderedDict() 46 | all_keys = list(role for _, role in tag2role.items()) + ["micro_avg"] 47 | for key in all_keys: 48 | result[key] = {"p_num": 0, "p_den": 0, "r_num": 0, "r_den": 0, "p": 0, "r": 0, "f1": 0} 49 | 50 | if not docids: 51 | for docid in golds: 52 | docids.append(docid) 53 | 54 | for docid in docids: 55 | pred = preds[docid] 56 | gold = golds[docid] 57 | 58 | for role in gold: 59 | pred_clusters = [] 60 | gold_clusters = [] 61 | for entity in gold[role]: 62 | gold_c = [] 63 | for mention in entity: 64 | gold_c.append(mention) 65 | gold_clusters.append(gold_c) 66 | 67 | for entity in pred[role]: 68 | pred_c = [] 69 | for mention in entity: 70 | pred_c.append(mention) 71 | pred_clusters.append(pred_c) 72 | 73 | pn, pd, rn, rd = ceaf(pred_clusters, gold_clusters, phi_similarity) 74 | result[role]["p_num"] += pn 75 | result[role]["p_den"] += pd 76 | result[role]["r_num"] += rn 77 | result[role]["r_den"] += rd 78 | 79 | result["micro_avg"]["p_num"] = sum(result[role]["p_num"] for _, role in tag2role.items()) 80 | result["micro_avg"]["p_den"] = sum(result[role]["p_den"] for _, role in tag2role.items()) 81 | result["micro_avg"]["r_num"] = sum(result[role]["r_num"] for _, role in tag2role.items()) 82 | result["micro_avg"]["r_den"] = sum(result[role]["r_den"] for _, role in tag2role.items()) 83 | 84 | 85 | for key in all_keys: 86 | result[key]["p"] = 0 if result[key]["p_num"] == 0 else result[key]["p_num"] / float(result[key]["p_den"]) 87 | result[key]["r"] = 0 if result[key]["r_num"] == 0 else result[key]["r_num"] / float(result[key]["r_den"]) 88 | result[key]["f1"] = f1(result[key]["p_num"], result[key]["p_den"], result[key]["r_num"], result[key]["r_den"]) 89 | 90 | return result 91 | 92 | 93 | def normalize_string(s): 94 | """Lower text and remove punctuation, articles and extra whitespace.""" 95 | def remove_articles(text): 96 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 97 | return re.sub(regex, ' ', text) 98 | def white_space_fix(text): 99 | return ' '.join(text.split()) 100 | def remove_punc(text): 101 | exclude = set(string.punctuation) 102 | return ''.join(ch for ch in text if ch not in exclude) 103 | def lower(text): 104 | return text.lower() 105 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 106 | 107 | def eval_ceaf(preds, golds, docids=[]): 108 | # normalization mention strings 109 | for docid in preds: 110 | for role in preds[docid]: 111 | for idx in range(len(preds[docid][role])): 112 | for idy in range(len(preds[docid][role][idx])): 113 | preds[docid][role][idx][idy] = normalize_string(preds[docid][role][idx][idy]) 114 | for docid in golds: 115 | for role in golds[docid]: 116 | for idx in range(len(golds[docid][role])): 117 | for idy in range(len(golds[docid][role][idx])): 118 | golds[docid][role][idx][idy] = normalize_string(golds[docid][role][idx][idy]) 119 | 120 | results_strict = eval_ceaf_base(preds, golds, phi_strict, docids) 121 | results_prop = eval_ceaf_base(preds, golds, phi_prop, docids) 122 | 123 | final_results = OrderedDict() 124 | final_results["strict"] = results_strict 125 | final_results["prop"] = results_prop 126 | 127 | return final_results 128 | 129 | def ree_eval(preds: dict, golds: dict): 130 | docids = [] 131 | results = eval_ceaf(preds, golds, docids) 132 | all_keys = list(role for _, role in tag2role.items()) + ["micro_avg"] 133 | str_print = [] 134 | for key in all_keys: 135 | if key == "micro_avg": 136 | print("***************** {} *****************".format(key)) 137 | else: 138 | print("================= {} =================".format(key)) 139 | 140 | str_print += [results["strict"][key]["p"] * 100, results["strict"][key]["r"] * 100, results["strict"][key]["f1"] * 100] 141 | print("P: {:.2f}%, R: {:.2f}%, F1: {:.2f}%".format(results["strict"][key]["p"] * 100, results["strict"][key]["r"] * 100, results["strict"][key]["f1"] * 100)) # phi_strict 142 | # print("phi_prop: P: {:.2f}%, R: {:.2f}%, F1: {:.2f}%".format(results["prop"][key]["p"] * 100, results["prop"][key]["r"] * 100, results["prop"][key]["f1"] * 100)) 143 | print() 144 | str_print= ["{:.2f}".format(r) for r in str_print] 145 | print("print: {}".format(" ".join(str_print))) 146 | 147 | return results['strict'] #['micro_avg']['f1'] 148 | 149 | 150 | if __name__ == "__main__": 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("--pred_file", default=None, type=str, required=False, help="preds output file") 153 | parser.add_argument("--gold_file", default="./data/muc/processed/test.json", type=str, required=False, help="gold file") 154 | args = parser.parse_args() 155 | 156 | ## get pred and gold extracts 157 | preds = OrderedDict() 158 | golds = OrderedDict() 159 | with open(args.pred_file, encoding="utf-8") as f: 160 | out_dict = json.load(f) 161 | for docid in out_dict: 162 | preds[docid] = out_dict[docid]["pred_extracts"] 163 | 164 | with open(args.gold_file, encoding="utf-8") as f: 165 | for line in f: 166 | line = json.loads(line) 167 | docid = str(int(line["docid"].split("-")[0][-1])*10000 + int(line["docid"].split("-")[-1])) 168 | 169 | extracts_raw = line["extracts"] 170 | 171 | extracts = OrderedDict() 172 | for role, entitys_raw in extracts_raw.items(): 173 | extracts[role] = [] 174 | for entity_raw in entitys_raw: 175 | entity = [] 176 | for mention_offset_pair in entity_raw: 177 | entity.append(mention_offset_pair[0]) 178 | if entity: 179 | extracts[role].append(entity) 180 | golds[docid] = extracts 181 | 182 | # import ipdb; ipdb.set_trace() 183 | docids = [] 184 | results = eval_ceaf(preds, golds, docids) 185 | all_keys = list(role for _, role in tag2role.items()) + ["micro_avg"] 186 | str_print = [] 187 | for key in all_keys: 188 | if key == "micro_avg": 189 | print("***************** {} *****************".format(key)) 190 | else: 191 | print("================= {} =================".format(key)) 192 | 193 | str_print += [results["strict"][key]["p"] * 100, results["strict"][key]["r"] * 100, results["strict"][key]["f1"] * 100] 194 | print("P: {:.2f}%, R: {:.2f}%, F1: {:.2f}%".format(results["strict"][key]["p"] * 100, results["strict"][key]["r"] * 100, results["strict"][key]["f1"] * 100)) # phi_strict 195 | # print("phi_prop: P: {:.2f}%, R: {:.2f}%, F1: {:.2f}%".format(results["prop"][key]["p"] * 100, results["prop"][key]["r"] * 100, results["prop"][key]["f1"] * 100)) 196 | print() 197 | str_print= ["{:.2f}".format(r) for r in str_print] 198 | print("print: {}".format(" ".join(str_print))) 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | transformers==4.3.0 3 | pandas==1.2.3 4 | hydra-core==1.0.6 5 | omegaconf==2.0.6 6 | scipy==1.6.1 7 | tabulate==0.8.9 8 | sentencepiece==0.1.95 9 | nltk==3.5 -------------------------------------------------------------------------------- /sagcopy.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied over from https://huggingface.co/transformers/_modules/transformers/models/bart/modeling_bart.html 3 | ''' 4 | from typing import Optional, Tuple 5 | import random 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from transformers import BartForConditionalGeneration, BartModel, BartConfig 10 | from transformers.models.bart.modeling_bart import BartDecoder, BartDecoderLayer, BartAttention, _expand_mask 11 | from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqModelOutput, BaseModelOutput 12 | class SAGCopyBartForConditionalGeneration(BartForConditionalGeneration): 13 | 14 | def __init__(self, config: BartConfig): 15 | super().__init__(config) 16 | self.model = SAGCopyBartModel(config) 17 | 18 | def forward( 19 | self, 20 | input_ids=None, 21 | attention_mask=None, 22 | decoder_input_ids=None, 23 | decoder_attention_mask=None, 24 | head_mask=None, 25 | decoder_head_mask=None, 26 | encoder_outputs=None, 27 | past_key_values=None, 28 | inputs_embeds=None, 29 | decoder_inputs_embeds=None, 30 | labels=None, 31 | use_cache=None, 32 | output_attentions=None, 33 | output_hidden_states=None, 34 | return_dict=None, 35 | ): 36 | r""" 37 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 38 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., 39 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored 40 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. 41 | 42 | Returns: 43 | """ 44 | assert output_attentions or self.model.config.output_attentions, "output_attentions must be true" 45 | 46 | # original outputs 47 | outputs = self.model(input_ids, 48 | attention_mask=attention_mask, 49 | decoder_input_ids=decoder_input_ids, 50 | encoder_outputs=encoder_outputs, 51 | decoder_attention_mask=decoder_attention_mask, 52 | head_mask=head_mask, 53 | decoder_head_mask=decoder_head_mask, 54 | past_key_values=past_key_values, 55 | inputs_embeds=inputs_embeds, 56 | decoder_inputs_embeds=decoder_inputs_embeds, 57 | use_cache=use_cache, 58 | output_attentions=output_attentions, 59 | output_hidden_states=output_hidden_states, 60 | return_dict=return_dict,) 61 | 62 | if input_ids is None: 63 | input_ids = self._cache_input_ids 64 | 65 | encoder_last_hidden_state = outputs.encoder_last_hidden_state # (batch, seq, hidden) 66 | decoder_last_hidden_state = outputs[0] #(batch, decoding_seq, hidden ) 67 | 68 | 69 | # compute lm logits based on attention 70 | last_cross_attentions = outputs.cross_attentions[-1] # (batch_size, num_heads, decoding_seq_length, encoding_seq_length). 71 | 72 | 73 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias #(batch_size, decoding_seq_length, emb_dim) 74 | cross_attentions_aggregate = last_cross_attentions.mean(dim=1) #(batch_size, decoding_seq_length, encoding_seq_length) 75 | dummy_input_ids = input_ids.unsqueeze(-1).expand(-1, -1, lm_logits.size(1)).transpose(1,2) # (batch, decoding_seq_length, encoding_seq_length) 76 | copy_logits = torch.zeros_like(lm_logits) # (batch, decoding_seq_length, emb_dim) 77 | copy_logits.scatter_add_(dim=2, index=dummy_input_ids, src=cross_attentions_aggregate) 78 | 79 | 80 | p_gen = torch.bmm(decoder_last_hidden_state, encoder_last_hidden_state.mean(dim=1).unsqueeze(dim=-1)) # (batch, decoding_seq, 1) 81 | p_gen = torch.sigmoid(p_gen) 82 | 83 | 84 | lm_logits = F.softmax(lm_logits, dim=-1) * p_gen + copy_logits * (1 - p_gen)#(batch_size, decoding_seq_length, emb_dim) 85 | 86 | 87 | 88 | 89 | masked_lm_loss = None 90 | if labels is not None: 91 | # compute loss mask and fill -100 with 0 92 | loss_mask = labels != -100 93 | labels.masked_fill_(~loss_mask, 0) 94 | # use negative log likelihood 95 | gold_probs = torch.gather(lm_logits, 2, labels.unsqueeze(2)).squeeze(2) 96 | eps = 1e-7 # for safe log 97 | masked_lm_loss = - torch.log(gold_probs + eps) * self._loss_weight[labels] 98 | masked_lm_loss = (masked_lm_loss * loss_mask).mean() 99 | 100 | 101 | if not return_dict: 102 | output = (lm_logits,) + outputs[1:] 103 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 104 | 105 | 106 | return Seq2SeqLMOutput( 107 | loss=masked_lm_loss, 108 | logits=lm_logits, 109 | past_key_values=outputs.past_key_values, 110 | decoder_hidden_states=outputs.decoder_hidden_states, 111 | decoder_attentions=outputs.decoder_attentions, 112 | cross_attentions=outputs.cross_attentions, 113 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 114 | encoder_hidden_states=outputs.encoder_hidden_states, 115 | encoder_attentions=outputs.encoder_attentions, 116 | ) 117 | 118 | 119 | class SAGCopyBartModel(BartModel): 120 | def __init__(self, config: BartConfig): 121 | super().__init__(config) 122 | self.decoder = SAGCopyBartDecoder(config, self.shared) 123 | 124 | def forward( 125 | self, 126 | input_ids=None, 127 | attention_mask=None, 128 | decoder_input_ids=None, 129 | decoder_attention_mask=None, 130 | head_mask=None, 131 | decoder_head_mask=None, 132 | encoder_outputs=None, 133 | past_key_values=None, 134 | inputs_embeds=None, 135 | decoder_inputs_embeds=None, 136 | use_cache=None, 137 | output_attentions=None, 138 | output_hidden_states=None, 139 | return_dict=None, 140 | ): 141 | 142 | # different to other models, Bart automatically creates decoder_input_ids from 143 | # input_ids if no decoder_input_ids are provided 144 | if decoder_input_ids is None and decoder_inputs_embeds is None: 145 | decoder_input_ids = shift_tokens_right( 146 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 147 | ) 148 | 149 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 150 | output_hidden_states = ( 151 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 152 | ) 153 | use_cache = use_cache if use_cache is not None else self.config.use_cache 154 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 155 | 156 | if encoder_outputs is None: 157 | encoder_outputs = self.encoder( 158 | input_ids=input_ids, 159 | attention_mask=attention_mask, 160 | head_mask=head_mask, 161 | inputs_embeds=inputs_embeds, 162 | output_attentions=True, 163 | output_hidden_states=output_hidden_states, 164 | return_dict=return_dict, 165 | ) 166 | 167 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 168 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 169 | 170 | encoder_outputs = BaseModelOutput( 171 | last_hidden_state=encoder_outputs[0], 172 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 173 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 174 | ) 175 | 176 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 177 | decoder_outputs = self.decoder( 178 | input_ids=decoder_input_ids, 179 | attention_mask=decoder_attention_mask, 180 | encoder_hidden_states=encoder_outputs[0], 181 | encoder_attention_mask=attention_mask, 182 | head_mask=decoder_head_mask, 183 | encoder_head_mask=head_mask, 184 | past_key_values=past_key_values, 185 | inputs_embeds=decoder_inputs_embeds, 186 | use_cache=use_cache, 187 | output_attentions=output_attentions, 188 | output_hidden_states=output_hidden_states, 189 | return_dict=return_dict, 190 | encoder_attentions=encoder_outputs[2] 191 | ) 192 | 193 | if not return_dict: 194 | return decoder_outputs + encoder_outputs 195 | 196 | return Seq2SeqModelOutput( 197 | last_hidden_state=decoder_outputs.last_hidden_state, 198 | past_key_values=decoder_outputs.past_key_values, 199 | decoder_hidden_states=decoder_outputs.hidden_states, 200 | decoder_attentions=decoder_outputs.attentions, 201 | cross_attentions=decoder_outputs.cross_attentions, 202 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 203 | encoder_hidden_states=encoder_outputs.hidden_states, 204 | encoder_attentions=encoder_outputs.attentions, 205 | ) 206 | 207 | class SAGCopyBartDecoder(BartDecoder): 208 | """ 209 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`BartDecoderLayer` 210 | 211 | Args: 212 | config: BartConfig 213 | embed_tokens (torch.nn.Embedding): output embedding 214 | """ 215 | 216 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 217 | super().__init__(config) 218 | self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers-1)] + [SAGCopyBARTDecoderLayer(config)]) 219 | 220 | 221 | 222 | def forward( 223 | self, 224 | input_ids=None, 225 | attention_mask=None, 226 | encoder_hidden_states=None, 227 | encoder_attention_mask=None, 228 | head_mask=None, 229 | encoder_head_mask=None, 230 | past_key_values=None, 231 | inputs_embeds=None, 232 | use_cache=None, 233 | output_attentions=None, 234 | output_hidden_states=None, 235 | return_dict=None, 236 | encoder_attentions=None 237 | ): 238 | r""" 239 | Args: 240 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 241 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 242 | provide it. 243 | 244 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 245 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 246 | for details. 247 | 248 | `What are input IDs? <../glossary.html#input-ids>`__ 249 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 250 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 251 | 252 | - 1 for tokens that are **not masked**, 253 | - 0 for tokens that are **masked**. 254 | 255 | `What are attention masks? <../glossary.html#attention-mask>`__ 256 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): 257 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 258 | of the decoder. 259 | encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): 260 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 261 | selected in ``[0, 1]``: 262 | 263 | - 1 for tokens that are **not masked**, 264 | - 0 for tokens that are **masked**. 265 | 266 | `What are attention masks? <../glossary.html#attention-mask>`__ 267 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 268 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 269 | 270 | - 1 indicates the head is **not masked**, 271 | - 0 indicates the heas is **masked**. 272 | 273 | encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 274 | Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention 275 | on hidden heads. Mask values selected in ``[0, 1]``: 276 | 277 | - 1 indicates the head is **not masked**, 278 | - 0 indicates the heas is **masked**. 279 | 280 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 281 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up 282 | decoding. 283 | 284 | If :obj:`past_key_values` are used, the user can optionally input only the last 285 | :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of 286 | shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, 287 | sequence_length)`. 288 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 289 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 290 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 291 | into associated vectors than the model's internal embedding lookup matrix. 292 | output_attentions (:obj:`bool`, `optional`): 293 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 294 | returned tensors for more detail. 295 | output_hidden_states (:obj:`bool`, `optional`): 296 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 297 | for more detail. 298 | return_dict (:obj:`bool`, `optional`): 299 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 300 | """ 301 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 302 | output_hidden_states = ( 303 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 304 | ) 305 | use_cache = use_cache if use_cache is not None else self.config.use_cache 306 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 307 | 308 | # retrieve input_ids and inputs_embeds 309 | if input_ids is not None and inputs_embeds is not None: 310 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 311 | elif input_ids is not None: 312 | input_shape = input_ids.size() 313 | input_ids = input_ids.view(-1, input_shape[-1]) 314 | elif inputs_embeds is not None: 315 | input_shape = inputs_embeds.size()[:-1] 316 | else: 317 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 318 | 319 | # past_key_values_length 320 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 321 | 322 | if inputs_embeds is None: 323 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 324 | 325 | attention_mask = self._prepare_decoder_attention_mask( 326 | attention_mask, input_shape, inputs_embeds, past_key_values_length 327 | ) 328 | 329 | # expand encoder attention mask 330 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 331 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 332 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 333 | 334 | # embed positions 335 | positions = self.embed_positions(input_shape, past_key_values_length) 336 | 337 | hidden_states = inputs_embeds + positions 338 | hidden_states = self.layernorm_embedding(hidden_states) 339 | 340 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 341 | 342 | # decoder layers 343 | all_hidden_states = () if output_hidden_states else None 344 | all_self_attns = () if output_attentions else None 345 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 346 | next_decoder_cache = () if use_cache else None 347 | 348 | # check if head_mask has a correct number of layers specified if desired 349 | if head_mask is not None: 350 | assert head_mask.size()[0] == ( 351 | len(self.layers) 352 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 353 | for idx, decoder_layer in enumerate(self.layers): 354 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 355 | if output_hidden_states: 356 | all_hidden_states += (hidden_states,) 357 | dropout_probability = random.uniform(0, 1) 358 | if self.training and (dropout_probability < self.layerdrop): 359 | continue 360 | 361 | past_key_value = past_key_values[idx] if past_key_values is not None else None 362 | 363 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 364 | 365 | if use_cache: 366 | logger.warn( 367 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 368 | "`use_cache=False`..." 369 | ) 370 | use_cache = False 371 | 372 | def create_custom_forward(module): 373 | def custom_forward(*inputs): 374 | # None for past_key_value 375 | return module(*inputs, output_attentions, use_cache) 376 | return custom_forward 377 | def create_custom_forward_last(module): 378 | def custom_forward_last(*inputs): 379 | # None for past_key_value 380 | return module(*inputs, output_attentions, use_cache, encoder_attentions) 381 | return custom_forward_last 382 | 383 | 384 | if idx != len(self.layers) -1: 385 | layer_outputs = torch.utils.checkpoint.checkpoint( 386 | create_custom_forward(decoder_layer), 387 | hidden_states, 388 | attention_mask, 389 | encoder_hidden_states, 390 | encoder_attention_mask, 391 | head_mask[idx] if head_mask is not None else None, 392 | encoder_head_mask[idx] if encoder_head_mask is not None else None, 393 | None, 394 | ) 395 | else: 396 | 397 | layer_outputs = torch.utils.checkpoint.checkpoint( 398 | create_custom_forward_last(decoder_layer), 399 | hidden_states, 400 | attention_mask, 401 | encoder_hidden_states, 402 | encoder_attention_mask, 403 | head_mask[idx] if head_mask is not None else None, 404 | encoder_head_mask[idx] if encoder_head_mask is not None else None, 405 | None, 406 | ) 407 | 408 | else: 409 | 410 | if idx != len(self.layers) -1: 411 | layer_outputs = decoder_layer( 412 | hidden_states, 413 | attention_mask=attention_mask, 414 | encoder_hidden_states=encoder_hidden_states, 415 | encoder_attention_mask=encoder_attention_mask, 416 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 417 | encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), 418 | past_key_value=past_key_value, 419 | output_attentions=output_attentions, 420 | use_cache=use_cache, 421 | ) 422 | else: 423 | 424 | layer_outputs = decoder_layer( 425 | hidden_states, 426 | attention_mask=attention_mask, 427 | encoder_hidden_states=encoder_hidden_states, 428 | encoder_attention_mask=encoder_attention_mask, 429 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 430 | encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), 431 | past_key_value=past_key_value, 432 | output_attentions=output_attentions, 433 | use_cache=use_cache, 434 | encoder_attentions=encoder_attentions, 435 | ) 436 | hidden_states = layer_outputs[0] 437 | 438 | if use_cache: 439 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 440 | 441 | if output_attentions: 442 | all_self_attns += (layer_outputs[1],) 443 | 444 | if encoder_hidden_states is not None: 445 | all_cross_attentions += (layer_outputs[2],) 446 | 447 | # add hidden states from the last decoder layer 448 | if output_hidden_states: 449 | all_hidden_states += (hidden_states,) 450 | 451 | next_cache = next_decoder_cache if use_cache else None 452 | 453 | if not return_dict: 454 | return tuple( 455 | v 456 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 457 | if v is not None 458 | ) 459 | return BaseModelOutputWithPastAndCrossAttentions( 460 | last_hidden_state=hidden_states, 461 | past_key_values=next_cache, 462 | hidden_states=all_hidden_states, 463 | attentions=all_self_attns, 464 | cross_attentions=all_cross_attentions, 465 | ) 466 | 467 | class SAGCopyBARTDecoderLayer(BartDecoderLayer): 468 | def __init__(self, config: BartConfig): 469 | super().__init__(config) 470 | 471 | self.encoder_attn = SAGCopyBartAttention( 472 | self.embed_dim, 473 | config.decoder_attention_heads, 474 | dropout=config.attention_dropout, 475 | is_decoder=True, 476 | ) 477 | 478 | def forward( 479 | self, 480 | hidden_states: torch.Tensor, 481 | attention_mask: Optional[torch.Tensor] = None, 482 | encoder_hidden_states: Optional[torch.Tensor] = None, 483 | encoder_attention_mask: Optional[torch.Tensor] = None, 484 | layer_head_mask: Optional[torch.Tensor] = None, 485 | encoder_layer_head_mask: Optional[torch.Tensor] = None, 486 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 487 | output_attentions: Optional[bool] = False, 488 | use_cache: Optional[bool] = True, 489 | encoder_attentions: torch.Tensor = None, 490 | ): 491 | """ 492 | Args: 493 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 494 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 495 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 496 | encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` 497 | encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size 498 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 499 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size 500 | `(config.encoder_attention_heads,)`. 501 | encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of 502 | size `(config.encoder_attention_heads,)`. 503 | past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states 504 | output_attentions (:obj:`bool`, `optional`): 505 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 506 | returned tensors for more detail. 507 | """ 508 | residual = hidden_states 509 | 510 | # Self Attention 511 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 512 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 513 | # add present self-attn cache to positions 1,2 of present_key_value tuple 514 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 515 | hidden_states=hidden_states, 516 | past_key_value=self_attn_past_key_value, 517 | attention_mask=attention_mask, 518 | layer_head_mask=layer_head_mask, 519 | output_attentions=output_attentions, 520 | ) 521 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 522 | hidden_states = residual + hidden_states 523 | hidden_states = self.self_attn_layer_norm(hidden_states) 524 | 525 | # Cross-Attention Block 526 | cross_attn_present_key_value = None 527 | cross_attn_weights = None 528 | if encoder_hidden_states is not None: 529 | residual = hidden_states 530 | 531 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 532 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 533 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 534 | hidden_states=hidden_states, 535 | key_value_states=encoder_hidden_states, 536 | attention_mask=encoder_attention_mask, 537 | layer_head_mask=encoder_layer_head_mask, 538 | past_key_value=cross_attn_past_key_value, 539 | output_attentions=output_attentions, 540 | encoder_attentions=encoder_attentions 541 | ) 542 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 543 | hidden_states = residual + hidden_states 544 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 545 | 546 | # add cross-attn to positions 3,4 of present_key_value tuple 547 | present_key_value = present_key_value + cross_attn_present_key_value 548 | 549 | # Fully Connected 550 | residual = hidden_states 551 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 552 | hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) 553 | hidden_states = self.fc2(hidden_states) 554 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 555 | hidden_states = residual + hidden_states 556 | hidden_states = self.final_layer_norm(hidden_states) 557 | 558 | outputs = (hidden_states,) 559 | 560 | if output_attentions: 561 | outputs += (self_attn_weights, cross_attn_weights) 562 | 563 | if use_cache: 564 | outputs += (present_key_value,) 565 | 566 | return outputs 567 | 568 | class SAGCopyBartAttention(BartAttention): 569 | """Multi-headed attention from 'Attention Is All You Need' paper""" 570 | 571 | def __init__( 572 | self, 573 | embed_dim: int, 574 | num_heads: int, 575 | dropout: float = 0.0, 576 | is_decoder: bool = False, 577 | bias: bool = True, 578 | ): 579 | super().__init__(embed_dim, num_heads, dropout, is_decoder, bias) 580 | 581 | # the same w_p as in the paper 582 | self.w_p = torch.nn.Linear(1, embed_dim, bias=False) 583 | 584 | def compute_centrality_representations(self, encoder_attentions, iterations=0): 585 | ''' 586 | encoder_attentions: (batch_size, num_heads, seq_len, seq_len) 587 | ''' 588 | 589 | encoder_attentions = encoder_attentions[-1].mean(1) #(batch, seq_len, seq_len) 590 | T = encoder_attentions / encoder_attentions[-1].sum(1).unsqueeze(1) #(batch, seq_len, seq_len) 591 | scores = T.sum(1) #(batch, seq_len) 592 | 593 | # do TextRank update 594 | for _ in range(iterations): 595 | scores = torch.bmm(T, scores.unsqueeze(2)).squeeze(2) 596 | scores = scores.unsqueeze(2) 597 | representations = self.w_p(scores) 598 | return representations 599 | 600 | def forward( 601 | self, 602 | hidden_states: torch.Tensor, 603 | key_value_states: Optional[torch.Tensor] = None, 604 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 605 | attention_mask: Optional[torch.Tensor] = None, 606 | layer_head_mask: Optional[torch.Tensor] = None, 607 | output_attentions: bool = False, 608 | encoder_attentions: torch.Tensor = None, 609 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 610 | """Input shape: Batch x Time x Channel""" 611 | 612 | # if key_value_states are provided this layer is used as a cross-attention layer 613 | # for the decoder 614 | is_cross_attention = key_value_states is not None 615 | bsz, tgt_len, embed_dim = hidden_states.size() 616 | 617 | # get query proj 618 | query_states = self.q_proj(hidden_states) * self.scaling 619 | # get key, value proj 620 | if is_cross_attention and past_key_value is not None: 621 | # reuse k,v, cross_attentions 622 | key_states = past_key_value[0] 623 | value_states = past_key_value[1] 624 | elif is_cross_attention: 625 | # cross_attentions plus (Whhi + wpscorei) 626 | key_states = self._shape(self.k_proj(key_value_states) + self.compute_centrality_representations(encoder_attentions, iterations=3), -1, bsz) 627 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 628 | elif past_key_value is not None: 629 | # reuse k, v, self_attention 630 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 631 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 632 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 633 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 634 | else: 635 | # self_attention 636 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 637 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 638 | 639 | if self.is_decoder: 640 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 641 | # Further calls to cross_attention layer can then reuse all cross-attention 642 | # key/value_states (first "if" case) 643 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 644 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 645 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 646 | # if encoder bi-directional self-attention `past_key_value` is always `None` 647 | past_key_value = (key_states, value_states) 648 | 649 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 650 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 651 | key_states = key_states.view(*proj_shape) 652 | value_states = value_states.view(*proj_shape) 653 | 654 | src_len = key_states.size(1) 655 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 656 | 657 | assert attn_weights.size() == ( 658 | bsz * self.num_heads, 659 | tgt_len, 660 | src_len, 661 | ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 662 | 663 | if attention_mask is not None: 664 | assert attention_mask.size() == ( 665 | bsz, 666 | 1, 667 | tgt_len, 668 | src_len, 669 | ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 670 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 671 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 672 | 673 | attn_weights = F.softmax(attn_weights, dim=-1) 674 | 675 | if layer_head_mask is not None: 676 | assert layer_head_mask.size() == ( 677 | self.num_heads, 678 | ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 679 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 680 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 681 | 682 | if output_attentions: 683 | # this operation is a bit akward, but it's required to 684 | # make sure that attn_weights keeps its gradient. 685 | # In order to do so, attn_weights have to reshaped 686 | # twice and have to be reused in the following 687 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 688 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 689 | else: 690 | attn_weights_reshaped = None 691 | 692 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) 693 | 694 | attn_output = torch.bmm(attn_probs, value_states) 695 | 696 | assert attn_output.size() == ( 697 | bsz * self.num_heads, 698 | tgt_len, 699 | self.head_dim, 700 | ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 701 | 702 | attn_output = ( 703 | attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 704 | .transpose(1, 2) 705 | .reshape(bsz, tgt_len, embed_dim) 706 | ) 707 | 708 | attn_output = self.out_proj(attn_output) 709 | 710 | return attn_output, attn_weights_reshaped, past_key_value -------------------------------------------------------------------------------- /scirex_eval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied over from https://github.com/allenai/SciREX/blob/master/scirex/evaluation_scripts/scirex_relation_evaluate.py 3 | ''' 4 | from typing import Dict, Tuple, List 5 | import os 6 | from collections import namedtuple 7 | from itertools import combinations 8 | import pandas as pd 9 | from copy import deepcopy 10 | 11 | BASEPATH = os.getenv("RESULT_EXTRACTION_BASEPATH", ".") 12 | 13 | available_entity_types_sciERC = ["Material", "Metric", "Task", "Generic", "OtherScientificTerm", "Method"] 14 | map_available_entity_to_true = {"Material": "dataset", "Metric": "metric", "Task": "task", "Method": "model_name"} 15 | map_true_entity_to_available = {v: k for k, v in map_available_entity_to_true.items()} 16 | 17 | used_entities = list(map_available_entity_to_true.keys()) 18 | true_entities = list(map_available_entity_to_true.values()) 19 | 20 | def has_all_mentions(doc, relation): 21 | has_mentions = all(len(doc["coref"][x[1]]) > 0 for x in relation) 22 | return has_mentions 23 | 24 | def compute_mapping(predicted_relations: List[Dict[str, str]], 25 | gold_entities: Dict[str, List], 26 | doc_tokens: List[str]): 27 | ''' 28 | Each relation in predicted_relations is a dict with two elements (for binary relation). e.g. 29 | { 30 | 'Metric': 'accuracy', 31 | 'Task': 'Natural language inference', 32 | } 33 | ''' 34 | # make a copy so we don't alter the original data 35 | gold_entities = deepcopy(gold_entities) 36 | predicted_mentions = set([mention for relation in predicted_relations for mention in relation.values()]) 37 | 38 | # # Assign each mention to one gold entity. 39 | predicted_mention2gold_entity_name : Dict[str, str] = {} 40 | for predicted_mention in predicted_mentions: 41 | gold_entity_name_to_pop = None 42 | for gold_entity_name, gold_mention_spans in gold_entities.items(): 43 | gold_mentions = { ' '.join(doc_tokens[start_tok:end_tok]) for (start_tok, end_tok) in gold_mention_spans} 44 | if predicted_mention in gold_mentions: 45 | gold_entity_name_to_pop = gold_entity_name 46 | predicted_mention2gold_entity_name[predicted_mention] = gold_entity_name 47 | break 48 | # Make sure each gold entity is only assigned once. 49 | if gold_entity_name_to_pop is not None: 50 | gold_entities.pop(gold_entity_name_to_pop) 51 | 52 | else: 53 | print(f"Cannot find span for {predicted_mention}") 54 | 55 | 56 | return predicted_mention2gold_entity_name 57 | 58 | 59 | 60 | 61 | def scirex_eval(predicted_relations, gold_data, cardinality:int): 62 | 63 | all_metrics = [] 64 | 65 | for types in combinations(used_entities, cardinality): 66 | for doc in gold_data: 67 | relations = predicted_relations[doc["doc_id"]] 68 | 69 | mapping = compute_mapping(relations, doc['coref'], doc["words"]) 70 | 71 | for relation in relations: 72 | for entity_type, entity_name in relation.items(): 73 | relation[entity_type] = mapping.get(entity_name, entity_name) 74 | 75 | # each iteration only evaluate those of corresponding types 76 | relations = set([tuple((t, x[t]) for t in types) for x in relations if all(t in x.keys() for t in types)]) 77 | 78 | gold_relations = [tuple((t, x[t]) for t in types) for x in doc['n_ary_relations']] 79 | gold_relations = set([x for x in gold_relations if has_all_mentions(doc, x)]) 80 | 81 | matched = relations & gold_relations 82 | 83 | metrics = { 84 | "p": len(matched) / (len(relations) + 1e-7), 85 | "r": len(matched) / (len(gold_relations) + 1e-7), 86 | } 87 | metrics["f1"] = 2 * metrics["p"] * metrics["r"] / (metrics["p"] + metrics["r"] + 1e-7) 88 | 89 | if len(gold_relations) > 0: 90 | all_metrics.append(metrics) 91 | 92 | all_metrics = pd.DataFrame(all_metrics) 93 | print("Relation Metrics n=2") 94 | print(all_metrics.describe().loc['mean'][['p', 'r', 'f1']]) 95 | 96 | # take the mean value 97 | return all_metrics.describe().loc['mean'][['p', 'r', 'f1']].to_dict() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import random 5 | from argparse import ArgumentParser 6 | 7 | import numpy as np 8 | import tqdm 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, Adafactor 12 | 13 | from model import GenerativeModel 14 | from config import Config 15 | from data import IEDataset 16 | from constants import * 17 | from util import * 18 | import ree_eval 19 | import scirex_eval 20 | 21 | # configuration 22 | parser = ArgumentParser() 23 | parser.add_argument('-c', '--config', default='config/generative_model.json') 24 | args = parser.parse_args() 25 | config = Config.from_json_file(args.config) 26 | print(config.to_dict()) 27 | 28 | # fix random seed 29 | random.seed(config.seed) 30 | np.random.seed(config.seed) 31 | torch.manual_seed(config.seed) 32 | torch.backends.cudnn.enabled = False 33 | 34 | # set GPU device 35 | use_gpu = config.use_gpu 36 | if use_gpu and config.gpu_device >= 0: 37 | torch.cuda.set_device(config.gpu_device) 38 | 39 | # output 40 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 41 | log_dir = os.path.join(config.log_path, timestamp) 42 | if not os.path.exists(log_dir): 43 | os.makedirs(log_dir) 44 | # logger = Logger(log_dir) 45 | output_dir = os.path.join(config.output_path, timestamp) 46 | if not os.path.exists(output_dir): 47 | os.makedirs(output_dir) 48 | log_file = os.path.join(output_dir, 'log.txt') 49 | with open(log_file, 'w', encoding='utf-8') as w: 50 | w.write(json.dumps(config.to_dict()) + '\n') 51 | print('Log file: {}'.format(log_file)) 52 | best_model = os.path.join(output_dir, 'best.mdl') 53 | train_result_file = os.path.join(output_dir, 'result.train.json') 54 | dev_result_file = os.path.join(output_dir, 'result.dev.json') 55 | test_result_file = os.path.join(output_dir, 'result.test.json') 56 | 57 | # datasets 58 | model_name = config.bert_model_name 59 | 60 | tokenizer = AutoTokenizer.from_pretrained(model_name, 61 | cache_dir=config.bert_cache_dir) 62 | 63 | tokenizer.add_tokens(SPECIAL_TOKENS) 64 | # special_tokens_dict = {'additional_special_tokens': SPECIAL_TOKENS} 65 | # num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 66 | 67 | 68 | print('==============Prepare Training Set=================') 69 | train_set = IEDataset(config.train_file, max_length=config.max_length, gpu=use_gpu) 70 | print('==============Prepare Dev Set=================') 71 | dev_set = IEDataset(config.dev_file, max_length=config.max_length, gpu=use_gpu) 72 | print('==============Prepare Test Set=================') 73 | test_set = IEDataset(config.test_file, max_length=config.max_length, gpu=use_gpu) 74 | vocabs = {} 75 | 76 | print('==============Prepare Training Set=================') 77 | train_set.numberize(tokenizer, vocabs) 78 | print('==============Prepare Dev Set=================') 79 | dev_set.numberize(tokenizer, vocabs) 80 | print('==============Prepare Test Set=================') 81 | test_set.numberize(tokenizer, vocabs) 82 | 83 | if config.task == ROLE_FILLER_ENTITY_EXTRACTION: 84 | grit_dev = read_grit_gold_file(config.grit_dev_file) 85 | grit_test = read_grit_gold_file(config.grit_test_file) 86 | elif config.task in {BINARY_RELATION_EXTRACTION, FOUR_ARY_RELATION_EXTRACTION}: 87 | scirex_dev = read_scirex_gold_file(config.scirex_dev_file) 88 | scirex_test = read_scirex_gold_file(config.scirex_test_file) 89 | 90 | batch_num = len(train_set) // (config.batch_size * config.accumulate_step) + \ 91 | (len(train_set) % (config.batch_size * config.accumulate_step) != 0) 92 | dev_batch_num = len(dev_set) // config.eval_batch_size + \ 93 | (len(dev_set) % config.eval_batch_size != 0) 94 | test_batch_num = len(test_set) // config.eval_batch_size + \ 95 | (len(test_set) % config.eval_batch_size != 0) 96 | 97 | # initialize the model 98 | 99 | model = GenerativeModel(config, vocabs) 100 | 101 | model.load_bert(model_name, cache_dir=config.bert_cache_dir, tokenizer=tokenizer) 102 | 103 | if not model_name.startswith('roberta'): 104 | model.bert.resize_token_embeddings(len(tokenizer)) 105 | 106 | if use_gpu: 107 | model.cuda(device=config.gpu_device) 108 | 109 | # optimizer 110 | param_groups = [ 111 | { 112 | 'params': [p for n, p in model.named_parameters() if n.startswith('bert')], 113 | 'lr': config.bert_learning_rate, 'weight_decay': config.bert_weight_decay 114 | }, 115 | { 116 | 'params': [p for n, p in model.named_parameters() if not n.startswith('bert') 117 | and 'crf' not in n and 'global_feature' not in n], 118 | 'lr': config.learning_rate, 'weight_decay': config.weight_decay 119 | }, 120 | { 121 | 'params': [p for n, p in model.named_parameters() if not n.startswith('bert') 122 | and ('crf' in n or 'global_feature' in n)], 123 | 'lr': config.learning_rate, 'weight_decay': 0 124 | } 125 | ] 126 | if model.bert.config.name_or_path.startswith('t5'): 127 | optimizer = Adafactor(params=param_groups) 128 | else: 129 | optimizer = AdamW(params=param_groups) 130 | schedule = get_linear_schedule_with_warmup(optimizer, 131 | num_warmup_steps=batch_num*config.warmup_epoch, 132 | num_training_steps=batch_num*config.max_epoch) 133 | 134 | # model state 135 | state = dict(model=model.state_dict(), 136 | config=config.to_dict(), 137 | vocabs=vocabs) 138 | 139 | 140 | best_dev = -np.inf 141 | current_step = 0 142 | best_epoch = 0 143 | print('================Start Training================') 144 | for epoch in range(config.max_epoch): 145 | 146 | progress = tqdm.tqdm(total=batch_num, ncols=75, 147 | desc='Train {}'.format(epoch)) 148 | optimizer.zero_grad() 149 | train_gold_outputs, train_pred_outputs, train_input_tokens, train_doc_ids, train_input_ids = [], [], [], [], [] 150 | training_loss = 0 151 | for batch_idx, batch in enumerate(DataLoader( 152 | train_set, batch_size=config.batch_size , 153 | shuffle=True, drop_last=False, collate_fn=train_set.collate_fn)): 154 | 155 | decoder_inputs_outputs = generate_decoder_inputs_outputs(batch, tokenizer, model, use_gpu, config.max_position_embeddings, permute_slots=config.permute_slots, task=config.task) 156 | decoder_input_ids = decoder_inputs_outputs['decoder_input_ids'] 157 | 158 | decoder_labels = decoder_inputs_outputs['decoder_labels'] 159 | decoder_masks = decoder_inputs_outputs['decoder_masks'] 160 | 161 | loss = model(batch, decoder_input_ids, decoder_labels, tokenizer=tokenizer)['loss'] 162 | current_step += 1 163 | loss = loss * (1 / config.accumulate_step) 164 | training_loss += loss.item() 165 | loss.backward() 166 | 167 | 168 | train_gold_outputs.extend(decoder_inputs_outputs['decoder_labels'].tolist()) 169 | train_input_ids.extend(decoder_input_ids.tolist()) 170 | 171 | 172 | if (batch_idx + 1) % config.accumulate_step == 0: 173 | progress.update(1) 174 | torch.nn.utils.clip_grad_norm_( 175 | model.parameters(), config.grad_clipping) 176 | optimizer.step() 177 | schedule.step() 178 | optimizer.zero_grad() 179 | # train the last batch 180 | if batch_num % config.accumulate_step != 0: 181 | progress.update(1) 182 | torch.nn.utils.clip_grad_norm_( 183 | model.parameters(), config.grad_clipping) 184 | optimizer.step() 185 | schedule.step() 186 | optimizer.zero_grad() 187 | 188 | print("training loss", training_loss) 189 | train_result = { 190 | 'pred_outputs': train_pred_outputs, 191 | 'gold_outputs': train_gold_outputs, 192 | 'input_tokens': train_input_tokens, 193 | 'decoder_input_ids': train_input_ids, 194 | 'doc_ids': train_doc_ids 195 | } 196 | with open( train_result_file + f'_{epoch}','w') as f: 197 | f.write(json.dumps(train_result)) 198 | 199 | progress.close() 200 | if config.max_epoch <= 50 or epoch % (config.max_epoch // 150) == 0 : 201 | # dev set 202 | progress = tqdm.tqdm(total=dev_batch_num, ncols=75, 203 | desc='Dev {}'.format(epoch)) 204 | 205 | dev_gold_outputs, dev_pred_outputs, dev_input_tokens, dev_doc_ids, dev_documents = [], [], [], [], [] 206 | 207 | for batch in DataLoader(dev_set, batch_size=config.eval_batch_size, 208 | shuffle=False, collate_fn=dev_set.collate_fn): 209 | progress.update(1) 210 | outputs = model.predict(batch, tokenizer,epoch=epoch) 211 | decoder_inputs_outputs = generate_decoder_inputs_outputs(batch, tokenizer, model, use_gpu, config.max_position_embeddings, task=config.task) 212 | dev_pred_outputs.extend(outputs['decoded_ids'].tolist()) 213 | dev_gold_outputs.extend(decoder_inputs_outputs['decoder_labels'].tolist()) 214 | dev_input_tokens.extend(batch.input_tokens) 215 | dev_doc_ids.extend(batch.doc_ids) 216 | dev_documents.extend(batch.document) 217 | progress.close() 218 | 219 | dev_result = { 220 | 'pred_outputs': dev_pred_outputs, 221 | 'gold_outputs': dev_gold_outputs, 222 | 'input_tokens': dev_input_tokens, 223 | 'doc_ids': dev_doc_ids, 224 | 'documents': dev_documents 225 | } 226 | with open( dev_result_file + f'_{epoch}','w') as f: 227 | f.write(json.dumps(dev_result)) 228 | 229 | # TODO: call the official evaluator 230 | 231 | if config.task == ROLE_FILLER_ENTITY_EXTRACTION: 232 | ree_preds = construct_outputs_for_ceaf(dev_pred_outputs, dev_input_tokens, dev_doc_ids, tokenizer) 233 | dev_scores = ree_eval.ree_eval(ree_preds, grit_dev) 234 | elif config.task == BINARY_RELATION_EXTRACTION: 235 | bre_preds = construct_outputs_for_scirex(dev_pred_outputs, dev_documents, dev_doc_ids, tokenizer, task=BINARY_RELATION_EXTRACTION) 236 | dev_scores = scirex_eval.scirex_eval(bre_preds, scirex_dev, cardinality=2) 237 | elif config.task == FOUR_ARY_RELATION_EXTRACTION: 238 | bre_preds = construct_outputs_for_scirex(dev_pred_outputs, dev_documents, dev_doc_ids, tokenizer, task=FOUR_ARY_RELATION_EXTRACTION) 239 | dev_scores = scirex_eval.scirex_eval(bre_preds, scirex_dev, cardinality=4) 240 | else: 241 | raise NotImplementedError 242 | save_model = False 243 | 244 | 245 | if config.task == ROLE_FILLER_ENTITY_EXTRACTION: 246 | current_dev_score = dev_scores['micro_avg']['f1'] 247 | save_model = current_dev_score > best_dev 248 | elif config.task in {BINARY_RELATION_EXTRACTION, FOUR_ARY_RELATION_EXTRACTION}: 249 | current_dev_score = dev_scores['f1'] 250 | save_model = current_dev_score > best_dev 251 | if save_model: 252 | best_dev = current_dev_score 253 | best_epoch = epoch 254 | print('Saving best model') 255 | torch.save(state, best_model) 256 | 257 | 258 | if save_model: 259 | # test set 260 | progress = tqdm.tqdm(total=test_batch_num, ncols=75, 261 | desc='Test {}'.format(epoch)) 262 | test_gold_outputs, test_pred_outputs, test_input_tokens, test_doc_ids, test_documents = [], [], [], [], [] 263 | test_loss = 0 264 | 265 | for batch in DataLoader(test_set, batch_size=config.eval_batch_size, shuffle=False, 266 | collate_fn=test_set.collate_fn): 267 | progress.update(1) 268 | outputs = model.predict(batch, tokenizer, epoch=epoch) 269 | decoder_inputs_outputs = generate_decoder_inputs_outputs(batch, tokenizer, model, use_gpu, config.max_position_embeddings, task=config.task) 270 | 271 | test_pred_outputs.extend(outputs['decoded_ids'].tolist()) 272 | test_gold_outputs.extend(decoder_inputs_outputs['decoder_labels'].tolist()) 273 | test_input_tokens.extend(batch.input_tokens) 274 | test_doc_ids.extend(batch.doc_ids) 275 | test_documents.extend(batch.document) 276 | progress.close() 277 | 278 | 279 | 280 | if config.task == ROLE_FILLER_ENTITY_EXTRACTION: 281 | ree_preds = construct_outputs_for_ceaf(test_pred_outputs, test_input_tokens, test_doc_ids, tokenizer) 282 | test_scores = ree_eval.ree_eval(ree_preds, grit_test) 283 | elif config.task == BINARY_RELATION_EXTRACTION: 284 | bre_preds = construct_outputs_for_scirex(test_pred_outputs, test_documents, test_doc_ids, tokenizer, task=BINARY_RELATION_EXTRACTION) 285 | test_scores = scirex_eval.scirex_eval(bre_preds, scirex_test, cardinality=2) 286 | elif config.task == FOUR_ARY_RELATION_EXTRACTION: 287 | bre_preds = construct_outputs_for_scirex(test_pred_outputs, test_documents, test_doc_ids, tokenizer, task=FOUR_ARY_RELATION_EXTRACTION) 288 | test_scores = scirex_eval.scirex_eval(bre_preds, scirex_test, cardinality=4) 289 | else: 290 | raise NotImplementedError 291 | 292 | test_result = { 293 | 'pred_outputs': test_pred_outputs, 294 | 'gold_outputs': test_gold_outputs, 295 | 'input_tokens': test_input_tokens, 296 | 'doc_ids': test_doc_ids, 297 | 'documents': test_documents 298 | } 299 | with open( test_result_file + f'_{epoch}','w') as f: 300 | f.write(json.dumps(test_result)) 301 | 302 | result = json.dumps( 303 | {'epoch': epoch, 'dev': dev_scores, 'test': test_scores}) 304 | with open(log_file, 'a', encoding='utf-8') as w: 305 | w.write(result + '\n') 306 | print('Log file', log_file) 307 | if config.task == ROLE_FILLER_ENTITY_EXTRACTION: 308 | get_best_score(log_file, 'micro_avg') 309 | elif config.task in {BINARY_RELATION_EXTRACTION, FOUR_ARY_RELATION_EXTRACTION}: 310 | get_best_score_bre(log_file) 311 | print(config.to_dict()) -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from collections import OrderedDict 4 | import json 5 | from tabulate import tabulate 6 | from typing import Dict, List, Tuple 7 | import re 8 | 9 | from constants import * 10 | 11 | 12 | 13 | 14 | def token2sub_tokens(tokenizer, token): 15 | """ 16 | Take in a string value and use tokenizer to tokenize it into subtokens. 17 | Return a list of sub tokens. 18 | """ 19 | res = [] 20 | for sub_token in tokenizer.tokenize(token): 21 | # make sure it's not an empty string 22 | if len(sub_token) > 0: 23 | res.append(tokenizer.convert_tokens_to_ids(sub_token)) 24 | return res 25 | 26 | def format_inputs_outputs(flattened_seqs, tokenizer, use_gpu, max_position_embeddings): 27 | 28 | max_seq_len = max([len(seq) for seq in flattened_seqs]) 29 | 30 | # cannot be greater than position embeddings 31 | max_seq_len = min(max_position_embeddings, max_seq_len) 32 | 33 | # create padding & mask 34 | decoder_input_ids = [] 35 | decoder_masks = [] 36 | decoder_labels = [] 37 | 38 | 39 | for flattened_seq in flattened_seqs: 40 | 41 | # minus 1 because mask should match the length of input_ids 42 | mask = [1] * len(flattened_seq) + [0] * (max_seq_len - len(flattened_seq)-1) 43 | 44 | # padding. 45 | flattened_seq += [tokenizer.pad_token_id] * (max_seq_len - len(flattened_seq)) 46 | # flattened_seq += [tokenizer.pad_token_id] * (max_seq_len - len(flattened_seq)) 47 | 48 | # make sure they do not exceeed max_seq_len -1 49 | mask = mask[:max_seq_len-1] 50 | flattened_seq = flattened_seq[:max_seq_len] 51 | 52 | input_ids = flattened_seq[:-1] 53 | labels = flattened_seq[1:] 54 | 55 | # For some reason, it seems huggingface use -100 to denote tokens that we don't want to compute loss on. 56 | labels = [l if l != tokenizer.pad_token_id else -100 for l in labels] 57 | 58 | decoder_input_ids.append(input_ids) 59 | decoder_labels.append(labels) 60 | decoder_masks.append(mask) 61 | 62 | # form tensor 63 | if use_gpu: 64 | decoder_input_ids = torch.cuda.LongTensor(decoder_input_ids) 65 | decoder_labels = torch.cuda.LongTensor(decoder_labels) 66 | decoder_masks = torch.cuda.FloatTensor(decoder_masks) 67 | 68 | else: 69 | decoder_input_ids = torch.LongTensor(decoder_input_ids) 70 | decoder_labels = torch.LongTensor(decoder_labels) 71 | decoder_masks = torch.FloatTensor(decoder_masks) 72 | 73 | 74 | res = { 75 | 'decoder_input_ids': decoder_input_ids, 76 | 'decoder_labels': decoder_labels, 77 | 'decoder_masks': decoder_masks 78 | } 79 | return res 80 | 81 | 82 | 83 | def generate_decoder_inputs_outputs(batch, tokenizer, model, use_gpu, max_position_embeddings, permute_slots=False, task=ROLE_FILLER_ENTITY_EXTRACTION): 84 | ''' 85 | Process decoder_input_chunks and produce a dictionary with keys decoder_input_ids and decoder_labels. 86 | decoder_input_chunks is a list where each element correspond to annotation of a document. 87 | ''' 88 | decoder_input_chunks = batch.decoder_input_chunks 89 | 90 | flattened_seqs = [] 91 | 92 | for decoder_input_chunk in decoder_input_chunks: 93 | ''' 94 | decoder_input_chunk: [[[template_1_entity_1],[template_1_entity_2], ..., ],[, [template_2_entit_1],[template_2_entity_2]] ] 95 | ''' 96 | 97 | flatten_entities = [] 98 | # shuffle templates 99 | 100 | for template in decoder_input_chunk: 101 | # shuffle the slots in each template. 102 | if permute_slots: 103 | template = template.copy() 104 | random.shuffle(template) 105 | 106 | # if BRE, we need to determine which mention to take beforehand 107 | if task in {BINARY_RELATION_EXTRACTION, FOUR_ARY_RELATION_EXTRACTION}: 108 | # assuming each entity has different first meniton, we use this to construct a map that determines 109 | # which mention to sample 110 | 111 | first_mention2mention_idx :Dict[Tuple, int] = {} 112 | for entity in template: 113 | first_mention2mention_idx[tuple(entity[0])] = random.randint(0, len(entity)-1) # randint includes the boundaries on both side. 114 | 115 | flatten_entities.append(tokenizer.convert_tokens_to_ids(START_OF_TEMPLATE)) 116 | for entity in template: 117 | if task == ROLE_FILLER_ENTITY_EXTRACTION: 118 | mention_chunk = random.choice(entity) 119 | elif task in {BINARY_RELATION_EXTRACTION, FOUR_ARY_RELATION_EXTRACTION}: 120 | mention_idx = first_mention2mention_idx[tuple(entity[0])] 121 | mention_chunk = entity[mention_idx] 122 | else: 123 | raise NotImplementedError 124 | 125 | for sub_token in mention_chunk: 126 | flatten_entities.append(sub_token) 127 | # 128 | flatten_entities.append(tokenizer.convert_tokens_to_ids(END_OF_TEMPLATE)) 129 | ''' 130 | flattened_seq should looks like [tokenizer.eos_token_id, tokenizer.bos_token_id, , , slot, name, , , entity, ,, ..., , tokenizer.eos_token_id] 131 | ''' 132 | if model.bert.config.name_or_path.startswith('facebook/bart') or model.bert.config.name_or_path.startswith('sshleifer/distilbart'): 133 | flattened_seq = [model.bert.config.decoder_start_token_id, tokenizer.bos_token_id] + flatten_entities + [tokenizer.eos_token_id] 134 | elif model.bert.config.name_or_path.startswith('t5') or model.bert.config.name_or_path.startswith('google/pegasus') : 135 | # t5 does not have in the decoded string 136 | flattened_seq = [model.bert.config.decoder_start_token_id] + flatten_entities + [tokenizer.eos_token_id] 137 | elif model.bert.config.decoder._name_or_path.startswith('roberta'): 138 | flattened_seq = [model.bert.config.decoder_start_token_id] + flatten_entities + [tokenizer.eos_token_id] 139 | else: 140 | print("model name ", model.bert.config) 141 | raise NotImplementedError 142 | 143 | 144 | flattened_seqs.append(flattened_seq) 145 | 146 | 147 | res = format_inputs_outputs(flattened_seqs, tokenizer, use_gpu, max_position_embeddings) 148 | 149 | return res 150 | 151 | def construct_outputs_for_scirex(preds, input_documents, doc_ids, tokenizer, task): 152 | res = dict() 153 | 154 | if task == BINARY_RELATION_EXTRACTION: 155 | cardinality = 2 156 | elif task == FOUR_ARY_RELATION_EXTRACTION: 157 | cardinality = 4 158 | else: 159 | raise NotImplementedError 160 | 161 | for predicted_id_sequence, input_document, doc_id in zip(preds, input_documents, doc_ids): 162 | # convert id to tokens 163 | predicted_sequence = tokenizer.decode(predicted_id_sequence) 164 | res[doc_id] = extract_relations_from_sequence(predicted_sequence, input_document, cardinality) 165 | 166 | return res 167 | 168 | 169 | def extract_relations_from_sequence(predicted_sequence: str, input_document: str, cardinality: int = 2): 170 | 171 | predicted_relations : List[Dict[str, str]] = [] 172 | 173 | # remove the first 174 | predicted_sequence = predicted_sequence[4:] 175 | 176 | # we should not decode beyond the second 177 | try: 178 | first_eos_index = predicted_sequence.index('') 179 | predicted_sequence = predicted_sequence[:first_eos_index] 180 | except: 181 | pass 182 | 183 | predicted_relation_sequences = predicted_sequence.replace('','').replace('','').split('') 184 | 185 | for seq in predicted_relation_sequences: 186 | 187 | entity_types = re.findall('(.*?(?=))',seq) 188 | entity_types = [et.strip() for et in entity_types] 189 | entity_names = [entity_name.strip() for entity_name in re.findall('(.*?(?=))',seq)] 190 | entity_names = [en.strip() for en in entity_names] 191 | 192 | if len(entity_types) == len(entity_names) == cardinality and \ 193 | dict(zip(entity_types, entity_names)) not in predicted_relations: 194 | predicted_relations.append(dict(zip(entity_types, entity_names))) 195 | 196 | 197 | return predicted_relations 198 | 199 | def construct_outputs_for_ceaf(preds, input_documents, doc_ids, tokenizer): 200 | ''' 201 | 202 | input_documents: a list of decoded document (str) 203 | 204 | ''' 205 | res = OrderedDict() 206 | for predicted_id_sequence, input_document, doc_id in zip(preds, input_documents, doc_ids): 207 | 208 | # convert id to tokens 209 | predicted_sequence = tokenizer.decode(predicted_id_sequence) 210 | 211 | # for unknown reason GRIT do this processing for docid 212 | doc_id = docid = str(int(doc_id.split("-")[0][-1])*10000 + int(doc_id.split("-")[-1])) 213 | 214 | # transform into doc 215 | res[doc_id] = event_templates_to_ceaf(predicted_sequence, input_document) 216 | 217 | return res 218 | 219 | 220 | def event_templates_to_ceaf(event_template_sequence: str, input_document: str): 221 | ''' 222 | Turns a sequence of event templates into a dictionary 223 | e.g. 224 | PerpIndsalvadoran rightist sectorsPerpIndsoldiersVictimhector oqueli colindresVictimhilda flores 225 | -> { 226 | 'PerpInd':[ 227 | [ 228 | ["salvadoran rightist sectors"], 229 | 230 | ], 231 | [ 232 | ["soldiers"] 233 | ] 234 | ], 235 | 'Victim':[ 236 | [ 237 | ['hector oqueli colindres'], 238 | ] 239 | [ 240 | ['hilda flores'] 241 | ] 242 | ] 243 | 244 | } 245 | ''' 246 | 247 | # remove the first 248 | event_template_sequence = event_template_sequence[4:] 249 | 250 | # we should not decode beyond the second 251 | try: 252 | first_eos_index = event_template_sequence.index('') 253 | event_template_sequence = event_template_sequence[:first_eos_index] 254 | except: 255 | pass 256 | res = { 257 | 'PerpInd':[], 258 | 'PerpOrg':[], 259 | 'Target':[], 260 | 'Victim':[], 261 | 'Weapon':[] 262 | } 263 | prev_slot_name = None 264 | prev_tag = None # this is for determining whether a mention is in the same entity cluster as the previous mention 265 | try: 266 | while event_template_sequence: 267 | 268 | # if encountered these, skip 269 | if event_template_sequence.startswith(START_OF_TEMPLATE): 270 | event_template_sequence = event_template_sequence[len(START_OF_TEMPLATE):] 271 | continue 272 | elif event_template_sequence.startswith(''): 273 | event_template_sequence = event_template_sequence[len(''):] 274 | continue 275 | 276 | elif event_template_sequence.startswith(START_OF_SLOT_NAME): 277 | if END_OF_SLOT_NAME in event_template_sequence: 278 | end_of_slot_name_index = event_template_sequence.index(END_OF_SLOT_NAME) 279 | current_slot_name = event_template_sequence[len(START_OF_SLOT_NAME):end_of_slot_name_index] 280 | slot_name_length = len(current_slot_name) 281 | 282 | event_template_sequence = event_template_sequence[len(START_OF_SLOT_NAME)+len(END_OF_SLOT_NAME)+slot_name_length:] 283 | 284 | current_slot_name = current_slot_name.strip() 285 | # if the current solt name is not valid, set it to None 286 | if current_slot_name not in res.keys(): 287 | current_slot_name = None 288 | continue 289 | 290 | prev_tag = SLOT_NAME_TAG 291 | new_slot_name_set = True 292 | 293 | 294 | 295 | else: 296 | # if tag is not ending with a tag, the sequence is problematic, end decoding 297 | break 298 | 299 | elif event_template_sequence.startswith(START_OF_ENTITY): 300 | if END_OF_ENTITY in event_template_sequence: 301 | end_of_entity_index = event_template_sequence.index(END_OF_ENTITY) 302 | mention = event_template_sequence[len(START_OF_ENTITY): end_of_entity_index].strip() 303 | mention_length = len(mention) 304 | event_template_sequence = event_template_sequence[len(START_OF_ENTITY)+len(END_OF_ENTITY) +mention_length :] 305 | 306 | else: 307 | # grab whatever we have left in the sequence and append it to the current result. 308 | mention = event_template_sequence[len(START_OF_ENTITY): ] 309 | event_template_sequence = '' 310 | 311 | # the extracted mention string must be part of the input document for the role-filler entity extraction task 312 | if mention in input_document: 313 | 314 | # if previous tag is entity, this means the current mention and the previous mention belongs to the same entity cluster 315 | if prev_tag == ENTITY_TAG: 316 | # append the current mention to the last entity cluster 317 | res[current_slot_name][-1].append(mention) 318 | else: 319 | # append a new cluster 320 | res[current_slot_name].append([mention]) 321 | 322 | 323 | 324 | prev_tag = ENTITY_TAG 325 | 326 | else: 327 | # if nothing match, reduce the sequence length by 1 and move forward 328 | event_template_sequence = event_template_sequence[1:] 329 | 330 | except Exception as e: 331 | 332 | print(event_template_sequence) 333 | 334 | return res 335 | 336 | 337 | def read_grit_gold_file(file: str): 338 | golds = OrderedDict() 339 | with open(file, encoding="utf-8") as f: 340 | for line in f: 341 | line = json.loads(line) 342 | docid = str(int(line["docid"].split("-")[0][-1])*10000 + int(line["docid"].split("-")[-1])) 343 | 344 | extracts_raw = line["extracts"] 345 | 346 | extracts = OrderedDict() 347 | for role, entitys_raw in extracts_raw.items(): 348 | extracts[role] = [] 349 | for entity_raw in entitys_raw: 350 | entity = [] 351 | for mention_offset_pair in entity_raw: 352 | entity.append(mention_offset_pair[0]) 353 | if entity: 354 | extracts[role].append(entity) 355 | golds[docid] = extracts 356 | return golds 357 | 358 | def read_scirex_gold_file(file: str) : 359 | return [json.loads(line) for line in open(file)] 360 | 361 | def construct_table(result): 362 | def format_string(score): 363 | return f'{score*100:.2f}' 364 | 365 | table = [["role", "prec", "rec",'f1']] 366 | for key, values in result.items(): 367 | table.append( [key, format_string(values['p']), format_string(values['r']), format_string(values['f1']) ]) 368 | 369 | return tabulate(table, headers="firstrow", tablefmt="grid") 370 | 371 | def get_best_score(log_file: str, role: str): 372 | 373 | with open(log_file, 'r', encoding='utf-8') as r: 374 | config = r.readline() 375 | 376 | best_scores = [] 377 | best_dev_score = 0 378 | for line in r: 379 | record = json.loads(line) 380 | dev = record['dev'] 381 | test = record['test'] 382 | epoch = record['epoch'] 383 | 384 | if dev[role]['f1'] > best_dev_score: 385 | best_dev_score = dev[role]['f1'] 386 | best_scores = [dev, test, epoch] 387 | 388 | print('Best Epoch: {}'.format(best_scores[-1])) 389 | 390 | best_dev, best_test, epoch = best_scores 391 | print("Dev") 392 | print(construct_table(best_dev)) 393 | print("Test") 394 | print(construct_table(best_test)) 395 | 396 | def get_best_score_bre(log_file: str): 397 | 398 | with open(log_file, 'r', encoding='utf-8') as r: 399 | 400 | best_scores = [] 401 | best_dev_score = 0 402 | for line in r: 403 | record = json.loads(line) 404 | dev = record['dev'] 405 | test = record['test'] 406 | epoch = record['epoch'] 407 | 408 | if dev['f1'] > best_dev_score: 409 | best_dev_score = dev['f1'] 410 | best_scores = [dev, test, epoch] 411 | 412 | print('Best Epoch: {}'.format(best_scores[-1])) 413 | 414 | best_dev, best_test, epoch = best_scores 415 | print("Dev") 416 | print(best_dev) 417 | print("Test") 418 | print(best_test) --------------------------------------------------------------------------------