├── LICENSE ├── README.md ├── convert_dyiepp_to_sentence.py ├── convert_et_result.py ├── data └── raw_data │ └── dyiepp_ace2005 │ ├── event.schema │ ├── test.json │ ├── train.json │ └── val.json ├── data_convert ├── __init__.py ├── convert_text_to_target.py ├── format │ ├── __init__.py │ └── text2target.py ├── task_format │ ├── __init__.py │ └── event_extraction.py └── utils.py ├── evaluation.py ├── extraction ├── __init__.py ├── event_schema.py ├── extract_constraint.py ├── extraction_metrics.py ├── label_tree.py └── predict_parser │ ├── __init__.py │ ├── predict_parser.py │ └── target_predict_parser.py ├── requirements.txt ├── run_arg_predict.bash ├── run_seq2seq.py ├── run_seq2seq_span.bash ├── run_tri_predict.bash └── seq2seq ├── __init__.py ├── constrained_seq2seq.py ├── label_smoother_sum.py ├── sentence_splitter.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # GDAP 2 | Code for ***[Generating Disentangled Arguments with Prompts: A Simple Event Extraction Framework that Works](https://arxiv.org/abs/2110.04525)*** 3 | 4 | ## Environment 5 | 6 | - Python (verified: v3.8) 7 | - CUDA (verified: v11.1) 8 | - Packages (see [requirements.txt](./requirements.txt)) 9 | 10 | ## Usage 11 | 12 | ### Preprocessing 13 | We follow [dygiepp](https://github.com/dwadden/dygiepp) for data preprocessing. 14 | 15 | - `text2et`: Event Type Detection 16 | - `ettext2tri`: Trigger Extraction 17 | - `etrttext2role`: Argument Extraction 18 | 19 | ```bash 20 | # data processed by dyieapp 21 | data/text2target/dyiepp_ace1005_ettext2tri_subtype 22 | ├── event.schema 23 | ├── test.json 24 | ├── train.json 25 | └── val.json 26 | 27 | # data processed by data_convert.convert_text_to_target 28 | data/text2target/dyiepp_ace1005_ettext2tri_subtype 29 | ├── event.schema 30 | ├── test.json 31 | ├── train.json 32 | └── val.json 33 | ``` 34 | Useful commands: 35 | 36 | ```bash 37 | python -m data_convert.convert_text_to_target # data/raw_data -> data/text2target 38 | python convert_dyiepp_to_sentence.py data/raw_data/dyiepp_ace2005 # doc -> sentence, used in evaluation 39 | ``` 40 | 41 | ### Training 42 | Relevant scripts: 43 | 44 | - `run_seq2seq.py`: Python code entry, modified from the transformers/examples/seq2seq/run_seq2seq.py 45 | - `run_seq2seq_span.bash`: Model training script logging to the log file. 46 | 47 | Example (see the above two files for more details): 48 | 49 | ```bash 50 | # ace05 event type detection t5-base, the metric_format use eval_trigger-F1 51 | bash run_seq2seq_span.bash --data=dyiepp_ace2005_text2et_subtype --model=t5-base --format=et --metric_format=eval_trigger-F1 52 | 53 | # ace05 tri extraction t5-base 54 | bash run_seq2seq_span.bash --data=dyiepp_ace2005_ettext2tri_subtype --model=t5-base --format=tri --metric_format=eval_trigger-F1 55 | 56 | # ace05 argument extraction t5-base 57 | bash run_seq2seq_span.bash --data=dyiepp_ace2005_etrttext2role_subtype --model=t5-base --format=role --metric_format=eval_role-F1 58 | 59 | ``` 60 | 61 | Trained models are saved in the `models/` folder. 62 | 63 | The event type detection use the same output format and metric_format as trigger extraction, so the et exp result is included in eval_trigger-* and test_trigger-* of the log. 64 | 65 | ### Evaluation 66 | - `run_tri_predict.bash`: trigger extraction evaluation and inference script. 67 | - `run_arg_predict.bash`: argument extraction evaluation and inference script. 68 | 69 | ## If you find this repo helpful... 70 | Please give us a :star: and cite our paper as 71 | ```bibtex 72 | @inproceedings{si2021-GDAP, 73 | title={Generating Disentangled Arguments with Prompts: A Simple Event Extraction Framework that Works}, 74 | author={Jinghui Si and Xutan Peng and Chen Li and Haotian Xu and Jianxin Li}, 75 | booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 76 | year={2022} 77 | } 78 | ``` 79 | 80 | > This project borrows code from [Text2Event](https://github.com/luyaojie/text2event) 81 | -------------------------------------------------------------------------------- /convert_dyiepp_to_sentence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import sys 4 | from os import path 5 | import json 6 | import collections 7 | 8 | 9 | def main(): 10 | 11 | output_dir = sys.argv[1] 12 | for fold in ["train", "dev", "test"]: 13 | g_convert = open(path.join(output_dir, fold + "_convert.json"), "w") 14 | with open(path.join(output_dir, fold + ".json"), "r") as g: 15 | print('convert %s to %s' % ( 16 | path.join(output_dir, fold + ".json"), 17 | path.join(output_dir, fold + "_convert.json") 18 | )) 19 | for line in g: 20 | line = json.loads(line) 21 | sentences = line["sentences"] 22 | ner = line["ner"] 23 | relations = line["relations"] 24 | events = line["events"] 25 | sentence_start = line["_sentence_start"] 26 | doc_key = line["doc_key"] 27 | 28 | assert len(sentence_start) == len(ner) == len( 29 | relations) == len(events) == len(sentence_start) 30 | 31 | for sentence, ner, relation, event, s_start in zip(sentences, ner, relations, events, sentence_start): 32 | sentence_annotated = collections.OrderedDict() 33 | sentence_annotated["sentence"] = sentence 34 | sentence_annotated["s_start"] = s_start 35 | sentence_annotated["ner"] = ner 36 | sentence_annotated["relation"] = relation 37 | sentence_annotated["event"] = event 38 | 39 | g_convert.write(json.dumps( 40 | sentence_annotated, default=int) + "\n") 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /convert_et_result.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import codecs 3 | import argparse 4 | import json 5 | 6 | from extraction.event_schema import EventSchema 7 | from extraction.predict_parser.target_predict_parser import ETPredictParser 8 | from data_convert.format.text2target import type_start, type_end 9 | 10 | 11 | parser = argparse.ArgumentParser(description='Convert et result') 12 | 13 | parser.add_argument('--et_pred_file', type=str) 14 | parser.add_argument('--et_text_file', type=str) # 与preds 文件对应的text文件名称 15 | parser.add_argument('--et_output_file', type=str) 16 | parser.add_argument('--schema_file', type=str) 17 | parser.add_argument('--mode', type=str, default="role") 18 | args = parser.parse_args() 19 | 20 | 21 | def read_file(file_name): 22 | return [line.strip() for line in open(file_name).readlines()] 23 | 24 | 25 | def et_text2role(schema, et_list, text): 26 | et_rt_dict = schema.type_role_dict 27 | source_list = [] 28 | target_list = [] 29 | # 遍历event 30 | for event_type in et_list: 31 | for role_type in et_rt_dict[event_type]: 32 | source_text = event_type + " " + role_type + " " + text 33 | target_text = "" 34 | 35 | source_list.append(source_text) 36 | target_list.append(target_text) 37 | 38 | # 在所有的target 上统一加上起始位置 39 | for i in range(len(target_list)): 40 | target_list[i] = f'{type_start} ' + target_list[i] + f' {type_end}' 41 | 42 | return source_list, target_list 43 | 44 | def et_text2tri(et_list, text): 45 | source_list = [] 46 | target_list = [] 47 | # 遍历event 48 | for event_type in et_list: 49 | source_text = event_type + " " + text 50 | target_text = "" 51 | 52 | source_list.append(source_text) 53 | target_list.append(target_text) 54 | 55 | # 在所有的target 上统一加上起始位置 56 | for i in range(len(target_list)): 57 | target_list[i] = f'{type_start} ' + target_list[i] + f' {type_end}' 58 | 59 | return source_list, target_list 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | label_schema = EventSchema.read_from_file( 65 | filename=args.schema_file 66 | ) 67 | 68 | # 采用解析评估函数对结果文件进行解析 69 | pred_reader = ETPredictParser(schema=label_schema) 70 | event_list, _ = pred_reader.decode( 71 | gold_list=[], 72 | pred_list=read_file(args.et_pred_file), 73 | text_list=[json.loads(line)['text'] 74 | for line in read_file(args.et_text_file)] 75 | ) 76 | 77 | # 输出文件 78 | event_output = codecs.open(args.et_output_file, 'w', 'UTF-8') 79 | 80 | for item in event_list: 81 | text = item["text"] 82 | event_list = item["pred_event"] 83 | 84 | if args.mode == "role": 85 | source_list, target_list = et_text2role(schema=label_schema, et_list=event_list, text=text) 86 | else: # trigger 87 | source_list, target_list = et_text2tri(et_list=event_list, text=text) 88 | 89 | # 将处理后的信息写入文件 90 | assert len(source_list) == len(target_list) 91 | for i in range(len(source_list)): 92 | event_output.write(json.dumps( 93 | {'text': source_list[i], 'event': target_list[i]}, ensure_ascii=False) + '\n') 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /data/raw_data/dyiepp_ace2005/event.schema: -------------------------------------------------------------------------------- 1 | ["Divorce", "Appeal", "Sentence", "Charge-Indict", "Phone-Write", "Be-Born", "End-Org", "Arrest-Jail", "Declare-Bankruptcy", "Injure", "Fine", "Start-Position", "Demonstrate", "Trial-Hearing", "Merge-Org", "Meet", "Pardon", "Marry", "Sue", "End-Position", "Execute", "Nominate", "Elect", "Transport", "Transfer-Money", "Start-Org", "Acquit", "Extradite", "Die", "Convict", "Release-Parole", "Transfer-Ownership", "Attack"] 2 | ["Target", "Org", "Entity", "Vehicle", "Attacker", "Place", "Destination", "Artifact", "Agent", "Seller", "Person", "Adjudicator", "Plaintiff", "Beneficiary", "Victim", "Origin", "Instrument", "Giver", "Prosecutor", "Recipient", "Defendant", "Buyer"] 3 | {"Transport": ["Vehicle", "Place", "Destination", "Origin", "Artifact", "Agent"], "Transfer-Ownership": ["Beneficiary", "Place", "Artifact", "Buyer", "Seller"], "Sue": ["Plaintiff", "Place", "Defendant", "Adjudicator"], "Divorce": ["Person", "Place"], "Nominate": ["Person", "Agent"], "Attack": ["Target", "Attacker", "Victim", "Place", "Instrument"], "Convict": ["Place", "Defendant", "Adjudicator"], "Meet": ["Place", "Entity"], "Charge-Indict": ["Prosecutor", "Place", "Defendant", "Adjudicator"], "Acquit": ["Defendant", "Adjudicator"], "Elect": ["Person", "Place", "Entity"], "Be-Born": ["Person", "Place"], "Declare-Bankruptcy": ["Place", "Org"], "Appeal": ["Plaintiff", "Place", "Adjudicator"], "End-Position": ["Person", "Place", "Entity"], "Pardon": ["Place", "Defendant", "Adjudicator"], "Fine": ["Place", "Entity", "Adjudicator"], "Execute": ["Person", "Place", "Agent"], "Extradite": ["Origin", "Agent", "Destination"], "Start-Position": ["Person", "Place", "Entity"], "Trial-Hearing": ["Prosecutor", "Place", "Defendant", "Adjudicator"], "Arrest-Jail": ["Person", "Place", "Agent"], "Release-Parole": ["Person", "Place", "Entity"], "Phone-Write": ["Place", "Entity"], "Die": ["Victim", "Place", "Agent", "Person", "Instrument"], "End-Org": ["Place", "Org"], "Marry": ["Person", "Place"], "Injure": ["Victim", "Place", "Instrument", "Agent"], "Sentence": ["Place", "Defendant", "Adjudicator"], "Demonstrate": ["Place", "Entity"], "Start-Org": ["Place", "Org", "Agent"], "Merge-Org": ["Org"], "Transfer-Money": ["Place", "Recipient", "Beneficiary", "Giver"]} 4 | -------------------------------------------------------------------------------- /data/raw_data/dyiepp_ace2005/test.json: -------------------------------------------------------------------------------- 1 | {"sentences": [["AFP_ENG_20030401.0476"], ["NEWS", "STORY"], ["20030401"], ["Energy", "regulator", "named", "new", "head", "of", "Britain", "'s", "FSA", "finance", "watchdog"], ["LONDON", ",", "April", "1", "(", "AFP", ")"], ["British", "Chancellor", "of", "the", "Exchequer", "Gordon", "Brown", "on", "Tuesday", "named", "the", "current", "head", "of", "the", "country", "'s", "energy", "regulator", "as", "the", "new", "chairman", "of", "finance", "watchdog", "the", "Financial", "Services", "Authority", "(", "FSA", ")", "."], ["Former", "senior", "banker", "Callum", "McCarthy", "begins", "what", "is", "one", "of", "the", "most", "important", "jobs", "in", "London", "'s", "financial", "world", "in", "September", ",", "when", "incumbent", "Howard", "Davies", "steps", "down", "."], ["Davies", "is", "leaving", "to", "become", "chairman", "of", "the", "London", "School", "of", "Economics", ",", "one", "of", "the", "best", "-", "known", "parts", "of", "the", "University", "of", "London", "."], ["Brown", "said", "McCarthy", "would", "bring", "to", "the", "FSA", "\"", "an", "unrivalled", "combination", "of", "experience", "of", "the", "public", "sector", ",", "experience", "as", "a", "senior", "practitioner", "in", "the", "financial", "services", "industry", ",", "and", "chairman", "of", "a", "highly", "successful", "regulator", ".", "\""], ["McCarthy", "currently", "heads", "the", "Office", "of", "Gas", "and", "Electricity", "Markets", ",", "or", "Ofgem", ",", "which", "regulates", "Britain", "'s", "privatised", "energy", "industry", "."], ["As", "well", "as", "previously", "holding", "senior", "positions", "at", "Barclays", "Bank", ",", "BZW", "and", "Kleinwort", "Benson", ",", "McCarthy", "was", "formerly", "a", "top", "civil", "servant", "at", "the", "Department", "of", "Trade", "and", "Industry", "."], ["Under", "Davies", "'s", "watch", ",", "the", "FSA", "became", "Britain", "'s", "sole", "regulatory", "body", "for", "financial", "services", "in", "November", "2001", ",", "taking", "on", "the", "various", "functions", "of", "nine", "separate", "agencies", "."]], "ner": [[], [], [], [], [[15, 15, "GPE"], [20, 20, "ORG"]], [[22, 22, "GPE"], [23, 23, "PER"], [26, 26, "ORG"], [27, 28, "PER"], [34, 34, "PER"], [37, 37, "GPE"], [40, 40, "ORG"], [44, 44, "PER"], [47, 47, "ORG"], [49, 51, "ORG"]], [[58, 58, "PER"], [59, 60, "PER"], [71, 71, "GPE"], [79, 79, "PER"], [80, 81, "PER"]], [[85, 85, "PER"], [90, 90, "PER"], [93, 96, "ORG"], [107, 109, "ORG"]], [[111, 111, "PER"], [113, 113, "PER"], [118, 118, "ORG"], [134, 134, "PER"], [139, 139, "ORG"], [142, 142, "PER"], [147, 147, "ORG"]], [[150, 150, "PER"], [154, 159, "ORG"], [162, 162, "ORG"], [166, 166, "GPE"], [170, 170, "ORG"]], [[180, 181, "ORG"], [183, 183, "ORG"], [185, 186, "ORG"], [188, 188, "PER"], [194, 194, "PER"], [197, 201, "ORG"]], [[204, 204, "PER"], [209, 209, "ORG"], [211, 211, "GPE"], [215, 215, "ORG"], [231, 231, "ORG"]]], "relations": [[], [], [], [], [], [[23, 23, 26, 26, "ORG-AFF.Employment"], [34, 34, 40, 40, "ORG-AFF.Employment"], [44, 44, 49, 51, "ORG-AFF.Membership"]], [], [[90, 90, 93, 96, "ORG-AFF.Employment"]], [[142, 142, 147, 147, "ORG-AFF.Membership"]], [[150, 150, 154, 159, "ORG-AFF.Employment"]], [[188, 188, 180, 181, "ORG-AFF.Employment"], [188, 188, 183, 183, "ORG-AFF.Employment"], [188, 188, 185, 186, "ORG-AFF.Employment"], [194, 194, 197, 201, "ORG-AFF.Employment"]], [[215, 215, 211, 211, "PART-WHOLE.Subsidiary"]]], "events": [[], [], [], [], [], [[[31, "Personnel.Nominate"], [34, 34, "Person"]]], [[[56, "Personnel.End-Position"], [59, 60, "Person"]], [[61, "Personnel.Start-Position"], [59, 60, "Person"]]], [[[87, "Personnel.End-Position"], [85, 85, "Person"]], [[89, "Personnel.Start-Position"], [85, 85, "Person"], [93, 96, "Entity"]]], [], [], [[[175, "Personnel.End-Position"], [180, 181, "Entity"], [183, 183, "Entity"], [185, 186, "Entity"], [188, 188, "Person"]], [[190, "Personnel.End-Position"], [188, 188, "Person"], [197, 201, "Entity"]]], []], "_sentence_start": [0, 1, 3, 4, 15, 22, 56, 85, 111, 150, 172, 203], "doc_key": "AFP_ENG_20030401.0476", "dataset": "ace-event"} -------------------------------------------------------------------------------- /data/raw_data/dyiepp_ace2005/train.json: -------------------------------------------------------------------------------- 1 | {"sentences": [["CNN_CF_20030303.1900.00"], ["STORY"], ["2003", "-", "03", "-", "03T19:00:00", "-", "05:00"], ["New", "Questions", "About", "Attacking", "Iraq", ";", "Is", "Torturing", "Terrorists", "Necessary", "?"], ["BEGALA", "Well", ",", "we", "'ll", "debate", "that", "later", "on", "in", "the", "show", "."], ["We", "'ll", "have", "a", "couple", "of", "experts", "come", "out", ",", "so", "I", "'ll", "withhold", "my", "comments", "until", "then", "."], ["Even", "as", "the", "secretary", "of", "homeland", "security", "was", "putting", "his", "people", "on", "high", "alert", "last", "month", ",", "a", "30-foot", "Cuban", "patrol", "boat", "with", "four", "heavily", "armed", "men", "landed", "on", "American", "shores", ",", "utterly", "undetected", "by", "the", "Coast", "Guard", "Secretary", "Ridge", "now", "leads", "."], ["Now", ",", "why", "has", "our", "president", "placed", "homeland", "security", "in", "the", "hands", "of", "Republican", "political", "hacks", "instead", "of", "professionals", ",", "by", "the", "way", "?", "Attorney", "General", "John", "Ashcroft", ",", "for", "example", ",", "is", "a", "career", "politician", "."], ["He", "lost", "an", "election", "to", "a", "dead", "man", "."], ["Secretary", "of", "Homeland", "Security", "Tom", "Ridge", "is", "another", "career", "politician", "who", "was", "passed", "over", "by", "Mr.", "Bush", "for", "the", "vice", "presidency", "."], ["And", "Deputy", "Secretary", "of", "Homeland", "Security", "Asa", "Hutchinson", "is", "yet", "another", "career", "politician", "and", "a", "graduate", "of", "the", "disgraceful", "Bob", "Jones", "University", "."], ["Apparently", ",", "Mr.", "Bush", "only", "turns", "to", "professionals", "when", "it", "'s", "really", "important", ",", "like", "political", "consulting", "."], ["NOVAK", "Paul", ",", "as", "I", "understand", "your", "definition", "of", "a", "political", "--", "of", "a", "professional", "politician", "based", "on", "that", "is", "somebody", "who", "is", "elected", "to", "public", "office", "."], ["Now", "in", "your", "administration", ",", "the", "Clinton", "administration", ",", "there", "were", "these", "members", "of", "the", "cabinet", "who", "by", "your", "definition", "were", "professional", "politicians", "--", "Lloyd", "Bentsen", ",", "Les", "Aspin", ",", "William", "S.", "Cohen", ",", "Janet", "Reno", ",", "Bruce", "Babbitt", ",", "Mike", "Espy", ",", "Dan", "Glickman", ",", "Norman", "Mineta", ",", "Henry", "Cisneros", ",", "Federico", "Pena", ",", "Bill", "Richardson", ",", "Richard", "Riley", ",", "12", "of", "them", ",", "not", "to", "mention", "former", "Democratic", "National", "Chairman", "Ron", "Brown", ",", "and", "one", "of", "the", "great", "professional", "politicians", "of", "all", "time", ",", "Bill", "Daly", "."], ["BEGALA", "And", "you", "know", "what", ",", "they", "did", "a", "hell", "of", "a", "job", "for", "our", "country", "."], ["And", "these", "bozos", "let", "four", "armed", "Cubans", "land", "on", "our", "shores", "when", "they", "'re", "trying", "to", "make", "a", "high", "terrorist", "alert", "."], ["Our", "president", "has", "put", "homeland", "security", "in", "the", "hands", "of", "failed", "Republican", "hacks", "."], ["Hire", "professionals", ",", "Mr.", "President", "."], ["NOVAK", "So", "it", "'s", "OK", "--", "it", "'s", "OK", "to", "have", "professional", "politicians", "at", "the", "Justice", "Department", "and", "the", "Pentagon", "..."], ["BEGALA", "Janet", "Reno", "was", "a", "career", "prosecutor", "."], ["NOVAK", "Just", "a", "minute", ",", "let", "me", "finish", "my", "sentence", ",", "please", "."], ["It", "'s", "OK", "to", "put", "Democratic", "career", "politicians", "at", "the", "Pentagon", "and", "the", "Justice", "Department", "if", "they", "'re", "Democrats", "but", "not", "if", "they", "'re", "Republicans", ",", "is", "that", "right", "?"], ["BEGALA", "No", ",", "the", "difference", "is", "Janet", "Reno", "was", "a", "career", "prosecutor", "."], ["(", "CROSSTALK", ")"], ["BEGALA", "John", "Ashcroft", "was", "n't", "half", "the", "woman", "that", "Janet", "Reno", "was", "."], ["NOVAK", "Another", "potential", "Democratic", "presidential", "candidate", "pondered", "whether", "to", "run", ",", "and", "he", "made", "his", "announcement", ",", "and", "surprise", ",", "surprise", ",", "the", "answer", "was", "no", "."], ["Senator", "Christopher", "Dodd", "of", "Connecticut", "made", "the", "announcement", "today", "that", "he", "would", "not", "be", "the", "10th", "candidate", "for", "the", "nomination", "."], ["Why", "not", "?", "Surely", "Chris", "Dodd", "is", "at", "least", "more", "credible", "than", "Carol", "Moseley", "-", "Braun", "or", "Dennis", "Kucinich", "."], ["He", "explained", "he", "could", "better", "spend", "the", "next", "two", "years", "on", "homeland", "security", ",", "the", "economy", "and", "judicial", "nominations", "."], ["I", "guess", "that", "means", "harassing", "Tom", "Ridge", ",", "fighting", "tax", "cuts", "and", "obstructing", "President", "Bush", "'s", "plans", "to", "reform", "judiciary", "."], ["Some", "two", "years", "."], ["BEGALA", "God", ",", "I", "hope", "so", "."], ["that", "'s", "all", "I", "can", "say", "."], ["I", "love", "Senator", "Dodd", "."], ["He", "would", "have", "brought", "a", "lot", "to", "the", "race", "."], ["He", "brings", "a", "lot", "to", "the", "Senate", "and", "to", "the", "debate", "."], ["And", "I", "'m", "glad", "that", "he", "'s", "going", "to", "be", "fighting", "those", "fights", "."], ["Well", ",", "the", "Commerce", "Department", "reports", "today", "that", "consumer", "spending", "declined", "in", "January", "and", "that", "the", "manufacturing", "sector", "slowed", "in", "February", "."], ["In", "all", ",", "two", "million", "Americans", "have", "lost", "their", "jobs", "under", "President", "Bush", "so", "far", ",", "not", "to", "mention", "three", "of", "them", "being", "the", "top", "three", "leaders", "of", "his", "economic", "team", "."], ["Meanwhile", ",", "the", "deficit", "now", "at", "$", "300", "billion", "."], ["It", "could", "swell", "to", "as", "much", "as", "$", "500", "billion", "if", "we", "go", "to", "war", "in", "Iraq", "."], ["Mr.", "Bush", "apparently", "is", "untroubled", "by", "this", "fiscal", "collapse", "."], ["One", "of", "his", "aides", "tells", "the", "current", "issue", "of", "\"", "TIME", ",", "\"", "magazine", ",", "quote", ",", "\"", "even", "if", "it", "'s", "$", "500", "billion", ",", "so", "what", "?", "\"", "Of", "course", ",", "even", "a", "$", "500", "billion", "deficit", "number", "does", "n't", "count", "the", "$", "3", "trillion", "Mr.", "Bush", "is", "robbing", "from", "the", "Social", "Security", "trust", "fund", "."], ["You", "know", ",", "we", "should", "have", "known", ",", "every", "time", "George", "W.", "Bush", "gets", "in", "trouble", ",", "he", "borrows", "from", "a", "trust", "fund", "."], ["It", "'s", "been", "his", "whole", "life", ",", "you", "know", "."], ["NOVAK", "Let", "me", "give", "you", "two", "economic", "facts", "of", "like", ",", "which", "you", "should", "know", ",", "even", "if", "you", "wo", "n't", "recognize", "them", "."], ["Number", "one", ",", "there", "is", "absolutely", "no", "relationship", "between", "the", "deficit", "and", "unemployment", "."], ["They", "do", "n't", "go", "together", "."], ["And", "number", "two", ",", "there", "is", "no", "Social", "Security", "fund", ",", "Virginia", ".", "There", "just", "is", "n't", "one", "."], ["BEGALA", "Because", "Bush", "squandered", "it", "on", "his", "tax", "cuts", "."], ["NOVAK", "There", "never", "has", "been", "one", "."], ["BEGALA", "There", "has", "been", "one", "for", "60", "years", "."]], "ner": [[], [], [], [], [[20, 20, "PER"]], [[39, 39, "PER"]], [[55, 55, "PER"], [57, 58, "ORG"], [62, 62, "PER"], [71, 71, "GPE"], [73, 73, "VEH"], [78, 78, "PER"], [81, 81, "GPE"], [82, 82, "LOC"], [88, 89, "ORG"], [90, 90, "PER"], [91, 91, "PER"]], [[100, 100, "PER"], [102, 102, "GPE"], [108, 108, "ORG"], [110, 110, "PER"], [113, 113, "PER"], [119, 119, "PER"], [121, 122, "PER"], [130, 130, "PER"]], [[139, 139, "PER"]], [[141, 141, "PER"], [143, 144, "ORG"], [145, 146, "PER"], [150, 150, "PER"], [156, 156, "PER"], [157, 157, "PER"]], [[165, 165, "PER"], [167, 168, "ORG"], [169, 170, "PER"], [175, 175, "PER"], [178, 178, "PER"], [182, 184, "ORG"]], [[188, 188, "PER"], [189, 189, "PER"], [193, 193, "PER"]], [[204, 204, "PER"], [205, 205, "PER"], [219, 219, "PER"]], [[235, 235, "ORG"], [238, 238, "PER"], [239, 239, "ORG"], [244, 244, "PER"], [247, 247, "PER"], [254, 254, "PER"], [256, 257, "PER"], [259, 260, "PER"], [262, 264, "PER"], [266, 267, "PER"], [269, 270, "PER"], [272, 273, "PER"], [275, 276, "PER"], [278, 279, "PER"], [281, 282, "PER"], [284, 285, "PER"], [287, 288, "PER"], [290, 291, "PER"], [301, 301, "ORG"], [303, 303, "PER"], [304, 305, "PER"], [318, 319, "PER"]], [[321, 321, "PER"], [336, 336, "GPE"]], [[340, 340, "PER"], [344, 344, "PER"], [348, 348, "LOC"], [357, 357, "PER"]], [[361, 361, "PER"], [364, 364, "GPE"], [371, 371, "ORG"], [372, 372, "PER"]], [[375, 375, "PER"], [377, 377, "PER"], [378, 378, "PER"]], [[380, 380, "PER"], [392, 392, "PER"], [395, 396, "ORG"], [399, 399, "ORG"]], [[401, 401, "PER"], [402, 403, "PER"], [407, 407, "PER"]], [[409, 409, "PER"]], [[427, 427, "ORG"], [429, 429, "PER"], [432, 432, "ORG"], [435, 436, "ORG"], [440, 440, "PER"], [446, 446, "PER"]], [[452, 452, "PER"], [458, 459, "PER"], [463, 463, "PER"]], [], [[468, 468, "PER"], [469, 470, "PER"], [475, 475, "PER"], [477, 478, "PER"]], [[481, 481, "PER"], [484, 484, "ORG"], [486, 486, "PER"]], [[508, 508, "PER"], [509, 510, "PER"], [512, 512, "GPE"], [524, 524, "PER"]], [[533, 534, "PER"], [541, 544, "PER"], [546, 547, "PER"]], [[560, 560, "GPE"]], [[574, 575, "PER"], [582, 582, "PER"], [583, 583, "PER"]], [], [[594, 594, "PER"]], [], [[610, 610, "PER"], [611, 611, "PER"]], [], [[629, 629, "ORG"]], [], [[652, 653, "ORG"], [657, 657, "PER"], [666, 666, "ORG"]], [[676, 676, "PER"], [682, 682, "PER"], [683, 683, "PER"], [697, 697, "PER"], [701, 701, "PER"]], [], [[729, 729, "GPE"]], [[731, 731, "PER"], [732, 732, "PER"]], [[744, 744, "PER"], [751, 754, "ORG"], [788, 788, "PER"], [789, 789, "PER"]], [[809, 811, "PER"]], [], [[833, 833, "PER"]], [], [], [[888, 888, "PER"]], [[896, 896, "PER"], [898, 898, "PER"]], [[906, 906, "PER"]], [[913, 913, "PER"]]], "relations": [[], [], [], [], [], [], [[55, 55, 57, 58, "ORG-AFF.Membership"], [71, 71, 73, 73, "ART.User-Owner-Inventor-Manufacturer"], [78, 78, 73, 73, "ART.User-Owner-Inventor-Manufacturer"], [82, 82, 81, 81, "PART-WHOLE.Geographical"]], [[110, 110, 108, 108, "ORG-AFF.Membership"]], [], [[141, 141, 143, 144, "ORG-AFF.Membership"]], [[165, 165, 167, 168, "ORG-AFF.Membership"], [178, 178, 182, 184, "ORG-AFF.Student-Alum"]], [], [], [[238, 238, 239, 239, "ORG-AFF.Membership"]], [], [[344, 344, 348, 348, "PHYS.Located"]], [[372, 372, 371, 371, "ORG-AFF.Membership"]], [], [[392, 392, 395, 396, "ORG-AFF.Employment"], [392, 392, 399, 399, "ORG-AFF.Employment"]], [], [], [[429, 429, 432, 432, "ORG-AFF.Employment"], [429, 429, 435, 436, "ORG-AFF.Employment"]], [], [], [], [[486, 486, 484, 484, "ORG-AFF.Membership"]], [[509, 510, 512, 512, "ORG-AFF.Employment"]], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []], "events": [[], [], [], [], [], [], [[[79, "Movement.Transport"], [73, 73, "Vehicle"], [78, 78, "Artifact"], [82, 82, "Destination"]]], [], [[[135, "Personnel.Elect"], [139, 139, "Person"]]], [], [], [], [[[227, "Personnel.Elect"]]], [], [], [[[345, "Movement.Transport"], [340, 340, "Agent"], [344, 344, "Artifact"], [348, 348, "Destination"]]], [], [[[374, "Personnel.Start-Position"], [375, 375, "Person"]]], [], [], [], [[[426, "Personnel.Start-Position"], [429, 429, "Person"], [432, 432, "Entity"], [435, 436, "Entity"]]], [], [], [], [], [[[527, "Personnel.Nominate"], [524, 524, "Person"]]], [], [[[567, "Personnel.Nominate"]]], [], [], [], [], [], [], [], [], [], [], [], [[[727, "Conflict.Attack"], [729, 729, "Place"]]], [], [], [], [], [], [], [], [], [], [], []], "_sentence_start": [0, 1, 2, 9, 20, 33, 52, 95, 132, 141, 163, 186, 204, 232, 321, 338, 360, 374, 380, 401, 409, 422, 452, 465, 468, 481, 508, 529, 549, 569, 590, 594, 601, 608, 613, 623, 635, 649, 671, 703, 713, 731, 741, 799, 823, 833, 857, 871, 877, 896, 906, 913], "doc_key": "CNN_CF_20030303.1900.00", "dataset": "ace-event"} -------------------------------------------------------------------------------- /data/raw_data/dyiepp_ace2005/val.json: -------------------------------------------------------------------------------- 1 | {"sentences": [["CNN_CF_20030303.1900.02"], ["STORY"], ["2003", "-", "03", "-", "03T19:00:00", "-", "05:00"], ["New", "Questions", "About", "Attacking", "Iraq", ";", "Is", "Torturing", "Terrorists", "Necessary", "?"], ["NOVAK", "Welcome", "back", "."], ["Orders", "went", "out", "today", "to", "deploy", "17,000", "U.S.", "Army", "soldiers", "in", "the", "Persian", "Gulf", "region", "."], ["The", "army", "'s", "entire", "first", "Calvary", "division", "based", "at", "Fort", "Hood", ",", "Texas", ",", "would", "join", "the", "quarter", "million", "U.S.", "forces", "already", "in", "the", "region", "."], ["We", "'re", "talking", "about", "possibilities", "of", "full", "scale", "war", "with", "former", "Congressman", "Tom", "Andrews", ",", "Democrat", "of", "Maine", "."], ["He", "'s", "now", "national", "director", "of", "Win", "Without", "War", ",", "and", "former", "Congressman", "Bob", "Dornan", ",", "Republican", "of", "California", "."], ["BEGALA", "Bob", ",", "one", "of", "the", "reasons", "I", "think", "so", "many", "Americans", "are", "worried", "about", "this", "war", "and", "so", "many", "people", "around", "the", "world", "do", "n't", "want", "to", "go", "is", "there", "have", "been", "a", "lot", "of", "problems", "with", "credibility", "from", "this", "administration", "."], ["Our", "president", "has", "repeatedly", ",", "for", "example", ",", "relied", "on", "a", "man", "whom", "you", "'re", "aware", ",", "Hussein", "Kamel", ",", "Saddam", "Hussein", "'s", "son", "-", "in", "-", "law", ",", "leader", "of", "the", "Iraq", "arms", "program", "who", "defected", "for", "a", "time", "."], ["And", "gave", "us", "a", "whole", "lot", "of", "information", "and", "then", "went", "home", "and", "his", "father", "-", "in", "-", "law", "killed", "him", "."], ["Bad", "move", "."], ["But", "while", "he", "was", "here", ",", "he", "gave", "us", "a", "whole", "lot", "of", "information", "."], ["Gave", "us", "a", "whole", "lot", "of", "information", "."], ["Well", ",", "our", "president", "told", "us", "that", "information", "proves", "that", "the", "dictator", "had", "chemical", "weapons", ",", "which", "is", "true", "."], ["But", "what", "we", "just", "learned", "this", "week", "from", "\"", "Newsweek", "\"", "magazine", "which", "got", "a", "hold", "of", "the", "debriefings", ",", "is", "that", "he", "also", "told", "us", "it", "was", "destroyed", "back", "in", "1995", "."], ["Why", "has", "n't", "our", "president", "told", "us", "that", "?"], ["Why", "do", "we", "have", "to", "learn", "it", "from", "\"", "Newsweek", "\"", "?"], ["DORNAN", "I", "do", "n't", "believe", "that", "he", "believed", "it", "was", "all", "destroyed", "."], ["The", "fact", "that", "this", "guy", "was", "such", "an", "idiot", "to", "go", "back", "and", "let", "his", "father", "-", "in", "-", "law", "kill", "him", "shows", "he", "was", "n't", "the", "most", "stable", "of", "people", "."], ["But", "the", "things", "that", "..."], ["BEGALA", "Good", "point", "."], ["But", "should", "n't", "our", "president", "have", "told", "us", "what", "the", "CIA", "told", "him", "."], ["Why", "do", "we", "learn", "from", "\"", "Newsweek?\"Should", "he", "level", "with", "us", "?"], ["DORNAN", "Paul", ",", "look", ",", "the", "problem", "is", "I", "would", "stipulate", "all", "four", "of", "us", "hates", "war", ".", "Any", "rational", "person", "hates", "war", "."], ["Bush", "is", "n't", "sitting", "in", "that", "White", "House", "not", "thinking", "about", "the", "body", "bags", "coming", "home", "with", "great", "young", "men", "."], ["Clinton", "suffered", "greatly", "over", "the", "19", "Rangers", "that", "died", ",", "18", "on", "the", "3rd", "of", "October", "and", "Matt", "Reersen", "(", "ph", ")", "three", "days", "later", "."], ["I", "visited", "all", "their", "families", "."], ["I", "was", "at", "the", "medal", "of", "honor", "ceremony", "for", "the", "kids", "."], ["Let", "me", "tell", "you", ",", "what", "trips", "to", "Walter", "Reed", "taught", "me", "was", ",", "that", "whoever", "thought", "up", "the", "term", ",", "the", "law", "of", "unintended", "consequences", "it", "pertains", "to", "war", "."], ["I", "am", "shook", "over", "the", "aftermath", "."], ["But", ",", "this", "guy", "is", "a", "monster", ",", "a", "mini", "-", "me", "Hitler", "."], ["He", "will", "blow", "a", "city", "off", "the", "earth", "in", "a", "minute", "if", "he", "can", "get", "the", "hold", "of", "the", "means", "to", "do", "it", "."], ["NOVAK", "Tom", "Andrews", ",", "I", "think", "we", "all", "realize", "that", "a", "government", "does", "n't", "go", "to", "war", "a", "nation", "goes", "to", "war", "."], ["And", "so", "I", "would", "like", "you", "to", "take", "a", "look", "at", "the", "CNN/\"USA", "TODAY\"", "/", "Gallup", "poll", ",", "taken", "last", "week", ",", "should", "U.S.", "troops", "to", "go", "to", "Iraq", "to", "remove", "Saddam", "Hussein", "from", "power", "."], ["Take", "a", "look", "at", "it", "."], ["Favor", "59", "%", ",", "opposed", "37", "%", ",", "that", "'s", "a", "vastly", "larger", "support", "than", "President", "Bush", "Senior", "had", "in", "getting", "the", "U.S.", "troops", "out", "of", "Kuwait", "before", "that", "war", "started", ".", "That", "'s", "pretty", "good", "support", "is", "n't", "it", "?"], ["ANDREWS", "Now", ",", "Bob", ",", "come", "on", "you", "do", "n't", "really", "buy", "this", "."], ["I", "mean", ",", "listen", ",", "this", "is", "the", "oldest", "trick", "in", "the", "book", "."], ["You", "can", "have", "a", "general", "question", "like", "this", "that", "could", "mean", "anything", "and", "ask", "people", "and", "they", "give", "you", "what", "comes", "off", "the", "top", "of", "their", "head", "."], ["But", ",", "ask", "them", "another", "question", ",", "ask", "them", "what", "they", "think", "about", "spending", "$", "1.3", "trillion", "in", "destroying", "this", "economy", "."], ["Ask", "them", "about", "going", "and", "not", "just", "a", "war", ",", "Bob", ",", "but", "an", "invasion", "and", "occupying", "for", "up", "to", "10", "years", "a", "sovereign", "Arab", "nation", "in", "the", "midst", "of", "one", "of", "the", "most", "distable", "and", "volatile", "regions", "in", "the", "world", "."], ["Ask", "them", "how", "they", "feel", "about", "getting", "bogged", "down", "."], ["Well", ",", "I", "'ve", "seen", "some", "of", "the", "figures", "."], ["Once", "you", "start", "telling", "Americans", "the", "story", ",", "--", "the", "administration", "refuse", "today", "tell", "us", "story", "."], ["They", "'re", "not", "coming", "forward", "and", "telling", "us", "what", "the", "risks", "are", ",", "what", "the", "costs", "are", ",", "how", "many", "years", "we", "might", "be", "in", ",", "the", "possibility", "of", "us", "getting", "bogged", "down", ",", "because", "what", "Americans", "know", "that", ",", "they", "'re", "opposed", "to", "this", "war", "."], ["The", "more", "they", "learn", "about", "this", "invasion", ",", "the", "more", "they", "learn", "about", "this", "occupation", ",", "the", "less", "they", "support", "it", "."], ["DORNAN", "Tom", ",", "you", "know", "what", "liberals", "want", "."], ["They", "do", "n't", "want", "a", "smoking", "gun", ",", "they", "want", "a", "smoking", "city", "."], ["The", "Clinton", "people", "all", "say", "..."], ["BEGALA", "That", "'s", "going", "to", "have", "to", "be", "last", "war", ",", "unfair", "and", "unfortunate", "as", "that", "is", ",", "I", "am", "sorry", ",", "they", "'re", "telling", "us", "we", "'re", "out", "of", "time", "."], ["Former", "Congressman", ",", "Bob", "Dornan", "from", "California", "..."], ["DORNAN", "You", "'re", "not", "going", "to", "get", "a", "smoking", "city", "."]], "ner": [[], [], [], [], [[20, 20, "PER"]], [[31, 31, "GPE"], [32, 32, "ORG"], [33, 33, "PER"], [36, 37, "LOC"], [38, 38, "LOC"]], [[41, 41, "ORG"], [44, 46, "ORG"], [49, 50, "GPE"], [52, 52, "GPE"], [59, 59, "GPE"], [60, 60, "PER"], [64, 64, "LOC"]], [[77, 77, "PER"], [78, 79, "PER"], [81, 81, "PER"], [83, 83, "GPE"]], [[89, 89, "PER"], [91, 93, "ORG"], [97, 97, "PER"], [98, 99, "PER"], [101, 101, "PER"], [103, 103, "GPE"]], [[105, 105, "PER"], [106, 106, "PER"], [116, 116, "PER"], [125, 125, "PER"], [128, 128, "LOC"], [146, 146, "ORG"]], [[149, 149, "PER"], [165, 166, "PER"], [168, 169, "PER"], [171, 175, "PER"], [177, 177, "PER"], [180, 180, "GPE"], [182, 182, "ORG"]], [[200, 200, "GPE"], [203, 207, "PER"]], [], [], [], [[240, 240, "PER"], [248, 248, "PER"], [251, 251, "WEA"]], [[266, 268, "ORG"]], [[294, 294, "PER"]], [[308, 308, "ORG"]], [[311, 311, "PER"]], [[328, 328, "PER"], [332, 332, "PER"], [339, 343, "PER"], [354, 354, "PER"]], [], [[361, 361, "PER"]], [[369, 369, "PER"], [375, 375, "ORG"]], [[385, 385, "ORG"]], [[391, 391, "PER"], [392, 392, "PER"], [411, 411, "PER"]], [[415, 415, "PER"], [421, 422, "FAC"], [434, 434, "PER"]], [[436, 436, "PER"], [442, 442, "PER"], [453, 454, "PER"]], [[466, 466, "PER"]], [[478, 478, "PER"]], [[488, 489, "FAC"]], [], [[521, 521, "PER"], [524, 524, "PER"], [527, 529, "PER"], [530, 530, "PER"]], [[536, 536, "GPE"], [539, 539, "LOC"]], [[556, 556, "PER"], [557, 558, "PER"], [567, 567, "GPE"], [574, 574, "GPE"]], [[591, 591, "ORG"], [591, 592, "ORG"], [594, 594, "ORG"], [602, 602, "GPE"], [603, 603, "PER"], [607, 607, "GPE"], [610, 611, "PER"]], [], [[636, 636, "PER"], [637, 637, "PER"], [638, 638, "PER"], [643, 643, "GPE"], [644, 644, "PER"], [647, 647, "GPE"]], [[662, 662, "PER"], [665, 665, "PER"]], [], [[704, 704, "PER"]], [], [[750, 750, "PER"], [764, 764, "PER"], [765, 765, "GPE"], [777, 777, "LOC"], [780, 780, "LOC"]], [], [], [[806, 806, "GPE"], [812, 812, "ORG"]], [[855, 855, "PER"]], [], [[888, 888, "PER"], [889, 889, "PER"], [894, 894, "PER"]], [[909, 909, "GPE"]], [[912, 912, "PER"], [913, 913, "PER"]], [[917, 917, "PER"]], [[950, 950, "PER"], [952, 953, "PER"], [955, 955, "GPE"]], [[957, 957, "PER"], [966, 966, "GPE"]]], "relations": [[], [], [], [], [], [[32, 32, 31, 31, "PART-WHOLE.Subsidiary"], [33, 33, 32, 32, "ORG-AFF.Employment"], [33, 33, 38, 38, "PHYS.Located"], [36, 37, 38, 38, "PART-WHOLE.Geographical"]], [[44, 46, 41, 41, "PART-WHOLE.Subsidiary"], [44, 46, 49, 50, "GEN-AFF.Org-Location"], [49, 50, 52, 52, "PART-WHOLE.Geographical"], [60, 60, 59, 59, "ORG-AFF.Employment"], [60, 60, 64, 64, "PHYS.Located"]], [[81, 81, 83, 83, "GEN-AFF.Citizen-Resident-Religion-Ethnicity"]], [[89, 89, 91, 93, "ORG-AFF.Membership"], [101, 101, 103, 103, "GEN-AFF.Citizen-Resident-Religion-Ethnicity"]], [[125, 125, 128, 128, "GEN-AFF.Citizen-Resident-Religion-Ethnicity"]], [[168, 169, 171, 175, "PER-SOC.Family"], [177, 177, 182, 182, "ORG-AFF.Membership"], [182, 182, 180, 180, "PART-WHOLE.Subsidiary"]], [], [], [], [], [[248, 248, 251, 251, "ART.User-Owner-Inventor-Manufacturer"]], [], [], [], [], [], [], [], [], [], [], [[415, 415, 421, 422, "PHYS.Located"]], [], [], [], [], [], [], [[536, 536, 539, 539, "PART-WHOLE.Geographical"]], [], [[603, 603, 602, 602, "ORG-AFF.Employment"], [603, 603, 607, 607, "PHYS.Located"], [610, 611, 607, 607, "PHYS.Located"]], [], [[644, 644, 643, 643, "ORG-AFF.Employment"], [644, 644, 647, 647, "PHYS.Located"]], [], [], [], [], [[777, 777, 780, 780, "PART-WHOLE.Geographical"]], [], [], [], [], [], [], [], [], [], [[952, 953, 955, 955, "GEN-AFF.Citizen-Resident-Religion-Ethnicity"]], []], "events": [[], [], [], [], [], [[[29, "Movement.Transport"], [33, 33, "Artifact"], [38, 38, "Destination"]]], [], [[[74, "Conflict.Attack"]], [[76, "Personnel.End-Position"], [78, 79, "Person"], [83, 83, "Entity"]]], [[[96, "Personnel.End-Position"], [98, 99, "Person"], [103, 103, "Entity"]]], [[[121, "Conflict.Attack"]]], [[[184, "Personnel.End-Position"], [177, 177, "Person"], [180, 180, "Entity"]]], [[[199, "Movement.Transport"], [200, 200, "Destination"]], [[208, "Life.Die"], [200, 200, "Place"], [203, 207, "Agent"]]], [], [], [], [], [], [], [], [], [[[334, "Movement.Transport"], [328, 328, "Artifact"]], [[344, "Life.Die"], [339, 343, "Agent"]]], [], [], [], [], [[[407, "Conflict.Attack"]], [[413, "Conflict.Attack"]]], [], [[[444, "Life.Die"], [442, 442, "Victim"], [453, 454, "Victim"]]], [[[463, "Contact.Meet"], [466, 466, "Entity"]]], [], [[[486, "Movement.Transport"]], [[509, "Conflict.Attack"]]], [], [], [[[534, "Conflict.Attack"], [539, 539, "Place"]]], [[[572, "Conflict.Attack"]], [[577, "Conflict.Attack"], [574, 574, "Attacker"]]], [[[605, "Movement.Transport"], [603, 603, "Artifact"], [607, 607, "Destination"]]], [], [[[641, "Movement.Transport"], [637, 637, "Agent"], [644, 644, "Artifact"], [647, 647, "Origin"]], [[650, "Conflict.Attack"]]], [], [], [], [], [[[748, "Conflict.Attack"]], [[754, "Conflict.Attack"]]], [], [], [], [[[864, "Conflict.Attack"]]], [[[872, "Conflict.Attack"]]], [], [], [], [[[926, "Conflict.Attack"]]], [], []], "_sentence_start": [0, 1, 2, 9, 20, 24, 40, 66, 85, 105, 148, 189, 211, 214, 229, 237, 257, 290, 299, 311, 324, 356, 361, 365, 379, 391, 415, 436, 462, 468, 480, 511, 518, 532, 556, 579, 615, 621, 662, 676, 690, 718, 740, 782, 792, 802, 819, 866, 888, 897, 911, 917, 949, 957], "doc_key": "CNN_CF_20030303.1900.02", "dataset": "ace-event"} -------------------------------------------------------------------------------- /data_convert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/GDAP/f6acafafd473cd1ccb57d79aaf95be14de7646fd/data_convert/__init__.py -------------------------------------------------------------------------------- /data_convert/convert_text_to_target.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import os 4 | import json 5 | from collections import Counter, defaultdict 6 | from data_convert.format.text2target import ETRTText2Role, Text2ET, ETText2Tri 7 | from data_convert.task_format.event_extraction import Event, DyIEPP 8 | from data_convert.utils import read_file, check_output, data_counter_to_table, get_schema, output_schema 9 | from nltk.corpus import stopwords 10 | 11 | english_stopwords = set(stopwords.words('english') + ["'s", "'re", "%"]) 12 | 13 | def convert_file_tuple(file_tuple, data_class=Event, target_class=ETRTText2Role, 14 | output_folder='data/text2target/framenet', 15 | ignore_nonevent=False, zh=False, 16 | mark_tree=False, type_format='subtype'): 17 | counter = defaultdict(Counter) 18 | data_counter = defaultdict(Counter) 19 | 20 | event_schema_set = set() 21 | 22 | span_output_folder = output_folder 23 | 24 | if not os.path.exists(span_output_folder): 25 | os.makedirs(span_output_folder) 26 | 27 | for in_filename, output_filename in file_tuple(output_folder): 28 | span_event_output = open(output_filename + '.json', 'w') 29 | 30 | for line in read_file(in_filename): 31 | document = data_class(json.loads(line.strip())) # 每行都是一个json文件格式 32 | for sentence in document.generate_sentence(type_format=type_format): 33 | 34 | if ignore_nonevent and len(sentence['events']) == 0: 35 | continue 36 | 37 | # 处理schema数据信息, 并进行统计 38 | for event in sentence['events']: 39 | event_schema_set = event_schema_set | get_schema(event) # set((ET, RT)) 合并后再遍历根据et 整理 dict 40 | sep = '' if zh else ' ' 41 | predicate = sep.join([sentence['tokens'][index] 42 | for index in event['tokens']]) # 触发词的文本信息 43 | counter['pred'].update([predicate]) 44 | counter['type'].update([event['type']]) # 事件类型 45 | data_counter[in_filename].update(['event']) 46 | for argument in event['arguments']: 47 | data_counter[in_filename].update(['argument']) 48 | counter['role'].update([argument[0]]) 49 | 50 | data_counter[in_filename].update(['sentence']) 51 | 52 | # 训练集与验证集、测试集区分处理的类别 53 | if (target_class == ETText2Tri or target_class == ETRTText2Role) and "train" not in in_filename: 54 | # 处理生成 span target 55 | span_source_list, span_target_list = target_class.annotate_span( 56 | tokens=sentence['tokens'], 57 | predicate_arguments=sentence['events'], 58 | zh=zh, 59 | mark_tree=mark_tree, 60 | isTest=True 61 | ) 62 | # print("================================") 63 | else: 64 | # 处理生成 span target 65 | span_source_list, span_target_list = target_class.annotate_span( 66 | tokens=sentence['tokens'], 67 | predicate_arguments=sentence['events'], 68 | zh=zh, 69 | mark_tree=mark_tree 70 | ) 71 | 72 | # 将处理后的span结果信息写入文件 73 | assert len(span_source_list) == len(span_target_list) 74 | for i in range(len(span_source_list)): 75 | span_event_output.write( 76 | json.dumps({'text': span_source_list[i], 'event': span_target_list[i]}, ensure_ascii=False) + '\n') 77 | 78 | span_event_output.close() 79 | 80 | check_output(output_filename) 81 | print('\n') 82 | 83 | # train、dev、test 转换完成后将整体的schema信息写入文件 84 | output_schema(event_schema_set, output_file=os.path.join( 85 | span_output_folder, 'event.schema')) 86 | print('Pred:', len(counter['pred']), counter['pred'].most_common(10)) 87 | print('Type:', len(counter['type']), counter['type'].most_common(10)) 88 | print('Role:', len(counter['role']), counter['role'].most_common(10)) 89 | print(data_counter_to_table(data_counter)) 90 | print('\n\n\n') 91 | 92 | 93 | def convert_dyiepp_event(output_folder='data/text2target/ace2005_event', type_format='subtype', 94 | ignore_nonevent=False, mark_tree=False, target_class=ETRTText2Role): 95 | from data_convert.task_format.event_extraction import DyIEPP_ace2005_file_tuple 96 | convert_file_tuple(file_tuple=DyIEPP_ace2005_file_tuple, 97 | output_folder=output_folder, 98 | ignore_nonevent=ignore_nonevent, 99 | mark_tree=mark_tree, 100 | type_format=type_format, 101 | data_class=DyIEPP, 102 | target_class = target_class 103 | ) 104 | 105 | if __name__ == "__main__": 106 | type_format_name = 'subtype' 107 | 108 | 109 | # ET + RT + src -> ( (Role) (Role) ) 110 | convert_dyiepp_event("data/text2target/dyiepp_ace2005_etrttext2role_%s" % type_format_name, 111 | type_format=type_format_name, 112 | ignore_nonevent=False, mark_tree=False, target_class=ETRTText2Role 113 | ) 114 | 115 | # src -> ((ET)(ET)) 116 | convert_dyiepp_event("data/text2target/dyiepp_ace2005_text2et_%s" % type_format_name, 117 | type_format=type_format_name, 118 | ignore_nonevent=False, mark_tree=False, target_class=Text2ET 119 | ) 120 | 121 | 122 | # ET + src -> ((Tri)(Tri)) 123 | convert_dyiepp_event("data/text2target/dyiepp_ace2005_ettext2tri_%s" % type_format_name, 124 | type_format=type_format_name, 125 | ignore_nonevent=False, mark_tree=False, target_class=ETText2Tri 126 | ) 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /data_convert/format/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/GDAP/f6acafafd473cd1ccb57d79aaf95be14de7646fd/data_convert/format/__init__.py -------------------------------------------------------------------------------- /data_convert/format/text2target.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import random 4 | from typing import Dict, List 5 | 6 | from extraction.event_schema import EventSchema 7 | 8 | 9 | # 通过span获取到具体的文本内容(start, end位置) 10 | def get_str_from_tokens(tokens, sentence, separator=' '): 11 | start, end_exclude = tokens[0], tokens[-1] + 1 12 | return separator.join(sentence[start:end_exclude]) 13 | 14 | # T5 tokenizer使用的特殊符号 15 | type_start = '' 16 | type_end = '' 17 | role_start = '' 18 | role_end = '' 19 | 20 | class TargetFormat: 21 | @staticmethod 22 | def annotate_spans(tokens: List[str], predicate_arguments: List[Dict], zh=False): pass 23 | 24 | 25 | # et + rt + src -> ((Role)(Role)) 论元抽取任务, 在给定schema文件的情况下才可以使用该方案 26 | class ETRTText2Role(TargetFormat): 27 | 28 | @staticmethod 29 | def annotate_span(tokens, predicate_arguments, mark_tree=False, zh=False, isTest=False): 30 | """ 31 | src: et + + RT + + source text 为 t5-base 分隔符 32 | traget: ((Role)) 只对role进行生成, 如果同一个et同样的rt存在多个role, 则生成 ((Role)(Role)...) 33 | :param tokens: 34 | US President George W. Bush told Canadian Prime Minister Jean Chretien by telephone Monday that he looked forward 35 | to seeing him at the upcoming summit of major industrialized nations and Russia , the White House said Tuesday . 36 | :param predicate_arguments: 37 | 38 | :return: 39 | """ 40 | 41 | token_separator = '' if zh else ' ' 42 | 43 | event_str_rep_list = list() 44 | 45 | source_list = [] 46 | target_list = [] 47 | 48 | # 若出现单句多个相同 type 事件, 则将其合并 49 | et_rt_index_dict = {} 50 | et_rt_span_dict = {} # 若出现span重复的情况则进行合并 51 | 52 | 53 | # 加载 schema文件中的信息, 目前每次都会读取文件, 很低效, 后续将这一步改为传参优化处理 54 | if zh: 55 | schema = EventSchema.read_from_file("data/raw_data/duee/event.schema") 56 | et_rt_dict = schema.type_role_dict 57 | schema_et_list = schema.type_list 58 | else: 59 | schema = EventSchema.read_from_file("data/raw_data/dyiepp_ace2005/event.schema") 60 | et_rt_dict = schema.type_role_dict 61 | schema_et_list = schema.type_list 62 | 63 | et_set = set() # 文本已经包含的事件类型 64 | 65 | # 遍历event 66 | for predicate_argument in predicate_arguments: 67 | event_type = predicate_argument['type'] 68 | et_set.add(event_type) 69 | 70 | # 遍历role 71 | for role_type, role_tokens in predicate_argument['arguments']: 72 | if role_type == event_type: 73 | continue 74 | 75 | if event_type + "-" + role_type not in et_rt_span_dict: 76 | et_rt_span_dict[event_type + "-" + role_type] = set() 77 | 78 | role_text = get_str_from_tokens(role_tokens, tokens, separator=token_separator) 79 | 80 | # 判断role是否出现过 81 | if role_text in et_rt_span_dict[event_type + "-" + role_type]: continue 82 | et_rt_span_dict[event_type + "-" + role_type].add(role_text) # 将当前role加入set中 83 | 84 | # role_str = ' '.join([type_start, role_type, role_text, type_end]) 85 | role_str = ' '.join([type_start, role_text, type_end]) 86 | 87 | # 在此处修改 source 与 target 格式 88 | source_text = event_type + " " + role_type + " " + token_separator.join(tokens) 89 | target_text = role_str 90 | 91 | # 判断当前ET-RT类型是否重复 92 | if event_type + "-" + role_type in et_rt_index_dict: 93 | # print("duplicate event:", event_type, " ", tokens, " ", predicate_arguments) 94 | target_list[et_rt_index_dict[event_type + "-" + role_type]] += " " + target_text 95 | else: 96 | source_list.append(source_text) 97 | target_list.append(target_text) 98 | et_rt_index_dict[event_type + "-" + role_type] = len(source_list) - 1 # 将位置进行记录 99 | 100 | # print(et_rt_dict) 101 | # 补全在当前事件类型下, 未出现的rt样例 102 | for role_type in et_rt_dict[event_type]: 103 | if event_type + "-" + role_type not in et_rt_span_dict: 104 | et_rt_span_dict[event_type + "-" + role_type] = set() 105 | 106 | source_text = event_type + " " + role_type + " " + token_separator.join(tokens) 107 | target_text = "" 108 | source_list.append(source_text) 109 | target_list.append(target_text) 110 | 111 | et_rt_index_dict[event_type + "-" + role_type] = len(source_list) - 1 # 将补全的位置进行记录 112 | 113 | # negative sample on role train data 114 | if len(predicate_arguments) > 0 and not isTest: 115 | for tmp_et in random.sample(set(schema_et_list) - et_set, 4): 116 | for role_type in et_rt_dict[tmp_et]: 117 | source_text = tmp_et + " " + role_type + " " + token_separator.join(tokens) 118 | target_text = "" 119 | source_list.append(source_text) 120 | target_list.append(target_text) 121 | 122 | # 在所有的target 上统一加上起始位置 123 | for i in range(len(target_list)): 124 | target_list[i] = f'{type_start} ' + target_list[i] + f' {type_end}' 125 | 126 | return source_list, target_list 127 | 128 | # src -> ((ET)) 事件检测任务, 129 | class Text2ET(TargetFormat): 130 | 131 | @staticmethod 132 | def annotate_span(tokens, predicate_arguments, mark_tree=False, zh=False): 133 | """ 134 | src: source text 135 | traget: ((ET)) 只对ET进行生成, ((ET)(ET)...) 136 | :param tokens: 137 | US President George W. Bush told Canadian Prime Minister Jean Chretien by telephone Monday that he looked forward 138 | to seeing him at the upcoming summit of major industrialized nations and Russia , the White House said Tuesday . 139 | :param predicate_arguments: 140 | 141 | :return: 142 | """ 143 | 144 | token_separator = '' if zh else ' ' 145 | 146 | et_list = [] 147 | 148 | # 去除重复的事件类型 149 | et_set = set() 150 | 151 | # 遍历event 152 | for predicate_argument in predicate_arguments: 153 | event_type = predicate_argument['type'] 154 | if event_type in et_set: continue 155 | else: et_set.add(event_type) 156 | 157 | # 在此处修改 source 与 target 格式 158 | 159 | et_text = f'{type_start} ' + event_type + f' {type_end}' 160 | # et_text = event_type 161 | et_list.append(et_text) 162 | 163 | source_text = token_separator.join(tokens) 164 | target_text = f'{type_start} ' + " ".join(et_list) + f' {type_end}' 165 | # target_text = " ".join(et_list) 166 | 167 | 168 | return [source_text], [target_text] 169 | 170 | # et + src -> ((tri)) 遍历事件类型, 针对每个给定的事件类型生成触发词(如果包含) 171 | class ETText2Tri(TargetFormat): 172 | 173 | @staticmethod 174 | def annotate_span(tokens, predicate_arguments, mark_tree=False, zh=False, isTest = False): 175 | """ 176 | src: et + + source text 为 t5-base 分隔符 177 | traget: ((Tri)) 只对tri进行生成, 如果同一个et存在多个Tri, 则生成 ((Tri)(Tri)...) 178 | :param tokens: 179 | US President George W. Bush told Canadian Prime Minister Jean Chretien by telephone Monday that he looked forward 180 | to seeing him at the upcoming summit of major industrialized nations and Russia , the White House said Tuesday . 181 | :param predicate_arguments: 182 | 183 | :return: 184 | """ 185 | 186 | token_separator = '' if zh else ' ' 187 | 188 | event_str_rep_list = list() 189 | 190 | source_list = [] 191 | target_list = [] 192 | 193 | # 若出现单句多个相同 type 事件, 则将其合并 194 | et_tri_dict = {} # {et + src: tri_list} 195 | et_set = set() # 文本已经包含的事件类型 196 | 197 | # 加载 schema文件中的信息, 目前每次都会读取文件, 很低效, 后续将这一步改为传参优化处理 198 | et_list = EventSchema.read_from_file("data/raw_data/dyiepp_ace2005/event.schema").type_list 199 | 200 | # 针对事件类型遍历制作训练样本 201 | for et in et_list: 202 | et_tri_dict[et + " " + token_separator.join(tokens)] = set() 203 | 204 | 205 | # 遍历event 206 | for predicate_argument in predicate_arguments: 207 | et = predicate_argument['type'] 208 | et_set.add(et) 209 | 210 | tri_text = get_str_from_tokens(predicate_argument['tokens'], tokens, separator=token_separator) # 此处的 predicate_argument['tokens'] 为 [start, end](多个单词), 或者[start] (一个单词) 211 | 212 | et_tri_dict[et + " " + token_separator.join(tokens)].add(tri_text) 213 | 214 | 215 | for src, tri_set in et_tri_dict.items(): 216 | if not tri_set: continue # 过滤掉不包含触发词的样本 217 | source_list.append(src) 218 | tmp_list = [] 219 | for tri_text in tri_set: 220 | tmp_list.append(' '.join([type_start, tri_text, type_end])) 221 | target_text = f'{type_start} ' + " ".join(tmp_list) + f' {type_end}' 222 | target_list.append(target_text) 223 | 224 | # negative sample on tri train data 225 | if not isTest: 226 | for tmp_et in random.sample(set(et_list) - et_set, 6): 227 | source_list.append(tmp_et + " " + token_separator.join(tokens)) 228 | target_list.append(f'{type_start} ' + " ".join([]) + f' {type_end}') 229 | 230 | 231 | return source_list, target_list 232 | 233 | 234 | if __name__ == "__main__": 235 | pass 236 | -------------------------------------------------------------------------------- /data_convert/task_format/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | 5 | if __name__ == "__main__": 6 | pass 7 | -------------------------------------------------------------------------------- /data_convert/task_format/event_extraction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import os 4 | 5 | class TaskFormat: 6 | def generate_sentence(self): pass 7 | 8 | class DyIEPP(TaskFormat): 9 | def __init__(self, doc_json): 10 | self.doc_key = doc_json['doc_key'] 11 | self.sentences = doc_json['sentences'] 12 | self.ner = doc_json['ner'] 13 | self.relations = doc_json['relations'] 14 | self.events = doc_json['events'] 15 | self.sentence_start = doc_json.get('sentence_start', doc_json['_sentence_start']) 16 | 17 | def generate_sentence(self, type_format='subtype'): 18 | for sentence, events_in_sentence, sentence_start in zip(self.sentences, self.events, self.sentence_start): 19 | events = list() 20 | for event in events_in_sentence: # 一个trigger代表一个事件的形式 21 | trigger, event_type = event[0] 22 | trigger -= sentence_start # sentence_start 指的是该句在整个document中的位置 23 | 24 | suptype, subtype = event_type.split('.') # XX.YY 类别信息, XX为大类别, YY为小类别 25 | 26 | if type_format == 'subtype': 27 | event_type = subtype 28 | elif type_format == 'suptype': 29 | event_type = suptype 30 | else: 31 | event_type = suptype + type_format + subtype 32 | 33 | arguments = list() 34 | for start, end, role in event[1:]: 35 | start -= sentence_start 36 | end -= sentence_start 37 | arguments += [[role, list(range(start, end + 1))]] 38 | 39 | event = {'type': event_type, 'tokens': [trigger], 'arguments': arguments} 40 | 41 | events += [event] 42 | yield {'tokens': sentence, 'events': events} 43 | 44 | 45 | class Event(TaskFormat): 46 | """ 47 | { 48 | "doc_id": "NYT_ENG_20130914.0094", 49 | "sent_id": "NYT_ENG_20130914.0094-1", 50 | "tokens": ["LARGO", "\u2014", "A", "judge", "on", "Friday", "refused", "to", "stop", "''", "Hiccup", "Girl", "\u2019'", "Jennifer", "Mee", "from", "giving", "interviews", "to", "the", "media", "in", "the", "final", "days", "before", "her", "murder", "trial", "."], 51 | "entities": [ 52 | {"entity_id": "NYT_ENG_20130914.0094-1-8-42", "entity_type": "PER", "mention_type": "NOM", "start": 3, "end": 4, "text": "judge"}, 53 | {"entity_id": "NYT_ENG_20130914.0094-1-39-2096", "entity_type": "PER", "mention_type": "NAM", "start": 10, "end": 12, "text": "Hiccup Girl"}, 54 | {"entity_id": "NYT_ENG_20130914.0094-1-39-48", "entity_type": "PER", "mention_type": "NAM", "start": 13, "end": 15, "text": "Jennifer Mee"}, 55 | {"entity_id": "NYT_ENG_20130914.0094-1-39-54", "entity_type": "PER", "mention_type": "PRO", "start": 26, "end": 27, "text": "her"} 56 | ], 57 | "relations": [], 58 | "events": [ 59 | { 60 | "event_id": "NYT_ENG_20130914.0094-1-25-998", "event_type": "contact", "event_subtype": "broadcast", 61 | "trigger": {"text": "interviews", "start": 17, "end": 18}, 62 | "arguments": [ 63 | {"entity_id": "NYT_ENG_20130914.0094-1-39-48", "role": "entity", "text": "Jennifer Mee"} 64 | ] 65 | }, 66 | { 67 | "event_id": "NYT_ENG_20130914.0094-1-1040-1019", "event_type": "justice", "event_subtype": "trialhearing", 68 | "trigger": {"text": "trial", "start": 28, "end": 29}, 69 | "arguments": [{"entity_id": "NYT_ENG_20130914.0094-1-39-54", "role": "defendant", "text": "her"} 70 | ] 71 | } 72 | ], 73 | "start": 231, "end": 380, 74 | "text": "LARGO \u2014 A judge on Friday refused to stop ''Hiccup Girl\u2019' Jennifer Mee from giving interviews to the media in the final days before her murder trial." 75 | } 76 | """ 77 | 78 | def __init__(self, doc_json): 79 | self.doc_key = doc_json['doc_id'] 80 | self.sentence = doc_json['tokens'] 81 | self.entities = {entity['id']: entity for entity in doc_json['entity_mentions']} 82 | self.relations = doc_json['relation_mentions'] 83 | self.events = doc_json['event_mentions'] 84 | # self.sentence_start = doc_json['start'] 85 | # self.sentence_end = doc_json['end'] 86 | # self.text = doc_json['text'] 87 | 88 | def generate_sentence(self, type_format='subtype'): 89 | events = list() 90 | 91 | for event in self.events: 92 | arguments = list() 93 | for argument in event['arguments']: 94 | argument_entity = self.entities[argument['entity_id']] 95 | arguments += [[argument['role'], list(range(argument_entity['start'], argument_entity['end']))]] 96 | 97 | suptype, subtype = event['event_type'].split(':') 98 | 99 | if type_format == 'subtype': 100 | event_type = subtype 101 | elif type_format == 'suptype': 102 | event_type = suptype 103 | else: 104 | event_type = suptype + type_format + subtype 105 | 106 | events += [{ 107 | 'type': event_type, 108 | 'tokens': list(range(event['trigger']['start'], event['trigger']['end'])), 109 | 'arguments': arguments 110 | }] 111 | 112 | yield {'tokens': self.sentence, 'events': events} 113 | 114 | 115 | def DyIEPP_ace2005_file_tuple(output_folder): 116 | if not os.path.exists(output_folder): 117 | os.makedirs(output_folder, exist_ok=True) 118 | 119 | conll_2012_folder = "data/raw_data/dyiepp_ace2005" 120 | 121 | file_tuple = [ 122 | (conll_2012_folder + "/train.json", output_folder + '/train'), 123 | (conll_2012_folder + "/dev.json", output_folder + '/val'), 124 | (conll_2012_folder + "/test.json", output_folder + '/test'), 125 | ] 126 | 127 | return file_tuple 128 | 129 | 130 | 131 | if __name__ == "__main__": 132 | pass 133 | -------------------------------------------------------------------------------- /data_convert/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import json 4 | from collections import defaultdict 5 | import codecs 6 | 7 | from tabulate import tabulate 8 | 9 | 10 | def read_file(filename): 11 | from tqdm import tqdm 12 | num_lines = sum(1 for _ in open(filename, 'r')) 13 | with open(filename, 'r') as f: 14 | for line in tqdm(f, total=num_lines): 15 | yield line 16 | 17 | 18 | def check_output(filename, line_num=2): 19 | import os 20 | os.system('tail -n %s %s*' % (line_num, filename)) 21 | 22 | 23 | def data_counter_to_table(data_counter): 24 | table = list() 25 | for filename, file_counter in data_counter.items(): 26 | table += [[filename, file_counter['sentence'], 27 | file_counter['event'], file_counter['argument']]] 28 | return tabulate(table, headers=['file', '#sent', '#event', '#arg']) 29 | 30 | 31 | def get_schema(event): 32 | event_type = event['type'] 33 | if len(event['arguments']) == 0: 34 | return {(event_type, None)} 35 | return set([(event_type, argument[0]) for argument in event['arguments']]) 36 | 37 | 38 | def output_schema(event_schema_set, output_file): 39 | event_type_list = list(set([schema[0] for schema in event_schema_set])) 40 | argument_role_list = list(set([schema[1] for schema in event_schema_set])) 41 | 42 | if None in argument_role_list: 43 | # Same Event only Type without argument 44 | argument_role_list.remove(None) 45 | 46 | event_type_set_dict = defaultdict(set) 47 | 48 | for event_type, arg_role in event_schema_set: 49 | if arg_role is None: 50 | continue 51 | event_type_set_dict[event_type].add(arg_role) 52 | 53 | event_type_list_dict = defaultdict(list) 54 | 55 | for event_type in event_type_set_dict: 56 | event_type_list_dict[event_type] = list( 57 | event_type_set_dict[event_type]) 58 | 59 | with codecs.open(output_file, 'w', 'UTF-8') as output: 60 | output.write(json.dumps(event_type_list, ensure_ascii=False) + '\n') 61 | output.write(json.dumps(argument_role_list, ensure_ascii=False) + '\n') 62 | output.write(json.dumps(event_type_list_dict, ensure_ascii=False) + '\n') 63 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | import numpy as np 6 | from copy import deepcopy 7 | from pprint import pprint 8 | from extraction.event_schema import EventSchema 9 | from extraction.predict_parser.target_predict_parser import RolePredictParser, TriPredictParser 10 | 11 | 12 | def read_file(file_name): 13 | return [line.strip() for line in open(file_name).readlines()] 14 | 15 | 16 | def generate_sentence_dyiepp(filename, type_format='subtype'): 17 | for line in open(filename): 18 | instance = json.loads(line) 19 | sentence = instance['sentence'] 20 | sentence_start = instance.get( 21 | 's_start', instance.get('_sentence_start')) 22 | events = instance['event'] 23 | 24 | # 不进行去重 25 | trigger_list = list() 26 | role_list = list() 27 | 28 | # 进行去重 29 | trigger_set = set() 30 | role_set = set() 31 | 32 | for event in events: 33 | trigger, event_type = event[0] 34 | trigger -= sentence_start 35 | 36 | suptype, subtype = event_type.split('.') 37 | 38 | if type_format == 'subtype': 39 | event_type = subtype 40 | elif type_format == 'suptype': 41 | event_type = suptype 42 | else: 43 | event_type = suptype + type_format + subtype 44 | 45 | # trigger_list += [(event_type, (trigger, trigger))] 46 | trigger_list += [(event_type, sentence[trigger])] 47 | trigger_set.add((event_type, sentence[trigger])) 48 | for start, end, role in event[1:]: 49 | start -= sentence_start 50 | end -= sentence_start 51 | role_list += [(event_type, role, " ".join(sentence[start: end+1]))] 52 | role_set.add((event_type, role, " ".join(sentence[start: end+1]))) 53 | 54 | # yield ' '.join(sentence), trigger_list, role_list # 不进行去重 55 | yield ' '.join(sentence), list(trigger_set), list(role_set) # 进行去重 56 | 57 | def generate_sentence_text2target(filename, pred_reader): 58 | text_gold_dict = {} 59 | event_list, _ = pred_reader.decode( 60 | gold_list=read_file(filename), 61 | pred_list=read_file(filename), 62 | text_list=[json.loads(line)['text'] 63 | for line in read_file(filename)], 64 | ) 65 | # print(event_list) 66 | for item in event_list: 67 | if item["text"] in text_gold_dict: 68 | # print("Warning: text duplicate , text: ", item["text"]) 69 | text_gold_dict[item["text"]][0] += item['gold_event'] 70 | text_gold_dict[item["text"]][1] += item['gold_role'] 71 | else: 72 | text_gold_dict[item["text"]] = [item['gold_event'], item['gold_role']] 73 | # print(text_gold_dict) 74 | 75 | gold_list = [] 76 | for text, events in text_gold_dict.items(): 77 | gold_list.append([text, events[0], events[1]]) 78 | return gold_list 79 | 80 | def match_sublist(the_list, to_match): 81 | """ 82 | :param the_list: [1, 2, 3, 4, 5, 6, 1, 2, 4, 5] 83 | :param to_match: [1, 2] 84 | :return: 85 | [(0, 1), (6, 7)] 86 | """ 87 | len_to_match = len(to_match) 88 | matched_list = list() 89 | for index in range(len(the_list) - len_to_match + 1): 90 | if to_match == the_list[index:index + len_to_match]: 91 | matched_list += [(index, index + len_to_match - 1)] 92 | return matched_list 93 | 94 | def record_to_offset(instance): 95 | """ 96 | Find Role's offset using closest matched with trigger work. 97 | :param instance: 98 | :return: 99 | """ 100 | trigger_list = list() 101 | role_list = list() 102 | 103 | token_list = instance['text'].split() 104 | 105 | trigger_matched_set = set() 106 | for record in instance['pred_record']: 107 | event_type = record['type'] 108 | trigger = record['trigger'] 109 | matched_list = match_sublist(token_list, trigger.split()) 110 | 111 | trigger_offset = None 112 | for matched in matched_list: 113 | if matched not in trigger_matched_set: 114 | trigger_list += [(event_type, matched)] 115 | trigger_offset = matched 116 | trigger_matched_set.add(matched) 117 | break 118 | 119 | # No trigger word, skip the record 120 | if trigger_offset is None: 121 | break 122 | 123 | for _, role_type, text_str in record['roles']: 124 | matched_list = match_sublist(token_list, text_str.split()) 125 | if len(matched_list) == 1: 126 | role_list += [(event_type, role_type, matched_list[0])] 127 | elif len(matched_list) == 0: 128 | sys.stderr.write("[Cannot reconstruct]: %s %s\n" % 129 | (text_str, token_list)) 130 | else: 131 | abs_distances = [abs(match[0] - trigger_offset[0]) 132 | for match in matched_list] 133 | closest_index = np.argmin(abs_distances) 134 | role_list += [(event_type, role_type, 135 | matched_list[closest_index])] 136 | 137 | return instance['text'], trigger_list, role_list 138 | 139 | class Metric: 140 | def __init__(self): 141 | self.tp = 0. 142 | self.gold_num = 0. 143 | self.pred_num = 0. 144 | 145 | @staticmethod 146 | def safe_div(a, b): 147 | if b == 0.: 148 | return 0. 149 | else: 150 | return a / b 151 | 152 | def compute_f1(self, prefix=''): 153 | tp = self.tp 154 | pred_num = self.pred_num 155 | gold_num = self.gold_num 156 | p, r = self.safe_div(tp, pred_num), self.safe_div(tp, gold_num) 157 | return {prefix + 'tp': tp, 158 | prefix + 'gold': gold_num, 159 | prefix + 'pred': pred_num, 160 | prefix + 'P': p * 100, 161 | prefix + 'R': r * 100, 162 | prefix + 'F1': self.safe_div(2 * p * r, p + r) * 100 163 | } 164 | 165 | def count_instance(self, gold_list, pred_list, verbose=False, text=None): 166 | if verbose: 167 | print("Gold:", gold_list) 168 | print("Pred:", pred_list) 169 | self.gold_num += len(gold_list) 170 | self.pred_num += len(pred_list) 171 | 172 | dup_gold_list = deepcopy(gold_list) 173 | for pred in pred_list: 174 | if pred in dup_gold_list: 175 | self.tp += 1 176 | dup_gold_list.remove(pred) 177 | else: 178 | print("text: ", text) 179 | print("gold_list: ", gold_list) 180 | print("no tp pred:", pred) 181 | pass 182 | 183 | def main(): 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument('--text_file', type=str) 186 | parser.add_argument('--pred_file', type=str) 187 | parser.add_argument('--gold_file', type=str) 188 | parser.add_argument('--schema_file', type=str) 189 | 190 | parser.add_argument('--format', type=str, default="dyiepp") 191 | parser.add_argument('--verbose', action='store_true') 192 | parser.add_argument('--decoding_format', type=str, default='noetrtspan') 193 | options = parser.parse_args() 194 | 195 | 196 | label_schema = EventSchema.read_from_file( 197 | filename=options.schema_file 198 | ) 199 | 200 | decoding_format_dict = { 201 | 'role': RolePredictParser, 202 | 'tri': TriPredictParser 203 | } 204 | 205 | # 替换为自己的predict parser 206 | pred_reader = decoding_format_dict[options.decoding_format](schema=label_schema) 207 | 208 | 209 | trigger_metric = Metric() 210 | argument_metric = Metric() 211 | 212 | # Reconstruct the offset of predicted event records. 213 | text_filename = options.text_file 214 | pred_filename = options.pred_file 215 | gold_filename = options.gold_file 216 | print("pred_filename: ", pred_filename) 217 | print("gold_filename: ", gold_filename) 218 | 219 | # 离线评估 220 | # 在此处处理的时候, 需要将 et、rt特殊处理的部分进行添加, 以及src相同的部分进行合并 221 | event_list, _ = pred_reader.decode( 222 | gold_list=[], 223 | pred_list=read_file(pred_filename), 224 | text_list=[json.loads(line)['text'] 225 | for line in read_file(text_filename)], 226 | ) 227 | # print(event_list[0]) 228 | 229 | # text 中空格一类的做key会有影响, 后续可考虑用id来指代 230 | text_pred_dict = {} # 构建 text: ([tri_list][role_list]) 类型的字典 231 | text_gold_dict = {} 232 | 233 | for item in event_list: 234 | if item["text"] in text_pred_dict: 235 | # print("Warning: text duplicate , text: ", item["text"]) 236 | text_pred_dict[item["text"]][0] += item['pred_event'] 237 | text_pred_dict[item["text"]][1] += item['pred_role'] 238 | else: 239 | text_pred_dict[item["text"]] = [item['pred_event'], item['pred_role']] 240 | 241 | # print(text_pred_dict) 242 | 243 | # Read gold event annotation with offsets. 244 | if options.format == 'dyiepp': 245 | gold_list = [event for event in generate_sentence_dyiepp(gold_filename)] # 根据dyiepp预处理后的文件获取gold 246 | else: 247 | # 使用 text2target文件处理, pred_num原因低在于test文件在制作时候自动过滤了未出现事件类型的句子, 因此需要引入pred中有结果而gold中无结果的句子进行计数 248 | gold_list = generate_sentence_text2target(gold_filename, pred_reader) # 根据text2target预处理后的文件获取gold 249 | 250 | # print("gold_list: ", gold_list) 251 | 252 | # 遍历计算tp 253 | gold_text_set = set() 254 | for gold in gold_list: 255 | if gold[0] in text_pred_dict: 256 | trigger_metric.count_instance( 257 | gold_list=gold[1], 258 | pred_list=text_pred_dict[gold[0]][0], 259 | verbose=options.verbose, 260 | text=gold[0] 261 | ) 262 | argument_metric.count_instance( 263 | gold_list=gold[2], 264 | pred_list=text_pred_dict[gold[0]][1], 265 | verbose=options.verbose, 266 | text=gold[0] 267 | ) 268 | else: 269 | # print(gold) 270 | trigger_metric.count_instance( 271 | gold_list=gold[1], 272 | pred_list=[], 273 | verbose=options.verbose, 274 | text=gold[0] 275 | ) 276 | argument_metric.count_instance( 277 | gold_list=gold[2], 278 | pred_list=[], 279 | verbose=options.verbose, 280 | text=gold[0] 281 | ) 282 | 283 | # 计算未在gold却在pred中的样本数量 284 | 285 | trigger_result = trigger_metric.compute_f1(prefix='result-trig-') 286 | role_result = argument_metric.compute_f1(prefix='result-role-') 287 | 288 | pprint(trigger_result) 289 | pprint(role_result) 290 | 291 | 292 | if __name__ == "__main__": 293 | main() 294 | -------------------------------------------------------------------------------- /extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/GDAP/f6acafafd473cd1ccb57d79aaf95be14de7646fd/extraction/__init__.py -------------------------------------------------------------------------------- /extraction/event_schema.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import json 4 | from collections import defaultdict 5 | from typing import List 6 | 7 | 8 | class EventSchema: 9 | def __init__(self, type_list, role_list, type_role_dict): 10 | self.type_list = type_list 11 | self.role_list = role_list 12 | self.type_role_dict = type_role_dict 13 | 14 | @staticmethod 15 | def read_from_file(filename): 16 | lines = open(filename).readlines() 17 | type_list = json.loads(lines[0]) 18 | role_list = json.loads(lines[1]) 19 | type_role_dict = json.loads(lines[2]) 20 | return EventSchema(type_list, role_list, type_role_dict) 21 | 22 | def write_to_file(self, filename): 23 | with open(filename, 'w') as output: 24 | output.write(json.dumps(self.type_list) + '\n') 25 | output.write(json.dumps(self.role_list) + '\n') 26 | output.write(json.dumps(self.type_role_dict) + '\n') 27 | 28 | 29 | def merge_schema(schema_list: List[EventSchema]): 30 | type_set = set() 31 | role_set = set() 32 | type_role_dict = defaultdict(list) 33 | 34 | for schema in schema_list: 35 | 36 | for type_name in schema.type_list: 37 | type_set.add(type_name) 38 | 39 | for role_name in schema.role_list: 40 | role_set.add(role_name) 41 | 42 | for type_name in schema.type_role_dict: 43 | type_role_dict[type_name] += schema.type_role_dict[type_name] 44 | 45 | for type_name in type_role_dict: 46 | type_role_dict[type_name] = list(set(type_role_dict[type_name])) 47 | 48 | return EventSchema(type_list=list(type_set), 49 | role_list=list(role_set), 50 | type_role_dict=type_role_dict, 51 | ) 52 | -------------------------------------------------------------------------------- /extraction/extract_constraint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | from typing import List, Dict 4 | 5 | from data_convert.format.text2target import type_start, type_end 6 | from extraction.label_tree import get_label_name_tree 7 | 8 | import os 9 | 10 | debug = True if 'DEBUG' in os.environ else False 11 | debug_step = True if 'DEBUG_STEP' in os.environ else False 12 | 13 | 14 | def match_sublist(the_list, to_match): 15 | """ 16 | 17 | :param the_list: [1, 2, 3, 4, 5, 6, 1, 2, 4, 5] 18 | :param to_match: [1, 2] 19 | :return: 20 | [(0, 1), (6, 7)] 21 | """ 22 | len_to_match = len(to_match) 23 | matched_list = list() 24 | for index in range(len(the_list) - len_to_match): 25 | if to_match == the_list[index:index + len_to_match]: 26 | matched_list += [(index, index + len_to_match - 1)] 27 | return matched_list 28 | 29 | 30 | def find_bracket_position(generated_text, _type_start, _type_end): 31 | bracket_position = {_type_start: list(), _type_end: list()} 32 | for index, char in enumerate(generated_text): 33 | if char in bracket_position: 34 | bracket_position[char] += [index] 35 | return bracket_position 36 | 37 | # 根据generated 找到 src_sequence中与其匹配的tokens并进行返回 38 | def generated_search_src_sequence(generated, src_sequence, end_sequence_search_tokens=None): 39 | print(generated, src_sequence) if debug else None 40 | 41 | if len(generated) == 0: 42 | # It has not been generated yet. All SRC are valid. 43 | return src_sequence 44 | 45 | matched_tuples = match_sublist(the_list=src_sequence, to_match=generated) 46 | 47 | valid_token = list() 48 | for _, end in matched_tuples: 49 | next_index = end + 1 50 | if next_index < len(src_sequence): 51 | valid_token += [src_sequence[next_index]] 52 | 53 | if end_sequence_search_tokens: 54 | valid_token += end_sequence_search_tokens 55 | 56 | return valid_token 57 | 58 | 59 | def get_constraint_decoder(tokenizer, type_schema, decoding_schema, source_prefix=None): 60 | 61 | decoding_format_dict = { 62 | 'role': RoleConstraintDecoder, 63 | 'et': ETConstraintDecoder, 64 | 'tri': TriConstraintDecoder, 65 | } 66 | 67 | if decoding_schema in decoding_format_dict: 68 | return decoding_format_dict[decoding_schema](tokenizer=tokenizer, type_schema=type_schema, source_prefix=source_prefix) 69 | else: 70 | raise NotImplementedError( 71 | 'Type Schema %s, Decoding Schema %s do not map to constraint decoder.' % ( 72 | decoding_schema, decoding_schema) 73 | ) 74 | 75 | 76 | class ConstraintDecoder: 77 | def __init__(self, tokenizer, source_prefix): 78 | self.tokenizer = tokenizer 79 | self.source_prefix = source_prefix 80 | self.source_prefix_tokenized = tokenizer.encode(source_prefix, 81 | add_special_tokens=False) if source_prefix else [] 82 | 83 | def get_state_valid_tokens(self, src_sentence: List[str], tgt_generated: List[str]) -> List[str]: 84 | pass 85 | 86 | def constraint_decoding(self, src_sentence, tgt_generated): 87 | if self.source_prefix_tokenized: 88 | # Remove Source Prefix for Generation 89 | src_sentence = src_sentence[len(self.source_prefix_tokenized):] 90 | 91 | if debug: 92 | print("Src:", self.tokenizer.convert_ids_to_tokens(src_sentence)) 93 | print("Tgt:", self.tokenizer.convert_ids_to_tokens(tgt_generated)) 94 | 95 | valid_token_ids = self.get_state_valid_tokens( 96 | src_sentence.tolist(), 97 | tgt_generated.tolist() 98 | ) 99 | 100 | if debug: 101 | print('========================================') 102 | print('valid tokens:', self.tokenizer.convert_ids_to_tokens( 103 | valid_token_ids), valid_token_ids) 104 | if debug_step: 105 | input() 106 | 107 | # return self.tokenizer.convert_tokens_to_ids(valid_tokens) 108 | return valid_token_ids 109 | 110 | # ET + RT + Src -> ((Role)(Role)), ETRTText2Role 使用 111 | class RoleConstraintDecoder(ConstraintDecoder): 112 | def __init__(self, tokenizer, type_schema, *args, **kwargs): 113 | super().__init__(tokenizer, *args, **kwargs) 114 | self.tree_end = '' 115 | self.type_schema = type_schema 116 | self.type_tree = get_label_name_tree(type_schema.role_list, 117 | tokenizer=self.tokenizer, 118 | end_symbol=self.tree_end) 119 | self.type_start = self.tokenizer.convert_tokens_to_ids([type_start])[0] 120 | self.type_end = self.tokenizer.convert_tokens_to_ids([type_end])[0] 121 | 122 | def check_state(self, tgt_generated): 123 | if tgt_generated[-1] == self.tokenizer.pad_token_id: # t5-base 124 | return 'start', -1 125 | 126 | special_token_set = {self.type_start, self.type_end} 127 | special_index_token = list( 128 | filter(lambda x: x[1] in special_token_set, list(enumerate(tgt_generated)))) 129 | # print(special_index_token) 130 | last_special_index, last_special_token = special_index_token[-1] 131 | 132 | if len(special_index_token) == 1: 133 | if last_special_token != self.type_start: 134 | return 'error', 0 135 | 136 | bracket_position = find_bracket_position( 137 | tgt_generated, _type_start=self.type_start, _type_end=self.type_end) 138 | start_number, end_number = len(bracket_position[self.type_start]), len( 139 | bracket_position[self.type_end]) # 计算左右括号的数量 140 | 141 | if start_number == end_number: 142 | return 'end_generate', -1 143 | if start_number == end_number + 1: 144 | state = 'start_first_generation' 145 | elif start_number == end_number + 2: 146 | state = 'generate_span' 147 | else: 148 | state = 'error' 149 | return state, last_special_index 150 | 151 | def search_prefix_tree_and_sequence(self, generated: List[str], prefix_tree: Dict, src_sentence: List[str], 152 | end_sequence_search_tokens: List[str] = None): 153 | """ 154 | Generate Text Span 155 | :param generated: 156 | :param prefix_tree: 157 | :param src_sentence: 158 | :param end_sequence_search_tokens: 159 | :return: 160 | """ 161 | tree = prefix_tree 162 | for index, token in enumerate(generated): 163 | tree = tree[token] 164 | is_tree_end = len(tree) == 1 and self.tree_end in tree 165 | 166 | if is_tree_end: 167 | valid_token = generated_search_src_sequence( 168 | generated=generated[index + 1:], 169 | src_sequence=src_sentence, 170 | end_sequence_search_tokens=end_sequence_search_tokens, 171 | ) 172 | return valid_token 173 | 174 | if self.tree_end in tree: 175 | try: 176 | valid_token = generated_search_src_sequence( 177 | generated=generated[index + 1:], 178 | src_sequence=src_sentence, 179 | end_sequence_search_tokens=end_sequence_search_tokens, 180 | ) 181 | return valid_token 182 | except IndexError: 183 | # Still search tree 184 | continue 185 | 186 | valid_token = list(tree.keys()) 187 | return valid_token 188 | 189 | def get_state_valid_tokens(self, src_sentence, tgt_generated): 190 | """ 191 | 192 | :param src_sentence: ET RT src 193 | :param tgt_generated: 194 | :return: 195 | List[str], valid token list 196 | """ 197 | old_src = src_sentence 198 | if self.tokenizer.eos_token_id in src_sentence: 199 | if src_sentence.count(self.tokenizer.eos_token_id) > 1: # 有新增的 200 | first_index = src_sentence.index(self.tokenizer.eos_token_id) # index函数会定位第一个出现的位置 201 | second_index = first_index + 1 + src_sentence[first_index + 1:].index(self.tokenizer.eos_token_id) # 注意要加上 first_index + 1的偏移 202 | third_index = second_index + 1 + src_sentence[second_index + 1: ].index(self.tokenizer.eos_token_id) 203 | src_sentence = src_sentence[second_index + 1: third_index] # 输入端 原句 src 204 | 205 | else: 206 | src_sentence = src_sentence[:src_sentence.index(self.tokenizer.eos_token_id)] 207 | # print("eos < 1 in src_sentence:", src_sentence) 208 | 209 | state, index = self.check_state(tgt_generated) 210 | 211 | print("State: %s" % state) if debug else None 212 | 213 | if state == 'error': 214 | print("Error:") 215 | print("Old src:", old_src) 216 | # print("first_index:", first_index) 217 | # print("second_index:", second_index) 218 | print("Src:", src_sentence) 219 | print("Tgt:", tgt_generated) 220 | valid_tokens = [self.tokenizer.eos_token_id] # t5-base 使用 221 | # valid_tokens = [self.tokenizer.sep_token_id] # uer/t5-small-chinese-cluecorpussmall 使用 [SEP] 222 | 223 | elif state == 'start': 224 | valid_tokens = [self.type_start] 225 | 226 | elif state == 'start_first_generation': 227 | valid_tokens = [self.type_start, self.type_end] 228 | 229 | elif state == 'generate_span': 230 | 231 | if tgt_generated[-1] == self.type_end: 232 | raise RuntimeError('Invalid %s in %s' % 233 | (self.type_end, tgt_generated)) 234 | else: 235 | valid_tokens = generated_search_src_sequence( 236 | generated=tgt_generated[index + 1:], 237 | src_sequence=src_sentence, 238 | end_sequence_search_tokens=[self.type_end], 239 | ) 240 | 241 | elif state == 'end_generate': 242 | valid_tokens = [self.tokenizer.eos_token_id] 243 | # valid_tokens = [self.tokenizer.sep_token_id] # uer/t5-small-chinese-cluecorpussmall 使用 244 | 245 | else: 246 | raise NotImplementedError( 247 | 'State `%s` for %s is not implemented.' % (state, self.__class__)) 248 | 249 | print("Valid: %s" % valid_tokens) if debug else None 250 | return valid_tokens 251 | 252 | # Src -> ((ET)(ET)), Text2ET 使用 253 | class ETConstraintDecoder(ConstraintDecoder): 254 | def __init__(self, tokenizer, type_schema, *args, **kwargs): 255 | super().__init__(tokenizer, *args, **kwargs) 256 | self.tree_end = '' 257 | self.type_tree = get_label_name_tree(type_schema.type_list, 258 | tokenizer=self.tokenizer, 259 | end_symbol=self.tree_end) 260 | self.type_start = self.tokenizer.convert_tokens_to_ids([type_start])[0] 261 | self.type_end = self.tokenizer.convert_tokens_to_ids([type_end])[0] 262 | 263 | def check_state(self, tgt_generated): 264 | if tgt_generated[-1] == self.tokenizer.pad_token_id: # t5-base 265 | return 'start', -1 266 | 267 | special_token_set = {self.type_start, self.type_end} 268 | special_index_token = list( 269 | filter(lambda x: x[1] in special_token_set, list(enumerate(tgt_generated)))) 270 | # print(special_index_token) 271 | last_special_index, last_special_token = special_index_token[-1] 272 | 273 | if len(special_index_token) == 1: 274 | if last_special_token != self.type_start: 275 | return 'error', 0 276 | 277 | bracket_position = find_bracket_position( 278 | tgt_generated, _type_start=self.type_start, _type_end=self.type_end) 279 | start_number, end_number = len(bracket_position[self.type_start]), len( 280 | bracket_position[self.type_end]) # 计算左右括号的数量 281 | 282 | if start_number == end_number: 283 | return 'end_generate', -1 284 | if start_number == end_number + 1: 285 | state = 'start_first_generation' 286 | elif start_number == end_number + 2: 287 | state = 'generate_span' 288 | else: 289 | state = 'error' 290 | return state, last_special_index 291 | 292 | def search_prefix_tree(self, generated: List[str], prefix_tree: Dict, 293 | end_sequence_search_tokens: List[str] = None): 294 | """ 295 | Generate Text Span 296 | :param generated: 297 | :param prefix_tree: 298 | :param src_sentence: 299 | :param end_sequence_search_tokens: 300 | :return: 301 | """ 302 | tree = prefix_tree 303 | for index, token in enumerate(generated): 304 | tree = tree[token] 305 | is_tree_end = len(tree) == 1 and self.tree_end in tree 306 | 307 | if is_tree_end: 308 | return end_sequence_search_tokens 309 | 310 | if self.tree_end in tree: 311 | return end_sequence_search_tokens 312 | 313 | valid_token = list(tree.keys()) 314 | return valid_token 315 | 316 | def get_state_valid_tokens(self, src_sentence, tgt_generated): 317 | """ 318 | 319 | :param src_sentence: ET RT src 320 | :param tgt_generated: 321 | :return: 322 | List[str], valid token list 323 | """ 324 | 325 | state, index = self.check_state(tgt_generated) 326 | 327 | # print("State: %s" % state) if debug else None 328 | # print("State: %s" % state) 329 | 330 | if state == 'error': 331 | print("Error:") 332 | print("Src:", src_sentence) 333 | print("Tgt:", tgt_generated) 334 | valid_tokens = [self.tokenizer.eos_token_id] # t5-base 使用 335 | # valid_tokens = [self.tokenizer.sep_token_id] # uer/t5-small-chinese-cluecorpussmall 使用 [SEP] 336 | 337 | elif state == 'start': 338 | valid_tokens = [self.type_start] 339 | 340 | elif state == 'start_first_generation': 341 | valid_tokens = [self.type_start, self.type_end] 342 | 343 | elif state == 'generate_span': 344 | 345 | if tgt_generated[-1] == self.type_start: 346 | # Start Event Label 347 | return list(self.type_tree.keys()) 348 | 349 | elif tgt_generated[-1] == self.type_end: 350 | raise RuntimeError('Invalid %s in %s' % 351 | (self.type_end, tgt_generated)) 352 | 353 | else: 354 | valid_tokens = self.search_prefix_tree( 355 | generated=tgt_generated[index + 1:], 356 | prefix_tree=self.type_tree, 357 | end_sequence_search_tokens=[self.type_end] 358 | ) 359 | 360 | elif state == 'end_generate': 361 | valid_tokens = [self.tokenizer.eos_token_id] 362 | # valid_tokens = [self.tokenizer.sep_token_id] # uer/t5-small-chinese-cluecorpussmall 使用 363 | 364 | else: 365 | raise NotImplementedError( 366 | 'State `%s` for %s is not implemented.' % (state, self.__class__)) 367 | 368 | print("Valid: %s" % valid_tokens) if debug else None 369 | return valid_tokens 370 | 371 | # ET + Src -> ((Tri)(Tri)), ETText2Tri 使用 372 | class TriConstraintDecoder(ConstraintDecoder): 373 | def __init__(self, tokenizer, type_schema, *args, **kwargs): 374 | super().__init__(tokenizer, *args, **kwargs) 375 | self.tree_end = '' 376 | self.type_schema = type_schema 377 | self.type_tree = get_label_name_tree(type_schema.role_list, 378 | tokenizer=self.tokenizer, 379 | end_symbol=self.tree_end) 380 | self.type_start = self.tokenizer.convert_tokens_to_ids([type_start])[0] 381 | self.type_end = self.tokenizer.convert_tokens_to_ids([type_end])[0] 382 | 383 | def check_state(self, tgt_generated): 384 | if tgt_generated[-1] == self.tokenizer.pad_token_id: # t5-base 385 | return 'start', -1 386 | 387 | special_token_set = {self.type_start, self.type_end} 388 | special_index_token = list( 389 | filter(lambda x: x[1] in special_token_set, list(enumerate(tgt_generated)))) 390 | # print(special_index_token) 391 | last_special_index, last_special_token = special_index_token[-1] 392 | 393 | if len(special_index_token) == 1: 394 | if last_special_token != self.type_start: 395 | return 'error', 0 396 | 397 | bracket_position = find_bracket_position( 398 | tgt_generated, _type_start=self.type_start, _type_end=self.type_end) 399 | start_number, end_number = len(bracket_position[self.type_start]), len( 400 | bracket_position[self.type_end]) # 计算左右括号的数量 401 | 402 | if start_number == end_number: 403 | return 'end_generate', -1 404 | if start_number == end_number + 1: 405 | state = 'start_first_generation' 406 | elif start_number == end_number + 2: 407 | state = 'generate_span' 408 | else: 409 | state = 'error' 410 | return state, last_special_index 411 | 412 | 413 | def get_state_valid_tokens(self, src_sentence, tgt_generated): 414 | """ 415 | 416 | :param src_sentence: ET src 417 | :param tgt_generated: 418 | :return: 419 | List[str], valid token list 420 | """ 421 | old_src = src_sentence 422 | if self.tokenizer.eos_token_id in src_sentence: 423 | if src_sentence.count(self.tokenizer.eos_token_id) > 1: # 有新增的 424 | first_index = src_sentence.index(self.tokenizer.eos_token_id) # index函数会定位第一个出现的位置 425 | second_index = first_index + 1 + src_sentence[first_index + 1:].index(self.tokenizer.eos_token_id) # 注意要加上 first_index + 1的偏移 426 | src_sentence = src_sentence[first_index + 1: second_index] # 输入端 原句 src 427 | 428 | else: 429 | src_sentence = src_sentence[:src_sentence.index(self.tokenizer.eos_token_id)] 430 | # print("eos < 1 in src_sentence:", src_sentence) 431 | 432 | state, index = self.check_state(tgt_generated) 433 | 434 | print("State: %s" % state) if debug else None 435 | 436 | if state == 'error': 437 | print("Error:") 438 | print("Old src:", old_src) 439 | # print("first_index:", first_index) 440 | # print("second_index:", second_index) 441 | print("Src:", src_sentence) 442 | print("Tgt:", tgt_generated) 443 | valid_tokens = [self.tokenizer.eos_token_id] # t5-base 使用 444 | # valid_tokens = [self.tokenizer.sep_token_id] # uer/t5-small-chinese-cluecorpussmall 使用 [SEP] 445 | 446 | elif state == 'start': 447 | valid_tokens = [self.type_start] 448 | 449 | elif state == 'start_first_generation': 450 | valid_tokens = [self.type_start, self.type_end] 451 | 452 | elif state == 'generate_span': 453 | 454 | if tgt_generated[-1] == self.type_end: 455 | raise RuntimeError('Invalid %s in %s' % 456 | (self.type_end, tgt_generated)) 457 | else: 458 | valid_tokens = generated_search_src_sequence( 459 | generated=tgt_generated[index + 1:], 460 | src_sequence=src_sentence, 461 | end_sequence_search_tokens=[self.type_end], 462 | ) 463 | 464 | elif state == 'end_generate': 465 | valid_tokens = [self.tokenizer.eos_token_id] 466 | # valid_tokens = [self.tokenizer.sep_token_id] # uer/t5-small-chinese-cluecorpussmall 使用 467 | 468 | else: 469 | raise NotImplementedError( 470 | 'State `%s` for %s is not implemented.' % (state, self.__class__)) 471 | 472 | print("Valid: %s" % valid_tokens) if debug else None 473 | return valid_tokens 474 | 475 | -------------------------------------------------------------------------------- /extraction/extraction_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from extraction.event_schema import EventSchema 4 | from extraction.predict_parser.predict_parser import Metric 5 | from extraction.predict_parser.target_predict_parser import RolePredictParser, ETPredictParser, TriPredictParser 6 | 7 | decoding_format_dict = { 8 | 'role': RolePredictParser, 9 | 'et': ETPredictParser, 10 | 'tri': TriPredictParser, 11 | } 12 | 13 | 14 | def get_predict_parser(format_name): 15 | return decoding_format_dict[format_name] 16 | 17 | 18 | def eval_pred(predict_parser, gold_list, pred_list, text_list=None, raw_list=None): 19 | 20 | well_formed_list, counter = predict_parser.decode( 21 | gold_list, pred_list, text_list, raw_list) 22 | 23 | event_metric = Metric() 24 | role_metric = Metric() 25 | 26 | for instance in well_formed_list: 27 | event_metric.count_instance(instance['gold_event'], 28 | instance['pred_event']) 29 | role_metric.count_instance(instance['gold_role'], 30 | instance['pred_role'], 31 | verbose=False) 32 | 33 | trigger_result = event_metric.compute_f1(prefix='trigger-') 34 | role_result = role_metric.compute_f1(prefix='role-') 35 | 36 | result = dict() 37 | result.update(trigger_result) # 将trigger_result 添加到result中 38 | result.update(role_result) # 将role_result 添加到result中 39 | result['AVG-F1'] = trigger_result.get('trigger-F1', 0.) + \ 40 | role_result.get('role-F1', 0.) 41 | result.update(counter) 42 | return result 43 | 44 | 45 | def get_extract_metrics(pred_lns: List[str], tgt_lns: List[str], label_constraint: EventSchema, decoding_format='tree'): 46 | predict_parser = get_predict_parser(format_name=decoding_format)( 47 | schema=label_constraint) 48 | return eval_pred( 49 | predict_parser=predict_parser, 50 | gold_list=tgt_lns, 51 | pred_list=pred_lns 52 | ) 53 | -------------------------------------------------------------------------------- /extraction/label_tree.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from transformers import AutoTokenizer 4 | 5 | 6 | def list_dictionary(d, n_tab=-1): 7 | if isinstance(d, list): 8 | for i in d: 9 | list_dictionary(i, n_tab) 10 | elif isinstance(d, dict): 11 | n_tab += 1 12 | for key, value in d.items(): 13 | if key == '': 14 | print("{}{}".format(" " * n_tab, key)) 15 | else: 16 | print("{}{}".format(" " * n_tab, key)) 17 | list_dictionary(value, n_tab) 18 | else: 19 | print("{}{}".format("\t" * n_tab, d)) 20 | 21 | 22 | def print_tree(tree): 23 | list_dictionary(tree) 24 | 25 | 26 | def get_label_name_tree(label_name_list, tokenizer, end_symbol=''): 27 | sub_token_tree = dict() 28 | 29 | label_tree = dict() 30 | for typename in label_name_list: 31 | after_tokenized = tokenizer.encode(typename, add_special_tokens=False) 32 | label_tree[typename] = after_tokenized 33 | 34 | for _, sub_label_seq in label_tree.items(): 35 | parent = sub_token_tree 36 | for value in sub_label_seq: 37 | if value not in parent: 38 | parent[value] = dict() 39 | parent = parent[value] 40 | 41 | parent[end_symbol] = None 42 | 43 | return sub_token_tree 44 | 45 | 46 | class PrefixTree: 47 | def __init__(self, label_name_list, tokenizer, end_symbol=''): 48 | self.label_name_list = label_name_list 49 | self._tokenizer = tokenizer 50 | self.label_name_tree = get_label_name_tree( 51 | label_name_list, tokenizer, end_symbol) 52 | self._end_symbol = end_symbol 53 | 54 | def is_end_of_tree(self, tree: Dict): 55 | return len(tree) == 1 and self._end_symbol in tree 56 | 57 | 58 | if __name__ == "__main__": 59 | event_subtype_name = [line.strip() for line in """Die 60 | Marry 61 | Divorce 62 | Injure 63 | Transfer-Ownership 64 | Transfer-Money 65 | Transport 66 | Start-Org 67 | Be-Born 68 | End-Org 69 | Declare-Bankruptcy 70 | Merge-Org 71 | Attack 72 | Demonstrate 73 | Meet 74 | Phone-Write 75 | Start-Position 76 | End-Position 77 | Nominate 78 | Elect 79 | Arrest-Jail 80 | Release-Parole 81 | Charge-Indict 82 | Trial-Hearing 83 | Sue 84 | Convict 85 | Sentence 86 | Fine 87 | Execute 88 | Extradite 89 | Acquit 90 | Pardon 91 | Appeal""".split('\n')] 92 | 93 | test_tokenizer = AutoTokenizer.from_pretrained('t5-base') 94 | 95 | suptype_tree = get_label_name_tree( 96 | event_subtype_name, test_tokenizer) 97 | # role_tree = get_label_name_tree(ACEEventMetaData.event_role_name) 98 | print_tree(suptype_tree) 99 | # print_tree(role_tree) 100 | -------------------------------------------------------------------------------- /extraction/predict_parser/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | -------------------------------------------------------------------------------- /extraction/predict_parser/predict_parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | from copy import deepcopy 4 | from typing import List, Counter, Tuple 5 | 6 | EVENT_EXTRACTION_KEYS = ["trigger-P", "trigger-R", "trigger-F1", 7 | "role-P", "role-R", "role-F1"] 8 | 9 | 10 | class PredictParser: 11 | def __init__(self, schema): 12 | self.predicate_set = schema.type_list 13 | self.role_set = schema.role_list 14 | 15 | def decode(self, gold_list, pred_list, text_list=None, raw_list=None) -> Tuple[List, Counter]: 16 | """ 17 | 18 | :param gold_list: 19 | :param pred_list: 20 | :param text_list: 21 | :param raw_list: 22 | :return: 23 | dict: 24 | pred_event -> [(type1, trigger1), (type2, trigger2), ...] 25 | gold_event -> [(type1, trigger1), (type2, trigger2), ...] 26 | pred_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] 27 | gold_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] 28 | Counter: 29 | """ 30 | pass 31 | 32 | @staticmethod 33 | def count_multi_event_role_in_instance(instance, counter): 34 | if len(instance['gold_event']) != len(set(instance['gold_event'])): 35 | counter.update(['multi-same-event-gold']) 36 | 37 | if len(instance['gold_role']) != len(set(instance['gold_role'])): 38 | counter.update(['multi-same-role-gold']) 39 | 40 | if len(instance['pred_event']) != len(set(instance['pred_event'])): 41 | counter.update(['multi-same-event-pred']) 42 | 43 | if len(instance['pred_role']) != len(set(instance['pred_role'])): 44 | counter.update(['multi-same-role-pred']) 45 | 46 | 47 | class Metric: 48 | def __init__(self): 49 | self.tp = 0. 50 | self.gold_num = 0. 51 | self.pred_num = 0. 52 | 53 | @staticmethod 54 | def safe_div(a, b): 55 | if b == 0.: 56 | return 0. 57 | else: 58 | return a / b 59 | 60 | def compute_f1(self, prefix=''): 61 | tp = self.tp 62 | pred_num = self.pred_num 63 | gold_num = self.gold_num 64 | p, r = self.safe_div(tp, pred_num), self.safe_div(tp, gold_num) 65 | return {prefix + 'tp': tp, 66 | prefix + 'gold': gold_num, 67 | prefix + 'pred': pred_num, 68 | prefix + 'P': p * 100, 69 | prefix + 'R': r * 100, 70 | prefix + 'F1': self.safe_div(2 * p * r, p + r) * 100 71 | } 72 | 73 | def count_instance(self, gold_list, pred_list, verbose=False): 74 | if verbose: 75 | print("Gold:", gold_list) 76 | print("Pred:", pred_list) 77 | self.gold_num += len(gold_list) 78 | self.pred_num += len(pred_list) 79 | 80 | dup_gold_list = deepcopy(gold_list) 81 | for pred in pred_list: 82 | if pred in dup_gold_list: 83 | self.tp += 1 84 | dup_gold_list.remove(pred) 85 | -------------------------------------------------------------------------------- /extraction/predict_parser/target_predict_parser.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from typing import Tuple, List, Dict 3 | 4 | from nltk.tree import ParentedTree 5 | import re 6 | 7 | from extraction.predict_parser.predict_parser import PredictParser 8 | 9 | 10 | type_start = '' 11 | type_end = '' 12 | role_start = '' 13 | role_end = '' 14 | 15 | 16 | left_bracket = '【' 17 | right_bracket = '】' 18 | brackets = left_bracket + right_bracket 19 | 20 | split_bracket = re.compile(r"") # t5-base/mt5-base 21 | specical_str = "" 22 | 23 | 24 | def add_space(text): 25 | """ 26 | add space between special token 27 | :param text: 28 | :return: 29 | """ 30 | new_text_list = list() 31 | for item in zip(split_bracket.findall(text), split_bracket.split(text)[1:]): # 此处将第一个左括号 左边的token全部去掉 (如果以[CLS]开头则会被丢弃) 32 | new_text_list += item 33 | return ' '.join(new_text_list) # 空格组合对于中文效果? 34 | 35 | def find_bracket_num(tree_str): 36 | """ 37 | Count Bracket Number, 0 indicate num_left = num_right 38 | :param tree_str: 39 | :return: 40 | """ 41 | count = 0 42 | for char in tree_str: 43 | if char == left_bracket: 44 | count += 1 45 | elif char == right_bracket: 46 | count -= 1 47 | else: 48 | pass 49 | return count 50 | 51 | 52 | def check_well_form(tree_str): 53 | return find_bracket_num(tree_str) == 0 54 | 55 | 56 | def clean_text(tree_str): 57 | count = 0 58 | sum_count = 0 59 | 60 | tree_str_list = tree_str.split() 61 | # bracket_num = find_bracket_num(tree_str_list) 62 | # bracket_num = find_bracket_num(tree_str_list) 63 | 64 | for index, char in enumerate(tree_str_list): 65 | if char == left_bracket: 66 | count += 1 67 | sum_count += 1 68 | elif char == right_bracket: 69 | count -= 1 70 | sum_count += 1 71 | else: 72 | pass 73 | if count == 0 and sum_count > 0: 74 | return ' '.join(tree_str_list[:index + 1]) 75 | return ' '.join(tree_str_list) 76 | 77 | 78 | def add_bracket(tree_str): # 补全不够的右括号 79 | """ 80 | add right bracket to fill ill-formed 81 | :param tree_str: 82 | :return: 83 | """ 84 | tree_str_list = tree_str.split() 85 | bracket_num = find_bracket_num(tree_str_list) 86 | tree_str_list += [right_bracket] * bracket_num 87 | return ' '.join(tree_str_list) 88 | 89 | 90 | def get_tree_str(tree): 91 | """ 92 | get str from event tree 93 | :param tree: 94 | :return: 95 | """ 96 | str_list = list() 97 | for element in tree: 98 | if isinstance(element, str): 99 | str_list += [element] 100 | return ' '.join(str_list) 101 | 102 | 103 | # ET + RT + Src -> ((Role)(Role)), ETRTText2Role 使用 104 | class RolePredictParser(PredictParser): 105 | 106 | def decode(self, gold_list, pred_list, text_list=None, raw_list=None) -> Tuple[List[Dict], Counter]: 107 | """ 108 | 109 | :param gold_list: 110 | :param pred_list: 111 | :param text_list: 112 | :param raw_list: 113 | :return: 114 | dict: 115 | pred_event -> [(type1, trigger1), (type2, trigger2), ...] 116 | gold_event -> [(type1, trigger1), (type2, trigger2), ...] 117 | pred_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] 118 | gold_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] 119 | Counter: 120 | """ 121 | counter = Counter() 122 | well_formed_list = [] 123 | 124 | def convert_bracket(_text): 125 | _text = add_space(_text) 126 | for start in [role_start, type_start]: 127 | _text = _text.replace(start, left_bracket) 128 | for end in [role_end, type_end]: 129 | _text = _text.replace(end, right_bracket) 130 | return _text 131 | 132 | if gold_list is None or len(gold_list) == 0: # 不存在标注信息的情况下根据pred_list 全部置为空 133 | gold_list = ["%s%s" % (type_start, type_end)] * len(pred_list) 134 | 135 | if text_list is None: 136 | text_list = [None] * len(pred_list) 137 | 138 | if raw_list is None: 139 | raw_list = [None] * len(pred_list) 140 | 141 | for gold, pred, text, raw_data in zip(gold_list, pred_list, text_list, raw_list): 142 | # print(gold) 143 | # print("*************************************") 144 | gold = convert_bracket(gold) 145 | pred = convert_bracket(pred) 146 | 147 | gold = clean_text(gold) 148 | pred = clean_text(pred) 149 | # print(gold) 150 | 151 | et = None # 事件类型 152 | rt = None # 角色类型 153 | if text and specical_str in text: 154 | et = text.split(specical_str)[0].strip() 155 | rt = text.split(specical_str)[1].strip() 156 | text = text.split(specical_str)[-1].strip() 157 | 158 | instance = {'gold': gold, 159 | 'pred': pred, 160 | 'gold_tree': None, 161 | 'text': text, 162 | 'raw_data': raw_data 163 | } 164 | 165 | # 重点修改部分 166 | instance['pred_event'], instance['pred_role'], instance['pred_record'] = self.get_event_list( 167 | span_str=instance["pred"], 168 | text=instance['text'], 169 | et = et, 170 | rt = rt 171 | ) 172 | instance['gold_event'], instance['gold_role'], instance['gold_record'] = self.get_event_list( 173 | span_str=instance["gold"], 174 | text=instance['text'], 175 | et = et, 176 | rt = rt 177 | ) 178 | 179 | 180 | # span中该部分无意义 181 | counter.update(['gold_tree']) 182 | counter.update(['pred_tree']) 183 | counter.update(['well-formed']) 184 | 185 | self.count_multi_event_role_in_instance(instance=instance, counter=counter) 186 | 187 | well_formed_list += [instance] 188 | 189 | return well_formed_list, counter 190 | 191 | 192 | def get_event_list(self, span_str, text=None, et=None, rt=None): 193 | 194 | event_list = list() 195 | role_list = list() 196 | record_list = list() 197 | 198 | # 将target结果格式化处理 199 | spans = [] 200 | for item in span_str.replace(left_bracket, "").split(right_bracket): 201 | t = item.strip() 202 | if len(t) > 0: spans.append(t) 203 | 204 | cur_et_type = et # 在 span 生成时候使用 205 | for span_item in spans: 206 | 207 | if len(span_item) == 0: 208 | continue 209 | 210 | span_text = span_item 211 | 212 | # role text 213 | if text is not None and span_text not in text: continue 214 | role_list += [(cur_et_type, rt, span_text)] 215 | 216 | record = {'roles': role_list, 'type': event_list, 'trigger': None} 217 | 218 | 219 | record_list += [record] 220 | 221 | return event_list, role_list, record_list 222 | 223 | # Src -> ((ET)(ET)), Text2ET 使用 224 | class ETPredictParser(PredictParser): 225 | 226 | def decode(self, gold_list, pred_list, text_list=None, raw_list=None) -> Tuple[List[Dict], Counter]: 227 | """ 228 | 229 | :param gold_list: 230 | :param pred_list: 231 | :param text_list: 232 | :param raw_list: 233 | :return: 234 | dict: 235 | pred_event -> [(type1, trigger1), (type2, trigger2), ...] 236 | gold_event -> [(type1, trigger1), (type2, trigger2), ...] 237 | pred_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] 238 | gold_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] 239 | Counter: 240 | """ 241 | counter = Counter() 242 | well_formed_list = [] 243 | 244 | def convert_bracket(_text): 245 | _text = add_space(_text) 246 | for start in [role_start, type_start]: 247 | _text = _text.replace(start, left_bracket) 248 | for end in [role_end, type_end]: 249 | _text = _text.replace(end, right_bracket) 250 | return _text 251 | 252 | if gold_list is None or len(gold_list) == 0: # 不存在标注信息的情况下根据pred_list 全部置为空 253 | gold_list = ["%s%s" % (type_start, type_end)] * len(pred_list) 254 | 255 | if text_list is None: 256 | text_list = [None] * len(pred_list) 257 | 258 | if raw_list is None: 259 | raw_list = [None] * len(pred_list) 260 | 261 | for gold, pred, text, raw_data in zip(gold_list, pred_list, text_list, raw_list): 262 | # print(gold) 263 | # print("*************************************") 264 | gold = convert_bracket(gold) 265 | pred = convert_bracket(pred) 266 | 267 | gold = clean_text(gold) 268 | pred = clean_text(pred) 269 | # print(gold) 270 | 271 | et = None # 事件类型 272 | rt = None # 角色类型 273 | if text and specical_str in text: 274 | et = text.split(specical_str)[0].strip() 275 | rt = text.split(specical_str)[1].strip() 276 | text = text.split(specical_str)[-1].strip() 277 | 278 | instance = {'gold': gold, 279 | 'pred': pred, 280 | 'gold_tree': None, 281 | 'text': text, 282 | 'raw_data': raw_data 283 | } 284 | 285 | # 重点修改部分 286 | instance['pred_event'], instance['pred_role'], instance['pred_record'] = self.get_event_list( 287 | span_str=instance["pred"], 288 | text=instance['text'], 289 | et = et, 290 | rt = rt 291 | ) 292 | instance['gold_event'], instance['gold_role'], instance['gold_record'] = self.get_event_list( 293 | span_str=instance["gold"], 294 | text=instance['text'], 295 | et = et, 296 | rt = rt 297 | ) 298 | 299 | 300 | # span中该部分无意义 301 | counter.update(['gold_tree']) 302 | counter.update(['pred_tree']) 303 | counter.update(['well-formed']) 304 | 305 | self.count_multi_event_role_in_instance(instance=instance, counter=counter) 306 | 307 | well_formed_list += [instance] 308 | 309 | return well_formed_list, counter 310 | 311 | 312 | def get_event_list(self, span_str, text=None, et=None, rt=None): 313 | 314 | event_list = list() 315 | role_list = list() 316 | record_list = list() 317 | 318 | # 将target结果格式化处理 319 | spans = [] 320 | for item in span_str.replace(left_bracket, "").split(right_bracket): 321 | t = item.strip() 322 | if len(t) > 0: spans.append(t) 323 | 324 | cur_et_type = et # 在 span 生成时候使用 325 | for span_item in spans: 326 | 327 | if len(span_item) == 0: 328 | continue 329 | 330 | span_text = span_item 331 | 332 | event_list += [(span_text)] 333 | 334 | record = {'roles': role_list, 'type': event_list, 'trigger': None} 335 | 336 | record_list += [record] 337 | 338 | return event_list, role_list, record_list 339 | 340 | # ET + Src -> ((Tri)(Tri)), ETText2Tri 使用 341 | class TriPredictParser(PredictParser): 342 | 343 | def decode(self, gold_list, pred_list, text_list=None, raw_list=None) -> Tuple[List[Dict], Counter]: 344 | counter = Counter() 345 | well_formed_list = [] 346 | 347 | def convert_bracket(_text): 348 | _text = add_space(_text) 349 | for start in [role_start, type_start]: 350 | _text = _text.replace(start, left_bracket) 351 | for end in [role_end, type_end]: 352 | _text = _text.replace(end, right_bracket) 353 | return _text 354 | 355 | if gold_list is None or len(gold_list) == 0: # 不存在标注信息的情况下根据pred_list 全部置为空 356 | gold_list = ["%s%s" % (type_start, type_end)] * len(pred_list) 357 | 358 | if text_list is None: 359 | text_list = [None] * len(pred_list) 360 | 361 | if raw_list is None: 362 | raw_list = [None] * len(pred_list) 363 | 364 | for gold, pred, text, raw_data in zip(gold_list, pred_list, text_list, raw_list): 365 | # print(gold) 366 | # print("*************************************") 367 | gold = convert_bracket(gold) 368 | pred = convert_bracket(pred) 369 | 370 | gold = clean_text(gold) 371 | pred = clean_text(pred) 372 | # print(gold) 373 | 374 | et = None # 事件类型 375 | rt = None # 角色类型 376 | if text and specical_str in text: 377 | et = text.split(specical_str)[0].strip() 378 | rt = text.split(specical_str)[1].strip() 379 | text = text.split(specical_str)[-1].strip() 380 | 381 | instance = {'gold': gold, 382 | 'pred': pred, 383 | 'gold_tree': None, 384 | 'text': text, 385 | 'raw_data': raw_data 386 | } 387 | 388 | # 重点修改部分 389 | instance['pred_event'], instance['pred_role'], instance['pred_record'] = self.get_event_list( 390 | span_str=instance["pred"], 391 | text=instance['text'], 392 | et = et, 393 | rt = rt 394 | ) 395 | instance['gold_event'], instance['gold_role'], instance['gold_record'] = self.get_event_list( 396 | span_str=instance["gold"], 397 | text=instance['text'], 398 | et = et, 399 | rt = rt 400 | ) 401 | 402 | 403 | # span中该部分无意义 404 | counter.update(['gold_tree']) 405 | counter.update(['pred_tree']) 406 | counter.update(['well-formed']) 407 | 408 | self.count_multi_event_role_in_instance(instance=instance, counter=counter) 409 | 410 | well_formed_list += [instance] 411 | 412 | return well_formed_list, counter 413 | 414 | def get_event_list(self, span_str, text=None, et=None, rt=None): 415 | 416 | event_list = list() 417 | role_list = list() 418 | record_list = list() 419 | 420 | # 将target结果格式化处理 421 | spans = [] 422 | for item in span_str.replace(left_bracket, "").split(right_bracket): 423 | t = item.strip() 424 | if len(t) > 0: spans.append(t) 425 | 426 | cur_et_type = et # 在 span 生成时候使用 427 | for span_item in spans: 428 | if len(span_item) == 0: 429 | continue 430 | span_text = span_item 431 | 432 | # role text 433 | if text is not None and span_text not in text: continue 434 | event_list += [(cur_et_type, span_text)] 435 | 436 | record = {'roles': role_list, 'type': event_list, 'trigger': None} 437 | 438 | 439 | record_list += [record] 440 | 441 | return event_list, role_list, record_list 442 | 443 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | transformers==4.4.2 3 | anytree~=2.8.0 4 | tensorboard 5 | scikit-learn 6 | seqeval 7 | psutil 8 | sacrebleu~=1.4.14 9 | rouge-score 10 | tensorflow_datasets 11 | matplotlib 12 | git-python==1.0.3 13 | faiss-cpu 14 | streamlit 15 | elasticsearch 16 | nltk~=3.5 17 | pandas 18 | datasets >= 1.1.3 19 | fire 20 | pytest 21 | conllu 22 | sentencepiece != 0.1.92 23 | protobuf 24 | numpy~=1.19.2 25 | tabulate~=0.8.7 26 | filelock~=3.0.12 27 | dataclasses~=0.6 28 | rich~=9.8.2 29 | -------------------------------------------------------------------------------- /run_arg_predict.bash: -------------------------------------------------------------------------------- 1 | export device="0" 2 | export et_model_path="XXX" 3 | export role_model_path="XXX" 4 | export et_data_name=dyiepp_ace2005_et_subtype_span 5 | export task_name="event" 6 | export batch=16 7 | export constraint_decoding="--constraint_decoding" 8 | 9 | export decoding_format='role' 10 | 11 | et_data_folder=data/text2target/${et_data_name} 12 | 13 | # et result 转化 14 | python convert_et_result.py \ 15 | --et_pred_file=${et_model_path}/test_preds_seq2seq.txt \ 16 | --et_text_file=${et_data_folder}/test.json \ 17 | --et_output_file=${et_data_folder}/test_pre.json \ 18 | --schema_file=${et_data_folder}/event.schema 19 | 20 | # # 依赖上述的转化文件进行 role predict 预测, decode_format = noetrtspan 21 | CUDA_VISIBLE_DEVICES=${device} python run_seq2seq.py \ 22 | --do_predict --task=${task_name} --predict_with_generate \ 23 | --validation_file=${et_data_folder}/val.json \ 24 | --test_file=${et_data_folder}/test_pre.json \ 25 | --event_schema=${et_data_folder}/event.schema \ 26 | --model_name_or_path=${role_model_path} \ 27 | --output_dir="${role_model_path}"_test \ 28 | --source_prefix="span: " \ 29 | ${constraint_decoding} \ 30 | --per_device_eval_batch_size=${batch} \ 31 | --decoding_format ${decoding_format} 32 | 33 | # evaluate 评估, 结果打印在控制台 34 | python evaluation.py \ 35 | --text_file=${et_data_folder}/test_pre.json \ 36 | --pred_file="${role_model_path}"_test/test_preds_seq2seq.txt \ 37 | --gold_file="data/raw_data/dyiepp_ace2005/test_convert.json" \ 38 | --schema_file=${et_data_folder}/event.schema \ 39 | --decoding_format ${decoding_format} \ 40 | --format="dyiepp" > "${role_model_path}"_test/total_dyiepp_result.txt 41 | 42 | # evaluate 评估, 结果打印在控制台(non-events 评估) 43 | python evaluation.py \ 44 | --text_file=${et_data_folder}/test_pre.json \ 45 | --pred_file="${role_model_path}"_test/test_preds_seq2seq.txt \ 46 | --gold_file="data/text2target/dyiepp_ace2005_etrttext2role_subtype/test.json" \ 47 | --schema_file=${et_data_folder}/event.schema \ 48 | --decoding_format ${decoding_format} \ 49 | --format="text2target" > "${role_model_path}"_test/total_text2target_result.txt 50 | 51 | 52 | -------------------------------------------------------------------------------- /run_seq2seq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import re 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import nltk # Here to have a nice missing dependency error message early on 29 | import numpy as np 30 | from datasets import load_dataset, load_metric 31 | 32 | import transformers 33 | from filelock import FileLock 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForSeq2SeqLM, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | MBartTokenizer, 41 | default_data_collator, 42 | set_seed 43 | ) 44 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 45 | 46 | from extraction.event_schema import EventSchema 47 | from extraction.extraction_metrics import decoding_format_dict, get_extract_metrics 48 | from seq2seq.constrained_seq2seq import ConstraintSeq2SeqTrainingArguments, ConstraintSeq2SeqTrainer 49 | 50 | # if not os.path.exists("~/nltk_data/tokenizers/punkt"): 51 | # print('Start download punk') 52 | # with FileLock(".lock") as lock: 53 | # nltk.download("punkt", quiet=True) 54 | 55 | logger = logging.getLogger(__name__) 56 | 57 | 58 | @dataclass 59 | class ModelArguments: 60 | """ 61 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 62 | """ 63 | 64 | model_name_or_path: str = field( 65 | metadata={ 66 | "help": "Path to pretrained model or model identifier from huggingface.co/models"} 67 | ) 68 | config_name: Optional[str] = field( 69 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 70 | ) 71 | tokenizer_name: Optional[str] = field( 72 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 73 | ) 74 | cache_dir: Optional[str] = field( 75 | default=None, 76 | metadata={ 77 | "help": "Where to store the pretrained models downloaded from huggingface.co"}, 78 | ) 79 | use_fast_tokenizer: bool = field( 80 | default=False, 81 | metadata={ 82 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 83 | ) 84 | # !!! must use non-fast version 85 | # fast: " " -> [32099, 3, 32098, 1] 86 | # non-fast: " " -> [32099, 32098, 1] 87 | model_revision: str = field( 88 | default="main", 89 | metadata={ 90 | "help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 91 | ) 92 | use_auth_token: bool = field( 93 | default=False, 94 | metadata={ 95 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 96 | "with private models)." 97 | }, 98 | ) 99 | 100 | 101 | @dataclass 102 | class DataTrainingArguments: 103 | """ 104 | Arguments pertaining to what data we are going to input our model for training and eval. 105 | """ 106 | 107 | task: str = field( 108 | default="summarization", 109 | metadata={ 110 | "help": "The name of the task, should be summarization (or summarization_{dataset} for evaluating " 111 | "pegasus) or translation (or translation_{xx}_to_{yy})." 112 | }, 113 | ) 114 | dataset_name: Optional[str] = field( 115 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 116 | ) 117 | dataset_config_name: Optional[str] = field( 118 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 119 | ) 120 | text_column: Optional[str] = field( 121 | default=None, 122 | metadata={ 123 | "help": "The name of the column in the datasets containing the full texts (for summarization)."}, 124 | ) 125 | summary_column: Optional[str] = field( 126 | default=None, 127 | metadata={ 128 | "help": "The name of the column in the datasets containing the summaries (for summarization)."}, 129 | ) 130 | train_file: Optional[str] = field( 131 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 132 | ) 133 | validation_file: Optional[str] = field( 134 | default=None, 135 | metadata={ 136 | "help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on " 137 | "(a jsonlines or csv file)." 138 | }, 139 | ) 140 | test_file: Optional[str] = field( 141 | default=None, 142 | metadata={ 143 | "help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on " 144 | "(a jsonlines or csv file)." 145 | }, 146 | ) 147 | overwrite_cache: bool = field( 148 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 149 | ) 150 | preprocessing_num_workers: Optional[int] = field( 151 | default=None, 152 | metadata={"help": "The number of processes to use for the preprocessing."}, 153 | ) 154 | max_source_length: Optional[int] = field( 155 | default=1024, 156 | metadata={ 157 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 158 | "than this will be truncated, sequences shorter will be padded." 159 | }, 160 | ) 161 | max_target_length: Optional[int] = field( 162 | default=128, 163 | metadata={ 164 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 165 | "than this will be truncated, sequences shorter will be padded." 166 | }, 167 | ) 168 | val_max_target_length: Optional[int] = field( 169 | default=None, 170 | metadata={ 171 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 172 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 173 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 174 | "during ``evaluate`` and ``predict``." 175 | }, 176 | ) 177 | pad_to_max_length: bool = field( 178 | default=False, 179 | metadata={ 180 | "help": "Whether to pad all samples to model maximum sentence length. " 181 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 182 | "efficient on GPU but very bad for TPU." 183 | }, 184 | ) 185 | max_train_samples: Optional[int] = field( 186 | default=None, 187 | metadata={ 188 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 189 | "value if set." 190 | }, 191 | ) 192 | max_val_samples: Optional[int] = field( 193 | default=None, 194 | metadata={ 195 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 196 | "value if set." 197 | }, 198 | ) 199 | max_test_samples: Optional[int] = field( 200 | default=None, 201 | metadata={ 202 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 203 | "value if set." 204 | }, 205 | ) 206 | source_lang: Optional[str] = field( 207 | default=None, metadata={"help": "Source language id for translation."}) 208 | target_lang: Optional[str] = field( 209 | default=None, metadata={"help": "Target language id for translation."}) 210 | num_beams: Optional[int] = field( 211 | default=None, 212 | metadata={ 213 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 214 | "which is used during ``evaluate`` and ``predict``." 215 | }, 216 | ) 217 | ignore_pad_token_for_loss: bool = field( 218 | default=True, 219 | metadata={ 220 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 221 | }, 222 | ) 223 | source_prefix: Optional[str] = field( 224 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 225 | ) 226 | 227 | def __post_init__(self): 228 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 229 | raise ValueError( 230 | "Need either a dataset name or a training/validation file.") 231 | else: 232 | if self.train_file is not None: 233 | extension = self.train_file.split(".")[-1] 234 | assert extension in [ 235 | "csv", "json"], "`train_file` should be a csv or a json file." 236 | if self.validation_file is not None: 237 | extension = self.validation_file.split(".")[-1] 238 | assert extension in [ 239 | "csv", "json"], "`validation_file` should be a csv or a json file." 240 | if not self.task.startswith("summarization") and not self.task.startswith( 241 | "translation") and not self.task.startswith('event'): 242 | raise ValueError( 243 | "`task` should be summarization, summarization_{dataset}, translation or translation_{xx}_to_{yy}." 244 | ) 245 | if self.val_max_target_length is None: 246 | self.val_max_target_length = self.max_target_length 247 | 248 | # Start Code for Event Extraction 249 | decoding_format: str = field( 250 | default='tree', 251 | metadata={"help": "Decoding Format, valid in %s" % 252 | decoding_format_dict.keys()} 253 | ) 254 | event_schema: str = field( 255 | default=None, metadata={"help": "The input event schema file."} 256 | ) 257 | # End Code for Event Extraction 258 | 259 | 260 | summarization_name_mapping = { 261 | "amazon_reviews_multi": ("review_body", "review_title"), 262 | "big_patent": ("description", "abstract"), 263 | "cnn_dailymail": ("article", "highlights"), 264 | "orange_sum": ("text", "summary"), 265 | "pn_summary": ("article", "summary"), 266 | "psc": ("extract_text", "summary_text"), 267 | "samsum": ("dialogue", "summary"), 268 | "thaisum": ("body", "summary"), 269 | "xglue": ("news_body", "news_title"), 270 | "xsum": ("document", "summary"), 271 | "wiki_summary": ("article", "highlights"), 272 | } 273 | 274 | event_extraction_name_mapping = { 275 | "event": ("text", "event") 276 | } 277 | 278 | 279 | def main(): 280 | # See all possible arguments in src/transformers/training_args.py 281 | # or by passing the --help flag to this script. 282 | # We now keep distinct sets of args, for a cleaner separation of concerns. 283 | 284 | parser = HfArgumentParser( 285 | (ModelArguments, DataTrainingArguments, ConstraintSeq2SeqTrainingArguments)) 286 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 287 | # If we pass only one argument to the script and it's the path to a json file, 288 | # let's parse it to get our arguments. 289 | model_args, data_args, training_args = parser.parse_json_file( 290 | json_file=os.path.abspath(sys.argv[1])) 291 | else: 292 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 293 | 294 | print(model_args) 295 | print(data_args) 296 | print(training_args) 297 | 298 | # Detecting last checkpoint. 299 | last_checkpoint = None 300 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 301 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 302 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 303 | raise ValueError( 304 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 305 | "Use --overwrite_output_dir to overcome." 306 | ) 307 | elif last_checkpoint is not None: 308 | logger.info( 309 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 310 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 311 | ) 312 | 313 | # Setup logging 314 | logging.basicConfig( 315 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 316 | datefmt="%m/%d/%Y %H:%M:%S", 317 | handlers=[logging.StreamHandler(sys.stdout)], 318 | ) 319 | logger.setLevel(logging.INFO if is_main_process( 320 | training_args.local_rank) else logging.WARN) 321 | 322 | # Log on each process the small summary: 323 | logger.warning( 324 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 325 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 326 | ) 327 | # Set the verbosity to info of the Transformers logger (on main process only): 328 | if is_main_process(training_args.local_rank): 329 | transformers.utils.logging.set_verbosity_info() 330 | logger.info("Training/evaluation parameters %s", training_args) 331 | 332 | # Set seed before initializing model. 333 | set_seed(training_args.seed) 334 | 335 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 336 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 337 | # (the dataset will be downloaded automatically from the datasets Hub). 338 | # 339 | # For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the 340 | # second column for the summaries (unless you specify column names for this with the `text_column` and 341 | # `summary_column` arguments). 342 | # For translation, only JSON files are supported, with one field named "translation" containing two keys for the 343 | # source and target languages (unless you adapt what follows). 344 | # 345 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 346 | # download the dataset. 347 | if data_args.dataset_name is not None: 348 | # Downloading and loading a dataset from the hub. 349 | datasets = load_dataset(data_args.dataset_name, 350 | data_args.dataset_config_name) 351 | else: 352 | data_files = {} 353 | if data_args.train_file is not None: 354 | data_files["train"] = data_args.train_file 355 | extension = data_args.train_file.split(".")[-1] 356 | if data_args.validation_file is not None: 357 | data_files["validation"] = data_args.validation_file 358 | extension = data_args.validation_file.split(".")[-1] 359 | if data_args.test_file is not None: 360 | data_files["test"] = data_args.test_file 361 | extension = data_args.test_file.split(".")[-1] 362 | datasets = load_dataset(extension, data_files=data_files) 363 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 364 | # https://huggingface.co/docs/datasets/loading_datasets.html. 365 | 366 | # Load pretrained model and tokenizer 367 | # 368 | # Distributed training: 369 | # The .from_pretrained methods guarantee that only one local process can concurrently 370 | # download model & vocab. 371 | config = AutoConfig.from_pretrained( 372 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 373 | cache_dir=model_args.cache_dir, 374 | revision=model_args.model_revision, 375 | use_auth_token=True if model_args.use_auth_token else None, 376 | mirror='tuna', 377 | ) 378 | 379 | # !!! Sometimes default max_length is setting to 20. 380 | config.max_length = data_args.max_target_length 381 | 382 | tokenizer = AutoTokenizer.from_pretrained( 383 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 384 | cache_dir=model_args.cache_dir, 385 | use_fast=model_args.use_fast_tokenizer, 386 | revision=model_args.model_revision, 387 | use_auth_token=True if model_args.use_auth_token else None, 388 | mirror='tuna', 389 | ) 390 | 391 | to_remove_token_list = list() 392 | if tokenizer.bos_token: 393 | to_remove_token_list += [tokenizer.bos_token] 394 | if tokenizer.eos_token: 395 | to_remove_token_list += [tokenizer.eos_token] 396 | if tokenizer.pad_token: 397 | to_remove_token_list += [tokenizer.pad_token] 398 | 399 | model = AutoModelForSeq2SeqLM.from_pretrained( 400 | model_args.model_name_or_path, 401 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 402 | config=config, 403 | cache_dir=model_args.cache_dir, 404 | revision=model_args.model_revision, 405 | use_auth_token=True if model_args.use_auth_token else None, 406 | # mirror='tuna', # 预训练模型为 uer/t5-base-chinese-cluecorpussmall 时会报错 407 | ) 408 | 409 | 410 | if tokenizer.encode(" ") != [32099, 32098, 1]: 411 | # For non-t5 tokenizer 412 | tokenizer.add_special_tokens( 413 | {"additional_special_tokens": ["", ""]}) 414 | model.resize_token_embeddings(len(tokenizer)) 415 | 416 | # Set decoder_start_token_id 417 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): 418 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] 419 | if model.config.decoder_start_token_id is None: 420 | raise ValueError( 421 | "Make sure that `config.decoder_start_token_id` is correctly defined") 422 | 423 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 424 | 425 | # Preprocessing the datasets. 426 | # We need to tokenize inputs and targets. 427 | if training_args.do_train: 428 | column_names = datasets["train"].column_names 429 | elif training_args.do_eval: 430 | column_names = datasets["validation"].column_names 431 | elif training_args.do_predict: 432 | column_names = datasets["test"].column_names 433 | else: 434 | logger.info( 435 | "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 436 | return 437 | 438 | # For translation we set the codes of our source and target languages (only useful for mBART, the others will 439 | # ignore those attributes). 440 | if data_args.task.startswith("translation"): 441 | if data_args.source_lang is not None: 442 | tokenizer.src_lang = data_args.source_lang 443 | if data_args.target_lang is not None: 444 | tokenizer.tgt_lang = data_args.target_lang 445 | 446 | # Start Code for Event Extraction 447 | if data_args.task.startswith("event"): 448 | decoding_type_schema = EventSchema.read_from_file( 449 | data_args.event_schema) 450 | else: 451 | decoding_type_schema = None 452 | # End Code for Event Extraction 453 | 454 | # To serialize preprocess_function below, each of those four variables needs to be defined (even if we won't use 455 | # them all). 456 | source_lang, target_lang, text_column, summary_column = None, None, None, None 457 | 458 | if data_args.task.startswith("summarization"): 459 | # Get the column names for input/target. 460 | dataset_columns = summarization_name_mapping.get( 461 | data_args.dataset_name, None) 462 | if data_args.text_column is None: 463 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 464 | else: 465 | text_column = data_args.text_column 466 | if data_args.summary_column is None: 467 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 468 | else: 469 | summary_column = data_args.summary_column 470 | # Start Code for Event Extraction 471 | elif data_args.task.startswith("event"): 472 | dataset_columns = event_extraction_name_mapping.get( 473 | data_args.dataset_name, None) 474 | if data_args.text_column is None: 475 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 476 | else: 477 | text_column = data_args.text_column 478 | if data_args.summary_column is None: 479 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 480 | else: 481 | summary_column = data_args.summary_column 482 | # End Code for Event Extraction 483 | else: 484 | # Get the language codes for input/target. 485 | lang_search = re.match( 486 | "translation_([a-z]+)_to_([a-z]+)", data_args.task) 487 | if data_args.source_lang is not None: 488 | source_lang = data_args.source_lang.split("_")[0] 489 | else: 490 | assert ( 491 | lang_search is not None 492 | ), "Provide a source language via --source_lang or rename your task 'translation_xx_to_yy'." 493 | source_lang = lang_search.groups()[0] 494 | 495 | if data_args.target_lang is not None: 496 | target_lang = data_args.target_lang.split("_")[0] 497 | else: 498 | assert ( 499 | lang_search is not None 500 | ), "Provide a target language via --target_lang or rename your task 'translation_xx_to_yy'." 501 | target_lang = lang_search.groups()[1] 502 | 503 | # Temporarily set max_target_length for training. 504 | max_target_length = data_args.max_target_length 505 | padding = "max_length" if data_args.pad_to_max_length else False 506 | 507 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 508 | logger.error( 509 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 510 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 511 | ) 512 | 513 | def preprocess_function(examples): 514 | if data_args.task.startswith("translation"): 515 | inputs = [ex[source_lang] for ex in examples["translation"]] 516 | targets = [ex[target_lang] for ex in examples["translation"]] 517 | else: 518 | inputs = examples[text_column] 519 | targets = examples[summary_column] 520 | 521 | inputs = [prefix + inp for inp in inputs] # 在每个句子前面增加 source_prefix 前缀标识 522 | model_inputs = tokenizer( 523 | inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 524 | 525 | # Setup the tokenizer for targets 526 | # 处理target数据 527 | with tokenizer.as_target_tokenizer(): 528 | labels = tokenizer( 529 | targets, max_length=max_target_length, padding=padding, truncation=True) 530 | 531 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 532 | # padding in the loss. 533 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 534 | labels["input_ids"] = [ 535 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] # 用-100来进行填充, -100部分不进行loss的计算 536 | ] 537 | 538 | model_inputs["labels"] = labels["input_ids"] 539 | return model_inputs 540 | 541 | if training_args.do_train: 542 | train_dataset = datasets["train"] 543 | if data_args.max_train_samples is not None: 544 | train_dataset = train_dataset.select( 545 | range(data_args.max_train_samples)) 546 | train_dataset = train_dataset.map( 547 | preprocess_function, 548 | batched=True, 549 | num_proc=data_args.preprocessing_num_workers, 550 | remove_columns=column_names, 551 | load_from_cache_file=not data_args.overwrite_cache, 552 | ) 553 | 554 | if training_args.do_eval: 555 | max_target_length = data_args.val_max_target_length 556 | eval_dataset = datasets["validation"] 557 | if data_args.max_val_samples is not None: 558 | eval_dataset = eval_dataset.select( 559 | range(data_args.max_val_samples)) 560 | eval_dataset = eval_dataset.map( 561 | preprocess_function, 562 | batched=True, 563 | num_proc=data_args.preprocessing_num_workers, 564 | remove_columns=column_names, 565 | load_from_cache_file=not data_args.overwrite_cache, 566 | ) 567 | 568 | if training_args.do_predict: 569 | max_target_length = data_args.val_max_target_length 570 | test_dataset = datasets["test"] 571 | if data_args.max_test_samples is not None: 572 | test_dataset = test_dataset.select( 573 | range(data_args.max_test_samples)) 574 | test_dataset = test_dataset.map( 575 | preprocess_function, 576 | batched=True, 577 | num_proc=data_args.preprocessing_num_workers, 578 | remove_columns=column_names, 579 | load_from_cache_file=not data_args.overwrite_cache, 580 | ) 581 | 582 | # Data collator 583 | label_pad_token_id = - \ 584 | 100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 585 | if data_args.pad_to_max_length: 586 | data_collator = default_data_collator 587 | else: 588 | data_collator = DataCollatorForSeq2Seq( 589 | tokenizer, 590 | model=model, 591 | label_pad_token_id=label_pad_token_id, 592 | pad_to_multiple_of=8 if training_args.fp16 else None, 593 | ) 594 | 595 | # Metric 596 | # metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu" 597 | # metric = load_metric(metric_name) 598 | 599 | def postprocess_text(preds, labels): 600 | preds = [pred.strip() for pred in preds] 601 | labels = [label.strip() for label in labels] 602 | 603 | # rougeLSum expects newline after each sentence 604 | if metric_name == "rouge": 605 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 606 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 607 | else: # sacrebleu 608 | labels = [[label] for label in labels] 609 | 610 | return preds, labels 611 | 612 | def compute_metrics(eval_preds): 613 | preds, labels = eval_preds 614 | if isinstance(preds, tuple): 615 | preds = preds[0] 616 | decoded_preds = tokenizer.batch_decode( 617 | preds, skip_special_tokens=False) 618 | if data_args.ignore_pad_token_for_loss: 619 | # Replace -100 in the labels as we can't decode them. 620 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 621 | decoded_labels = tokenizer.batch_decode( 622 | labels, skip_special_tokens=False) 623 | 624 | def clean_str(x_str): 625 | for to_remove_token in to_remove_token_list: 626 | x_str = x_str.replace(to_remove_token, '') 627 | return x_str.strip() 628 | 629 | decoded_preds = [clean_str(x) for x in decoded_preds] 630 | decoded_labels = [clean_str(x) for x in decoded_labels] 631 | 632 | # Some simple post-processing 633 | # decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 634 | 635 | # if metric_name == "rouge": 636 | # result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 637 | # # Extract a few results from ROUGE 638 | # result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 639 | # else: 640 | # result = metric.compute(predictions=decoded_preds, references=decoded_labels) 641 | # result = {"bleu": result["score"]} 642 | 643 | result = get_extract_metrics( 644 | pred_lns=decoded_preds, 645 | tgt_lns=decoded_labels, 646 | label_constraint=decoding_type_schema, 647 | decoding_format=data_args.decoding_format, 648 | ) 649 | 650 | prediction_lens = [np.count_nonzero( 651 | pred != tokenizer.pad_token_id) for pred in preds] 652 | result["gen_len"] = np.mean(prediction_lens) 653 | result = {k: round(v, 4) for k, v in result.items()} 654 | return result 655 | 656 | # Initialize our Trainer 657 | trainer = ConstraintSeq2SeqTrainer( 658 | model=model, 659 | args=training_args, 660 | train_dataset=train_dataset if training_args.do_train else None, 661 | eval_dataset=eval_dataset if training_args.do_eval else None, 662 | tokenizer=tokenizer, 663 | data_collator=data_collator, 664 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 665 | decoding_type_schema=decoding_type_schema, 666 | decoding_format=data_args.decoding_format, 667 | source_prefix=prefix, 668 | ) 669 | 670 | # Training 671 | if training_args.do_train: 672 | # if last_checkpoint is not None: 673 | # checkpoint = last_checkpoint 674 | # elif os.path.isdir(model_args.model_name_or_path): 675 | # checkpoint = model_args.model_name_or_path 676 | # else: 677 | # checkpoint = None 678 | # TODO fix better about max_length 679 | checkpoint = None 680 | 681 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 682 | trainer.save_model() # Saves the tokenizer too for easy upload 683 | 684 | # train_pred_results = trainer.predict( 685 | # train_dataset, 686 | # metric_key_prefix="train", 687 | # max_length=data_args.val_max_target_length, 688 | # num_beams=data_args.num_beams, 689 | # ) 690 | 691 | output_train_file = os.path.join( 692 | training_args.output_dir, "train_results.txt") 693 | if trainer.is_world_process_zero(): 694 | with open(output_train_file, "w") as writer: 695 | logger.info("***** Train results *****") 696 | for key, value in sorted(train_result.metrics.items()): 697 | logger.info(f" {key} = {value}") 698 | writer.write(f"{key} = {value}\n") 699 | 700 | # Need to save the state, since Trainer.save_model saves only the tokenizer with the model 701 | trainer.state.save_to_json(os.path.join( 702 | training_args.output_dir, "trainer_state.json")) 703 | 704 | # if training_args.predict_with_generate: 705 | # train_preds = tokenizer.batch_decode( 706 | # train_pred_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=True 707 | # ) 708 | # train_preds = [pred.replace('', '').replace('', '').replace('', '').strip() 709 | # for pred in train_preds] 710 | # output_train_preds_file = os.path.join(training_args.output_dir, "train_preds_seq2seq.txt") 711 | # with open(output_train_preds_file, "w") as writer: 712 | # writer.write("\n".join(train_preds)) 713 | 714 | # Evaluation 715 | results = {} 716 | if training_args.do_eval: 717 | logger.info("*** Evaluate ***") 718 | 719 | results = trainer.evaluate( 720 | max_length=data_args.val_max_target_length, num_beams=data_args.num_beams) 721 | results = {k: round(v, 4) for k, v in results.items()} 722 | 723 | eval_results = trainer.predict( 724 | eval_dataset, 725 | metric_key_prefix="eval", 726 | max_length=data_args.val_max_target_length, 727 | num_beams=data_args.num_beams, 728 | ) 729 | 730 | output_eval_file = os.path.join( 731 | training_args.output_dir, "eval_results_seq2seq.txt") 732 | if trainer.is_world_process_zero(): 733 | with open(output_eval_file, "w") as writer: 734 | logger.info("***** Eval results *****") 735 | for key, value in sorted(results.items()): 736 | logger.info(f" {key} = {value}") 737 | writer.write(f"{key} = {value}\n") 738 | 739 | if training_args.predict_with_generate: 740 | eval_preds = tokenizer.batch_decode( 741 | eval_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=True 742 | ) 743 | eval_preds = [pred.replace('', '').replace('', '').replace('', '').strip() 744 | for pred in eval_preds] 745 | output_test_preds_file = os.path.join( 746 | training_args.output_dir, "eval_preds_seq2seq.txt") 747 | with open(output_test_preds_file, "w") as writer: 748 | writer.write("\n".join(eval_preds)) 749 | 750 | if training_args.do_predict: 751 | logger.info("*** Test ***") 752 | 753 | test_results = trainer.predict( 754 | test_dataset, 755 | metric_key_prefix="test", 756 | max_length=data_args.val_max_target_length, 757 | num_beams=data_args.num_beams, 758 | ) 759 | test_metrics = test_results.metrics 760 | test_metrics["test_loss"] = round(test_metrics["test_loss"], 4) 761 | 762 | output_test_result_file = os.path.join( 763 | training_args.output_dir, "test_results_seq2seq.txt") 764 | if trainer.is_world_process_zero(): 765 | with open(output_test_result_file, "w") as writer: 766 | logger.info("***** Test results *****") 767 | for key, value in sorted(test_metrics.items()): 768 | logger.info(f" {key} = {value}") 769 | writer.write(f"{key} = {value}\n") 770 | 771 | if training_args.predict_with_generate: 772 | test_preds = tokenizer.batch_decode( 773 | test_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=True 774 | ) 775 | test_preds = [pred.replace('', '').replace('', '').replace('', '').strip() 776 | for pred in test_preds] 777 | output_test_preds_file = os.path.join( 778 | training_args.output_dir, "test_preds_seq2seq.txt") 779 | with open(output_test_preds_file, "w") as writer: 780 | writer.write("\n".join(test_preds)) 781 | 782 | return results 783 | 784 | 785 | def _mp_fn(index): 786 | # For xla_spawn (TPUs) 787 | main() 788 | 789 | 790 | if __name__ == "__main__": 791 | main() 792 | -------------------------------------------------------------------------------- /run_seq2seq_span.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding:utf-8 -*- 3 | EXP_ID=$(date +%F-%H-%M-$RANDOM) 4 | export CUDA_VISIBLE_DEVICES="0" 5 | export batch_size="16" 6 | export model_name=t5-base 7 | export data_name=dyiepp_ace2005_subtype 8 | 9 | export lr=5e-5 # t5-large 10 | export task_name="event" 11 | export seed="421" 12 | export lr_scheduler=constant_with_warmup 13 | export label_smoothing="0" 14 | export epoch=25 15 | export decoding_format='treespan' 16 | export eval_steps=500 17 | export warmup_steps=2000 18 | export constraint_decoding='--constraint_decoding' 19 | export metric_format=eval_role-F1 20 | 21 | OPTS=$(getopt -o b:d:m:i:t:s:l:f: --long batch:,device:,model:,data:,task:,seed:,lr:,lr_scheduler:,label_smoothing:,epoch:,format:,eval_steps:,metric_format:,warmup_steps:,pretrain:,wo_constraint_decoding -n 'parse-options' -- "$@") 22 | 23 | if [ $? != 0 ]; then 24 | echo "Failed parsing options." >&2 25 | exit 1 26 | fi 27 | 28 | eval set -- "$OPTS" 29 | 30 | while true; do 31 | case "$1" in 32 | -b | --batch) 33 | batch_size="$2" 34 | shift 35 | shift 36 | ;; 37 | -d | --device) 38 | CUDA_VISIBLE_DEVICES="$2" 39 | shift 40 | shift 41 | ;; 42 | -m | --model) 43 | model_name="$2" 44 | shift 45 | shift 46 | ;; 47 | -i | --data) 48 | data_name="$2" 49 | shift 50 | shift 51 | ;; 52 | -t | --task) 53 | task_name="$2" 54 | shift 55 | shift 56 | ;; 57 | -s | --seed) 58 | seed="$2" 59 | shift 60 | shift 61 | ;; 62 | -l | --lr) 63 | lr="$2" 64 | shift 65 | shift 66 | ;; 67 | -f | --format) 68 | decoding_format="$2" 69 | shift 70 | shift 71 | ;; 72 | --lr_scheduler) 73 | lr_scheduler="$2" 74 | shift 75 | shift 76 | ;; 77 | --label_smoothing) 78 | label_smoothing="$2" 79 | shift 80 | shift 81 | ;; 82 | --epoch) 83 | epoch="$2" 84 | shift 85 | shift 86 | ;; 87 | --eval_steps) 88 | eval_steps="$2" 89 | shift 90 | shift 91 | ;; 92 | --metric_format) 93 | metric_format="$2" 94 | shift 95 | shift 96 | ;; 97 | --warmup_steps) 98 | warmup_steps="$2" 99 | shift 100 | shift 101 | ;; 102 | --wo_constraint_decoding) 103 | constraint_decoding="" 104 | shift 105 | ;; 106 | --) 107 | shift 108 | break 109 | ;; 110 | *) 111 | echo "$1" not recognize. 112 | exit 113 | ;; 114 | esac 115 | done 116 | 117 | # google/mt5-base -> google_mt5-base 118 | model_name_log=$(echo ${model_name} | sed -s "s/\//_/g") 119 | 120 | model_folder=models/span_${EXP_ID}_${model_name_log}_${decoding_format}_${data_name}_${lr_scheduler}_lr${lr}_ls${label_smoothing}_${batch_size}_wu${warmup_steps} 121 | data_folder=data/text2target/${data_name} 122 | 123 | output_dir=${model_folder} 124 | mkdir -p ${output_dir} 125 | 126 | CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} python run_seq2seq.py \ 127 | --do_train --do_eval --do_predict ${constraint_decoding} \ 128 | --label_smoothing_sum=False \ 129 | --use_fast_tokenizer=False \ 130 | --evaluation_strategy steps \ 131 | --predict_with_generate \ 132 | --metric_for_best_model ${metric_format} \ 133 | --save_total_limit 1 \ 134 | --load_best_model_at_end \ 135 | --max_source_length=256 \ 136 | --max_target_length=128 \ 137 | --num_train_epochs=${epoch} \ 138 | --task=${task_name} \ 139 | --train_file=${data_folder}/train.json \ 140 | --validation_file=${data_folder}/val.json \ 141 | --test_file=${data_folder}/test.json \ 142 | --event_schema=${data_folder}/event.schema \ 143 | --per_device_train_batch_size=${batch_size} \ 144 | --per_device_eval_batch_size=$((batch_size * 4)) \ 145 | --output_dir=${output_dir}/span_pretrain \ 146 | --logging_dir=${output_dir}/span_pretrain_log \ 147 | --model_name_or_path=${model_name} \ 148 | --learning_rate=${lr} \ 149 | --lr_scheduler_type=${lr_scheduler} \ 150 | --label_smoothing_factor=${label_smoothing} \ 151 | --eval_steps ${eval_steps} \ 152 | --decoding_format ${decoding_format} \ 153 | --warmup_steps ${warmup_steps} \ 154 | --source_prefix="span: " \ 155 | --seed=${seed} --disable_tqdm True >${output_dir}/span_pretrain.log 2>${output_dir}/span_pretrain.log 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /run_tri_predict.bash: -------------------------------------------------------------------------------- 1 | export device="0" 2 | export et_model_path="XXX" 3 | export tri_model_path="XXX" 4 | export et_data_name=dyiepp_ace2005_text2et_subtype_span 5 | export task_name="event" 6 | export batch=16 7 | export constraint_decoding="--constraint_decoding" 8 | export decoding_format='tri' 9 | 10 | et_data_folder=data/text2target/${et_data_name} 11 | 12 | 13 | # et result 转化 14 | python convert_et_result.py \ 15 | --et_pred_file=${et_model_path}/test_preds_seq2seq.txt \ 16 | --et_text_file=${et_data_folder}/test.json \ 17 | --et_output_file=${et_data_folder}/test_pre_et.json \ 18 | --schema_file=${et_data_folder}/event.schema \ 19 | --mode="tri" 20 | 21 | # # 依赖上述的转化文件进行 role predict 预测, decode_format = noetrtspan 22 | CUDA_VISIBLE_DEVICES=${device} python run_seq2seq.py \ 23 | --do_predict --task=${task_name} --predict_with_generate \ 24 | --validation_file=${et_data_folder}/val.json \ 25 | --test_file=${et_data_folder}/test_pre_et.json \ 26 | --event_schema=${et_data_folder}/event.schema \ 27 | --model_name_or_path=${tri_model_path} \ 28 | --output_dir="${tri_model_path}"_ettest \ 29 | --source_prefix="span: " \ 30 | ${constraint_decoding} \ 31 | --per_device_eval_batch_size=${batch} \ 32 | --decoding_format ${decoding_format} 33 | 34 | # evaluate 评估, 结果输出到文件 35 | python evaluation.py \ 36 | --text_file=${et_data_folder}/test_pre_et.json \ 37 | --pred_file="${tri_model_path}"_ettest/test_preds_seq2seq.txt \ 38 | --gold_file="data/raw_data/dyiepp_ace2005/test_convert.json" \ 39 | --schema_file=${et_data_folder}/event.schema \ 40 | --decoding_format ${decoding_format} \ 41 | --format="dyiepp" > "${tri_model_path}"_ettest/total_dyiepp_result.txt 42 | 43 | # evaluate 评估, 结果打印在控制台 44 | python evaluation.py \ 45 | --text_file=${et_data_folder}/test_pre_et.json \ 46 | --pred_file="${tri_model_path}"_ettest/test_preds_seq2seq.txt \ 47 | --gold_file="data/raw_data/dyiepp_ace2005/test_convert.json" \ 48 | --schema_file=${et_data_folder}/event.schema \ 49 | --decoding_format ${decoding_format} \ 50 | --format="dyiepp" 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(1, os.path.dirname(os.path.realpath(__file__))) 5 | -------------------------------------------------------------------------------- /seq2seq/constrained_seq2seq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import logging 4 | import os 5 | 6 | import torch 7 | torch.set_printoptions(profile="full") # 输出tensor全部内容, 输出到控制台调试时使用 8 | import torch.nn as nn 9 | from dataclasses import dataclass, field 10 | from typing import Union, List, Callable, Dict, Tuple, Any, Optional 11 | import numpy as np 12 | from torch.cuda.amp import autocast 13 | 14 | from transformers import ( 15 | PreTrainedTokenizer, 16 | EvalPrediction, 17 | Seq2SeqTrainer, 18 | Seq2SeqTrainingArguments, 19 | ) 20 | 21 | from extraction.event_schema import EventSchema 22 | from extraction.extract_constraint import get_constraint_decoder 23 | from extraction.extraction_metrics import get_extract_metrics 24 | from seq2seq.label_smoother_sum import SumLabelSmoother 25 | from seq2seq.utils import lmap 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def add_logging_file(training_args): 31 | fh = logging.FileHandler(os.path.join(training_args.output_dir.rstrip(os.sep) + '.log')) 32 | fh.setLevel(logging.DEBUG) 33 | logger.addHandler(fh) 34 | 35 | 36 | def decode_tree_str(sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor"], 37 | tokenizer: PreTrainedTokenizer) -> List[str]: 38 | def clean_tree_text(x): 39 | return x.replace('', '').replace('', '').replace('', '').strip() 40 | 41 | sequences = np.where(sequences != -100, sequences, tokenizer.pad_token_id) 42 | 43 | str_list = tokenizer.batch_decode(sequences, skip_special_tokens=False) 44 | return lmap(clean_tree_text, str_list) 45 | 46 | 47 | def build_compute_extract_metrics_event_fn(decoding_type_schema: EventSchema, 48 | decoding_format: str, 49 | tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]: 50 | def non_pad_len(tokens: np.ndarray) -> int: 51 | return np.count_nonzero(tokens != tokenizer.pad_token_id) 52 | 53 | def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: 54 | return decode_tree_str(pred.predictions, tokenizer), decode_tree_str(pred.label_ids, tokenizer) 55 | 56 | def extraction_metrics(pred: EvalPrediction) -> Dict: 57 | pred_str, label_str = decode_pred(pred) 58 | extraction = get_extract_metrics(pred_lns=pred_str, tgt_lns=label_str, label_constraint=decoding_type_schema, 59 | decoding_format=decoding_format) 60 | # rouge: Dict = calculate_rouge(pred_str, label_str) 61 | summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) 62 | extraction.update({"gen_len": summ_len}) 63 | # extraction.update( ) 64 | return extraction 65 | 66 | compute_metrics_fn = extraction_metrics 67 | return compute_metrics_fn 68 | 69 | 70 | @dataclass 71 | class ConstraintSeq2SeqTrainingArguments(Seq2SeqTrainingArguments): 72 | """ 73 | Parameters: 74 | constraint_decoding (:obj:`bool`, `optional`, defaults to :obj:`False`): 75 | Whether to use Constraint Decoding 76 | structure_weight (:obj:`float`, `optional`, defaults to :obj:`None`): 77 | """ 78 | constraint_decoding: bool = field(default=False, metadata={"help": "Whether to Constraint Decoding or not."}) 79 | label_smoothing_sum: bool = field(default=False, 80 | metadata={"help": "Whether to use sum token loss for label smoothing"}) 81 | 82 | 83 | class ConstraintSeq2SeqTrainer(Seq2SeqTrainer): 84 | def __init__(self, decoding_type_schema=None, decoding_format='tree', source_prefix=None, *args, **kwargs): 85 | super().__init__(*args, **kwargs) 86 | 87 | self.decoding_format = decoding_format 88 | self.decoding_type_schema = decoding_type_schema # 在event任务中为 event schema 89 | 90 | # Label smoothing by sum token loss, different from different Label smootheing 91 | if self.args.label_smoothing_sum and self.args.label_smoothing_factor != 0: 92 | self.label_smoother = SumLabelSmoother(epsilon=self.args.label_smoothing_factor) 93 | print('Using %s' % self.label_smoother) 94 | elif self.args.label_smoothing_factor != 0: 95 | print('Using %s' % self.label_smoother) 96 | else: 97 | self.label_smoother = None 98 | 99 | if self.args.constraint_decoding: 100 | self.constraint_decoder = get_constraint_decoder(tokenizer=self.tokenizer, 101 | type_schema=self.decoding_type_schema, 102 | decoding_schema=self.decoding_format, 103 | source_prefix=source_prefix) 104 | else: 105 | self.constraint_decoder = None 106 | 107 | def prediction_step( 108 | self, 109 | model: nn.Module, 110 | inputs: Dict[str, Union[torch.Tensor, Any]], 111 | prediction_loss_only: bool, 112 | ignore_keys: Optional[List[str]] = None, 113 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 114 | """ 115 | Perform an evaluation step on :obj:`model` using obj:`inputs`. 116 | 117 | Subclass and override to inject custom behavior. 118 | 119 | Args: 120 | model (:obj:`nn.Module`): 121 | The model to evaluate. 122 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 123 | The inputs and targets of the model. 124 | 125 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 126 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 127 | prediction_loss_only (:obj:`bool`): 128 | Whether or not to return the loss only. 129 | 130 | Return: 131 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 132 | labels (each being optional). 133 | """ 134 | 135 | def prefix_allowed_tokens_fn(batch_id, sent): 136 | # print(self.tokenizer.convert_ids_to_tokens(inputs['labels'][batch_id])) 137 | src_sentence = inputs['input_ids'][batch_id] 138 | # print("input_ids:", inputs.keys()) 139 | # print("src_sentece:", src_sentence) 140 | # print("sent:", sent) 141 | # print(self.constraint_decoder.constraint_decoding(src_sentence=src_sentence, tgt_generated=sent)) 142 | return self.constraint_decoder.constraint_decoding(src_sentence=src_sentence, 143 | tgt_generated=sent) 144 | 145 | if not self.args.predict_with_generate or prediction_loss_only: 146 | return super().prediction_step( 147 | model=model, 148 | inputs=inputs, 149 | prediction_loss_only=prediction_loss_only, 150 | ignore_keys=ignore_keys, 151 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn if self.constraint_decoder else None, 152 | ) 153 | 154 | has_labels = "labels" in inputs 155 | inputs = self._prepare_inputs(inputs) 156 | 157 | gen_kwargs = { 158 | "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, 159 | "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, 160 | "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn if self.constraint_decoder else None, 161 | } 162 | 163 | # 调试信息 164 | # print("input_ids:", inputs["input_ids"]) 165 | # print("attention_mask:", inputs["attention_mask"]) 166 | # print("gen_kwargs: ", gen_kwargs) 167 | 168 | generated_tokens = self.model.generate( 169 | inputs["input_ids"], 170 | attention_mask=inputs["attention_mask"], 171 | **gen_kwargs, 172 | ) 173 | # print("before: ", generated_tokens[0]) 174 | # in case the batch is shorter than max length, the output should be padded 175 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 176 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 177 | # print("after: ", generated_tokens[0]) 178 | 179 | with torch.no_grad(): 180 | if self.use_amp: 181 | with autocast(): 182 | outputs = model(**inputs) 183 | else: 184 | outputs = model(**inputs) 185 | if has_labels: 186 | if self.label_smoother is not None: 187 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 188 | else: 189 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 190 | else: 191 | loss = None 192 | 193 | if self.args.prediction_loss_only: 194 | return loss, None, None 195 | 196 | labels = inputs["labels"] 197 | 198 | if labels.shape[-1] < gen_kwargs["max_length"]: 199 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 200 | 201 | return loss, generated_tokens, labels 202 | 203 | 204 | def main(): pass 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | -------------------------------------------------------------------------------- /seq2seq/label_smoother_sum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import torch 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class SumLabelSmoother: 9 | """ 10 | Adds label-smoothing on a pre-computed output from a Transformers model. 11 | 12 | Args: 13 | epsilon (:obj:`float`, `optional`, defaults to 0.1): 14 | The label smoothing factor. 15 | ignore_index (:obj:`int`, `optional`, defaults to -100): 16 | The index in the labels to ignore when computing the loss. 17 | """ 18 | 19 | epsilon: float = 0.1 20 | ignore_index: int = -100 21 | 22 | def __call__(self, model_output, labels): 23 | logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0] 24 | log_probs = -torch.nn.functional.log_softmax(logits, dim=-1) 25 | if labels.dim() == log_probs.dim() - 1: 26 | labels = labels.unsqueeze(-1) 27 | 28 | padding_mask = labels.eq(self.ignore_index) 29 | # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask 30 | # will ignore them in any case. 31 | labels.clamp_min_(0) 32 | nll_loss = log_probs.gather(dim=-1, index=labels) 33 | smoothed_loss = log_probs.sum(dim=-1, keepdim=True) 34 | 35 | nll_loss.masked_fill_(padding_mask, 0.0) 36 | smoothed_loss.masked_fill_(padding_mask, 0.0) 37 | 38 | # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): 39 | # num_active_elements = padding_mask.numel() - padding_mask.long().sum() 40 | nll_loss = nll_loss.sum() # / num_active_elements 41 | smoothed_loss = smoothed_loss.sum() # / (num_active_elements * log_probs.shape[-1]) 42 | eps_i = self.epsilon / log_probs.size(-1) 43 | return (1 - self.epsilon) * nll_loss + eps_i * smoothed_loss 44 | -------------------------------------------------------------------------------- /seq2seq/sentence_splitter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import re 16 | 17 | from filelock import FileLock 18 | 19 | try: 20 | import nltk 21 | 22 | NLTK_AVAILABLE = True 23 | except (ImportError, ModuleNotFoundError): 24 | NLTK_AVAILABLE = False 25 | 26 | # if NLTK_AVAILABLE: 27 | # if not os.path.exists("~/nltk_data/tokenizers/punkt"): 28 | # print('Start download punk') 29 | # with FileLock(".lock") as lock: 30 | # nltk.download("punkt", quiet=True) 31 | # else: 32 | # print('Skip download punk') 33 | 34 | 35 | def add_newline_to_end_of_each_sentence(x: str) -> str: 36 | """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" 37 | re.sub("", "", x) # remove pegasus newline char 38 | assert NLTK_AVAILABLE, "nltk must be installed to separate newlines between sentences. (pip install nltk)" 39 | return "\n".join(nltk.sent_tokenize(x)) 40 | -------------------------------------------------------------------------------- /seq2seq/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import itertools 16 | import json 17 | import linecache 18 | import math 19 | import os 20 | import pickle 21 | import socket 22 | from logging import getLogger 23 | from pathlib import Path 24 | from typing import Callable, Dict, Iterable, List, Tuple, Union 25 | 26 | import git 27 | import numpy as np 28 | import torch 29 | import torch.distributed as dist 30 | from rouge_score import rouge_scorer, scoring 31 | from sacrebleu import corpus_bleu 32 | from torch import nn 33 | from torch.utils.data import Dataset, Sampler 34 | 35 | from seq2seq.sentence_splitter import add_newline_to_end_of_each_sentence 36 | from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer 37 | from transformers.file_utils import cached_property 38 | 39 | try: 40 | from fairseq.data.data_utils import batch_by_size 41 | 42 | FAIRSEQ_AVAILABLE = True 43 | except (ImportError, ModuleNotFoundError): 44 | FAIRSEQ_AVAILABLE = False 45 | 46 | 47 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): 48 | """From fairseq""" 49 | if target.dim() == lprobs.dim() - 1: 50 | target = target.unsqueeze(-1) 51 | nll_loss = -lprobs.gather(dim=-1, index=target) 52 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 53 | if ignore_index is not None: 54 | pad_mask = target.eq(ignore_index) 55 | nll_loss.masked_fill_(pad_mask, 0.0) 56 | smooth_loss.masked_fill_(pad_mask, 0.0) 57 | else: 58 | nll_loss = nll_loss.squeeze(-1) 59 | smooth_loss = smooth_loss.squeeze(-1) 60 | 61 | nll_loss = nll_loss.sum() # mean()? Scared to break other math. 62 | smooth_loss = smooth_loss.sum() 63 | eps_i = epsilon / lprobs.size(-1) 64 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 65 | return loss, nll_loss 66 | 67 | 68 | def lmap(f: Callable, x: Iterable) -> List: 69 | """list(map(f, x))""" 70 | return list(map(f, x)) 71 | 72 | 73 | def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: 74 | """Uses sacrebleu's corpus_bleu implementation.""" 75 | return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} 76 | 77 | 78 | def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]: 79 | def non_pad_len(tokens: np.ndarray) -> int: 80 | return np.count_nonzero(tokens != tokenizer.pad_token_id) 81 | 82 | def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: 83 | pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) 84 | label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) 85 | pred_str = lmap(str.strip, pred_str) 86 | label_str = lmap(str.strip, label_str) 87 | return pred_str, label_str 88 | 89 | def summarization_metrics(pred: EvalPrediction) -> Dict: 90 | pred_str, label_str = decode_pred(pred) 91 | rouge: Dict = calculate_rouge(pred_str, label_str) 92 | summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) 93 | rouge.update({"gen_len": summ_len}) 94 | return rouge 95 | 96 | def translation_metrics(pred: EvalPrediction) -> Dict: 97 | pred_str, label_str = decode_pred(pred) 98 | bleu: Dict = calculate_bleu(pred_str, label_str) 99 | gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) 100 | bleu.update({"gen_len": gen_len}) 101 | return bleu 102 | 103 | compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics 104 | return compute_metrics_fn 105 | 106 | 107 | def trim_batch( 108 | input_ids, 109 | pad_token_id, 110 | attention_mask=None, 111 | ): 112 | """Remove columns that are populated exclusively by pad_token_id""" 113 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 114 | if attention_mask is None: 115 | return input_ids[:, keep_column_mask] 116 | else: 117 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 118 | 119 | 120 | class AbstractSeq2SeqDataset(Dataset): 121 | def __init__( 122 | self, 123 | tokenizer, 124 | data_dir, 125 | max_source_length, 126 | max_target_length, 127 | type_path="train", 128 | n_obs=None, 129 | prefix="", 130 | **dataset_kwargs 131 | ): 132 | super().__init__() 133 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 134 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 135 | self.len_file = Path(data_dir).joinpath(type_path + ".len") 136 | if os.path.exists(self.len_file): 137 | self.src_lens = pickle_load(self.len_file) 138 | self.used_char_len = False 139 | else: 140 | self.src_lens = self.get_char_lens(self.src_file) 141 | self.used_char_len = True 142 | self.max_source_length = max_source_length 143 | self.max_target_length = max_target_length 144 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 145 | self.tokenizer = tokenizer 146 | self.prefix = prefix if prefix is not None else "" 147 | 148 | if n_obs is not None: 149 | self.src_lens = self.src_lens[:n_obs] 150 | self.pad_token_id = self.tokenizer.pad_token_id 151 | self.dataset_kwargs = dataset_kwargs 152 | dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}) 153 | 154 | def __len__(self): 155 | return len(self.src_lens) 156 | 157 | @staticmethod 158 | def get_char_lens(data_file): 159 | return [len(x) for x in Path(data_file).open().readlines()] 160 | 161 | @cached_property 162 | def tgt_lens(self): 163 | """Length in characters of target documents""" 164 | return self.get_char_lens(self.tgt_file) 165 | 166 | def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs): 167 | if distributed: 168 | return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs) 169 | else: 170 | return SortishSampler(self.src_lens, batch_size, shuffle=shuffle) 171 | 172 | def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs): 173 | assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`" 174 | assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler" 175 | sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False)) 176 | 177 | def num_tokens_in_example(i): 178 | return min(self.src_lens[i], self.max_target_length) 179 | 180 | # call fairseq cython function 181 | batch_sampler: List[List[int]] = batch_by_size( 182 | sorted_indices, 183 | num_tokens_fn=num_tokens_in_example, 184 | max_tokens=max_tokens_per_batch, 185 | required_batch_size_multiple=64, 186 | ) 187 | shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))] 188 | # move the largest batch to the front to OOM quickly (uses an approximation for padding) 189 | approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches] 190 | largest_batch_idx = np.argmax(approximate_toks_per_batch) 191 | shuffled_batches[0], shuffled_batches[largest_batch_idx] = ( 192 | shuffled_batches[largest_batch_idx], 193 | shuffled_batches[0], 194 | ) 195 | return shuffled_batches 196 | 197 | def __getitem__(self, item): 198 | raise NotImplementedError("You must implement this") 199 | 200 | def collate_fn(self, batch): 201 | raise NotImplementedError("You must implement this") 202 | 203 | 204 | class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): 205 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 206 | """Call tokenizer on src and tgt_lines""" 207 | index = index + 1 # linecache starts at 1 208 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 209 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 210 | assert source_line, f"empty source line for index {index}" 211 | assert tgt_line, f"empty tgt line for index {index}" 212 | source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length) 213 | target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length) 214 | 215 | source_ids = source_inputs["input_ids"].squeeze() 216 | target_ids = target_inputs["input_ids"].squeeze() 217 | src_mask = source_inputs["attention_mask"].squeeze() 218 | return { 219 | "input_ids": source_ids, 220 | "attention_mask": src_mask, 221 | "labels": target_ids, 222 | } 223 | 224 | def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): 225 | """Only used by LegacyDataset""" 226 | return tokenizer( 227 | [line], 228 | max_length=max_length, 229 | padding="max_length" if pad_to_max_length else None, 230 | truncation=True, 231 | return_tensors=return_tensors, 232 | **self.dataset_kwargs, 233 | ) 234 | 235 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 236 | input_ids = torch.stack([x["input_ids"] for x in batch]) 237 | masks = torch.stack([x["attention_mask"] for x in batch]) 238 | target_ids = torch.stack([x["labels"] for x in batch]) 239 | pad_token_id = self.pad_token_id 240 | y = trim_batch(target_ids, pad_token_id) 241 | source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 242 | batch = { 243 | "input_ids": source_ids, 244 | "attention_mask": source_mask, 245 | "labels": y, 246 | } 247 | return batch 248 | 249 | 250 | class Seq2SeqDataset(AbstractSeq2SeqDataset): 251 | """A dataset that calls prepare_seq2seq_batch.""" 252 | 253 | def __getitem__(self, index) -> Dict[str, str]: 254 | index = index + 1 # linecache starts at 1 255 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 256 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 257 | assert source_line, f"empty source line for index {index}" 258 | assert tgt_line, f"empty tgt line for index {index}" 259 | return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1} 260 | 261 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 262 | """Call prepare_seq2seq_batch.""" 263 | batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( 264 | [x["src_texts"] for x in batch], 265 | tgt_texts=[x["tgt_texts"] for x in batch], 266 | max_length=self.max_source_length, 267 | max_target_length=self.max_target_length, 268 | return_tensors="pt", 269 | **self.dataset_kwargs, 270 | ).data 271 | batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) 272 | return batch_encoding 273 | 274 | 275 | class Seq2SeqDataCollator: 276 | def __init__(self, tokenizer, data_args, tpu_num_cores=None): 277 | self.tokenizer = tokenizer 278 | self.pad_token_id = tokenizer.pad_token_id 279 | assert ( 280 | self.pad_token_id is not None 281 | ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." 282 | self.data_args = data_args 283 | self.tpu_num_cores = tpu_num_cores 284 | self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 285 | if data_args.src_lang is not None: 286 | self.dataset_kwargs["src_lang"] = data_args.src_lang 287 | if data_args.tgt_lang is not None: 288 | self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang 289 | 290 | def __call__(self, batch) -> Dict[str, torch.Tensor]: 291 | if hasattr(self.tokenizer, "prepare_seq2seq_batch"): 292 | batch = self._encode(batch) 293 | input_ids, attention_mask, labels = ( 294 | batch["input_ids"], 295 | batch["attention_mask"], 296 | batch["labels"], 297 | ) 298 | else: 299 | input_ids = torch.stack([x["input_ids"] for x in batch]) 300 | attention_mask = torch.stack([x["attention_mask"] for x in batch]) 301 | labels = torch.stack([x["labels"] for x in batch]) 302 | 303 | labels = trim_batch(labels, self.pad_token_id) 304 | input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) 305 | 306 | batch = { 307 | "input_ids": input_ids, 308 | "attention_mask": attention_mask, 309 | "labels": labels, 310 | } 311 | return batch 312 | 313 | def _shift_right_t5(self, input_ids): 314 | # shift inputs to the right 315 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 316 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 317 | shifted_input_ids[..., 0] = self.pad_token_id 318 | return shifted_input_ids 319 | 320 | def _encode(self, batch) -> Dict[str, torch.Tensor]: 321 | batch_encoding = self.tokenizer.prepare_seq2seq_batch( 322 | [x["src_texts"] for x in batch], 323 | tgt_texts=[x["tgt_texts"] for x in batch], 324 | max_length=self.data_args.max_source_length, 325 | max_target_length=self.data_args.max_target_length, 326 | padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack 327 | return_tensors="pt", 328 | **self.dataset_kwargs, 329 | ) 330 | return batch_encoding.data 331 | 332 | 333 | class SortishSampler(Sampler): 334 | "Go through the text data by order of src length with a bit of randomness. From fastai repo." 335 | 336 | def __init__(self, data, batch_size, shuffle=True): 337 | self.data, self.bs, self.shuffle = data, batch_size, shuffle 338 | 339 | def __len__(self) -> int: 340 | return len(self.data) 341 | 342 | def __iter__(self): 343 | return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle)) 344 | 345 | 346 | def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array: 347 | "Go through the text data by order of src length with a bit of randomness. From fastai repo." 348 | if not shuffle: 349 | return np.argsort(np.array(data) * -1) 350 | 351 | def key_fn(i): 352 | return data[i] 353 | 354 | idxs = np.random.permutation(len(data)) 355 | sz = bs * 50 356 | ck_idx = [idxs[i: i + sz] for i in range(0, len(idxs), sz)] 357 | sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx]) 358 | sz = bs 359 | ck_idx = [sort_idx[i: i + sz] for i in range(0, len(sort_idx), sz)] 360 | max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, 361 | ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. 362 | sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) 363 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) 364 | return sort_idx 365 | 366 | 367 | class DistributedSortishSampler(Sampler): 368 | """Copied from torch DistributedSampler""" 369 | 370 | def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True): 371 | if num_replicas is None: 372 | if not dist.is_available(): 373 | raise RuntimeError("Requires distributed package to be available") 374 | num_replicas = dist.get_world_size() 375 | if rank is None: 376 | if not dist.is_available(): 377 | raise RuntimeError("Requires distributed package to be available") 378 | rank = dist.get_rank() 379 | self.dataset = dataset 380 | self.num_replicas = num_replicas 381 | self.rank = rank 382 | self.epoch = 0 383 | if add_extra_examples: 384 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 385 | self.total_size = self.num_samples * self.num_replicas 386 | else: 387 | self.total_size = len(dataset) 388 | self.num_samples = len(self.available_indices) 389 | self.batch_size = batch_size 390 | self.add_extra_examples = add_extra_examples 391 | self.shuffle = shuffle 392 | 393 | def __iter__(self) -> Iterable: 394 | g = torch.Generator() 395 | g.manual_seed(self.epoch) 396 | 397 | sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] 398 | sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle) 399 | indices = [self.available_indices[i] for i in sortish_indices] 400 | assert len(indices) == self.num_samples 401 | return iter(indices) 402 | 403 | @cached_property 404 | def available_indices(self) -> np.array: 405 | indices = list(range(len(self.dataset))) 406 | # add extra samples to make it evenly divisible 407 | indices += indices[: (self.total_size - len(indices))] 408 | assert len(indices) == self.total_size 409 | # subsample 410 | available_indices = indices[self.rank: self.total_size: self.num_replicas] 411 | return available_indices 412 | 413 | def __len__(self): 414 | return self.num_samples 415 | 416 | def set_epoch(self, epoch): 417 | self.epoch = epoch 418 | 419 | 420 | logger = getLogger(__name__) 421 | 422 | 423 | def use_task_specific_params(model, task): 424 | """Update config with summarization specific params.""" 425 | task_specific_params = model.config.task_specific_params 426 | 427 | if task_specific_params is not None: 428 | pars = task_specific_params.get(task, {}) 429 | logger.info(f"setting model.config to task specific params for {task}:\n {pars}") 430 | logger.info("note: command line args may override some of these") 431 | model.config.update(pars) 432 | 433 | 434 | def pickle_load(path): 435 | """pickle.load(path)""" 436 | with open(path, "rb") as f: 437 | return pickle.load(f) 438 | 439 | 440 | def pickle_save(obj, path): 441 | """pickle.dump(obj, path)""" 442 | with open(path, "wb") as f: 443 | return pickle.dump(obj, f) 444 | 445 | 446 | def flatten_list(summary_ids: List[List]): 447 | return [x for x in itertools.chain.from_iterable(summary_ids)] 448 | 449 | 450 | def save_git_info(folder_path: str) -> None: 451 | """Save git information to output_dir/git_log.json""" 452 | repo_infos = get_git_info() 453 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 454 | 455 | 456 | def save_json(content, path, indent=4, **json_dump_kwargs): 457 | with open(path, "w") as f: 458 | json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs) 459 | 460 | 461 | def load_json(path): 462 | with open(path) as f: 463 | return json.load(f) 464 | 465 | 466 | def get_git_info(): 467 | try: 468 | repo = git.Repo(search_parent_directories=True) 469 | repo_infos = { 470 | "repo_id": str(repo), 471 | "repo_sha": str(repo.head.object.hexsha), 472 | "repo_branch": str(repo.active_branch), 473 | "hostname": str(socket.gethostname()), 474 | } 475 | return repo_infos 476 | except TypeError: 477 | return { 478 | "repo_id": None, 479 | "repo_sha": None, 480 | "repo_branch": None, 481 | "hostname": None, 482 | } 483 | 484 | 485 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 486 | 487 | 488 | def extract_rouge_mid_statistics(dct): 489 | new_dict = {} 490 | for k1, v1 in dct.items(): 491 | mid = v1.mid 492 | new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]} 493 | return new_dict 494 | 495 | 496 | def calculate_rouge( 497 | pred_lns: List[str], 498 | tgt_lns: List[str], 499 | use_stemmer=True, 500 | rouge_keys=ROUGE_KEYS, 501 | return_precision_and_recall=False, 502 | bootstrap_aggregation=True, 503 | newline_sep=True, 504 | ) -> Dict: 505 | """Calculate rouge using rouge_scorer package. 506 | Args: 507 | pred_lns: list of summaries generated by model 508 | tgt_lns: list of groundtruth summaries (e.g. contents of val.target) 509 | use_stemmer: Bool indicating whether Porter stemmer should be used to 510 | strip word suffixes to improve matching. 511 | rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum 512 | return_precision_and_recall: (False) whether to also return precision and recall. 513 | bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False 514 | this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]`` 515 | newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL 516 | on multi sentence summaries (CNN/DM dataset). 517 | Returns: 518 | Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys 519 | """ 520 | scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer) 521 | aggregator = scoring.BootstrapAggregator() 522 | for pred, tgt in zip(tgt_lns, pred_lns): 523 | # rougeLsum expects "\n" separated sentences within a summary 524 | if newline_sep: 525 | pred = add_newline_to_end_of_each_sentence(pred) 526 | tgt = add_newline_to_end_of_each_sentence(tgt) 527 | scores = scorer.score(pred, tgt) 528 | aggregator.add_scores(scores) 529 | 530 | if bootstrap_aggregation: 531 | result = aggregator.aggregate() 532 | if return_precision_and_recall: 533 | return extract_rouge_mid_statistics(result) # here we return dict 534 | else: 535 | return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} 536 | 537 | else: 538 | return aggregator._scores # here we return defaultdict(list) 539 | 540 | 541 | # Utilities for freezing parameters and checking whether they are frozen 542 | 543 | 544 | def freeze_params(model: nn.Module): 545 | """Set requires_grad=False for each of model.parameters()""" 546 | for par in model.parameters(): 547 | par.requires_grad = False 548 | 549 | 550 | def freeze_embeds(model): 551 | """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" 552 | model_type = model.config.model_type 553 | 554 | if model_type == "t5": 555 | freeze_params(model.shared) 556 | for d in [model.encoder, model.decoder]: 557 | freeze_params(d.embed_tokens) 558 | elif model_type == "fsmt": 559 | for d in [model.model.encoder, model.model.decoder]: 560 | freeze_params(d.embed_positions) 561 | freeze_params(d.embed_tokens) 562 | else: 563 | freeze_params(model.model.shared) 564 | for d in [model.model.encoder, model.model.decoder]: 565 | freeze_params(d.embed_positions) 566 | freeze_params(d.embed_tokens) 567 | 568 | 569 | def grad_status(model: nn.Module) -> Iterable: 570 | return (par.requires_grad for par in model.parameters()) 571 | 572 | 573 | def any_requires_grad(model: nn.Module) -> bool: 574 | return any(grad_status(model)) 575 | 576 | 577 | def assert_all_frozen(model): 578 | model_grads: List[bool] = list(grad_status(model)) 579 | n_require_grad = sum(lmap(int, model_grads)) 580 | npars = len(model_grads) 581 | assert not any(model_grads), f"{n_require_grad / npars:.1%} of {npars} weights require grad" 582 | 583 | 584 | def assert_not_all_frozen(model): 585 | model_grads: List[bool] = list(grad_status(model)) 586 | npars = len(model_grads) 587 | assert any(model_grads), f"none of {npars} weights require grad" 588 | 589 | 590 | def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]: 591 | """ 592 | Parse an argv list of unspecified command line args to a dict. 593 | Assumes all values are either numeric or boolean in the form of true/false. 594 | """ 595 | result = {} 596 | assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}" 597 | num_pairs = len(unparsed_args) // 2 598 | for pair_num in range(num_pairs): 599 | i = 2 * pair_num 600 | assert unparsed_args[i].startswith("--") 601 | if unparsed_args[i + 1].lower() == "true": 602 | value = True 603 | elif unparsed_args[i + 1].lower() == "false": 604 | value = False 605 | else: 606 | try: 607 | value = int(unparsed_args[i + 1]) 608 | except ValueError: 609 | value = float(unparsed_args[i + 1]) # this can raise another informative ValueError 610 | 611 | result[unparsed_args[i][2:]] = value 612 | return result 613 | 614 | 615 | def write_txt_file(ordered_tgt, path): 616 | f = Path(path).open("w") 617 | for ln in ordered_tgt: 618 | f.write(ln + "\n") 619 | f.flush() 620 | 621 | 622 | def chunks(lst, n): 623 | """Yield successive n-sized chunks from lst.""" 624 | for i in range(0, len(lst), n): 625 | yield lst[i: i + n] 626 | 627 | 628 | def check_output_dir(args, expected_items=0): 629 | """ 630 | Checks whether to bail out if output_dir already exists and has more than expected_items in it 631 | `args`: needs to have the following attributes of `args`: 632 | - output_dir 633 | - do_train 634 | - overwrite_output_dir 635 | `expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM) 636 | """ 637 | if ( 638 | os.path.exists(args.output_dir) 639 | and len(os.listdir(args.output_dir)) > expected_items 640 | and args.do_train 641 | and not args.overwrite_output_dir 642 | ): 643 | raise ValueError( 644 | f"Output directory ({args.output_dir}) already exists and " 645 | f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). " 646 | "Use --overwrite_output_dir to overcome." 647 | ) 648 | --------------------------------------------------------------------------------