├── README.md
├── assets
├── Figure1.jpg
├── Figure1.pdf
├── IEInstruct.png
├── closedie.png
├── example.png
├── inference.png
├── on-demandie.png
└── openie.png
├── data
├── CoNLL2003
│ ├── processed
│ │ ├── rel_des.json
│ │ └── train.json
│ └── readme.md
├── Readme.md
├── ace2005-en
│ ├── label2id_des.json
│ └── role2id_des.json
└── tacred
│ ├── .DS_Store
│ └── rel_des.json
├── requirements.txt
├── scripts
├── generate_dpo_sample_data.sh
├── generate_mixtural_dpo_data.sh
├── generate_mixtural_train_data.sh
├── generate_test_set.sh
├── generate_unified_data.sh
├── tasks
│ ├── ROBUST
│ │ ├── __pycache__
│ │ │ └── desc.cpython-310.pyc
│ │ ├── data
│ │ │ ├── README.md
│ │ │ └── results.txt
│ │ ├── desc.py
│ │ ├── load_data_fs.py
│ │ ├── reformat_results_open.py
│ │ └── src
│ │ │ ├── analysis
│ │ │ ├── ctk_similarity.py
│ │ │ ├── hws_distance.py
│ │ │ └── tag_dict.json
│ │ │ ├── diversified_filter.py
│ │ │ ├── robust_scorer.py
│ │ │ └── utils
│ │ │ ├── CaRB
│ │ │ ├── .gitignore
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── carb.py
│ │ │ ├── matcher.py
│ │ │ ├── oie_readers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── argument.py
│ │ │ │ ├── benchmarkGoldReader.py
│ │ │ │ ├── clausieReader.py
│ │ │ │ ├── extraction.py
│ │ │ │ ├── goldReader.py
│ │ │ │ ├── oieReader.py
│ │ │ │ ├── ollieReader.py
│ │ │ │ ├── openieFiveReader.py
│ │ │ │ ├── openieFourReader.py
│ │ │ │ ├── propsReader.py
│ │ │ │ ├── reVerbReader.py
│ │ │ │ ├── split_corpus.py
│ │ │ │ ├── stanfordReader.py
│ │ │ │ └── tabReader.py
│ │ │ ├── pr_plot.py
│ │ │ └── requirements.txt
│ │ │ ├── __init__.py
│ │ │ └── __pycache__
│ │ │ └── __init__.cpython-310.pyc
│ ├── ace2005-eae
│ │ ├── __pycache__
│ │ │ └── desc.cpython-310.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── ace2005-ed
│ │ ├── __pycache__
│ │ │ ├── desc.cpython-310.pyc
│ │ │ └── desc.cpython-311.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── ace2005-ner
│ │ ├── __pycache__
│ │ │ ├── desc.cpython-310.pyc
│ │ │ └── desc.cpython-311.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── conll-2003
│ │ ├── __pycache__
│ │ │ └── desc.cpython-310.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── fewnerd
│ │ ├── load_data_fs.py
│ │ └── open_evaluate.py
│ ├── fewrel
│ │ ├── __pycache__
│ │ │ └── desc.cpython-310.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── instructions.json
│ ├── matres
│ │ ├── load_data_fs.py
│ │ └── open_eval.py
│ ├── maven-arg
│ │ ├── __pycache__
│ │ │ └── desc.cpython-310.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── maven-ed
│ │ ├── __pycache__
│ │ │ └── desc.cpython-310.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── maven-ere
│ │ ├── __pycache__
│ │ │ └── desc.cpython-310.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── ondemandie
│ │ ├── evaluation
│ │ │ ├── rougel_for_content.py
│ │ │ └── sim_for_header.py
│ │ ├── exact_match
│ │ │ ├── README.md
│ │ │ ├── app.py
│ │ │ ├── exact_match.py
│ │ │ └── requirements.txt
│ │ ├── load_data_fs.py
│ │ ├── o_generate_evaluatefile.py
│ │ └── rouge
│ │ │ ├── README.md
│ │ │ ├── app.py
│ │ │ ├── requirements.txt
│ │ │ └── rouge.py
│ ├── ontonote5
│ │ ├── __pycache__
│ │ │ ├── desc.cpython-310.pyc
│ │ │ └── label_encoding.cpython-310.pyc
│ │ ├── desc.py
│ │ ├── label_encoding.py
│ │ └── load_data_fs.py
│ ├── openie4
│ │ ├── __pycache__
│ │ │ ├── desc.cpython-310.pyc
│ │ │ └── desc.cpython-38.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── rams
│ │ ├── __pycache__
│ │ │ └── desc.cpython-310.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
│ ├── readme.md
│ ├── richere-eae
│ │ ├── load_data_fs.py
│ │ └── open_evaluation.py
│ ├── richere-ed
│ │ ├── load_data_fs.py
│ │ └── open_eval.py
│ ├── semeval
│ │ ├── load_data_fs.py
│ │ └── open_evaluate.py
│ └── tacred
│ │ ├── __pycache__
│ │ └── desc.cpython-310.pyc
│ │ ├── desc.py
│ │ └── load_data_fs.py
└── utils
│ ├── DPO
│ ├── __pycache__
│ │ └── ref_query.cpython-310.pyc
│ ├── compute_metric_4OpenInstruct.py
│ ├── generate_dpo_data_TuluFormat.py
│ ├── merge.py
│ └── ref_query.py
│ ├── fewshot_testdatasets.py
│ ├── filter_train_NA_data.py
│ ├── gpts
│ ├── GenerateIdx.py
│ ├── GenerateInstance4GPT.py
│ ├── GenerateInstance4GPT_Cot.py
│ ├── Postprocessing.py
│ ├── Prompt.py
│ ├── PromptCot.py
│ ├── __pycache__
│ │ ├── Prompt.cpython-310.pyc
│ │ ├── Prompt.cpython-38.pyc
│ │ ├── PromptCot.cpython-310.pyc
│ │ ├── PromptCot.cpython-38.pyc
│ │ └── desc.cpython-38.pyc
│ ├── gpt-4.py
│ ├── run.sh
│ ├── template
│ │ ├── __pycache__
│ │ │ ├── desc4openie.cpython-310.pyc
│ │ │ └── desc4openie.cpython-38.pyc
│ │ ├── desc.py
│ │ ├── desc4openie.py
│ │ └── template_generate.py
│ └── test.py
│ ├── mixture_task_fs_tuluformat.py
│ ├── mixture_task_fs_tuluformat_Analysis.py
│ └── reformat_tuluv2.py
├── train4llama
├── LICENSE
├── README_TULU.md
├── ds_configs
│ ├── stage2.conf
│ ├── stage3_no_offloading.conf
│ ├── stage3_no_offloading_accelerate.conf
│ ├── stage3_offloading.conf
│ └── stage3_offloading_accelerate.conf
├── eval
│ ├── __pycache__
│ │ ├── dispatch_openai_requests.cpython-310.pyc
│ │ ├── templates.cpython-310.pyc
│ │ └── utils.cpython-310.pyc
│ ├── dispatch_openai_requests.py
│ ├── predict.py
│ ├── templates.py
│ └── utils.py
├── open_instruct
│ ├── __pycache__
│ │ └── dpo_utils.cpython-310.pyc
│ ├── dpo_tune.py
│ ├── dpo_utils.py
│ ├── finetune.py
│ ├── finetune_trainer.py
│ ├── get_statistics.py
│ ├── gradio_demo.py
│ ├── gradio_demo_chat.py
│ ├── instruction_encode_templates.py
│ ├── merge_lora.py
│ ├── reformat_datasets.py
│ └── safe_save_trainer.py
├── requirements.txt
├── scripts
│ ├── dpo_train_with_accelerate.sh
│ ├── eval.sh
│ ├── finetune_with_accelerate.sh
│ └── predict.sh
└── weight-diff-requirements.txt
└── unified_data
└── test_format
├── fewshot_test_history
├── MATRES.jsonl
├── ROBUST.jsonl
├── few-nerd-supervised.jsonl
└── semeval.jsonl
└── zeroshot
├── MATRES.jsonl
├── ROBUST.jsonl
├── few-nerd-supervised.jsonl
├── ondemand.jsonl
└── semeval.jsonl
/assets/Figure1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/assets/Figure1.jpg
--------------------------------------------------------------------------------
/assets/Figure1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/assets/Figure1.pdf
--------------------------------------------------------------------------------
/assets/IEInstruct.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/assets/IEInstruct.png
--------------------------------------------------------------------------------
/assets/closedie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/assets/closedie.png
--------------------------------------------------------------------------------
/assets/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/assets/example.png
--------------------------------------------------------------------------------
/assets/inference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/assets/inference.png
--------------------------------------------------------------------------------
/assets/on-demandie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/assets/on-demandie.png
--------------------------------------------------------------------------------
/assets/openie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/assets/openie.png
--------------------------------------------------------------------------------
/data/CoNLL2003/processed/rel_des.json:
--------------------------------------------------------------------------------
1 | {
2 | "Person": [
3 | "PER",
4 | "person"
5 | ],
6 | "Organization": [
7 | "ORG",
8 | "organization"
9 | ],
10 | "Location": [
11 | "LOC",
12 | "location"
13 | ],
14 | "Miscellaneous": [
15 | "MISC",
16 | "name of miscellaneous"
17 | ]
18 | }
--------------------------------------------------------------------------------
/data/CoNLL2003/readme.md:
--------------------------------------------------------------------------------
1 | Each line contains four fields: the word, its part-of-speech tag, its chunk tag and its named entity tag.
2 | Words tagged with O are outside of named entities and the I-XXX tag is used for words inside a named entity of type XXX. Whenever two entities of type XXX are immediately next to each other, the first word of the second entity will be tagged B-XXX in order to show that it starts another entity.
3 | The data contains entities of four types: persons (PER), organizations (ORG), locations (LOC) and miscellaneous names (MISC).
--------------------------------------------------------------------------------
/data/Readme.md:
--------------------------------------------------------------------------------
1 | ## Original Datasets
2 | Our script is capable of processing the following datasets.
3 |
4 |
5 |
6 |
7 | Among our investigated tasks, the copyright of ``TACRED``, ``ACE 2005``, and ``RichERE`` belongs to ``LDC2`` and we access them through our LDC membership. All the other datasets are open-sourced, and we strictly adhere to their licenses.
8 |
9 | ### Prepare datasets for training and testing
10 |
13 | You can download the original data in this directory and modify the ``input_dir`` of the corresponding task in the ``./scripts/generate_unified_data.sh`` directory to generate the IEInstruct full version data set.
--------------------------------------------------------------------------------
/data/ace2005-en/role2id_des.json:
--------------------------------------------------------------------------------
1 | {
2 | "Time-Within": "Time-Within is the most common association between a temporal expression and a event.",
3 | "Instrument": "The instrument used in the attack or used to killed or used to inflict the harm",
4 | "Target": "The target of the event.",
5 | "Place": "Where the Event takes place",
6 | "Artifact": "The person doing the traveling or the artifact being transported",
7 | "Destination": "Where the person is extradited to, the destination",
8 | "Attacker": "The attacking/instigating agent",
9 | "Origin": "Where the Event/Person originated",
10 | "Agent": "The agent(PER/ORG/GPE/FAC) in the event",
11 | "Victim": "The harmed person(s) / The person who died",
12 | "Entity": "the ORG/GPE",
13 | "Time-Holds": "Use Time-Holds when the context explicitly states that the event lasts for the entire time interval.",
14 | "Vehicle": "The vehicle used to transport the person or artifact",
15 | "Beneficiary": "The agent that benefits from the transaction",
16 | "Buyer": "The buying agent",
17 | "Person": "The person in the event",
18 | "Position": "JOB-TITLE",
19 | "Time-Starting": "when the context explicitly indicates that the event begins at a given time.",
20 | "Time-After": "when the context explicitly states that the event occurs after the given time interval",
21 | "Time-Before": "when the context explicitly states that the event occurs before the given time interval.",
22 | "Time-Ending": "when the context explicitly indicates that the event ends at a given time.",
23 | "Seller": "The selling agent",
24 | "Org": "The organization in the event",
25 | "Giver": "The donating agent",
26 | "Recipient": "The recipient agent",
27 | "Prosecutor": "The prosecuting agent",
28 | "Money": "The amount given, donated or loaned",
29 | "Defendant": "The convicted agent(s)",
30 | "Time-At-Beginning": "Time-At-Beginning means something happened at the beginning of some time period.",
31 | "Time-At-End": "Time-At-End means something happened at the end of some time period",
32 | "Plaintiff": "The suing agent",
33 | "Adjudicator": "the judge or court",
34 | "Sentence": "The sentence that has been leveled against the DEFENDANTARG following conviction",
35 | "Crime": "The crime for which the Justice Event has been undertaken",
36 | "Price": "The job which the PERSONNEL Event is concerned with"
37 | }
--------------------------------------------------------------------------------
/data/tacred/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/data/tacred/.DS_Store
--------------------------------------------------------------------------------
/data/tacred/rel_des.json:
--------------------------------------------------------------------------------
1 | {
2 | "org:founded": "The organization, or geopolitical entity that was founded by the assigned organization. ",
3 | "org:subsidiaries": "Organizations that are subsidiaries of the assigned organization (the inverse of org:parents).",
4 | "per:date_of_birth": "The date on which the assigned person was born. ",
5 | "per:cause_of_death": "The explicit cause of death for the assigned person.",
6 | "per:age": "A reported age of the assigned person.",
7 | "per:stateorprovince_of_birth": "The geopolitical entity at state or province level in which the assigned person was born.",
8 | "per:countries_of_residence": "All countries in which the assigned person has lived",
9 | "per:country_of_birth": " The country in which the assigned person was born.",
10 | "per:stateorprovinces_of_residence": "Geopolitical entities at the state or province level in which the assigned person has lived.",
11 | "org:website": "An official top level URL for the organization's website.",
12 | "per:cities_of_residence": "Geopolitical entities at the level of city, town, or village in which the assigned person has lived.",
13 | "per:parents": "The parents of the assigned person.",
14 | "per:employee_of": "The organizations or geopolitical entities (governments) of which the assigned person has been an employee or member.",
15 | "no_relation": "There is no relation between the subject and object entity.",
16 | "per:city_of_birth": "The geopolitical entity at the municipality level (city, town, or village) in which the assigned person was born.",
17 | "org:parents": ": Organizations or geopolitical entities of which the assigned organization is a subsidiary (the inverse of org:subsidiaries).",
18 | "org:political/religious_affiliation": "Ideological groups with which the organization is associated.",
19 | "per:schools_attended": "Any school (college, high school, university, etc.) that the assigned person has attended. ",
20 | "per:country_of_death": " The country in which the assigned person died.",
21 | "per:children": "The children of the assigned person, including adopted and step-children.",
22 | "org:top_members/employees": "The persons in high-level, leading positions at the assigned organization.",
23 | "per:date_of_death": "The date of the assigned person's death.",
24 | "org:members": "Organizations or Geopolitical entities that are members of the assigned organization (the inverse of org:member_of).",
25 | "org:alternate_names": "Any name used to refer to the assigned organization that is distinct from the 'official' name.",
26 | "per:religion": "The religion to which the assigned person has belonged.",
27 | "org:member_of": "Organizations or geopolitical entities of which the assigned organization is a member itself (the inverse of org:members). ",
28 | "org:city_of_headquarters": "Location of the headquarters of the assigned organization at the city, town, or village level.",
29 | "per:origin": "The nationality and/or ethnicity of the assigned person. ",
30 | "org:shareholders": "Any organization, person, or geopolitical entity that holds shares (majority or not) of the organization. ",
31 | "per:charges": "The charges or crimes (alleged or convicted) of the assigned person.",
32 | "per:title": "Official or unofficial name(s) of the employment or membership positions that have been held by the assigned person.",
33 | "org:number_of_employees/members": "The total number of people who are employed by or have membership in an organization.",
34 | "org:dissolved": "The date on which the assigned organization was dissolved.",
35 | "org:country_of_headquarters": "Countries in which the headquarters of the assigned organization are located.",
36 | "per:alternate_names": "Names used to refer to the assigned person that are distinct from the 'official' name",
37 | "per:siblings": "The brothers and sisters of the assigned person.",
38 | "org:stateorprovince_of_headquarters": "Location of the headquarters of the query organization at the state or province level.",
39 | "per:spouse": "The spouse(s) of the assigned person.",
40 | "per:other_family": "Family other than siblings, parents, children, and spouse (or former spouse).",
41 | "per:city_of_death": "The geopolitical entity at the level of city, town, village in which the assigned person died.",
42 | "per:stateorprovince_of_death": "The geopolitical entity at state or province level in which the assigned person died.",
43 | "org:founded_by": "The person, organization, or geopolitical entity that founded the assigned organization. "
44 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch<=2.0.1
2 | scipy
3 | packaging
4 | sentencepiece
5 | datasets
6 | deepspeed>=0.10.0
7 | accelerate>=0.21.0,<0.23.0 # 0.23.0 will cause an incorrect learning rate schedule when using deepspeed, which is likely caused by https://github.com/huggingface/accelerate/commit/727d624322c67db66a43c559d8c86414d5ffb537
8 | peft>=0.4.0
9 | bitsandbytes>=0.41.1
10 | evaluate>=0.4.0
11 | tokenizers>=0.13.3
12 | protobuf
13 | # Transformers library (v4.34.0) still has a bug for left padding,
14 | # and significantly affect the inference and thus our evaluation performance (e.g., MMLU and TruthfulQA).
15 | # Follwing PR is a temporary fix for it but has not been merged yet.
16 | # See https://github.com/huggingface/transformers/pull/25284
17 | # But this PR is not compatible with the latest version of Transformers library (v4.34.0).
18 | # To incorporate it, we forked the Transformers library and made some changes to make it compatible with the latest version.
19 | git+https://github.com/yizhongw/transformers.git@left_padding
20 | openai>=1.5.0
21 | tiktoken
22 | rouge_score
23 | tensorboard
24 | wandb
25 | gradio==3.50.2
26 | termcolor
27 | jsonlines
28 | unidic-lite
29 | einops
30 | flash-attn==2.2.2
31 | auto-gptq
32 | fire
33 | alpaca-eval==0.5.3
34 | # for human eval web app
35 | flask
36 | vllm
37 | openpyxl
--------------------------------------------------------------------------------
/scripts/generate_dpo_sample_data.sh:
--------------------------------------------------------------------------------
1 |
2 | python utils/DPO/generate_dpo_data_TuluFormat.py \
3 | --unified_data_dir ../unified_data \
4 | --hold_in_datasets conll-2003 ace2005-ner ontonote5 fewrel tacred ace2005-ed maven-ed ace2005-eae RAMS-eae maven-eae MAVEN-ERE openie4\
5 | --ondemandIE_dir ../unified_data/ondemandIE \
6 | --ondemand_cot_rate 0.3 \
7 | --Limit_total_data 52000 \
8 | --Limit_dataset 4000 \
9 | --version "DPO" \
10 | --reserve_all_gptdata 0 \
11 | --WORD_Limit 1200 \
12 |
13 |
14 |
--------------------------------------------------------------------------------
/scripts/generate_mixtural_dpo_data.sh:
--------------------------------------------------------------------------------
1 |
2 | for p in {1..5}
3 | do
4 | python utils/DPO/compute_metric_4OpenInstruct.py \
5 | --input_path ../unified_data/train_mixture/sample4dpo_results/ADELIE-SFT/T_1.0_$p/mix_vDPO
6 | done
7 |
8 | python utils/DPO/merge.py
--------------------------------------------------------------------------------
/scripts/generate_mixtural_train_data.sh:
--------------------------------------------------------------------------------
1 |
2 | #过滤训练数据中的NA,比例为 valid:na=0.8:0.2
3 | python utils/filter_train_NA_data.py
4 |
5 |
6 | #生成 tuluv2 数据
7 | python utils/reformat_tuluv2.py
8 |
9 | #生成 ondemandie 数据
10 | python tasks/ondemandie/load_data_fs.py
11 |
12 |
13 | #生成 IEInstruct 数据, output_path: /ADELIE/unified_data/train_mixture
14 | python utils/mixture_task_fs_tuluformat.py \
15 | --unified_data_dir ../unified_data \
16 | --hold_in_datasets conll-2003 ace2005-ner ontonote5 fewrel tacred ace2005-ed maven-ed ace2005-eae RAMS-eae MAVEN-ERE maven-eae ee other_ner other_rc re openie4\
17 | --ondemandIE_dir ../unified_data/ondemandIE \
18 | --ondemand_cot_rate 0.2 \
19 | --general_training_file ../unified_data/tuluv2/train.jsonl \
20 | --Limit_total_data 400000 \
21 | --General_rate 0.8 \
22 | --Limit_dataset 5000 \
23 | --version "IEInstruct" \
24 | --general_filter 0 \
25 | --reserve_all_gptdata 0
26 |
27 |
--------------------------------------------------------------------------------
/scripts/generate_test_set.sh:
--------------------------------------------------------------------------------
1 | # 生成 fewshot 和 zeroshot 测试数据
2 | python utils/fewshot_testdatasets.py \
3 | --input_dir ../unified_data_longshot \
4 | --output_dir ../unified_data/test_format_32shot \
5 | --hold_out_datasets few-nerd-supervised semeval RichERE-ed RichERE-eae MATRES ROBUST \
6 | --num_shot 4
7 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ROBUST/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/data/README.md:
--------------------------------------------------------------------------------
1 | Please refer to the download link in home dir.
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/data/results.txt:
--------------------------------------------------------------------------------
1 | The carb scores AUC: 0.41274842767295605, P: 0.6247916666666666, R: 0.49396933962264145, F1: 0.5517316482252379
2 | The robust scores AUC: 0.26252044025157234, P: 0.41670833333333335, R: 0.3520542452830189, F1: 0.3816625363291085
3 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/reformat_results_open.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import re
4 | import argparse
5 |
6 |
7 | def get_args(id, pred):
8 | # print("ID",id,len(pred))
9 | try:
10 | prediction = pred[id]["output"].split("[Answer]: ")[1]
11 | except:
12 | prediction = pred[id]["output"]
13 | pattern = re.compile("\((.*)\)")
14 | triple = re.findall(pattern, prediction)
15 | # print("PRED:",prediction)
16 | # print("TRIPLE: ",triple)
17 | args = []
18 | for vert in triple:
19 | arg = vert.split("; ")
20 | args.append(arg)
21 |
22 | # print("ARGS:",args)
23 | return args
24 |
25 |
26 | def reformat(input_file):
27 | pred = []
28 | with open(input_file) as f:
29 | for line in f.readlines():
30 | instance = json.loads(line.strip())
31 | pred.append(instance)
32 |
33 | with open(
34 | "/SSD_DATA/qyj/Alignment_on_IE_tasks/scripts/tasks/ROBUST/data/ROBUST.json", "r"
35 | ) as f:
36 | gold = json.load(f)
37 | f.close()
38 |
39 | result = []
40 | idx = -1
41 | for instance in gold:
42 | # ori_sentence
43 | idx += 1
44 | args = get_args(idx, pred)
45 | instance["ori_args"] = args
46 | # para
47 | for i, para in enumerate(instance["paraphrases"]):
48 | idx += 1
49 | args = get_args(idx, pred)
50 | instance["paraphrases"][i]["args"] = args
51 | result.append(instance)
52 |
53 | print("idx:", idx)
54 | print("pred num:", len(pred))
55 | return result
56 |
57 |
58 | if __name__ == "__main__":
59 | parser = argparse.ArgumentParser(description="tacred")
60 | parser.add_argument(
61 | "--input_dir",
62 | type=str,
63 | default="/SSD_DATA/qyj/Alignment_on_IE_tasks/scripts/tasks/ROBUST/data/",
64 | )
65 | args = parser.parse_args()
66 |
67 | result = reformat(args.input_dir)
68 | out_file = open("result.json", "w")
69 | json.dump(result, out_file, indent=1)
70 | out_file.close()
71 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/analysis/ctk_similarity.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ROBUST/src/analysis/ctk_similarity.py
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/analysis/tag_dict.json:
--------------------------------------------------------------------------------
1 | {
2 | "ROOT": 0,
3 | "S": 1,
4 | "NP": 2,
5 | "NNP": 3,
6 | "VP": 4,
7 | "VBD": 5,
8 | ":": 6,
9 | "SBAR": 7,
10 | "PP": 8,
11 | "IN": 9,
12 | "DT": 10,
13 | "NN": 11,
14 | ",": 12,
15 | "CD": 13,
16 | "NNS": 14,
17 | "JJ": 15,
18 | ".": 16,
19 | "CC": 17,
20 | "VBG": 18,
21 | "TO": 19,
22 | "VB": 20,
23 | "WHPP": 21,
24 | "WHNP": 22,
25 | "WDT": 23,
26 | "WP": 24,
27 | "ADJP": 25,
28 | "VBN": 26,
29 | "ADVP": 27,
30 | "RB": 28,
31 | "PRP": 29,
32 | "PRT": 30,
33 | "RP": 31,
34 | "PRP$": 32,
35 | "MD": 33,
36 | "JJR": 34,
37 | "NNPS": 35,
38 | "QP": 36,
39 | "POS": 37,
40 | "VBZ": 38,
41 | "PRN": 39,
42 | "-LRB-": 40,
43 | "-RRB-": 41,
44 | "X": 42,
45 | "VBP": 43,
46 | "JJS": 44,
47 | "RBS": 45,
48 | "FW": 46,
49 | "CONJP": 47,
50 | "FRAG": 48,
51 | "WHADVP": 49,
52 | "WRB": 50,
53 | "$": 51,
54 | "WP$": 52,
55 | "SINV": 53,
56 | "RBR": 54,
57 | "UCP": 55,
58 | "EX": 56,
59 | "``": 57,
60 | "''": 58,
61 | "#": 59,
62 | "PDT": 60,
63 | "SQ": 61,
64 | "NX": 62,
65 | "NAC": 63,
66 | "RRC": 64,
67 | "SYM": 65,
68 | "LS": 66,
69 | "SBARQ": 67,
70 | "WHADJP": 68,
71 | "INTJ": 69,
72 | "LST": 70,
73 | "UH": 71,
74 | "NML": 72,
75 | "NFP": 73,
76 | "ADD": 74,
77 | "HYPH": 75
78 | }
79 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/diversified_filter.py:
--------------------------------------------------------------------------------
1 | ### Prepare annotation files
2 | """ Rules for creating dataset:
3 | Retain at most 3 sentences with heuristic segmentation method:
4 | (1) For a pair of two sentences s1,s2 having the highest BLEU score in the remaining set:
5 | a. if the length of one of them is less than 2/3 of original sentence, remove this one;
6 | a. otherwise, remove the one with highest sum of BLEU scores with all other sentences.
7 | (2) Remain at least 3 sentences.
8 |
9 |
10 | """
11 | SELF_IDS = list(range(10,1282,1282//20))
12 | n_sents = 0
13 | n_sents_para = 0
14 | ann_para = []
15 | ann_para_records = []
16 | ann_para_self = []
17 | for i, x in enumerate(result):
18 | # paraphrase sample
19 | x_para = {
20 | "ori_sent": x["ori_sent"],
21 | "paraphrases": [
22 | p["sent"] for p in x["paraphrases"]
23 | ]
24 | }
25 |
26 | x_para_record = {
27 | "ori_sent": x["ori_sent"],
28 | "paraphrases": [],
29 | "bleus": None
30 | }
31 |
32 | # apply heuristic rules
33 | _all_sents = sorted([p["sent"] for p in x["paraphrases"]], key=lambda _ss:len(_ss.split()))
34 | _all_sents_records =sorted([p["sent"] for p in x["paraphrases"]], key=lambda _ss:len(_ss.split()))
35 | ## length rule
36 | for _s in _all_sents:
37 | if len(_s.split()) < len(x["ori_sent"].split())*(2/3) and len(_all_sents)>3:
38 | _all_sents_records[_all_sents.index(_s)] = (_s, 'discard_length', len(_s.split()))
39 | _all_sents.pop(_all_sents.index(_s))
40 | ## bleu score rule
41 | if len(_all_sents) > 3:
42 | _sents_bleu = np.zeros([len(_all_sents),len(_all_sents)])
43 | for ii in range(len(_all_sents)):
44 | for iii in range(ii+1, len(_all_sents)):
45 | _b1 = sacrebleu.corpus_bleu([_all_sents[ii]], [[_all_sents[iii]]])
46 | _b2 = sacrebleu.corpus_bleu([_all_sents[iii]], [[_all_sents[ii]]])
47 | _sents_bleu[ii][iii] = _b1.score
48 | _sents_bleu[iii][ii] = _b2.score
49 | x_para_record["bleus"] = copy.copy(_sents_bleu).tolist()
50 | while len(_all_sents) > 3:
51 | _m1, _m2 = [_[0] for _ in np.where(_sents_bleu==_sents_bleu.max())]
52 | _sum_m1 = (_sents_bleu[_m1,:].sum() + _sents_bleu[:,_m1].sum()) / 2
53 | _sum_m2 = (_sents_bleu[_m2,:].sum() + _sents_bleu[:,_m2].sum()) / 2
54 | if _sum_m1 > _sum_m2:
55 | _all_sents.pop(_m1)
56 | _sents_bleu = np.delete(_sents_bleu, _m1, axis=0)
57 | _sents_bleu = np.delete(_sents_bleu, _m1, axis=1)
58 | _all_sents_records[_m1] = (_all_sents_records[_m1], 'discard_bleu', _sum_m1)
59 | else:
60 | _all_sents.pop(_m2)
61 | _sents_bleu = np.delete(_sents_bleu, _m2, axis=0)
62 | _sents_bleu = np.delete(_sents_bleu, _m2, axis=1)
63 | _all_sents_records[_m2] = (_all_sents_records[_m2], 'discard_bleu', _sum_m2)
64 | x_para["paraphrases"] = _all_sents
65 | x_para_record["paraphrases"] = _all_sents_records
66 | ann_para.append(x_para)
67 | ann_para_records.append(x_para_record)
68 |
69 | # paraphrase self-annotation sample
70 | if i in SELF_IDS:
71 | x_para_self = {
72 | "ori_sent": x["ori_sent"],
73 | "paraphrases":[
74 | {
75 | "sent":_s,
76 | "annotation_sent": " ",
77 | "annotation_args": [
78 | {"arg1":" ", "pred": " ", "arg2": " "},
79 | {"arg1":" ", "pred": " ", "arg2": " "}
80 | ]
81 | } for _s in _all_sents
82 | ]
83 | }
84 | ann_para_self.append(x_para_self)
85 | n_sents += 1
86 | n_sents_para += len(x_para["paraphrases"])
87 |
88 | # prepare shuffled all samples
89 | np.random.shuffle(ann_para)
90 | np.random.shuffle(ann_para_self)
91 | # save
92 | with open(result_ann_para, "w") as f:
93 | json.dump(ann_para, f, indent=4)
94 | with open(result_ann_para[:-4]+'records.json', "w") as f:
95 | json.dump(ann_para_records, f, indent=4)
96 | with open(result_ann_para_self, "w") as f:
97 | json.dump(ann_para_self, f, indent=4)
98 | print(f"total # of instances: {n_sents}")
99 | print(f"total # of paraphrases: {n_sents_para}")
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/robust_scorer.py:
--------------------------------------------------------------------------------
1 | """ Calculate the robust scorers of P, R, F1 for the predicted file and gold file of unified json format.
2 | """
3 |
4 | import sys
5 |
6 | sys.path.append("src/utils/CaRB")
7 |
8 | import os
9 | import json
10 | import argparse
11 | import numpy as np
12 | from utils.CaRB.carb import Benchmark
13 | from utils.CaRB.oie_readers.extraction import Extraction
14 | from utils.CaRB.matcher import Matcher
15 | from operator import itemgetter
16 |
17 |
18 | def read_unified(fn):
19 | """Build inputs of CaRB format, each element for a robust sample containing an original sentences with multiple paraphrases"""
20 | with open(fn, "r") as f:
21 | f_data = json.load(f)
22 |
23 | results = {} # ori sent as keys
24 | for X in f_data:
25 | # for CaRB sentence
26 | sample = {}
27 | d_carb = {}
28 | for ori_args in X["ori_args"]:
29 | curExtraction = Extraction(
30 | pred=ori_args[0], head_pred_index=-1, sent=X["ori_sent"], confidence=-1
31 | )
32 | for i in range(1, len(ori_args)):
33 | curExtraction.addArg(ori_args[i])
34 | d_carb[X["ori_sent"]] = d_carb.get(X["ori_sent"], []) + [curExtraction]
35 | sample[X["ori_sent"]] = d_carb
36 |
37 | # for paraphrases
38 | for para in X["paraphrases"]:
39 | d_para = {}
40 | for para_args in para["args"]:
41 | curExtraction = Extraction(
42 | pred=para_args[0],
43 | head_pred_index=-1,
44 | sent=para["sent"],
45 | confidence=-1,
46 | )
47 | for i in range(1, len(para_args)):
48 | curExtraction.addArg(para_args[i])
49 | d_para[para["sent"]] = d_para.get(para["sent"], []) + [curExtraction]
50 | sample[para["sent"]] = d_para
51 | results[X["ori_sent"]] = sample
52 | return results
53 |
54 |
55 | def main(gold_file, pred_file, save_dir):
56 | """ """
57 | pred = read_unified(pred_file)
58 | gold = read_unified(gold_file)
59 |
60 | all_scores = []
61 | for ori_sent in pred:
62 | cur_scores = []
63 | for sent in pred[ori_sent]:
64 | p_extraction = pred[ori_sent][sent]
65 | g_extraction = gold[ori_sent][sent]
66 | # evaluate one by one
67 | bench = Benchmark()
68 | bench.gold = g_extraction
69 | auc, [p, r, f1] = bench.compare(
70 | p_extraction, Matcher.binary_linient_tuple_match
71 | )
72 | # keep the scores of carb in the first position
73 | if sent == ori_sent:
74 | cur_scores.insert(0, [auc, p, r, f1])
75 | # print("Pred: ", [(ext.pred, ext.args) for ext in p_extraction[sent]])
76 | # print("Gold", [(ext.pred, ext.args) for ext in g_extraction[sent]])
77 | # print([auc, p, r, f1])
78 | else:
79 | cur_scores.append([auc, p, r, f1])
80 | all_scores.append(cur_scores)
81 |
82 | print(len(all_scores))
83 | # get carb scores
84 | carb_scores = np.array(([sco[0] for sco in all_scores]))
85 | carb_auc = carb_scores[:, 0].mean()
86 | carb_p = carb_scores[:, 1].mean()
87 | carb_r = carb_scores[:, 2].mean()
88 | carb_f1 = Benchmark.f1(carb_p, carb_r)
89 |
90 | # get the robust socres with worst F1
91 | robust_scores = [min(score, key=itemgetter(3)) for score in all_scores]
92 | robust_scores = np.array(robust_scores)
93 | robust_auc = robust_scores[:, 0].mean()
94 | robust_p = robust_scores[:, 1].mean()
95 | robust_r = robust_scores[:, 2].mean()
96 | # robust_f1 = robust_scores[:, 3].mean()
97 | robust_f1 = Benchmark.f1(robust_p, robust_r)
98 |
99 | print(
100 | "The carb scores AUC: {}, P: {}, R: {}, F1: {}".format(
101 | carb_auc, carb_p, carb_r, carb_f1
102 | )
103 | )
104 | print(
105 | "The robust scores AUC: {}, P: {}, R: {}, F1: {}".format(
106 | robust_auc, robust_p, robust_r, robust_f1
107 | )
108 | )
109 |
110 | os.makedirs(save_dir, exist_ok=True)
111 | with open(os.path.join(save_dir, "results.txt"), "w") as f:
112 | f.write(
113 | "The carb scores AUC: {}, P: {}, R: {}, F1: {}\n".format(
114 | carb_auc, carb_p, carb_r, carb_f1
115 | )
116 | )
117 | f.write(
118 | "The robust scores AUC: {}, P: {}, R: {}, F1: {}\n".format(
119 | robust_auc, robust_p, robust_r, robust_f1
120 | )
121 | )
122 |
123 |
124 | if __name__ == "__main__":
125 | parser = argparse.ArgumentParser()
126 | parser.add_argument(
127 | "--pred_file",
128 | "-pred",
129 | type=str,
130 | default="result.json",
131 | help="The predicted json file with the unified ROBUST format.",
132 | )
133 | parser.add_argument(
134 | "--gold_file",
135 | "-gold",
136 | type=str,
137 | default="../../../data/openie6/ROBUST.json",
138 | help="The gold ROBUST json file.",
139 | )
140 | parser.add_argument(
141 | "--save_dir",
142 | "-s",
143 | type=str,
144 | default="data/",
145 | help="The directory to save results.",
146 | )
147 | args = parser.parse_args()
148 |
149 | main(args.gold_file, args.pred_file, args.save_dir)
150 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 |
3 |
4 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Data Analytics and Intelligence Research (DAIR) Group, IIT Delhi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/README.md:
--------------------------------------------------------------------------------
1 | # CaRB - A Crowdsourced Benchmark for Open IE
2 |
3 | CaRB : ***C***rowdsourced ***a***utomatic open ***R***elation extraction ***B***enchmark
4 |
5 |
6 | ## Introduction
7 |
8 | CaRB is a dataset cum evaluation framework for benchmarking Open Information Extraction systems.
9 |
10 | The details of this benchmark are elaborated in our [EMNLP 2019 Paper](https://www.aclweb.org/anthology/D19-1651/).
11 |
12 | ### Citing
13 | If you use this software, please cite:
14 | ```
15 | @inproceedings{bhardwaj-etal-2019-carb,
16 | title = "{C}a{RB}: A Crowdsourced Benchmark for Open {IE}",
17 | author = "Bhardwaj, Sangnie and
18 | Aggarwal, Samarth and
19 | Mausam, Mausam",
20 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
21 | month = nov,
22 | year = "2019",
23 | address = "Hong Kong, China",
24 | publisher = "Association for Computational Linguistics",
25 | url = "https://www.aclweb.org/anthology/D19-1651",
26 | doi = "10.18653/v1/D19-1651",
27 | pages = "6263--6268",
28 | abstract = "Open Information Extraction (Open IE) systems have been traditionally evaluated via manual annotation. Recently, an automated evaluator with a benchmark dataset (OIE2016) was released {--} it scores Open IE systems automatically by matching system predictions with predictions in the benchmark dataset. Unfortunately, our analysis reveals that its data is rather noisy, and the tuple matching in the evaluator has issues, making the results of automated comparisons less trustworthy. We contribute CaRB, an improved dataset and framework for testing Open IE systems. To the best of our knowledge, CaRB is the first crowdsourced Open IE dataset and it also makes substantive changes in the matching code and metrics. NLP experts annotate CaRB{'}s dataset to be more accurate than OIE2016. Moreover, we find that on one pair of Open IE systems, CaRB framework provides contradictory results to OIE2016. Human assessment verifies that CaRB{'}s ranking of the two systems is the accurate ranking. We release the CaRB framework along with its crowdsourced dataset.",
29 | }
30 | ```
31 |
32 | ### Contact
33 | Leave us a note at
34 | ```samarthaggarwal2510 (at) gmail (dot) com```
35 |
36 | ## Requirements
37 |
38 | * Python 3
39 | * See required python packages [here](requirements.txt).
40 |
41 |
42 |
43 | ## Evaluating an Open IE Extractor
44 |
45 | Currently, we support the following Open IE output formats:
46 |
47 | * [ClausIE](https://www.mpi-inf.mpg.de/departments/databases-and-information-systems/software/clausie/)
48 | * [OLLIE](http://knowitall.github.io/ollie/)
49 | * [OpenIE-4](https://github.com/allenai/openie-standalone)
50 | * [OpenIE-5](https://github.com/allenai/openie-standalone)
51 | * [PropS](http://u.cs.biu.ac.il/~stanovg/props.html)
52 | * [ReVerb](http://reverb.cs.washington.edu/)
53 | * [Stanford Open IE](http://nlp.stanford.edu/software/openie.html)
54 | * Tab Separated - Read simple tab format file, where each line consists of:
55 | sent, prob, pred,arg1, arg2, ...
56 |
57 | To evaluate your OpenIE system:
58 |
59 | 1. Run your extractor over the [dev sentences](data/dev.txt) or [test sentences](data/test.txt) and store the output into "*your_output*.txt"
60 |
61 | 2. Depending on your output format, you can get a precision-recall curve by running [carb.py](carb.py):
62 | ```
63 | Usage:
64 | python carb.py --gold=GOLD_OIE --out=OUTPUT_FILE (--stanford=STANFORD_OIE | --ollie=OLLIE_OIE |--reverb=REVERB_OIE | --clausie=CLAUSIE_OIE | --openiefour=OPENIEFOUR_OIE | --props=PROPS_OIE)
65 |
66 | Options:
67 | --gold=GOLD_OIE The gold reference Open IE file (by default, it should be under ./oie_corpus/all.oie).
68 | --out=OUTPUT_FILE The output file, into which the precision recall curve will be written.
69 | --clausie=CLAUSIE_OIE Read ClausIE format from file CLAUSIE_OIE.
70 | --ollie=OLLIE_OIE Read OLLIE format from file OLLIE_OIE.
71 | --openiefour=OPENIEFOUR_OIE Read Open IE 4 format from file OPENIEFOUR_OIE.
72 | --openiefive=OPENIEFIVE_OIE Read Open IE 5 format from file OPENIEFIVE_OIE.
73 | --props=PROPS_OIE Read PropS format from file PROPS_OIE
74 | --reverb=REVERB_OIE Read ReVerb format from file REVERB_OIE
75 | --stanford=STANFORD_OIE Read Stanford format from file STANFORD_OIE
76 | --tabbed=TABBED_OIE Read tabbed format from file TABBED_OIE
77 | ```
78 |
79 | ## Evaluating Existing Systems
80 |
81 | In the course of this work we tested the above mentioned Open IE parsers against our benchmark.
82 | We provide the output files (i.e., Open IE extractions) of each of these
83 | systems in [system_outputs/test](system_outputs/test).
84 | You can give each of these files to [carb.py](carb.py), to get the corresponding precision recall curve.
85 |
86 | For example, to evaluate Stanford Open IE output, run:
87 | ```
88 | python carb.py --gold=data/gold/test.tsv --out=dump/OpenIE-4.dat --openiefour=system_outputs/test/openie4_output.txt
89 | ```
90 |
91 | ## Plotting
92 |
93 | You can plot together multiple outputs of [carb.py](carb.py), by using [pr_plot.py](pr_plot.py):
94 |
95 | ```
96 | Usage:
97 | pr_plot --in=DIR_NAME --out=OUTPUT_FILENAME
98 |
99 | Options:
100 | --in=DIR_NAME Folder in which to search for *.dat files, all of which should be in a P/R column format (outputs from benchmark.py).
101 | --out=OUTPUT_FILENAME Output filename, filetype will determine the format. Possible formats: pdf, pgf, png
102 | ```
103 |
104 | ### References
105 |
106 | 1. Creating a large benchmark for Open Information Extraction - Stanovsky and Dagan, 2016
107 | 2. Analysing Errors of Open Information Extraction Systems - Schneider et al., 2017
108 | 3. Wire57 : A Fine-Grained Benchmark for Open Information Extraction - Lechelle et al., 2018
109 |
110 |
111 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ROBUST/src/utils/CaRB/__init__.py
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/__init__.py
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/argument.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | from operator import itemgetter
3 |
4 | class Argument:
5 | def __init__(self, arg):
6 | self.words = [x for x in arg[0].strip().split(' ') if x]
7 | self.posTags = map(itemgetter(1), nltk.pos_tag(self.words))
8 | self.indices = arg[1]
9 | self.feats = {}
10 |
11 | def __str__(self):
12 | return "({})".format('\t'.join(map(str,
13 | [escape_special_chars(' '.join(self.words)),
14 | str(self.indices)])))
15 |
16 | COREF = 'coref'
17 |
18 | ## Helper functions
19 | def escape_special_chars(s):
20 | return s.replace('\t', '\\t')
21 |
22 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/benchmarkGoldReader.py:
--------------------------------------------------------------------------------
1 | """ Usage:
2 | benchmarkGoldReader --in=INPUT_FILE
3 |
4 | Read a tab-formatted file.
5 | Each line consists of:
6 | sent, prob, pred, arg1, arg2, ...
7 |
8 | """
9 |
10 | from oie_readers.oieReader import OieReader
11 | from oie_readers.extraction import Extraction
12 | from docopt import docopt
13 | import logging
14 |
15 | logging.basicConfig(level = logging.DEBUG)
16 |
17 | class BenchmarkGoldReader(OieReader):
18 |
19 | def __init__(self):
20 | self.name = 'BenchmarkGoldReader'
21 |
22 | def read(self, fn):
23 | """
24 | Read a tabbed format line
25 | Each line consists of:
26 | sent, prob, pred, arg1, arg2, ...
27 | """
28 | d = {}
29 | ex_index = 0
30 | with open(fn) as fin:
31 | for line in fin:
32 | if not line.strip():
33 | continue
34 | data = line.strip().split('\t')
35 | text, rel = data[:2]
36 | curExtraction = Extraction(pred = rel.strip(),
37 | head_pred_index = None,
38 | sent = text.strip(),
39 | confidence = 1.0,
40 | question_dist = "./question_distributions/dist_wh_sbj_obj1.json",
41 | index = ex_index)
42 | ex_index += 1
43 |
44 | for arg in data[2:]:
45 | curExtraction.addArg(arg.strip())
46 |
47 | d[text] = d.get(text, []) + [curExtraction]
48 | self.oie = d
49 |
50 |
51 | if __name__ == "__main__":
52 | args = docopt(__doc__)
53 | input_fn = args["--in"]
54 | tr = BenchmarkGoldReader()
55 | tr.read(input_fn)
56 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/clausieReader.py:
--------------------------------------------------------------------------------
1 | """ Usage:
2 | --in=INPUT_FILE --out=OUTPUT_FILE [--debug]
3 |
4 | Convert to tabbed format
5 | """
6 | # External imports
7 | import logging
8 | from pprint import pprint
9 | from pprint import pformat
10 | from docopt import docopt
11 |
12 | # Local imports
13 | from oie_readers.oieReader import OieReader
14 | from oie_readers.extraction import Extraction
15 | import ipdb
16 | #=-----
17 |
18 | class ClausieReader(OieReader):
19 |
20 | def __init__(self):
21 | self.name = 'ClausIE'
22 |
23 | def read(self, fn):
24 | d = {}
25 | with open(fn, encoding="utf-8") as fin:
26 | for line in fin:
27 | data = line.strip().split('\t')
28 | if len(data) == 1:
29 | text = data[0]
30 | elif len(data) == 5:
31 | arg1, rel, arg2 = [s[1:-1] for s in data[1:4]]
32 | confidence = data[4]
33 |
34 | curExtraction = Extraction(pred = rel,
35 | head_pred_index = -1,
36 | sent = text,
37 | confidence = float(confidence))
38 |
39 | curExtraction.addArg(arg1)
40 | curExtraction.addArg(arg2)
41 | d[text] = d.get(text, []) + [curExtraction]
42 | self.oie = d
43 | # self.normalizeConfidence()
44 |
45 | # # remove exxtractions below the confidence threshold
46 | # if type(self.threshold) != type(None):
47 | # new_d = {}
48 | # for sent in self.oie:
49 | # for extraction in self.oie[sent]:
50 | # if extraction.confidence < self.threshold:
51 | # continue
52 | # else:
53 | # new_d[sent] = new_d.get(sent, []) + [extraction]
54 | # self.oie = new_d
55 |
56 |
57 |
58 | def normalizeConfidence(self):
59 | ''' Normalize confidence to resemble probabilities '''
60 | EPSILON = 1e-3
61 |
62 | confidences = [extraction.confidence for sent in self.oie for extraction in self.oie[sent]]
63 | maxConfidence = max(confidences)
64 | minConfidence = min(confidences)
65 |
66 | denom = maxConfidence - minConfidence + (2*EPSILON)
67 |
68 | for sent, extractions in self.oie.items():
69 | for extraction in extractions:
70 | extraction.confidence = ( (extraction.confidence - minConfidence) + EPSILON) / denom
71 |
72 |
73 |
74 | if __name__ == "__main__":
75 | # Parse command line arguments
76 | args = docopt(__doc__)
77 | inp_fn = args["--in"]
78 | out_fn = args["--out"]
79 | debug = args["--debug"]
80 | if debug:
81 | logging.basicConfig(level = logging.DEBUG)
82 | else:
83 | logging.basicConfig(level = logging.INFO)
84 |
85 |
86 | oie = ClausieReader()
87 | oie.read(inp_fn)
88 | oie.output_tabbed(out_fn)
89 |
90 | logging.info("DONE")
91 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/goldReader.py:
--------------------------------------------------------------------------------
1 | from oie_readers.oieReader import OieReader
2 | from oie_readers.extraction import Extraction
3 | from _collections import defaultdict
4 | import ipdb
5 |
6 | class GoldReader(OieReader):
7 |
8 | # Path relative to repo root folder
9 | default_filename = './oie_corpus/all.oie'
10 |
11 | def __init__(self):
12 | self.name = 'Gold'
13 |
14 | def read(self, fn):
15 | d = defaultdict(lambda: [])
16 | with open(fn) as fin:
17 | for line_ind, line in enumerate(fin):
18 | # print line
19 | data = line.strip().split('\t')
20 | text, rel = data[:2]
21 | args = data[2:]
22 | confidence = 1
23 |
24 | curExtraction = Extraction(pred = rel.strip(),
25 | head_pred_index = None,
26 | sent = text.strip(),
27 | confidence = float(confidence),
28 | index = line_ind)
29 | for arg in args:
30 | if "C: " in arg:
31 | continue
32 | curExtraction.addArg(arg.strip())
33 |
34 | d[text.strip()].append(curExtraction)
35 | self.oie = d
36 |
37 |
38 | if __name__ == '__main__' :
39 | g = GoldReader()
40 | g.read('../oie_corpus/all.oie', includeNominal = False)
41 | d = g.oie
42 | e = d.items()[0]
43 | print(e[1][0].bow())
44 | print(g.count())
45 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/oieReader.py:
--------------------------------------------------------------------------------
1 | class OieReader:
2 |
3 | def read(self, fn, includeNominal):
4 | ''' should set oie as a class member
5 | as a dictionary of extractions by sentence'''
6 | raise Exception("Don't run me")
7 |
8 | def count(self):
9 | ''' number of extractions '''
10 | return sum([len(extractions) for _, extractions in self.oie.items()])
11 |
12 | def split_to_corpus(self, corpus_fn, out_fn):
13 | """
14 | Given a corpus file name, containing a list of sentences
15 | print only the extractions pertaining to it to out_fn in a tab separated format:
16 | sent, prob, pred, arg1, arg2, ...
17 | """
18 | raw_sents = [line.strip() for line in open(corpus_fn)]
19 | with open(out_fn, 'w') as fout:
20 | for line in self.get_tabbed().split('\n'):
21 | data = line.split('\t')
22 | sent = data[0]
23 | if sent in raw_sents:
24 | fout.write(line + '\n')
25 |
26 | def output_tabbed(self, out_fn):
27 | """
28 | Write a tabbed represenation of this corpus.
29 | """
30 | with open(out_fn, 'w') as fout:
31 | fout.write(self.get_tabbed())
32 |
33 | def get_tabbed(self):
34 | """
35 | Get a tabbed format representation of this corpus (assumes that input was
36 | already read).
37 | """
38 | return "\n".join(['\t'.join(map(str,
39 | [ex.sent,
40 | ex.confidence,
41 | ex.pred,
42 | '\t'.join(ex.args)]))
43 | for (sent, exs) in self.oie.iteritems()
44 | for ex in exs])
45 |
46 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/ollieReader.py:
--------------------------------------------------------------------------------
1 | from oie_readers.oieReader import OieReader
2 | from oie_readers.extraction import Extraction
3 |
4 | class OllieReader(OieReader):
5 |
6 | def __init__(self):
7 | self.name = 'OLLIE'
8 |
9 | def read(self, fn):
10 | d = {}
11 | with open(fn) as fin:
12 | fin.readline() #remove header
13 | for line in fin:
14 | data = line.strip().split('\t')
15 | confidence, arg1, rel, arg2, enabler, attribution, text = data[:7]
16 | curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
17 | curExtraction.addArg(arg1)
18 | curExtraction.addArg(arg2)
19 | d[text] = d.get(text, []) + [curExtraction]
20 | self.oie = d
21 |
22 |
23 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/openieFiveReader.py:
--------------------------------------------------------------------------------
1 | from oie_readers.oieReader import OieReader
2 | from oie_readers.extraction import Extraction
3 |
4 | class OpenieFiveReader(OieReader):
5 |
6 | def __init__(self):
7 | self.name = 'OpenIE-5'
8 |
9 | def read(self, fn):
10 | d = {}
11 | with open(fn) as fin:
12 | for line in fin:
13 | data = line.strip().split('\t')
14 | confidence = data[0]
15 |
16 | if not all(data[2:5]):
17 | continue
18 | arg1, rel = [s[s.index('(') + 1:s.index(',List(')] for s in data[2:4]]
19 | #args = data[4].strip().split(');')
20 | #print arg2s
21 | args = [s[s.index('(') + 1:s.index(',List(')] for s in data[4].strip().split(');')]
22 | # if arg1 == "the younger La Flesche":
23 | # print len(args)
24 | text = data[5]
25 | if data[1]:
26 | #print arg1, rel
27 | s = data[1]
28 | if not (arg1 + ' ' + rel).startswith(s[s.index('(') + 1:s.index(',List(')]):
29 | #print "##########Not adding context"
30 | arg1 = s[s.index('(') + 1:s.index(',List(')] + ' ' + arg1
31 | #print arg1 + rel, ",,,,, ", s[s.index('(') + 1:s.index(',List(')]
32 | #curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence))
33 | curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
34 | curExtraction.addArg(arg1)
35 | for arg in args:
36 | curExtraction.addArg(arg)
37 | d[text] = d.get(text, []) + [curExtraction]
38 | self.oie = d
39 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/openieFourReader.py:
--------------------------------------------------------------------------------
1 | """ Usage:
2 | --in=INPUT_FILE --out=OUTPUT_FILE [--debug]
3 |
4 | Convert to tabbed format
5 | """
6 | # External imports
7 | import logging
8 | from pprint import pprint
9 | from pprint import pformat
10 | from docopt import docopt
11 |
12 | # Local imports
13 | from oie_readers.oieReader import OieReader
14 | from oie_readers.extraction import Extraction
15 | import ipdb
16 |
17 | #=-----
18 |
19 | class OpenieFourReader(OieReader):
20 |
21 | def __init__(self):
22 | self.name = 'OpenIE-4'
23 |
24 | def read(self, fn):
25 | d = {}
26 | with open(fn) as fin:
27 | for line in fin:
28 | data = line.strip().split('\t')
29 | confidence = data[0]
30 | if not all(data[2:5]):
31 | logging.debug("Skipped line: {}".format(line))
32 | continue
33 | arg1, rel, arg2 = [s[s.index('(') + 1:s.index(',List(')] for s in data[2:5]]
34 | text = data[5]
35 | curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
36 | curExtraction.addArg(arg1)
37 | curExtraction.addArg(arg2)
38 | d[text] = d.get(text, []) + [curExtraction]
39 | self.oie = d
40 |
41 |
42 |
43 | if __name__ == "__main__":
44 | # Parse command line arguments
45 | args = docopt(__doc__)
46 | inp_fn = args["--in"]
47 | out_fn = args["--out"]
48 | debug = args["--debug"]
49 | if debug:
50 | logging.basicConfig(level = logging.DEBUG)
51 | else:
52 | logging.basicConfig(level = logging.INFO)
53 |
54 |
55 | oie = OpenieFourReader()
56 | oie.read(inp_fn)
57 | oie.output_tabbed(out_fn)
58 |
59 | logging.info("DONE")
60 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/propsReader.py:
--------------------------------------------------------------------------------
1 | from oie_readers.oieReader import OieReader
2 | from oie_readers.extraction import Extraction
3 |
4 |
5 | class PropSReader(OieReader):
6 |
7 | def __init__(self):
8 | self.name = 'PropS'
9 |
10 | def read(self, fn):
11 | d = {}
12 | with open(fn) as fin:
13 | for line in fin:
14 | if not line.strip():
15 | continue
16 | data = line.strip().split('\t')
17 | confidence, text, rel = data[:3]
18 | curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence), head_pred_index=-1)
19 |
20 | for arg in data[4::2]:
21 | curExtraction.addArg(arg)
22 |
23 | d[text] = d.get(text, []) + [curExtraction]
24 | self.oie = d
25 | # self.normalizeConfidence()
26 |
27 |
28 | def normalizeConfidence(self):
29 | ''' Normalize confidence to resemble probabilities '''
30 | EPSILON = 1e-3
31 |
32 | self.confidences = [extraction.confidence for sent in self.oie for extraction in self.oie[sent]]
33 | maxConfidence = max(self.confidences)
34 | minConfidence = min(self.confidences)
35 |
36 | denom = maxConfidence - minConfidence + (2*EPSILON)
37 |
38 | for sent, extractions in self.oie.items():
39 | for extraction in extractions:
40 | extraction.confidence = ( (extraction.confidence - minConfidence) + EPSILON) / denom
41 |
42 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/reVerbReader.py:
--------------------------------------------------------------------------------
1 | from oie_readers.oieReader import OieReader
2 | from oie_readers.extraction import Extraction
3 |
4 | class ReVerbReader(OieReader):
5 |
6 | def __init__(self):
7 | self.inputSents = [sent.strip() for sent in open(ReVerbReader.RAW_SENTS_FILE).readlines()]
8 | self.name = 'ReVerb'
9 |
10 | def read(self, fn):
11 | d = {}
12 | with open(fn) as fin:
13 | for line in fin:
14 | data = line.strip().split('\t')
15 | arg1, rel, arg2 = data[2:5]
16 | confidence = data[11]
17 | text = self.inputSents[int(data[1])-1]
18 |
19 | curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence))
20 | curExtraction.addArg(arg1)
21 | curExtraction.addArg(arg2)
22 | d[text] = d.get(text, []) + [curExtraction]
23 | self.oie = d
24 |
25 | # ReVerb requires a different files from which to get the input sentences
26 | # Relative to repo root folder
27 | RAW_SENTS_FILE = './raw_sentences/all.txt'
28 |
29 |
30 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/split_corpus.py:
--------------------------------------------------------------------------------
1 | """ Usage:
2 | split_corpus --corpus=CORPUS_FN --reader=READER --in=INPUT_FN --out=OUTPUT_FN
3 |
4 | Split OIE extractions according to raw sentences.
5 | This is used in order to split a large file into train, dev and test.
6 |
7 | READER - points out which oie reader to use (see dictionary for possible entries)
8 | """
9 | from clausieReader import ClausieReader
10 | from ollieReader import OllieReader
11 | from openieFourReader import OpenieFourReader
12 | from propsReader import PropSReader
13 | from reVerbReader import ReVerbReader
14 | from stanfordReader import StanfordReader
15 | from docopt import docopt
16 | import logging
17 | logging.basicConfig(level = logging.INFO)
18 |
19 | available_readers = {
20 | "clausie": ClausieReader,
21 | "ollie": OllieReader,
22 | "openie4": OpenieFourReader,
23 | "props": PropSReader,
24 | "reverb": ReVerbReader,
25 | "stanford": StanfordReader
26 | }
27 |
28 |
29 | if __name__ == "__main__":
30 | args = docopt(__doc__)
31 | inp = args["--in"]
32 | out = args["--out"]
33 | corpus = args["--corpus"]
34 | reader = available_readers[args["--reader"]]()
35 | reader.read(inp)
36 | reader.split_to_corpus(corpus,
37 | out)
38 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/stanfordReader.py:
--------------------------------------------------------------------------------
1 | from oie_readers.oieReader import OieReader
2 | from oie_readers.extraction import Extraction
3 |
4 | class StanfordReader(OieReader):
5 |
6 | def __init__(self):
7 | self.name = 'Stanford'
8 |
9 | def read(self, fn):
10 | d = {}
11 | with open(fn) as fin:
12 | for line in fin:
13 | data = line.strip().split('\t')
14 | arg1, rel, arg2 = data[2:5]
15 | confidence = data[11]
16 | text = data[12]
17 |
18 | curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
19 | curExtraction.addArg(arg1)
20 | curExtraction.addArg(arg2)
21 | d[text] = d.get(text, []) + [curExtraction]
22 | self.oie = d
23 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/oie_readers/tabReader.py:
--------------------------------------------------------------------------------
1 | """ Usage:
2 | tabReader --in=INPUT_FILE
3 |
4 | Read a tab-formatted file.
5 | Each line consists of:
6 | sent, prob, pred, arg1, arg2, ...
7 |
8 | """
9 |
10 | from oie_readers.oieReader import OieReader
11 | from oie_readers.extraction import Extraction
12 | from docopt import docopt
13 | import logging
14 | import ipdb
15 |
16 | logging.basicConfig(level = logging.DEBUG)
17 |
18 | class TabReader(OieReader):
19 |
20 | def __init__(self):
21 | self.name = 'TabReader'
22 |
23 | def read(self, fn):
24 | """
25 | Read a tabbed format line
26 | Each line consists of:
27 | sent, prob, pred, arg1, arg2, ...
28 | """
29 | d = {}
30 | ex_index = 0
31 | with open(fn) as fin:
32 | for line in fin:
33 | if not line.strip():
34 | continue
35 | data = line.strip().split('\t')
36 | text, confidence, rel = data[:3]
37 | curExtraction = Extraction(pred = rel,
38 | head_pred_index = None,
39 | sent = text,
40 | confidence = float(confidence),
41 | question_dist = "./question_distributions/dist_wh_sbj_obj1.json",
42 | index = ex_index)
43 | ex_index += 1
44 |
45 | for arg in data[3:]:
46 | curExtraction.addArg(arg)
47 |
48 | d[text] = d.get(text, []) + [curExtraction]
49 | self.oie = d
50 |
51 |
52 | if __name__ == "__main__":
53 | args = docopt(__doc__)
54 | input_fn = args["--in"]
55 | tr = TabReader()
56 | tr.read(input_fn)
57 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/pr_plot.py:
--------------------------------------------------------------------------------
1 | """ Usage:
2 | pr_plot --in=DIR_NAME --out=OUTPUT_FILENAME
3 |
4 | Options:
5 | --in=DIR_NAME Folder in which to search for *.dat files, all of which should be in a P/R column format (outputs from benchmark.py)
6 | --out=OUTPUT_FILENAME Output filename, filetype will determine the format. Possible formats: pdf, pgf, png
7 |
8 |
9 | """
10 |
11 | import os
12 | import ntpath
13 | import numpy as np
14 | from glob import glob
15 | from docopt import docopt
16 | import matplotlib.pyplot as plt
17 | import logging
18 | import ipdb
19 | logging.basicConfig(level = logging.INFO)
20 |
21 | plt.rcParams.update({'font.size': 14})
22 |
23 | def trend_name(path):
24 | ''' return a system trend name from dat file path '''
25 | head, tail = ntpath.split(path)
26 | ret = tail or ntpath.basename(head)
27 | return ret.split('.')[0]
28 |
29 | def get_pr(path):
30 | ''' get PR curve from file '''
31 | with open(path) as fin:
32 | # remove header line
33 | fin.readline()
34 | prc = list(zip(*[[float(x) for x in line.strip().split('\t')] for line in fin]))
35 | p = prc[0]
36 | r = prc[1]
37 | return p, r
38 |
39 | if __name__ == '__main__':
40 | args = docopt(__doc__)
41 | input_folder = args['--in']
42 | output_file = args['--out']
43 |
44 | # plot graphs for all *.dat files in input path
45 | files = glob(os.path.join(input_folder, '*.dat'))
46 | for _file in files:
47 | p, r = get_pr(_file)
48 | name = trend_name(_file)
49 | plt.plot(r, p, label = name)
50 |
51 | # Set figure properties and save
52 | logging.info("Plotting P/R graph to {}".format(output_file))
53 | plt.ylim([0.0, 1.05])
54 | plt.xlim([0.0, 0.8])
55 | plt.xlabel('Recall')
56 | plt.ylabel('Precision')
57 | plt.legend(loc="lower right")
58 | plt.savefig(output_file)
59 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/CaRB/requirements.txt:
--------------------------------------------------------------------------------
1 | ipdb
2 | docopt==0.6.2
3 | nltk==3.2.1
4 | numpy==1.11.2
5 | pandas==0.19.0
6 | scikit-learn==0.18
7 | scipy==0.18.1
8 |
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ROBUST/src/utils/__init__.py
--------------------------------------------------------------------------------
/scripts/tasks/ROBUST/src/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ROBUST/src/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ace2005-eae/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ace2005-eae/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ace2005-eae/desc.py:
--------------------------------------------------------------------------------
1 | JSON_BASE = "Please give the answer in json format."
2 |
3 | OUTPUT_BASE = [
4 | (
5 | 'Please give the answer in the form "[Answer]: {word}: {role}; ".',
6 | "{word}: {type}; ",
7 | ),
8 | (
9 | 'Please give the answer in the form "[Answer]: (word: {word}, role: {type}); ".',
10 | "(word: {word}, role: {type}); ",
11 | ),
12 | (
13 | "Please give the answer in natural language.",
14 | 'the event role "{type}" is "{word}"; ',
15 | ),
16 | (
17 | "Please give the answer in natural language.",
18 | '"{word}" is the role of "{type}"; ',
19 | ),
20 | (
21 | 'Please give the answer in the form "[Answer]: ({word}, {role}); ".',
22 | "({word}, {type}); ",
23 | ),
24 | (
25 | "What is the role of each word in the described event?",
26 | 'Role of "{word}" is "{type}". ',
27 | ),
28 | (
29 | "Identify the roles and words associated with the event.",
30 | '"{word}" plays the role of "{type}". ',
31 | ),
32 | (
33 | "Can you extract the event arguments and their roles?",
34 | '"{word}" is identified as "{type}". ',
35 | ),
36 | (
37 | "Highlight the key elements and their roles within the {event} context.",
38 | "Element: {word}, Role: {type}; ",
39 | ),
40 | (
41 | "Can you dissect the {event} and label each component with its respective function?",
42 | "Component: {word}, Function: {type}; ",
43 | ),
44 | (
45 | "Identify and describe the roles of different elements in the {event}.",
46 | "Element Identified: {word}, Described Role: {type}; ",
47 | ),
48 | (
49 | "Break down the {event} into its essential parts and explain their significance.",
50 | "Essential Part: {word}, Significance: {type}; ",
51 | ),
52 | (
53 | "Identify the key elements and their functions within the given event.",
54 | "Element: {word}, Function: {type}. ",
55 | ),
56 | (
57 | "What are the components and their categories in this event?",
58 | "Component: {word} falls under the category: {type}. ",
59 | ),
60 | (
61 | "Break down the event into its constituents and describe their roles.",
62 | "Constituent: {word} plays the role of: {type}. ",
63 | ),
64 | ]
65 | # EXPLAN较为奇怪,因此去除EAE部分的output_explain
66 | OUTPUT_EXPLAN = [
67 | 'In the given context, the event "{event}" is associated with the event type "{etype}".',
68 | 'Within the provided context, the event "{event}" is linked to the event type "{etype}".',
69 | 'In the given context, the event "{event}" is connected to the event type "{etype}".',
70 | 'According to the context, the event "{event}" is correlated with the event type "{etype}".',
71 | 'In the provided context, the event "{event}" is aligned with the event type "{etype}".',
72 | 'In the scenario described, the event "{event}" falls under the category of "{etype}".',
73 | 'The narrative provided outlines an event, specifically "{event}", which pertains to the event type "{etype}".',
74 | 'Reflecting on the details, it\'s evident that the event "{event}" is linked to the event type "{etype}".',
75 | 'The analysis reveals that the "{event}" pertains to the "{etype}" category.',
76 | 'Within the context of "{event}", it is evident that it falls under the "{etype}" classification.',
77 | 'Upon dissecting the "{event}", it becomes clear that its nature is best described as "{etype}".',
78 | 'Upon examining the "{event}", it becomes clear that this event is categorized under the "{etype}" type.',
79 | 'Analyzing the "{event}" reveals that it falls under the "{etype}" category, showcasing its diverse elements and their roles.',
80 | 'Delving into the "{event}", it\'s apparent that this event is a manifestation of the "{etype}" type, with each part holding specific significance.',
81 | 'Upon examining the "{event}", it becomes clear that this scenario typifies the "{etype}" category, highlighting the roles and significance of its components.',
82 | 'The detailed analysis of "{event}" reveals its classification as a "{etype}", emphasizing the categorization of its integral parts.',
83 | 'Scrutinizing the event "{event}", it is evident that it exemplifies the "{etype}" genre, delineating the roles of its various constituents.',
84 | ]
85 |
86 | RANDOM_SYMBOLS = [
87 | "A",
88 | "B",
89 | "C",
90 | "D",
91 | "E",
92 | "F",
93 | "G",
94 | "LABEL_1",
95 | "LABEL_2",
96 | "LABEL_3",
97 | "LABEL_4",
98 | "LABEL_5",
99 | "LABEL_6",
100 | ]
101 |
--------------------------------------------------------------------------------
/scripts/tasks/ace2005-ed/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ace2005-ed/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ace2005-ed/__pycache__/desc.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ace2005-ed/__pycache__/desc.cpython-311.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ace2005-ner/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ace2005-ner/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ace2005-ner/__pycache__/desc.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ace2005-ner/__pycache__/desc.cpython-311.pyc
--------------------------------------------------------------------------------
/scripts/tasks/conll-2003/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/conll-2003/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/fewnerd/open_evaluate.py:
--------------------------------------------------------------------------------
1 | # ner task evaluate
2 |
3 | from sys import argv
4 | import json
5 | import argparse
6 |
7 | # turn the output string into dictionary form in order to make comparison
8 | OPTIONS = "art-broadcastprogram, art-film, art-music, art-other, art-painting, art-writtenart, building-airport, building-hospital, building-hotel, building-library, building-other, building-restaurant, building-sportsfacility, building-theater, event-attack/battle/war/militaryconflict, event-disaster, event-election, event-other, event-protest, event-sportsevent, location-bodiesofwater, location-GPE, location-island, location-mountain, location-other, location-park, location-road/railway/highway/transit, O, organization-company, organization-education, organization-government/governmentagency, organization-media/newspaper, organization-other, organization-politicalparty, organization-religion, organization-showorganization, organization-sportsleague, organization-sportsteam, other-astronomything, other-award, other-biologything, other-chemicalthing, other-currency, other-disease, other-educationaldegree, other-god, other-language, other-law, other-livingthing, other-medical, person-actor, person-artist/author, person-athlete, person-director, person-other, person-politician, person-scholar, person-soldier, product-airplane, product-car, product-food, product-game, product-other, product-ship, product-software, product-train, product-weapon"
9 |
10 | Env_type_convert = []
11 | Env_type_convert = OPTIONS.split(", ")
12 | Env_type_convert = [vert.replace("-", "_") for vert in Env_type_convert]
13 |
14 |
15 | def turn_into_dict(ans: str):
16 | res_dict = {}
17 | for pair in ans.split(";"):
18 | pair_split = pair.split(":")
19 | if len(pair_split) == 2:
20 | entity, type_ent = pair_split[0].strip().lower(), pair_split[
21 | 1
22 | ].strip().removesuffix(".")
23 | if type_ent in Env_type_convert:
24 | res_dict[entity] = type_ent
25 | return res_dict
26 |
27 |
28 | def safe_division(a, b) -> float:
29 | return a / b if b != 0 else 0
30 |
31 |
32 | def evaluate(input_path: str):
33 | # input_data = json.load(open(input_path, "r", encoding='utf-8'))
34 | golden, predict, correct = 0, 0, 0
35 | # ... by entity level
36 |
37 | with open(input_path) as f:
38 | for line in f.readlines():
39 | data = json.loads(line.strip())
40 |
41 | ref_ans = str(data["messages"][-1]["content"]).strip().lower()
42 | mod_ans = ""
43 | if "[Answer]:" in data["output"]:
44 | mod_ans = (
45 | str(data["output"]).strip().split("[Answer]:")[1].strip().lower()
46 | )
47 | else:
48 | mod_ans = data["output"].lower()
49 |
50 | ref_dict, mod_dict = turn_into_dict(ref_ans), turn_into_dict(mod_ans)
51 | golden += len(ref_dict)
52 | predict += len(mod_dict)
53 | for entity in ref_dict:
54 | if entity in mod_dict and ref_dict[entity] == mod_dict[entity]:
55 | correct += 1
56 | precision = safe_division(correct, predict)
57 | recall = safe_division(correct, golden)
58 | f1_score = safe_division(2 * precision * recall, precision + recall)
59 | return {
60 | "golden": golden,
61 | "predict": predict,
62 | "correct": correct,
63 | "precision": precision,
64 | "recall": recall,
65 | "f1_score": f1_score,
66 | }
67 |
68 |
69 | if __name__ == "__main__":
70 | # assert(len(argv) == 3)
71 | parser = argparse.ArgumentParser(description="Evaluate Event Detection")
72 | # IO
73 | parser.add_argument(
74 | "--input_dir",
75 | type=str,
76 | default="/SSD_DATA/qyj/open-instruct-main/saves/llama_v2_7B/full/no_output_desc/predict/zeroshot/RichERE-ed.jsonl",
77 | )
78 | parser.add_argument("--output_dir", type=str, default="result.json")
79 |
80 | args = parser.parse_args()
81 |
82 | input_path = args.input_dir
83 | result = evaluate(input_path)
84 | print(result)
85 | # json.dump(result, fp=open(output_path, "w", encoding='utf-8'), indent=4)
86 |
--------------------------------------------------------------------------------
/scripts/tasks/fewrel/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/fewrel/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/matres/open_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import numpy as np
4 | import re
5 | import random
6 |
7 |
8 | parser = argparse.ArgumentParser(description="Query OpenAI")
9 | parser.add_argument(
10 | "--output_file",
11 | type=str,
12 | default="/SSD_DATA/qyj/open-instruct-main/saves/llama_v2_7B/full/no_output_desc/predict/zeroshot/MATRES.jsonl",
13 | ) # fewshot_test_history_3
14 | args = parser.parse_args()
15 |
16 |
17 | # 对于4种关系,num_ans为label的数量、num_out为回答的数量、num_cor为答对的数量
18 | num_ans, num_out, num_cor = [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]
19 | with open(args.output_file) as file:
20 | lines = file.readlines()
21 | rel_class = {"before": 0, "after": 1, "equal": 2, "vague": 3}
22 |
23 |
24 | def parse_triple(text):
25 | pattern = re.compile("\((.*); (.*); (.*)\)")
26 | # print("Text:",text)
27 | triple = re.findall(pattern, text)
28 | if len(triple) == 0:
29 | return None
30 | # print("text:",text)
31 | # print("Triple:",triple[0])
32 | return triple[0][1]
33 |
34 |
35 | for line in lines:
36 | result = json.loads(line)
37 | label = result["messages"][-1]["content"]
38 | output = result["output"]
39 | if label != "none" and label != "":
40 | num_ans[rel_class[label]] += 1
41 | try:
42 | search_obj = re.search(r"[Answer](.*?): (.*)", output)
43 | out_label = re.findall(r"([a-zA-Z-]+)", search_obj.group(2))[0]
44 | except:
45 | continue
46 | la = output.lower()
47 | # if la in rel_class.keys():
48 | # num_out[rel_class[la]] += 1
49 | # if la in label:
50 | # num_cor[rel_class[la]] += 1
51 | # else:
52 | # num_out[random.randint(0, 3)] += 1
53 |
54 | pre_triple = parse_triple(la)
55 | # print("pre_triple:",pre_triple)
56 | if pre_triple != [] and pre_triple != None:
57 | for k in rel_class.keys():
58 | if k in pre_triple:
59 | num_out[rel_class[k]] += 1
60 | if k in label:
61 | num_cor[rel_class[k]] += 1
62 |
63 | print("ANS:", num_ans)
64 | print("OUT:", num_out)
65 | print("COR:", num_cor)
66 | num_ans.append(np.sum(num_ans))
67 | num_out.append(np.sum(num_out))
68 | num_cor.append(np.sum(num_cor))
69 | num_ans = np.array(num_ans)
70 | num_out = np.array(num_out)
71 | num_cor = np.array(num_cor)
72 | p = num_cor / num_out
73 | r = num_cor / num_ans
74 | f1 = 2 * p * r / (p + r)
75 | print(" Before | After | Equal | Vague | Total")
76 | print(f"precision: {p[0]:.4f} | {p[1]:.4f}| {p[2]:.4f}| {p[3]:.4f}| {p[4]:.4f}")
77 | print(f"recall: {r[0]:.4f} | {r[1]:.4f}| {r[2]:.4f}| {r[3]:.4f}| {r[4]:.4f}")
78 | print(f"f1-score: {f1[0]:.4f} | {f1[1]:.4f}| {f1[2]:.4f}| {f1[3]:.4f}| {f1[4]:.4f}")
79 |
--------------------------------------------------------------------------------
/scripts/tasks/maven-arg/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/maven-arg/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/maven-arg/desc.py:
--------------------------------------------------------------------------------
1 | JSON_BASE = "Please give the answer in json format."
2 |
3 | OUTPUT_BASE = [
4 | (
5 | 'Please give the answer in the form "[Answer]: {content}: {role}; ".',
6 | "{word}: {type}; ",
7 | ),
8 | (
9 | 'Please give the answer in the form "[Answer]: (content: {word}, role: {type}); ".',
10 | "(content: {word}, role: {type}); ",
11 | ),
12 | (
13 | "Please give the answer in natural language.",
14 | 'the event role "{type}" is "{word}"; ',
15 | ),
16 | (
17 | "Please give the answer in natural language.",
18 | '"{word}" is the role of "{type}"; ',
19 | ),
20 | (
21 | 'Please give the answer in the form "[Answer]: ({content}, {role}); ".',
22 | "({word}, {type}); ",
23 | ),
24 | (
25 | "What is the role of each word in the described event?",
26 | 'Role of "{word}" is "{type}". ',
27 | ),
28 | (
29 | "Identify the roles and words associated with the event.",
30 | '"{word}" plays the role of "{type}". ',
31 | ),
32 | (
33 | "Can you extract the event arguments and their roles?",
34 | '"{word}" is identified as "{type}". ',
35 | ),
36 | (
37 | "Highlight the key elements and their roles within the {event} context.",
38 | "Element: {word}, Role: {type}; ",
39 | ),
40 | (
41 | "Can you dissect the {event} and label each component with its respective function?",
42 | "Component: {word}, Function: {type}; ",
43 | ),
44 | (
45 | "Identify and describe the roles of different elements in the {event}.",
46 | "Element Identified: {word}, Described Role: {type}; ",
47 | ),
48 | (
49 | "Break down the {event} into its essential parts and explain their significance.",
50 | "Essential Part: {word}, Significance: {type}; ",
51 | ),
52 | (
53 | "Identify the key elements and their functions within the given event.",
54 | "Element: {word}, Function: {type}. ",
55 | ),
56 | (
57 | "What are the components and their categories in this event?",
58 | "Component: {word} falls under the category: {type}. ",
59 | ),
60 | (
61 | "Break down the event into its constituents and describe their roles.",
62 | "Constituent: {word} plays the role of: {type}. ",
63 | ),
64 | ]
65 | # EXPLAN较为奇怪,因此去除EAE部分的output_explain
66 | OUTPUT_EXPLAN = [
67 | 'In the given context, the event "{event}" is associated with the event type "{etype}".',
68 | 'Within the provided context, the event "{event}" is linked to the event type "{etype}".',
69 | 'In the given context, the event "{event}" is connected to the event type "{etype}".',
70 | 'According to the context, the event "{event}" is correlated with the event type "{etype}".',
71 | 'In the provided context, the event "{event}" is aligned with the event type "{etype}".',
72 | 'In the scenario described, the event "{event}" falls under the category of "{etype}".',
73 | 'The narrative provided outlines an event, specifically "{event}", which pertains to the event type "{etype}".',
74 | 'Reflecting on the details, it\'s evident that the event "{event}" is linked to the event type "{etype}".',
75 | 'The analysis reveals that the "{event}" pertains to the "{etype}" category.',
76 | 'Within the context of "{event}", it is evident that it falls under the "{etype}" classification.',
77 | 'Upon dissecting the "{event}", it becomes clear that its nature is best described as "{etype}".',
78 | 'Upon examining the "{event}", it becomes clear that this event is categorized under the "{etype}" type.',
79 | 'Analyzing the "{event}" reveals that it falls under the "{etype}" category, showcasing its diverse elements and their roles.',
80 | 'Delving into the "{event}", it\'s apparent that this event is a manifestation of the "{etype}" type, with each part holding specific significance.',
81 | 'Upon examining the "{event}", it becomes clear that this scenario typifies the "{etype}" category, highlighting the roles and significance of its components.',
82 | 'The detailed analysis of "{event}" reveals its classification as a "{etype}", emphasizing the categorization of its integral parts.',
83 | 'Scrutinizing the event "{event}", it is evident that it exemplifies the "{etype}" genre, delineating the roles of its various constituents.',
84 | ]
85 |
86 | RANDOM_SYMBOLS = [
87 | "A",
88 | "B",
89 | "C",
90 | "D",
91 | "E",
92 | "F",
93 | "G",
94 | "LABEL_1",
95 | "LABEL_2",
96 | "LABEL_3",
97 | "LABEL_4",
98 | "LABEL_5",
99 | "LABEL_6",
100 | ]
101 |
--------------------------------------------------------------------------------
/scripts/tasks/maven-ed/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/maven-ed/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/maven-ed/desc.py:
--------------------------------------------------------------------------------
1 | JSON_BASE = "Please give the answer in json format."
2 |
3 | OUTPUT_BASE = [
4 | (
5 | 'Please give the answer in the form "[Answer]: {event}: {class}; ".',
6 | "{word}: {type}; ",
7 | ),
8 | (
9 | 'Please give the answer in the form "[Answer]: ({event}, {class}); ".',
10 | "({word}, {type}); ",
11 | ),
12 | (
13 | 'Please give the answer in the form "[Answer]: (event trigger: {event}, class: {class}); ".',
14 | "(event trigger: {word}, class: {type}); ",
15 | ),
16 | (
17 | "Please give the answer in natural language.",
18 | '"{word}" is linked to the "{type}" event. ',
19 | ),
20 | (
21 | "Please give the answer in natural language.",
22 | '"{word}" triggers an event identified as "{type}". ',
23 | ),
24 | (
25 | "Identify the event and its type from the text.",
26 | 'The event "{word}" falls under the category "{type}".',
27 | ),
28 | (
29 | "What event is described, and how can it be classified?",
30 | 'Event: "{word}", Type: "{type}".',
31 | ),
32 | (
33 | "Extract the key event and its corresponding type from the provided text.",
34 | '"{word}" is an event that is categorized as "{type}".',
35 | ),
36 | (
37 | "From the given text, identify the event trigger and its type.",
38 | 'Trigger: "{word}", Classified as: "{type}".',
39 | ),
40 | (
41 | "Identify the event and its category from the given text.",
42 | 'Event identified: "{word}", Category: "{type}".',
43 | ),
44 | (
45 | "Describe the event and determine its classification.",
46 | 'Described Event: "{word}", Classification: "{type}".',
47 | ),
48 | (
49 | "What type of event does this scenario depict?",
50 | 'Depicted Event: "{word}", Event Type: "{type}".',
51 | ),
52 | (
53 | "Identify the event and its category from the description.",
54 | 'Identified Event: "{word}", Category: "{type}".',
55 | ),
56 | (
57 | "Highlight the key event and its corresponding classification from the text.",
58 | 'Key Event: "{word}", Classification: "{type}".',
59 | ),
60 | (
61 | "From the given narrative, extract the event and its type.",
62 | 'Extracted Event: "{word}", Type: "{type}".',
63 | ),
64 | ]
65 | # (prefix,to_type,summary): prefix+ (multi) to_type+summary+[Answer]:+output_base
66 | # {word}, {btype},{stype},{type}
67 | OUTPUT_EXPLAN = [
68 | (
69 | "Based on the given predefined event type and text: ",
70 | '"{word}" is an event trigger word, which triggers an event of type "{type}". ',
71 | "To sum up, ",
72 | ),
73 | (
74 | "In consideration of the provided predefined event type and text, ",
75 | '"{word}" is specifically linked to the category "{type}". ',
76 | "In brief, ",
77 | ),
78 | (
79 | "Given the predefined event type and text, ",
80 | '"{word}" triggers an event classified as "{type}". ',
81 | "To summarize, ",
82 | ),
83 | (
84 | "According to the provided event type and text, ",
85 | '"{word}" serves as an event trigger word, instigating an event classified under "{type}". ',
86 | "In conclusion, ",
87 | ),
88 | (
89 | "Based on the given predefined event type and text, ",
90 | '"{word}" operates as an event trigger word, initiating an event categorized as "{type}". ',
91 | "Hence, ",
92 | ),
93 | [
94 | "Upon analyzing the context, ",
95 | '"{word}" is pinpointed as the catalyst for an event, which is best described by the category "{type}". ',
96 | "In conclusion, ",
97 | ],
98 | [
99 | "After a thorough examination of the context, ",
100 | 'it\'s evident that "{word}" serves as the trigger for an event, which can be classified under "{type}". ',
101 | "Therefore, ",
102 | ],
103 | [
104 | "Following a detailed analysis, ",
105 | 'the term "{word}" emerges as a significant event trigger, falling into the "{type}" category. ',
106 | "As a result, ",
107 | ],
108 | [
109 | "Upon close inspection of the text, ",
110 | '"{word}" is identified as the trigger for an event, which is classified under the type "{type}". ',
111 | "Summarily, ",
112 | ],
113 | [
114 | "Upon examining the context, ",
115 | 'it becomes clear that "{word}" is identified as the pivotal event, which is best categorized under "{type}". ',
116 | "This leads to the conclusion that ",
117 | ],
118 | [
119 | "After a thorough review of the context, ",
120 | 'the descriptor "{word}" is pinpointed as an event, which is aptly classified under "{type}". ',
121 | "This analysis brings us to understand that ",
122 | ],
123 | [
124 | "Through the lens of the scenario provided, ",
125 | '"{word}" is highlighted as the event in question, with its type being "{type}". ',
126 | "This delineation makes it clear that ",
127 | ],
128 | [
129 | "Upon analyzing the description, ",
130 | '"{word}" emerges as the pivotal event, falling under the category of "{type}". ',
131 | "Therefore, ",
132 | ],
133 | [
134 | "In dissecting the text, ",
135 | '"{word}" is pinpointed as the key event, which is classified under "{type}". ',
136 | "In essence, ",
137 | ],
138 | [
139 | "Delving into the narrative provided, ",
140 | '"{word}" is extracted as the significant event, with its type being "{type}". ',
141 | "Conclusively, ",
142 | ],
143 | ]
144 |
145 | RANDOM_SYMBOLS = [
146 | "A",
147 | "B",
148 | "C",
149 | "D",
150 | "E",
151 | "F",
152 | "G",
153 | "LABEL_1",
154 | "LABEL_2",
155 | "LABEL_3",
156 | "LABEL_4",
157 | "LABEL_5",
158 | "LABEL_6",
159 | ]
160 |
--------------------------------------------------------------------------------
/scripts/tasks/maven-ere/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/maven-ere/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/evaluation/rougel_for_content.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import evaluate
4 | from tqdm import tqdm
5 |
6 | def load_json(path):
7 | with open(path) as f:
8 | data = json.loads(f.read())
9 | return data
10 |
11 | def extract_markdown(s, output=False):
12 | first_index = s.find("|") # the index for the first |
13 | # last_index = s.rfind("|") # the index for the last |
14 | last_index = len(s) - 1
15 | if first_index != -1 and last_index != -1:
16 | table = s[first_index: last_index + 1]
17 | else:
18 | return ""
19 | # post processing
20 | table = table.replace("| |", "| N/A |")
21 | table = table.replace("| |", "| N/A |")
22 | table = table.replace("| |", "| N/A |")
23 | table = table.replace("|-|", "| N/A |")
24 | table = table.replace("| - |", "| N/A |")
25 | table = table.replace("| Not specified |", "| N/A |")
26 | table = table.replace("| not specified |", "| N/A |")
27 | table = table.replace("| Not Specified |", "| N/A |")
28 | table = table.replace("| None |", "| N/A |")
29 | table = table.replace("| none |", "| N/A |")
30 |
31 | return table
32 |
33 |
34 | def get_score(pred, gold, metric):
35 | return 100 * metric.compute(predictions=pred,references=gold)['rougeLsum']
36 |
37 | def group_by_tag(data, score, tag):
38 | dic, cnt = {}, {}
39 | for i in range(len(data)):
40 | tag_value = data[i][tag]
41 | if tag_value not in dic:
42 | dic[tag_value] = score[i]
43 | cnt[tag_value] = 1
44 | else:
45 | dic[tag_value] += score[i]
46 | cnt[tag_value] += 1
47 | print('\n' + tag)
48 | for key in dic:
49 | print(f"{key}: {dic[key] / cnt[key]} ({cnt[key]})")
50 |
51 | def main(path):
52 | # if len(sys.argv) < 2:
53 | # path = 'model_output/ODIE-7b-filter.json'
54 | # else:
55 | # path = sys.argv[1]
56 | print(f"model output file: {path}")
57 |
58 | data = load_json(path)
59 | metric = evaluate.load('rouge')
60 | score = []
61 |
62 | for i in tqdm(range(len(data))):
63 | pred = [extract_markdown(data[i]['output'], output=True)]
64 | gold = [extract_markdown(data[i]['gold'])]
65 | cur_score = get_score(pred, gold, metric)
66 | score.append(cur_score)
67 |
68 | print(f'Overall:\n{sum(score) / len(score)}')
69 | group_by_tag(data, score, "difficulty")
70 | group_by_tag(data, score, "category")
71 | group_by_tag(data, score, "source_type")
72 |
73 | if __name__ == "__main__":
74 | main("test_data.json")
75 |
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/evaluation/sim_for_header.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import torch
4 | # import evaluate
5 | from tqdm import tqdm
6 | from sentence_transformers import SentenceTransformer, util
7 |
8 | def load_json(path):
9 | with open(path) as f:
10 | data = json.loads(f.read())
11 | return data
12 |
13 | def extract_header(s):
14 | first_index = s.find("|") # the index for the first |
15 | last_index = s.rfind("|") # the index for the last |
16 |
17 | if first_index != -1 and last_index != -1:
18 | table = s[first_index: last_index + 1]
19 | else:
20 | return ""
21 | # post processing
22 | table = table.replace("| |", "| N/A |")
23 | table = table.replace("| |", "| N/A |")
24 | table = table.replace("| |", "| N/A |")
25 | table = table.replace("|-|", "| N/A |")
26 | table = table.replace("| - |", "| N/A |")
27 | table = table.replace("| not specified |", "| N/A |")
28 | table = table.replace("| none |", "| N/A |")
29 |
30 | # get header
31 | header = table.split('\n')[0]
32 | header = header.strip('|').strip()
33 | header = header.split(' | ')
34 |
35 | return header
36 |
37 | def soft_match(query, value, sim_model):
38 | device = torch.device("cuda:0") # specify CPU as device
39 | sim_model = sim_model.to(device) # move model to CPU
40 | embedding_1 = sim_model.encode(query, convert_to_tensor=True, device=device) # encode on CPU
41 | embedding_2 = sim_model.encode(value, convert_to_tensor=True, device=device) # encode on CPU
42 | sim_matrix = util.pytorch_cos_sim(embedding_1, embedding_2)
43 | return sim_matrix
44 |
45 | def header_soft_score(pred_list, gold_list):
46 | device = torch.device("cpu")
47 | gold_n, pred_n, pred_in_gold_n, gold_in_pred_n = 0, 0, 0, 0
48 | sim_model = SentenceTransformer('/home/qyj/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/44eb4044493a3c34bc6d7faae1a71ec76665ebc6')
49 | sim_model = sim_model.to(device)
50 | for gold, pred in zip(gold_list, pred_list):
51 | gold_n += len(gold)
52 | pred_n += len(pred)
53 | scores = soft_match(gold, pred, sim_model)
54 | max_gold_score = torch.max(scores, dim=0).values
55 | pred_in_gold_n += torch.sum(max_gold_score)
56 | max_pred_score = torch.max(scores, dim=1).values
57 | gold_in_pred_n += torch.sum(max_pred_score)
58 | try:
59 | pre, rec, f1 = 0, 0, 0
60 | pre = 100.0 * pred_in_gold_n / pred_n
61 | rec = 100.0 * gold_in_pred_n / gold_n
62 | f1 = 2 * pre * rec / (pre + rec)
63 | except:
64 | pre = rec = f1 = 0
65 |
66 | return pre, rec, f1
67 |
68 | def calculate_by_tag(data, tag, metrics):
69 | dic = {}
70 | for i in range(len(data)):
71 | cur_tag = data[i][tag]
72 | if cur_tag not in dic:
73 | dic[cur_tag] = 1
74 | print(tag)
75 | for key in dic:
76 | pred_list, gold_list = [], []
77 | for i in range(len(data)):
78 | if data[i][tag] == key:
79 | pred_header = extract_header(data[i]['output'].lower())
80 | gold_header = extract_header(data[i]['gold'].lower())
81 | pred_list.append(pred_header)
82 | gold_list.append(gold_header)
83 | P, R, F = metrics(pred_list, gold_list)
84 | print(f"{key}: {F}")
85 | print()
86 |
87 | def main(path):
88 | # if len(sys.argv) < 2:
89 | # path = 'model_output/ODIE-7b-filter.json'
90 | # else:
91 | # path = sys.argv[1]
92 | print(f"model output file: {path}")
93 |
94 | metrics = header_soft_score
95 |
96 | data = load_json(path)
97 |
98 | pred_list, gold_list = [], []
99 |
100 | for i in tqdm(range(len(data))):
101 | pred_header = extract_header(data[i]['output'].lower())
102 | gold_header = extract_header(data[i]['gold'].lower())
103 |
104 | pred_list.append(pred_header)
105 | gold_list.append(gold_header)
106 |
107 |
108 | P, R, F = metrics(pred_list, gold_list)
109 | print(f"Overall\n{F}")
110 |
111 | calculate_by_tag(data, 'source_type', metrics)
112 | calculate_by_tag(data, 'category', metrics)
113 | calculate_by_tag(data, 'difficulty', metrics)
114 |
115 | if __name__ == "__main__":
116 | main('test_data.json')
117 |
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/exact_match/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: Exact Match
3 | emoji: 🤗
4 | colorFrom: blue
5 | colorTo: green
6 | sdk: gradio
7 | sdk_version: 3.0.2
8 | app_file: app.py
9 | pinned: false
10 | tags:
11 | - evaluate
12 | - comparison
13 | description: >-
14 | Returns the rate at which the predictions of one model exactly match those of another model.
15 | ---
16 |
17 |
18 | # Comparison Card for Exact Match
19 |
20 | ## Comparison description
21 |
22 | Given two model predictions the exact match score is 1 if they are the exact same, and is 0 otherwise. The overall exact match score is the average.
23 |
24 | - **Example 1**: The exact match score if prediction 1.0 is [0, 1] is 0, given prediction 2 is [0, 1].
25 | - **Example 2**: The exact match score if prediction 0.0 is [0, 1] is 0, given prediction 2 is [1, 0].
26 | - **Example 3**: The exact match score if prediction 0.5 is [0, 1] is 0, given prediction 2 is [1, 1].
27 |
28 | ## How to use
29 |
30 | At minimum, this metric takes as input predictions and references:
31 | ```python
32 | >>> exact_match = evaluate.load("exact_match", module_type="comparison")
33 | >>> results = exact_match.compute(predictions1=[0, 1, 1], predictions2=[1, 1, 1])
34 | >>> print(results)
35 | {'exact_match': 0.66}
36 | ```
37 |
38 | ## Output values
39 |
40 | Returns a float between 0.0 and 1.0 inclusive.
41 |
42 | ## Examples
43 |
44 | ```python
45 | >>> exact_match = evaluate.load("exact_match", module_type="comparison")
46 | >>> results = exact_match.compute(predictions1=[0, 0, 0], predictions2=[1, 1, 1])
47 | >>> print(results)
48 | {'exact_match': 1.0}
49 | ```
50 |
51 | ```python
52 | >>> exact_match = evaluate.load("exact_match", module_type="comparison")
53 | >>> results = exact_match.compute(predictions1=[0, 1, 1], predictions2=[1, 1, 1])
54 | >>> print(results)
55 | {'exact_match': 0.66}
56 | ```
57 |
58 |
59 | ## Limitations and bias
60 |
61 | ## Citations
62 |
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/exact_match/app.py:
--------------------------------------------------------------------------------
1 | import evaluate
2 | from evaluate.utils import launch_gradio_widget
3 |
4 |
5 | module = evaluate.load("exact_match", module_type="comparison")
6 | launch_gradio_widget(module)
7 |
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/exact_match/exact_match.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Evaluate Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Exact match test for model comparison."""
15 |
16 | import datasets
17 | import numpy as np
18 |
19 | import evaluate
20 |
21 |
22 | _DESCRIPTION = """
23 | Returns the rate at which the predictions of one model exactly match those of another model.
24 | """
25 |
26 |
27 | _KWARGS_DESCRIPTION = """
28 | Args:
29 | predictions1 (`list` of `int`): Predicted labels for model 1.
30 | predictions2 (`list` of `int`): Predicted labels for model 2.
31 |
32 | Returns:
33 | exact_match (`float`): Dictionary containing exact_match rate. Possible values are between 0.0 and 1.0, inclusive.
34 |
35 | Examples:
36 | >>> exact_match = evaluate.load("exact_match", module_type="comparison")
37 | >>> results = exact_match.compute(predictions1=[1, 1, 1], predictions2=[1, 1, 1])
38 | >>> print(results)
39 | {'exact_match': 1.0}
40 | """
41 |
42 |
43 | _CITATION = """
44 | """
45 |
46 |
47 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
48 | class ExactMatch(evaluate.Comparison):
49 | def _info(self):
50 | return evaluate.ComparisonInfo(
51 | module_type="comparison",
52 | description=_DESCRIPTION,
53 | citation=_CITATION,
54 | inputs_description=_KWARGS_DESCRIPTION,
55 | features=datasets.Features(
56 | {
57 | "predictions1": datasets.Value("int64"),
58 | "predictions2": datasets.Value("int64"),
59 | }
60 | ),
61 | )
62 |
63 | def _compute(self, predictions1, predictions2):
64 | score_list = [p1 == p2 for p1, p2 in zip(predictions1, predictions2)]
65 | return {"exact_match": np.mean(score_list)}
66 |
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/exact_match/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/huggingface/evaluate@{COMMIT_PLACEHOLDER}
2 | scipy
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/load_data_fs.py:
--------------------------------------------------------------------------------
1 | # from ph_unified_data:
2 | import os
3 | import json
4 | import random
5 | import copy
6 | from pathlib import Path
7 |
8 | EXPLAN_QUERY = [
9 | "First explain your thoughts and then give the answer. ",
10 | "Please give the analysis first. ",
11 | "Please explain first. ",
12 | "Make analysis according to the sentence and give the answer. ",
13 | ]
14 |
15 |
16 | def construct_response(input_folder, output_folder, split, config):
17 |
18 | input_path = os.path.join(input_folder, split + ".json")
19 | with open(input_path, "r", encoding="utf-8") as reader:
20 | init_dataset = json.load(reader)
21 | reader.close()
22 |
23 | # 获得example数据
24 | if config["IsCoT"] == False:
25 | tinput_path = os.path.join(input_folder, "training_data.json")
26 | else:
27 | tinput_path = os.path.join(input_folder, "training_data_cot.json")
28 | example_dataset = []
29 | with open(tinput_path, "r", encoding="utf-8") as reader:
30 | example_dataset = json.load(reader)
31 | reader.close()
32 |
33 | # 处理原始数据
34 | unified_dataset = []
35 | fewshot_num = 0
36 | for i, vert in enumerate(init_dataset):
37 | unified_instance = {
38 | "instruction": vert["instruction"] + "\n",
39 | "input": "Text: " + vert["text"],
40 | "output": "[Answer]: " + vert["table"],
41 | "system": "You are a helpful assistant. Follow the user instruction to extract information from the given text into a concise markdown table.",
42 | "history": [],
43 | }
44 | if config["IsCoT"] == True:
45 | ex_q = random.sample(EXPLAN_QUERY, 1)[0]
46 | unified_instance["instruction"] += ex_q
47 | unified_instance["output"] = (
48 | "[Explanation]: " + vert["explanation"] + "\n[Answer]: " + vert["table"]
49 | )
50 |
51 | # examples
52 | HistoryData = random.sample(example_dataset, config["NUM_FEWSHOT_Limit"])
53 | origin_word = unified_instance["instruction"].split(" ") + unified_instance[
54 | "input"
55 | ].split(" ")
56 | total_word = len(origin_word)
57 | for hd in HistoryData:
58 | if hd == vert:
59 | continue
60 | if total_word > config["WORD_Limit"]:
61 | break
62 |
63 | input = hd["instruction"] + "\n" + "Text: " + hd["text"]
64 | if config["IsCoT"]:
65 | output = (
66 | "[Explanation]: " + hd["explanation"] + "\n[Answer]: " + hd["table"]
67 | )
68 | else:
69 | output = "[Answer]: " + hd["table"]
70 |
71 | word = input.split(" ") + output.split(" ")
72 | total_word += len(word)
73 | if total_word > config["WORD_Limit"]:
74 | break
75 | fewshot_num += 1
76 | unified_instance["history"].append([input, output])
77 | unified_dataset.append(unified_instance)
78 |
79 | fewshot_num /= len(unified_dataset)
80 | print("fewshot_avg:", fewshot_num)
81 | if "train" in split:
82 | if config["IsCoT"] == False:
83 | out_name = split + "_" + str(config["NUM_FEWSHOT_Limit"]) + "shot.jsonl"
84 | else:
85 | out_name = split + "_" + str(config["NUM_FEWSHOT_Limit"]) + "shotCoT.jsonl"
86 | out_file = open(os.path.join(output_folder, out_name), "w")
87 | for vert in unified_dataset:
88 | out_file.write(json.dumps(vert) + "\n")
89 | out_file.close()
90 | else:
91 | out_name = "ondemand.json"
92 | out_file = open(os.path.join(output_folder, out_name), "w")
93 | json.dump(unified_dataset, out_file, indent=2)
94 | out_file.close()
95 |
96 | print("total:", len(unified_dataset))
97 |
98 |
99 | if __name__ == "__main__":
100 | input_folder = Path("../data/ondemandIE")
101 | output_folder = Path("../unified_data/ondemandIE")
102 | # args
103 | config = {
104 | # COT
105 | "IsCoT": False,
106 | # few-shot相关
107 | "NUM_FEWSHOT_Limit": 0,
108 | "WORD_Limit": 2000,
109 | "EXAM_NA_RATE": 0.0,
110 | }
111 | output_folder.mkdir(exist_ok=True, parents=True)
112 | construct_response(input_folder, output_folder, "training_data", config)
113 | config["IsCoT"] = True
114 | construct_response(input_folder, output_folder, "training_data_cot", config)
115 | config["IsCoT"] = False
116 | construct_response(input_folder, output_folder, "test_data", config)
117 |
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/o_generate_evaluatefile.py:
--------------------------------------------------------------------------------
1 | # from ph_unified_data:
2 | import os
3 | import json
4 | import random
5 | import copy
6 | import jsonlines
7 | from pathlib import Path
8 | import argparse
9 |
10 |
11 | def construct_response(pre_file, gold_file, output_file):
12 | print("gold_file:", gold_file)
13 | with open(gold_file, "r", encoding="utf-8") as reader:
14 | init_dataset = json.load(reader)
15 | reader.close()
16 |
17 | pred_dataset = []
18 | with open(pre_file, "r", encoding="utf-8") as reader:
19 | for item in jsonlines.Reader(reader):
20 | pred_dataset.append(item)
21 | reader.close()
22 |
23 | assert len(init_dataset) == len(
24 | pred_dataset
25 | ), "number of predictions and targets are not the same."
26 |
27 | for gold, pred in zip(init_dataset, pred_dataset):
28 |
29 | gold["gold"] = pred["messages"][-1]["content"].strip().lower()
30 | gold["output"] = pred["output"]
31 |
32 | print("output_file:", output_file)
33 | out_file = open(output_file, "w")
34 | json.dump(init_dataset, out_file, indent=2)
35 | out_file.close()
36 |
37 |
38 | if __name__ == "__main__":
39 | parser = argparse.ArgumentParser(description="Evaluate Event Detection")
40 | # IO
41 | parser.add_argument(
42 | "--input_dir",
43 | type=str,
44 | default="/SSD_DATA/qyj/open-instruct-main/saves/llama_v2_7B/full/rate_0.2_plus_2/predict/zeroshot/ondemand.jsonl",
45 | )
46 | args = parser.parse_args()
47 |
48 | input_file1 = args.input_dir
49 | input_folder2 = Path("/data1/qyj/ADELIE/data/ondemandIE")
50 | input_file2 = "test_data.json"
51 | output_folder = Path("/data1/qyj/ADELIE/scripts/tasks/ondemandie/") # 修改路径
52 | output_folder.mkdir(exist_ok=True, parents=True)
53 |
54 | construct_response(
55 | input_file1,
56 | os.path.join(input_folder2, input_file2),
57 | os.path.join(output_folder, input_file2),
58 | )
59 |
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/rouge/app.py:
--------------------------------------------------------------------------------
1 | import evaluate
2 | from evaluate.utils import launch_gradio_widget
3 |
4 |
5 | module = evaluate.load("rouge")
6 | launch_gradio_widget(module)
7 |
--------------------------------------------------------------------------------
/scripts/tasks/ondemandie/rouge/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/huggingface/evaluate@{COMMIT_PLACEHOLDER}
2 | absl-py
3 | nltk
4 | rouge_score>=0.1.2
--------------------------------------------------------------------------------
/scripts/tasks/ontonote5/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ontonote5/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ontonote5/__pycache__/label_encoding.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/ontonote5/__pycache__/label_encoding.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/ontonote5/label_encoding.py:
--------------------------------------------------------------------------------
1 | # This code modified from: https://github.com/hitz-zentroa/GoLLIE/tree/main/src/tasks/ontonotes
2 |
3 | from typing import List
4 |
5 |
6 | def to_iob_encoding(tags: List[str]) -> List[str]:
7 | # From IOB2 or BILOU
8 | prev_tag_b: str = "O"
9 | prev_tag_t: str = ""
10 |
11 | for i in range(len(tags)):
12 | tag = tags[i]
13 | if tag == "O":
14 | prev_tag_b = "O"
15 | prev_tag_t = ""
16 | else:
17 | try:
18 | b, t = tag.split("-")
19 | except ValueError:
20 | raise ValueError(
21 | f"Error tag {tag}, unable to split the tag in 2 fields."
22 | )
23 | if (b == "B" or b == "U") and prev_tag_b != "O" and prev_tag_t == t:
24 | tags[i] = f"B-{t}"
25 | else:
26 | tags[i] = f"I-{t}"
27 |
28 | prev_tag_b = b
29 | prev_tag_t = t
30 |
31 | return tags
32 |
33 |
34 | def to_iob2_encoding(tags: List[str]) -> List[str]:
35 | # From IOB or BILOU
36 | prev_tag_b: str = "O"
37 | prev_tag_t: str = ""
38 | for i in range(len(tags)):
39 | tag = tags[i]
40 | if tag == "O":
41 | prev_tag_b = "O"
42 | prev_tag_t = ""
43 | else:
44 | try:
45 | b, t = tag.split("-")
46 | except ValueError:
47 | raise ValueError(
48 | f"Error tag {tag}, unable to split the tag in 2 fields."
49 | )
50 |
51 | if (b == "B" or b == "U") or (
52 | (prev_tag_b == "O") or (prev_tag_t != "" and prev_tag_t != t)
53 | ):
54 | tags[i] = f"B-{t}"
55 | else:
56 | tags[i] = f"I-{t}"
57 |
58 | prev_tag_b = b
59 | prev_tag_t = t
60 |
61 | return tags
62 |
63 |
64 | def to_bilou_encoding(tags: List[str]) -> List[str]:
65 | # From IOB or IOB2
66 |
67 | prev_word_tag_tmp: str = ""
68 | for i in range(len(tags)):
69 | tag = tags[i]
70 | if tag == "O":
71 | if prev_word_tag_tmp != "":
72 | try:
73 | prev_b, prev_t = prev_word_tag_tmp.split("-")
74 | except ValueError:
75 | raise ValueError(
76 | f"Error in tag {prev_word_tag_tmp}, unable to split the tag in 2 fields."
77 | )
78 |
79 | if prev_b == "B":
80 | tags[i - 1] = f"U-{prev_t}"
81 | else:
82 | tags[i - 1] = f"L-{prev_t}"
83 |
84 | prev_word_tag_tmp: str = ""
85 |
86 | else:
87 | try:
88 | b, t = tag.split("-")
89 | except ValueError:
90 | raise ValueError(
91 | f"Error in tag {prev_word_tag_tmp}, unable to split the tag in 2 fields."
92 | )
93 |
94 | if prev_word_tag_tmp == "":
95 | if b == "U":
96 | prev_word_tag_tmp = ""
97 | else:
98 | prev_word_tag_tmp = f"B-{t}"
99 |
100 | else:
101 | try:
102 | prev_b, prev_t = prev_word_tag_tmp.split("-")
103 | except ValueError:
104 | raise ValueError(
105 | f"Error in tag {prev_word_tag_tmp}, unable to split the tag in 2 fields."
106 | )
107 |
108 | if b == "U":
109 | if prev_b == "B":
110 | tags[i - 1] = f"U-{prev_t}"
111 | else:
112 | tags[i - 1] = f"L-{prev_t}"
113 |
114 | prev_word_tag_tmp = ""
115 |
116 | elif b == "B":
117 | if prev_b == "B":
118 | tags[i - 1] = f"U-{prev_t}"
119 | else:
120 | tags[i - 1] = f"L-{prev_t}"
121 |
122 | prev_word_tag_tmp = f"B-{t}"
123 |
124 | else:
125 | if prev_t != t:
126 | if prev_b == "B":
127 | tags[i - 1] = f"U-{prev_t}"
128 | else:
129 | tags[i - 1] = f"L-{prev_t}"
130 |
131 | prev_word_tag_tmp = f"B-{t}"
132 | else:
133 | if prev_b == "B":
134 | tags[i - 1] = f"B-{prev_t}"
135 | else:
136 | tags[i - 1] = f"I-{prev_t}"
137 |
138 | prev_word_tag_tmp = f"I-{t}"
139 |
140 | if prev_word_tag_tmp != "":
141 | try:
142 | prev_b, prev_t = prev_word_tag_tmp.split("-")
143 | except ValueError:
144 | raise ValueError(
145 | f"Error in tag {prev_word_tag_tmp}, unable to split the tag in 2 fields."
146 | )
147 |
148 | if prev_b == "B":
149 | tags[-1] = f"U-{prev_t}"
150 | else:
151 | tags[-1] = f"L-{prev_t}"
152 |
153 | return tags
154 |
155 |
156 | def rewrite_labels(labels, encoding: str = "iob2") -> List[str]:
157 | if encoding.lower() == "iob":
158 | return to_iob_encoding(labels)
159 | elif encoding.lower() == "iob2":
160 | return to_iob2_encoding(labels)
161 | elif encoding.lower() == "bilou":
162 | return to_bilou_encoding(labels)
163 | else:
164 | raise NotImplementedError(
165 | f"Encoding {encoding} not supported. Supported encodings [IOB,IOB2,BILOU]"
166 | )
167 |
--------------------------------------------------------------------------------
/scripts/tasks/openie4/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/openie4/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/openie4/__pycache__/desc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/openie4/__pycache__/desc.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/tasks/rams/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/rams/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/tasks/rams/desc.py:
--------------------------------------------------------------------------------
1 | JSON_BASE = "Please give the answer in json format."
2 |
3 | OUTPUT_BASE = [
4 | (
5 | 'Please give the answer in the form "[Answer]: {content}: {role}; ".',
6 | "{word}: {type}; ",
7 | ),
8 | (
9 | 'Please give the answer in the form "[Answer]: (content: {word}, role: {type}); ".',
10 | "(content: {word}, role: {type}); ",
11 | ),
12 | (
13 | "Please give the answer in natural language.",
14 | 'the event role "{type}" is "{word}"; ',
15 | ),
16 | (
17 | "Please give the answer in natural language.",
18 | '"{word}" is the role of "{type}"; ',
19 | ),
20 | (
21 | 'Please give the answer in the form "[Answer]: ({content}, {role}); ".',
22 | "({word}, {type}); ",
23 | ),
24 | (
25 | "What is the role of each word in the described event?",
26 | 'Role of "{word}" is "{type}". ',
27 | ),
28 | (
29 | "Identify the roles and words associated with the event.",
30 | '"{word}" plays the role of "{type}". ',
31 | ),
32 | (
33 | "Can you extract the event arguments and their roles?",
34 | '"{word}" is identified as "{type}". ',
35 | ),
36 | (
37 | "Highlight the key elements and their roles within the {event} context.",
38 | "Element: {word}, Role: {type}; ",
39 | ),
40 | (
41 | "Can you dissect the {event} and label each component with its respective function?",
42 | "Component: {word}, Function: {type}; ",
43 | ),
44 | (
45 | "Identify and describe the roles of different elements in the {event}.",
46 | "Element Identified: {word}, Described Role: {type}; ",
47 | ),
48 | (
49 | "Break down the {event} into its essential parts and explain their significance.",
50 | "Essential Part: {word}, Significance: {type}; ",
51 | ),
52 | (
53 | "Identify the key elements and their functions within the given event.",
54 | "Element: {word}, Function: {type}. ",
55 | ),
56 | (
57 | "What are the components and their categories in this event?",
58 | "Component: {word} falls under the category: {type}. ",
59 | ),
60 | (
61 | "Break down the event into its constituents and describe their roles.",
62 | "Constituent: {word} plays the role of: {type}. ",
63 | ),
64 | ]
65 | # EXPLAN较为奇怪,因此去除EAE部分的output_explain
66 | OUTPUT_EXPLAN = [
67 | 'In the given context, the event "{event}" is associated with the event type "{etype}".',
68 | 'Within the provided context, the event "{event}" is linked to the event type "{etype}".',
69 | 'In the given context, the event "{event}" is connected to the event type "{etype}".',
70 | 'According to the context, the event "{event}" is correlated with the event type "{etype}".',
71 | 'In the provided context, the event "{event}" is aligned with the event type "{etype}".',
72 | 'In the scenario described, the event "{event}" falls under the category of "{etype}".',
73 | 'The narrative provided outlines an event, specifically "{event}", which pertains to the event type "{etype}".',
74 | 'Reflecting on the details, it\'s evident that the event "{event}" is linked to the event type "{etype}".',
75 | 'The analysis reveals that the "{event}" pertains to the "{etype}" category.',
76 | 'Within the context of "{event}", it is evident that it falls under the "{etype}" classification.',
77 | 'Upon dissecting the "{event}", it becomes clear that its nature is best described as "{etype}".',
78 | 'Upon examining the "{event}", it becomes clear that this event is categorized under the "{etype}" type.',
79 | 'Analyzing the "{event}" reveals that it falls under the "{etype}" category, showcasing its diverse elements and their roles.',
80 | 'Delving into the "{event}", it\'s apparent that this event is a manifestation of the "{etype}" type, with each part holding specific significance.',
81 | 'Upon examining the "{event}", it becomes clear that this scenario typifies the "{etype}" category, highlighting the roles and significance of its components.',
82 | 'The detailed analysis of "{event}" reveals its classification as a "{etype}", emphasizing the categorization of its integral parts.',
83 | 'Scrutinizing the event "{event}", it is evident that it exemplifies the "{etype}" genre, delineating the roles of its various constituents.',
84 | ]
85 |
86 | RANDOM_SYMBOLS = [
87 | "A",
88 | "B",
89 | "C",
90 | "D",
91 | "E",
92 | "F",
93 | "G",
94 | "LABEL_1",
95 | "LABEL_2",
96 | "LABEL_3",
97 | "LABEL_4",
98 | "LABEL_5",
99 | "LABEL_6",
100 | ]
101 |
--------------------------------------------------------------------------------
/scripts/tasks/readme.md:
--------------------------------------------------------------------------------
1 | ## Runing
2 |
3 | ### 数据准备
4 |
5 | sh generate_unified_data.sh
6 |
7 | sh generate_mixtural_train_data.sh
8 |
9 | ### 训练
10 |
11 | 详情看 Train4Llama 中的 README_ours.md
12 |
13 | ### 评估数据生成
14 |
15 | sh generate_test_set.sh
16 |
17 | ## Details
18 |
19 | 数据集处理参数(diverse的方式):
20 |
21 | ```json
22 | config={
23 | #sample:在NER/ED/EAE任务中,对于每个选项的举例:比如:'Person':such as 'Bob', 'Amy'...
24 | "SAMPLE_RATE":0.5, #在生成的数据中,总共有SAMPLE_RATE比例的数据,instruction里面含有sample
25 | "EACH_SAMPLE_NUM":5, #一个option含有EACH_SAMPLE_NUM个samples
26 | "LIMIT_SAMPLE":8, #由于context length的限制,至多有LIMIT_SAMPLE个option有samples
27 |
28 | #output_json:输出格式为json
29 | "JSON_RATE":0.1,
30 |
31 | #desc:对于tacred/fewrel..数据集
32 | "DESC_RATE":0.3, #在生成的数据中,总共有DESC_RATE比例的数据,instruction里面含有option的description
33 | "LIMIT_DESC":8,
34 | "ISALL":0.67,#在所有数据中,(SAMPLE_RATE+DESC_RATE)*ISALL比例的数据,它们的每个option都有samples/description
35 |
36 | #b-s type:对于ace等数据集含有大分类和小分类,进行大分类的概率
37 | "BIG_TYPE":0.2,
38 |
39 | #args
40 | "CLASS_DROPOUT_RATE":0.1,
41 | "UNCOMPELETE_OPTION":True, #对于option本身的diverse
42 |
43 | #explain相关
44 | "EXPLAIN_RATE":0.5,
45 |
46 | #few-shot相关
47 | "NUM_FEWSHOT_Limit":8,
48 | "WORD_Limit":1200,
49 | "EXAM_NA_RATE":0.0
50 | }
51 | ```
52 |
53 | generate_unified_data.sh 之后生成的数据格式
54 | (load_data_fs后生成的数据格式)
55 |
56 | ```json
57 | {
58 | "id": int,
59 | "instruction": "",
60 | "query": [bool, ""], # [0]:是否有explanation,[1]:指定何种格式回答
61 | "examples": [
62 | ["",""], #[input, output],其中input为text,output为按照指定格式回答的输出
63 | ...
64 | ]
65 | "input": "", #query_text
66 | "reference": "", #标准格式下的答案,进行后续evaluate使用
67 | "output": "" #指定格式下的答案
68 | }
69 | ```
70 |
71 | generate_mixtural_train_data.sh 之后生成的数据格式
72 |
73 | ```json
74 | [
75 | {
76 | "instruction": "",
77 | "input": "",
78 | "system": "", #ondemand ie
79 | "history":[]
80 | }
81 | ]
82 | ```
--------------------------------------------------------------------------------
/scripts/tasks/richere-eae/open_evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import copy
4 | import argparse
5 |
6 |
7 | rich_role = {
8 | "NA": 0,
9 | "entity": 1,
10 | "person": 2,
11 | "position": 3,
12 | "place": 4,
13 | "giver": 5,
14 | "recipient": 6,
15 | "money": 7,
16 | "time": 8,
17 | "attacker": 9,
18 | "target": 10,
19 | "victim": 11,
20 | "defendant": 12,
21 | "crime": 13,
22 | "agent": 14,
23 | "sentence": 15,
24 | "thing": 16,
25 | "artifact": 17,
26 | "origin": 18,
27 | "audience": 19,
28 | "prosecutor": 20,
29 | "plaintiff": 21,
30 | "destination": 22,
31 | "instrument": 23,
32 | "adjudicator": 24,
33 | "org": 25,
34 | "beneficiary": 26,
35 | }
36 |
37 |
38 | def comput_f1(input_file):
39 | tp = 0
40 | n_gold = 0
41 | n_pred = 0
42 | # input_file = os.path.join(input_file, 'test.json')
43 | with open(input_file) as f:
44 | for line in f.readlines():
45 | instance = json.loads(line.strip())
46 | # gold triple
47 | gold_text = instance["messages"][-1]["content"].strip().lower()
48 | gold_text = "".join(gold_text.split())
49 | gold_label = set()
50 | for label in gold_text.split(";"):
51 | if label:
52 | gold_label.add(label)
53 | # pred triple
54 | pred_text = ""
55 | if "[Answer]:" in instance["output"]:
56 | pred_text = instance["output"].split("[Answer]:")[1].strip().lower()
57 | else:
58 | pred_text = instance["output"].lower()
59 | pred_text = "".join(pred_text.split())
60 |
61 | pred_label = set()
62 | for label in pred_text.split(";"):
63 | label = label.strip()
64 | if label != "" and "NA" not in label and ":" in label:
65 | word = label.split(":")[0].strip()
66 | role = label.split(":")[1].strip()
67 | if role in rich_role:
68 | pred_label.add(":".join([l.strip() for l in label.split(":")]))
69 |
70 | label_stack = []
71 | for vert in gold_label:
72 | label_stack.append(vert)
73 | # print('gold', gold_label)
74 | # print('pred', pred_label)
75 | for label in pred_label:
76 | if label in label_stack:
77 | tp += 1
78 | label_stack.remove(label)
79 | n_gold += len(gold_label)
80 | n_pred += len(pred_label)
81 | precision = tp / (n_pred + 1e-10)
82 | recall = tp / (n_gold + 1e-10)
83 | f1 = 2 * precision * recall / (precision + recall + 1e-10)
84 | return {"precision": precision, "recall": recall, "f1": f1}
85 |
86 |
87 | if __name__ == "__main__":
88 | parser = argparse.ArgumentParser(description="Evaluate Event Detection")
89 | # IO
90 | parser.add_argument(
91 | "--input_dir",
92 | type=str,
93 | default="//SSD_DATA/qyj/open-instruct-main/saves/llama_v2_7B/full/no_output_desc/predict/zeroshot/RichERE-eae.jsonl",
94 | )
95 | # parser.add_argument("--output_dir", type=str, default='/home/qijy/workspace/Alignment_on_IE_tasks/Train_code/LLaMA-Factory-main/saves/LLaMA2-7B-Chat/lora/predict/ace2005-eae/generated_predictions.jsonl')
96 |
97 | args = parser.parse_args()
98 | # args.input_dir = '/Users/chenjianhui/Code/Academics/LLM-on-Complex-Tasks/unified_data/TAC-KBP/test.json'
99 |
100 | result = comput_f1(args.input_dir)
101 | # with open(args.output_dir, 'w') as f:
102 | # json.dump(result, f, indent=4)
103 | print(result)
104 |
--------------------------------------------------------------------------------
/scripts/tasks/richere-ed/open_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import copy
4 | import argparse
5 |
6 | rich_ed = {
7 | "NA": 0,
8 | "endposition": 1,
9 | "transfermoney": 2,
10 | "fine": 3,
11 | "attack": 4,
12 | "die": 5,
13 | "startposition": 6,
14 | "convict": 7,
15 | "chargeindict": 8,
16 | "demonstrate": 9,
17 | "transaction": 10,
18 | "sentence": 11,
19 | "pardon": 12,
20 | "releaseparole": 13,
21 | "transferownership": 14,
22 | "transportartifact": 15,
23 | "arrestjail": 16,
24 | "broadcast": 17,
25 | "contact": 18,
26 | "divorce": 19,
27 | "execute": 20,
28 | "trialhearing": 21,
29 | "meet": 22,
30 | "sue": 23,
31 | "transportperson": 24,
32 | "beborn": 25,
33 | "marry": 26,
34 | "injure": 27,
35 | "correspondence": 28,
36 | "elect": 29,
37 | "nominate": 30,
38 | "acquit": 31,
39 | "startorg": 32,
40 | "extradite": 33,
41 | "endorg": 34,
42 | "appeal": 35,
43 | "declarebankruptcy": 36,
44 | "mergeorg": 37,
45 | "artifact": 38,
46 | }
47 |
48 |
49 | def comput_f1(input_file):
50 | tp = 0
51 | n_gold = 0
52 | n_pred = 0
53 | with open(input_file) as f:
54 | for line in f.readlines():
55 | instance = json.loads(line.strip())
56 |
57 | # gold triple
58 | gold_text = instance["messages"][-1]["content"].strip()
59 | gold_text = "".join(gold_text.split())
60 | gold_label = set()
61 | for label in gold_text.split(";"):
62 | if label:
63 | gold_label.add(label)
64 | # pred triple
65 | pred_text = instance["output"]
66 | if "[Answer]:" in instance["output"]:
67 | pred_text = instance["output"].split("[Answer]:")[1].strip()
68 | else:
69 | pred_text = instance["output"]
70 | pred_text = "".join(pred_text.split())
71 |
72 | pred_label = set()
73 | for label in pred_text.split(";"):
74 | label = label.strip()
75 | if label != "" and "NA" not in label and ":" in label:
76 | word = label.split(":")[0].strip()
77 | role = label.split(":")[1].strip()
78 | if role in rich_ed:
79 | pred_label.add(":".join([l.strip() for l in label.split(":")]))
80 |
81 | label_stack = []
82 | for vert in gold_label:
83 | label_stack.append(vert)
84 | # print('gold', gold_label)
85 | # print('pred', pred_label)
86 |
87 | for label in pred_label:
88 | for vert in label_stack:
89 | if label == vert:
90 | tp += 1
91 | label_stack.remove(vert)
92 | break
93 |
94 | n_gold += len(gold_label)
95 | n_pred += len(pred_label)
96 | precision = tp / (n_pred + 1e-10)
97 | recall = tp / (n_gold + 1e-10)
98 | f1 = 2 * precision * recall / (precision + recall + 1e-10)
99 | print("gold:", n_gold)
100 | print("pred:", n_pred)
101 | print("tp:", tp)
102 | return {"precision": precision, "recall": recall, "f1": f1}
103 |
104 |
105 | if __name__ == "__main__":
106 | parser = argparse.ArgumentParser(description="Evaluate Event Detection")
107 | # IO
108 | parser.add_argument(
109 | "--input_dir",
110 | type=str,
111 | default="/SSD_DATA/qyj/open-instruct-main/saves/llama_v2_7B/full/no_output_desc/predict/zeroshot/RichERE-ed.jsonl",
112 | )
113 | parser.add_argument("--output_dir", type=str, default="result.json")
114 |
115 | args = parser.parse_args()
116 | # args.input_dir = '/Users/chenjianhui/Code/Academics/LLM-on-Complex-Tasks/unified_data/TAC-KBP/test.json'
117 |
118 | result = comput_f1(args.input_dir)
119 | print(result)
120 | # with open(args.output_dir, 'w') as f:
121 | # json.dump(result, f, indent=4)
122 | # print(result)
123 |
--------------------------------------------------------------------------------
/scripts/tasks/semeval/open_evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import re
4 | import argparse
5 |
6 | OPTIONS = "Component-Whole(e2,e1), Instrument-Agency(e2,e1), Member-Collection(e1,e2), Cause-Effect(e2,e1), Entity-Destination(e1,e2), Content-Container(e1,e2), Message-Topic(e1,e2), Product-Producer(e2,e1), Member-Collection(e2,e1), Entity-Origin(e1,e2), Cause-Effect(e1,e2), Component-Whole(e1,e2), Message-Topic(e2,e1), Product-Producer(e1,e2), Entity-Origin(e2,e1), Content-Container(e2,e1), Instrument-Agency(e1,e2), Entity-Destination(e2,e1)"
7 |
8 | STR = " and "
9 | NewF = False
10 |
11 | semeval_rel = OPTIONS.lower()
12 | semeval_rel = semeval_rel.split(", ")
13 | for i, item in enumerate(semeval_rel):
14 | e1 = item.split("-")[0]
15 | e2 = item.split("-")[1].split("(")[0]
16 | if NewF == True:
17 | if "(e1,e2)" in item:
18 | semeval_rel[i] = "is_" + e1 + "_and_" + e2 + "_is"
19 | else:
20 | semeval_rel[i] = "is_" + e2 + "_and_" + e1 + "_is"
21 | else:
22 | if "(e1,e2)" in item:
23 | semeval_rel[i] = e1 + STR + e2
24 | else:
25 | semeval_rel[i] = e2 + STR + e1
26 | print(semeval_rel)
27 |
28 |
29 | def parse_triple(text):
30 | pattern = re.compile("\((.*); (.*); (.*)\)")
31 | # print("text:",text)
32 | triple = re.findall(pattern, text)
33 |
34 | if len(triple) == 0:
35 | return None
36 |
37 | t = set()
38 | for vert in triple:
39 | t.add(vert)
40 | tlist = []
41 | for vert in t:
42 | tlist.append(list(vert))
43 | return tlist
44 |
45 |
46 | def comput_f1(input_file):
47 | tp = 0
48 | n_gold = 0
49 | n_pred = 0
50 | id = -1
51 | with open(input_file) as f:
52 | for line in f.readlines():
53 | id += 1
54 | instance = json.loads(line.strip())
55 | # gold triple
56 | gold_triples = []
57 | gold_text = instance["messages"][-1]["content"].strip().lower()
58 | if gold_text != "" and gold_text != "[answer]: na":
59 | gold_triple = parse_triple(gold_text)
60 | # print(gold_text)
61 | assert gold_triple is not None
62 | gold_triples = gold_triple
63 |
64 | pred_triples = []
65 | pred_text = ""
66 | if "[Answer]:" in instance["output"]:
67 | pred_text = instance["output"].split("[Answer]:")[1].strip().lower()
68 | else:
69 | pred_text = instance["output"].strip().lower()
70 |
71 | pred_triple = parse_triple(pred_text)
72 | if pred_triple is not None:
73 | pred_triples = pred_triple
74 |
75 | for triple in gold_triples:
76 | if triple in pred_triples:
77 | tp += 1
78 | else:
79 | # triple=list(triple)[0]
80 | # print("triple:",triple)
81 | e1 = triple[0]
82 | e2 = triple[2]
83 | if NewF == True:
84 | r1 = triple[1].split("_")[1]
85 | r2 = triple[1].split("_")[3]
86 | new_triple = [e2, "is_" + r2 + "_and_" + r1 + "_is", e1]
87 | else:
88 | r1 = triple[1].split(STR)[0]
89 | r2 = triple[1].split(STR)[1]
90 | new_triple = [e2, r2 + STR + r1, e1]
91 |
92 | if new_triple in pred_triples:
93 | tp += 1
94 | # print("pred:",pred_triples)
95 | # print("gold:",gold_triples)
96 |
97 | PFlag = False
98 | for triple in pred_triples:
99 | if triple[1] in semeval_rel:
100 | if len(gold_triples) > 0:
101 | if (
102 | triple[0] == gold_triples[0][0]
103 | and triple[2] == gold_triples[0][2]
104 | ):
105 | n_pred += 1
106 | PFlag = True
107 | if (
108 | triple[2] == gold_triples[0][0]
109 | and triple[0] == gold_triples[0][2]
110 | ):
111 | n_pred += 1
112 | PFlag = True
113 | else:
114 | n_pred += 1
115 | PFlag = True
116 |
117 | if PFlag == True or len(pred_triples) == 0:
118 | n_gold += len(gold_triples)
119 |
120 | print("tp:", tp)
121 | print("n_pred:", n_pred)
122 | print("n_gold:", n_gold)
123 | # precision = tp / n_pred
124 | # recall = tp / n_gold
125 | # if precision+recall==0:
126 | # f1=0
127 | # else:
128 | # f1 = 2 * precision * recall / (precision + recall)
129 | precision = tp / (n_pred + 1e-10)
130 | recall = tp / (n_gold + 1e-10)
131 | f1 = 2 * precision * recall / (precision + recall + 1e-10)
132 |
133 | return {"precision": precision, "recall": recall, "f1": f1}
134 |
135 |
136 | if __name__ == "__main__":
137 | parser = argparse.ArgumentParser(description="Evaluate Event Detection")
138 | # IO
139 | parser.add_argument(
140 | "--input_dir",
141 | type=str,
142 | default="/SSD_DATA/qyj/open-instruct-main/saves/llama_v2_7B/full/no_output_desc/predict/zeroshot/RichERE-ed.jsonl",
143 | )
144 |
145 | args = parser.parse_args()
146 |
147 | result = comput_f1(args.input_dir)
148 | print(result)
149 |
--------------------------------------------------------------------------------
/scripts/tasks/tacred/__pycache__/desc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/tasks/tacred/__pycache__/desc.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/utils/DPO/__pycache__/ref_query.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/utils/DPO/__pycache__/ref_query.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/utils/DPO/compute_metric_4OpenInstruct.py:
--------------------------------------------------------------------------------
1 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
2 | import json
3 | import os
4 | import jsonlines
5 | from tqdm import tqdm
6 | import random
7 | from pathlib import Path
8 | import evaluate
9 | import argparse
10 |
11 |
12 | def get_score(pred, gold, metric):
13 | return metric.compute(predictions=pred, references=gold)["rougeLsum"]
14 |
15 |
16 | def GenerateDPOunifiedData_4OpenInstructs(path, metric):
17 | trainInf_file = path + ".jsonl"
18 | out_bleu_file = path + "_" + metric + ".json"
19 | out_dpo_file = path + "_" + metric + "_4train.json"
20 | Limit_DPO_data = 50000
21 | if metric == "rough-l":
22 | metric_evaluate = evaluate.load("rouge")
23 |
24 | infdata = []
25 | with open(trainInf_file, "r", encoding="utf-8") as reader:
26 | for item in jsonlines.Reader(reader):
27 | infdata.append(item)
28 | reader.close()
29 |
30 | bleudata = []
31 | sum_bleu = 0
32 | for i, instance in enumerate(infdata):
33 | gold_text = instance["messages"][-1]["content"].split("[Answer]: ")[1]
34 | infdata[i]["metric"] = metric
35 | try:
36 | pre_text = instance["output"].split("[Answer]: ")[1]
37 | if metric == "bleu":
38 | bleu_score = sentence_bleu(
39 | [list(gold_text)],
40 | list(pre_text),
41 | smoothing_function=SmoothingFunction().method3,
42 | )
43 | infdata[i]["score"] = bleu_score
44 | elif metric == "rough-l":
45 | print("================")
46 | print("pre_text:", pre_text)
47 | print("gold_text:", gold_text)
48 | cur_score = get_score([pre_text], [gold_text], metric_evaluate)
49 | infdata[i]["score"] = cur_score
50 | print("cur_score:", cur_score)
51 | else:
52 | print("No metric")
53 | sum_bleu += infdata[i]["score"]
54 | except:
55 | # prediction的时候失序,出现了过量重复的输出,怀疑可能是和cutoff有关
56 | # pre_text=instance["output"]
57 | # print(pre_text)
58 | # continue
59 | infdata[i]["score"] = -1
60 |
61 | bleudata.append(instance)
62 |
63 | out_file = open(path + "_" + metric + ".json", "w", encoding="utf-8")
64 | json.dump(bleudata, out_file, indent=2)
65 | out_file.close()
66 |
67 | bleudata.sort(key=lambda x: x["score"]) # 从小到大排序
68 |
69 | out_file = open(path + "_" + metric + "_sorted.json", "w", encoding="utf-8")
70 | json.dump(bleudata, out_file, indent=2)
71 | out_file.close()
72 |
73 | sum_bleu = sum_bleu / len(bleudata)
74 | print("avg_score:", sum_bleu)
75 |
76 | count = {}
77 | unified_data = []
78 | cnt = -1
79 | for i, instance in enumerate(bleudata):
80 | if instance["score"] == -1:
81 | continue
82 | cnt += 1
83 | if cnt >= Limit_DPO_data:
84 | break
85 | # if instance["dataset"]=="ondemand":
86 | # continue
87 | if instance["score"] >= 1:
88 | print("Now is score >= 1.0:", i)
89 | break
90 | instance["chosen"] = instance["messages"]
91 | instance["rejected"] = instance["messages"][:-1]
92 | reject_output = {"role": "assistant", "content": instance["output"]}
93 | instance["rejected"].append(reject_output)
94 | if instance["dataset"] not in count:
95 | count[instance["dataset"]] = 0
96 | count[instance["dataset"] + "_avg_delta"] = 0
97 | count[instance["dataset"] + "_avg_delta"] = (
98 | (count[instance["dataset"] + "_avg_delta"] * count[instance["dataset"]])
99 | + (1 - instance["score"])
100 | ) / (count[instance["dataset"]] + 1)
101 | count[instance["dataset"]] += 1
102 | unified_data.append(instance)
103 | print(count)
104 |
105 | out_file = open(out_dpo_file, "w", encoding="utf-8") ###########################
106 | json.dump(unified_data, out_file, indent=2)
107 | out_file.close()
108 |
109 | count["total"] = len(unified_data)
110 | out_file = open(
111 | "info_t_1.0.txt", "w", encoding="utf-8"
112 | ) ###########################
113 | json.dump(count, out_file, indent=2)
114 | out_file.close()
115 |
116 |
117 | if __name__ == "__main__":
118 | parser = argparse.ArgumentParser(description="mixture")
119 | parser.add_argument("--input_path", type=str)
120 | args = parser.parse_args()
121 | path = args.input_path
122 | GenerateDPOunifiedData_4OpenInstructs(path, "bleu") # bleu rough-l
123 | # test()
124 |
--------------------------------------------------------------------------------
/scripts/utils/DPO/merge.py:
--------------------------------------------------------------------------------
1 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
2 | import json
3 | import os
4 | import jsonlines
5 | from tqdm import tqdm
6 | import random
7 | from pathlib import Path
8 | import copy
9 |
10 |
11 | files = [
12 | "../unified_data/train_mixture/sample4dpo_results/ADELIE-SFT/T_1.0_1/mix_vDPO_bleu.json",
13 | "../unified_data/train_mixture/sample4dpo_results/ADELIE-SFT/T_1.0_2/mix_vDPO_bleu.json",
14 | "../unified_data/train_mixture/sample4dpo_results/ADELIE-SFT/T_1.0_3/mix_vDPO_bleu.json",
15 | "../unified_data/train_mixture/sample4dpo_results/ADELIE-SFT/T_1.0_4/mix_vDPO_bleu.json",
16 | "../unified_data/train_mixture/sample4dpo_results/ADELIE-SFT/T_1.0_5/mix_vDPO_bleu.json",
17 | ]
18 | gold_rate = 0.7
19 | No_Gold = False
20 | Limit_DPO_data = 10000
21 | Limit_delta = 0.1
22 | out_dpo_file = "../unified_data/train_mixture/IEFeedback.json"
23 | info_file = "../unified_data/train_mixture/info.IEFeedback.txt"
24 |
25 |
26 | def GenerateDPOunifiedData_Online_4OpenInstructs():
27 |
28 | infdata = []
29 | length = 0
30 | for f in files:
31 | with open(f, "r", encoding="utf-8") as reader:
32 | d = json.load(reader)
33 | reader.close()
34 | if length == 0:
35 | length = len(d)
36 | temp = len(d)
37 | assert length == temp, "error match length"
38 | infdata.append(d)
39 |
40 | count = {}
41 | unified_data = []
42 | cnt = -1
43 | samescore = 0
44 | gold_count = 0
45 |
46 | # gold
47 | gold_num = int(Limit_DPO_data * gold_rate)
48 | print(gold_num)
49 |
50 | not_in = []
51 | for i, instance in enumerate(infdata[0]):
52 | c_num = -1.0
53 | r_num = 2.0
54 | id = instance["id"]
55 | for j in range(0, len(files)):
56 | assert infdata[j][i]["id"] == id, "error match id"
57 | if infdata[j][i]["score"] > c_num:
58 | chosen_output = {
59 | "role": "assistant",
60 | "content": infdata[j][i]["output"],
61 | }
62 | c_num = infdata[j][i]["score"]
63 | if infdata[j][i]["score"] < r_num:
64 | reject_output = {
65 | "role": "assistant",
66 | "content": infdata[j][i]["output"],
67 | }
68 | r_num = infdata[j][i]["score"]
69 |
70 | delta = abs(c_num - r_num)
71 |
72 | if delta <= Limit_delta:
73 | not_in.append(i)
74 | continue
75 |
76 | cnt += 1
77 |
78 | if cnt >= Limit_DPO_data - gold_num:
79 | not_in.append(i)
80 | continue
81 |
82 | instance = {
83 | "dataset": instance["dataset"],
84 | "id": instance["id"],
85 | "messages": instance["messages"],
86 | "chosen": instance["messages"][:-1],
87 | "rejected": instance["messages"][:-1],
88 | }
89 |
90 | instance["rejected"].append(reject_output)
91 | instance["chosen"].append(chosen_output)
92 |
93 | if instance["dataset"] not in count:
94 | count[instance["dataset"]] = 0
95 | count[instance["dataset"] + "_avg_delta"] = 0
96 | count[instance["dataset"] + "_avg_delta"] = (
97 | (count[instance["dataset"] + "_avg_delta"] * count[instance["dataset"]])
98 | + delta
99 | ) / (count[instance["dataset"]] + 1)
100 | count[instance["dataset"]] += 1
101 | unified_data.append(instance)
102 | samescore += delta
103 |
104 | gold_idx = random.sample(not_in, gold_num)
105 |
106 | for i, instance in enumerate(infdata[0]):
107 | if i not in gold_idx:
108 | continue
109 | c_num = 1.0
110 | r_num = 2.0
111 | id = instance["id"]
112 | for j in range(0, len(files)):
113 | assert infdata[j][i]["id"] == id, "error match id"
114 | if infdata[j][i]["score"] < r_num:
115 | reject_output = {
116 | "role": "assistant",
117 | "content": infdata[j][i]["output"],
118 | }
119 | r_num = infdata[j][i]["score"]
120 | chosen_output = infdata[j][i]["messages"][-1]
121 |
122 | delta = abs(c_num - r_num)
123 |
124 | if delta <= Limit_delta:
125 | continue
126 |
127 | gold_count += 1
128 | instance = {
129 | "dataset": instance["dataset"],
130 | "id": instance["id"],
131 | "messages": instance["messages"],
132 | "chosen": instance["messages"][:-1],
133 | "rejected": instance["messages"][:-1],
134 | }
135 |
136 | instance["rejected"].append(reject_output)
137 | instance["chosen"].append(chosen_output)
138 |
139 | if instance["dataset"] not in count:
140 | count[instance["dataset"]] = 0
141 | count[instance["dataset"] + "_avg_delta"] = 0
142 | count[instance["dataset"] + "_avg_delta"] = (
143 | (count[instance["dataset"] + "_avg_delta"] * count[instance["dataset"]])
144 | + delta
145 | ) / (count[instance["dataset"]] + 1)
146 | count[instance["dataset"]] += 1
147 | unified_data.append(instance)
148 | samescore += delta
149 |
150 | out_file = open(out_dpo_file, "w", encoding="utf-8") ###########################
151 | json.dump(unified_data, out_file, indent=2)
152 | out_file.close()
153 |
154 | count["total"] = len(unified_data)
155 | print(count)
156 | print(gold_count, count["total"])
157 | print("gold_rate: ", gold_count / count["total"])
158 | print("avg_delta: ", samescore / count["total"])
159 | out_file = open(info_file, "w", encoding="utf-8") ###########################
160 | json.dump(count, out_file, indent=2)
161 | out_file.close()
162 |
163 |
164 | if __name__ == "__main__":
165 | GenerateDPOunifiedData_Online_4OpenInstructs()
166 | # test()
167 |
--------------------------------------------------------------------------------
/scripts/utils/DPO/ref_query.py:
--------------------------------------------------------------------------------
1 | QUERY = {
2 | # NER
3 | "conll-2003": (
4 | 'Please give the answer in the form "[Answer]: {entity}: {type}; ".',
5 | "{entity}: {type}; ",
6 | ),
7 | "ontonote5": (
8 | 'Please give the answer in the form "[Answer]: {entity}: {type}; ".',
9 | "{entity}: {type}; ",
10 | ),
11 | "ace2005-ner": (
12 | 'Please give the answer in the form "[Answer]: {entity}: {type}; ".',
13 | "{entity}: {type}; ",
14 | ),
15 | # RC
16 | "tacred": (
17 | 'Please give the answer in the tuple form "[Answer]: ({subject}; {relation}; {object})\n".',
18 | "({head}; {type}; {tail})\n",
19 | ),
20 | "fewrel": (
21 | 'Please give the answer in the tuple form "[Answer]: ({subject}; {relation}; {object})\n".',
22 | "({head}; {type}; {tail})\n",
23 | ),
24 | # ED
25 | "ace2005-ed": (
26 | 'Please give the answer in the form "[Answer]: {event}: {class}; ".',
27 | "{word}: {type}; ",
28 | ),
29 | "maven-ed": (
30 | 'Please give the answer in the form "[Answer]: {event}: {class}; ".',
31 | "{word}: {type}; ",
32 | ),
33 | # EAE
34 | "ace2005-eae": (
35 | 'Please give the answer in the form "[Answer]: {word}: {role}; ".',
36 | "{word}: {type}; ",
37 | ),
38 | "maven-eae": (
39 | 'Please give the answer in the form "[Answer]: {word}: {role}; ".',
40 | "{word}: {type}; ",
41 | ),
42 | "RAMS-eae": (
43 | 'Please give the answer in the form "[Answer]: {word}: {role}; ".',
44 | "{word}: {type}; ",
45 | ),
46 | # ERE -- 需要进一步处理一下
47 | "MAVEN-ERE": (
48 | 'Please give the answer in the list form "[Answer]: [{relation}]".',
49 | "[{type}]",
50 | ),
51 | # OPENIE -- 需要进一步处理一下
52 | "openie4": (
53 | 'Please give the answer in the tuple form "[Answer]: ({predicate}; {subject}; {object}; {time}; {location})\n". If one or more of the last three elements does not exist, it can be omitted.',
54 | "",
55 | ),
56 | # OndemandIE 数据集不处理
57 | }
58 |
--------------------------------------------------------------------------------
/scripts/utils/filter_train_NA_data.py:
--------------------------------------------------------------------------------
1 | # 对于数据集中的NA数据进行处理,保证其比例<=0.2
2 |
3 | import random
4 | import json
5 | import jsonlines
6 | import os
7 | from pathlib import Path
8 |
9 | NA_rate = 0.2
10 |
11 |
12 | def sample_train():
13 | path = "../unified_data"
14 | path_list = os.listdir(path)
15 | # path_list.remove("tuluv2")
16 | # path_list.remove("ondemandIE")
17 | path_list.remove("train_mixture")
18 | path_list.remove("test_format")
19 | # path_list.remove("re")
20 | # path_list.remove("ee")
21 | # path_list.remove("other_rc")
22 | # path_list.remove("other_ner")
23 |
24 | valid_data = []
25 | NA_data = []
26 | for dirname in path_list:
27 | file_list = os.listdir(os.path.join(path, dirname))
28 | filename = ""
29 | for f in file_list:
30 | if "train" in f and "after_filter" not in f:
31 | filename = f
32 | break
33 | if filename == "":
34 | continue
35 | if filename[-6:] != ".jsonl":
36 | continue
37 |
38 | # if dirname != "MAVEN-ERE-MoreDoc":
39 | # continue
40 |
41 | if dirname == "ee" or dirname == "re" or dirname == "other_rc":
42 | continue
43 |
44 | sum_valid_data = 0
45 | valid_data.clear()
46 | NA_data.clear()
47 | with open(
48 | os.path.join(path, dirname, filename), "r", encoding="utf-8"
49 | ) as reader:
50 | for item in jsonlines.Reader(reader):
51 | if (
52 | item["reference"] == "NA"
53 | or item["reference"] == "[]"
54 | or item["reference"] == ""
55 | or item["reference"] == []
56 | or item["reference"] == "none"
57 | ):
58 | NA_data.append(item)
59 | else:
60 | valid_data.append(item)
61 | reader.close()
62 | sum_valid_data = len(valid_data)
63 | num_NA_data = int((sum_valid_data / (1.0 - NA_rate)) * NA_rate)
64 | print("-------------------------")
65 | print("filter_dataset:", dirname)
66 | print("sum_valid_data | sum_NA_data | num_resume_NA_data")
67 | print(sum_valid_data, " | ", len(NA_data), " | ", num_NA_data)
68 | if num_NA_data < len(NA_data):
69 | NA_data = random.sample(NA_data, num_NA_data)
70 | print("sample NA data:", len(NA_data))
71 | valid_data = valid_data + NA_data
72 | random.shuffle(valid_data)
73 | print("After filter total:", len(valid_data))
74 |
75 | out_file = open(
76 | os.path.join(path, dirname, "after_filter_train_new.jsonl"),
77 | "w",
78 | encoding="utf-8",
79 | )
80 | for vert in valid_data:
81 | out_file.write(json.dumps(vert) + "\n")
82 | out_file.close()
83 |
84 |
85 | if __name__ == "__main__":
86 | sample_train()
87 |
--------------------------------------------------------------------------------
/scripts/utils/gpts/GenerateIdx.py:
--------------------------------------------------------------------------------
1 | import json
2 | import jsonlines
3 | import argparse
4 |
5 | if __name__ == '__main__':
6 | idx=set()
7 | parser = argparse.ArgumentParser(description="generate")
8 | # I/O
9 | parser.add_argument("--input_file", type=str, default="../../../unified_data/RAMS-eae/gpt4cot.jsonl")
10 | parser.add_argument("--output_file", type=str, default="../../../unified_data/RAMS-eae/CotIdx.txt")
11 | args = parser.parse_args()
12 |
13 |
14 | with open(args.input_file, 'r') as f:
15 | data = json.load(f)
16 | f.close()
17 |
18 | for vert in data["request_states"]:
19 | idx.add(vert["instance"]["id"])
20 |
21 | with open(args.output_file, 'w') as f:
22 | for vert in idx:
23 | f.write(str(vert)+"\n")
24 | f.close()
--------------------------------------------------------------------------------
/scripts/utils/gpts/GenerateInstance4GPT.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | import json
3 | import argparse
4 | import random
5 | from Prompt import Prompt,Content,Demo
6 |
7 |
8 | def get_input(d,rate=0.7):
9 | instruction=d["instruction"]
10 | instruction+=d["query"][1]
11 |
12 | text_Flag=False
13 | if "" in instruction:
14 | id=(int)(random.random()<=rate)
15 | if id:
16 | instruction=instruction.replace("","Text: '"+d["input"]+"'")
17 | text_Flag=True
18 |
19 | if text_Flag==False:
20 | input="Text: '"+d["input"]+"'"
21 | return instruction+" "+input
22 | else:
23 | input=""
24 | return instruction
25 |
26 | def get_output(d):
27 | output=d["output"]
28 | output=output.split("[Answer]: ")[1]
29 | return output
30 |
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser(description="generate")
34 | # I/O
35 | parser.add_argument("--input_file", type=str, default="../../../unified_data/maven-ed/after_filter_train.jsonl")
36 | parser.add_argument("--output_file", type=str, default="../../../unified_data/maven-ed/gpt4explanation.jsonl")
37 | parser.add_argument("--k", type=str, default=1000)
38 |
39 | args = parser.parse_args()
40 |
41 | filename=args.input_file
42 | data={
43 | "prompt": {
44 | "instructions": Prompt,
45 | "input_prefix": "",
46 | "input_suffix": "",
47 | "output_prefix": "",
48 | "output_suffix": "",
49 | "demonstrations": Demo,
50 | },
51 | "request_states": []
52 | }
53 | ex_data=[]
54 | with open(filename, "r", encoding='utf-8') as reader:
55 | for item in jsonlines.Reader(reader):
56 | if item["query"][0]==1:
57 | ex_data.append(item)
58 | reader.close()
59 |
60 | print("num of all query=1 data:",len(ex_data))
61 | k=args.k
62 | print("num of select:",k)
63 | ex_data=random.sample(ex_data,k)
64 |
65 | for vert in ex_data:
66 | unified_data={
67 | "ori_instance": vert,
68 | "instance": {
69 | "input": {
70 | "text": Content.format(input=get_input(vert),output=get_output(vert))
71 | },
72 | "question_type": 1,
73 | "id": vert["id"]
74 | },
75 | "request": {
76 | "result": {
77 | "success": False,
78 | "completions": [
79 | {
80 | "text": ""
81 | }
82 | ]
83 | },
84 | "request_time": 0,
85 | "request_datetime": 1
86 | }
87 | }
88 | data["request_states"].append(unified_data)
89 |
90 | json.dump(data, open(args.output_file, "w"), indent=4)
--------------------------------------------------------------------------------
/scripts/utils/gpts/GenerateInstance4GPT_Cot.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | import json
3 | import argparse
4 | import random
5 | from PromptCot import Prompt,Content,Demo
6 |
7 |
8 | def get_input(d,rate=0.7):
9 | instruction=d["instruction"]
10 | instruction+=d["query"][1]
11 |
12 | text_Flag=False
13 | if "" in instruction:
14 | id=(int)(random.random()<=rate)
15 | if id:
16 | instruction=instruction.replace("","Text: '"+d["input"]+"'")
17 | text_Flag=True
18 |
19 | if text_Flag==False:
20 | input="Text: '"+d["input"]+"'"
21 | return instruction+" "+input
22 | else:
23 | input=""
24 | return instruction
25 |
26 | def get_output(d):
27 | output=d["output"]
28 | output=output.split("[Answer]: ")[1]
29 | return output
30 |
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser(description="generate")
34 | # I/O
35 | parser.add_argument("--input_file", type=str, default="../../../unified_data/other_rc/train.jsonl")
36 | parser.add_argument("--except_file", type=str, default="../../../unified_data/other_rc/ExplanationIdx.txt")
37 | parser.add_argument("--output_file", type=str, default="../../../unified_data/other_rc/gpt4cot.jsonl")
38 | parser.add_argument("--k", type=str, default=500)
39 |
40 | args = parser.parse_args()
41 |
42 | filename=args.input_file
43 | words_num=random.randint(70,170)
44 | data={
45 | "prompt": {
46 | "instructions": Prompt.format(words_number=words_num),
47 | "input_prefix": "",
48 | "input_suffix": "",
49 | "output_prefix": "",
50 | "output_suffix": "",
51 | "demonstrations": Demo,
52 | },
53 | "request_states": []
54 | }
55 | print(data)
56 |
57 | with open(args.except_file, "r", encoding='utf-8') as reader:
58 | except_idx=reader.readlines()
59 | reader.close()
60 |
61 | for i,vert in enumerate(except_idx):
62 | except_idx[i]=vert[:-1]
63 |
64 | # print(except_idx)
65 |
66 | ex_data=[]
67 | with open(filename, "r", encoding='utf-8') as reader:
68 | for item in jsonlines.Reader(reader):
69 | if item["id"] not in except_idx and str(item["id"]) not in except_idx:
70 | if item["query"][0]==1:
71 | ex_data.append(item)
72 | # else:
73 | # print("in",item["id"])
74 | reader.close()
75 |
76 | print("num of all query=1 data:",len(ex_data))
77 | k=args.k
78 | print("num of select:",k)
79 | ex_data=random.sample(ex_data,k)
80 |
81 |
82 | for vert in ex_data:
83 | unified_data={
84 | "ori_instance": vert,
85 | "instance": {
86 | "input": {
87 | "text": Content.format(input=get_input(vert),output=get_output(vert))
88 | },
89 | "question_type": 1,
90 | "id": vert["id"]
91 | },
92 | "request": {
93 | "result": {
94 | "success": False,
95 | "completions": [
96 | {
97 | "text": ""
98 | }
99 | ]
100 | },
101 | "request_time": 0,
102 | "request_datetime": 1
103 | }
104 | }
105 | data["request_states"].append(unified_data)
106 |
107 | json.dump(data, open(args.output_file, "w"), indent=4)
--------------------------------------------------------------------------------
/scripts/utils/gpts/Postprocessing.py:
--------------------------------------------------------------------------------
1 | import json
2 | import jsonlines
3 | import argparse
4 | import re
5 |
6 |
7 | def get_output(d):
8 | output = d["output"]
9 | output = output.split("[Answer]: ")[1]
10 | return output
11 |
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser(description="generate")
15 | # I/O
16 | parser.add_argument(
17 | "--input_file",
18 | type=str,
19 | default="../../../unified_data/openie4/gpt4cot.jsonl",
20 | )
21 | parser.add_argument(
22 | "--output_file",
23 | type=str,
24 | default="../../../unified_data/openie4/gpt4cot_post.jsonl",
25 | )
26 |
27 | args = parser.parse_args()
28 |
29 | with open(args.input_file, "r") as f:
30 | data = json.load(f)
31 | f.close()
32 |
33 | out_file = open(args.output_file, "w", encoding="utf-8")
34 | for vert in data["request_states"]:
35 | ori = vert["ori_instance"]
36 | output = get_output(ori)
37 | result = vert["request"]["result"]["completions"][0]["text"]
38 | if "[Step-by-Step Explanation]:" in result:
39 | result = result.split("[Step-by-Step Explanation]:")[1].strip()
40 | if 'Explanation":' in result:
41 | print("================= need postprocess =================")
42 | print(result)
43 | try:
44 | pattern = r"{[^{}]*}"
45 | matches = re.findall(pattern, result)
46 | d = eval(matches[0])
47 | result = str(d["Step-by-Step Explanation"])
48 | except:
49 | print("==== check =====")
50 | result = result.split('Explanation":')[1].strip()
51 | # result=result["Explanation"]
52 | print(ori["id"])
53 | print(result)
54 | ori["output"] = "[Step-by-Step Explanation]: " + result + " [Answer]: " + output
55 | ori["query"][0] = 2
56 | out_file.write(json.dumps(ori) + "\n")
57 | out_file.close()
58 |
--------------------------------------------------------------------------------
/scripts/utils/gpts/Prompt.py:
--------------------------------------------------------------------------------
1 | Prompt = """Please provide an explanation of the [Answer] based on [Question].
2 | The generated explanation should make use of the content in the [Question] as much as possible, and must be consistent with the [Answer].
3 | It will eventually be provided at the front of the answer.
4 | No more than {words_number} words."""
5 |
6 | Content = """
7 | [Question]: {input}
8 |
9 | [Answer]: {output}
10 | [Explanation]:
11 | """
12 |
13 | Demo = []
14 |
--------------------------------------------------------------------------------
/scripts/utils/gpts/PromptCot.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | Prompt='''Please generate a step-by-step explanation for [Answer] based on [Question], and give reasons for each step.
4 | The generated explanation should make use of the content in the [Question] as much as possible, and must be consistent with the [Answer].
5 | It will eventually be provided at the front of the answer.
6 | No more than {words_number} words.'''
7 |
8 | Content='''
9 | [Question]: {input}
10 |
11 | [Answer]: {output}
12 | [Step-by-Step Explanation]:
13 | '''
14 |
15 | Demo=[]
16 |
--------------------------------------------------------------------------------
/scripts/utils/gpts/__pycache__/Prompt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/utils/gpts/__pycache__/Prompt.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/utils/gpts/__pycache__/Prompt.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/utils/gpts/__pycache__/Prompt.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/utils/gpts/__pycache__/PromptCot.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/utils/gpts/__pycache__/PromptCot.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/utils/gpts/__pycache__/PromptCot.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/utils/gpts/__pycache__/PromptCot.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/utils/gpts/__pycache__/desc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/utils/gpts/__pycache__/desc.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/utils/gpts/run.sh:
--------------------------------------------------------------------------------
1 | python gpt-4.py \
2 | --test_file ../../../unified_data/ace2005-eae/gpt4cot.jsonl
3 |
4 | python gpt-4.py \
5 | --test_file ../../../unified_data/RAMS-eae/gpt4cot.jsonl
6 |
7 | python gpt-4.py \
8 | --test_file ../../../unified_data/ace2005-ed/gpt4cot.jsonl
9 |
10 | python gpt-4.py \
11 | --test_file ../../../unified_data/ace2005-ner/gpt4cot.jsonl
12 |
13 | python gpt-4.py \
14 | --test_file ../../../unified_data/conll-2003/gpt4cot.jsonl
15 |
16 | python gpt-4.py \
17 | --test_file ../../../unified_data/ontonote5/gpt4cot.jsonl
18 |
19 | python gpt-4.py \
20 | --test_file ../../../unified_data/ee/gpt4cot.jsonl
21 |
22 | python gpt-4.py \
23 | --test_file ../../../unified_data/re/gpt4cot.jsonl
24 |
25 | python gpt-4.py \
26 | --test_file ../../../unified_data/fewrel/gpt4cot.jsonl
27 |
28 | python gpt-4.py \
29 | --test_file ../../../unified_data/maven-ed/gpt4cot.jsonl
30 |
31 | python gpt-4.py \
32 | --test_file ../../../unified_data/maven-eae/gpt4cot.jsonl
33 |
34 | python gpt-4.py \
35 | --test_file ../../../unified_data/MAVEN-ERE/gpt4cot.jsonl
36 |
37 | python gpt-4.py \
38 | --test_file ../../../unified_data/openie4/gpt4cot.jsonl
39 |
40 | python gpt-4.py \
41 | --test_file ../../../unified_data/other_ner/gpt4cot.jsonl
42 |
43 | python gpt-4.py \
44 | --test_file ../../../unified_data/other_rc/gpt4cot.jsonl
45 |
46 | python gpt-4.py \
47 | --test_file ../../../unified_data/tacred/gpt4cot.jsonl
48 |
49 |
--------------------------------------------------------------------------------
/scripts/utils/gpts/template/__pycache__/desc4openie.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/utils/gpts/template/__pycache__/desc4openie.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/utils/gpts/template/__pycache__/desc4openie.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/scripts/utils/gpts/template/__pycache__/desc4openie.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/utils/gpts/template/desc.py:
--------------------------------------------------------------------------------
1 | DATASETS=["ace2005-eae","maven-arg","rams","ace2005-ed","maven-ed","ace2005-ner","conll-2003","ontonotes5","fewrel","tacred","maven-ere"]
2 | # 更新 task and elements
3 | PROMPT='''
4 | You need to follow the template list to come up with a set of diverse templates.
5 | The task indicated by this template is the "Event Relation Extraction" task. We need to write the input format and corresponding output format template for it.
6 | The output contains two parts, one is "explanation" and the other is "answer".
7 | The explanation template content should include the following strings to facilitate subsequent replacement of the content: {e1}, {e2}, {entitytype1}, {entitytype2}.
8 | The answer template content should include the following strings to facilitate subsequent replacement of the content: {head}, {type}, {tail}.
9 |
10 | Here are the requirements:
11 | 1. Try not to repeat the verb for each template to maximize diversity.
12 | 2. The language used for the template also should be diverse. For example, use interrogative sentences, imperative sentences, etc.
13 | 3. Do not repeat the format of the answer template, nor repeat the examples given.
14 | 4. The templates should be in English.
15 |
16 | '''
17 | OUTPUT_BASE=[
18 | ('Please give the answer in the tuple form "[Answer]: ({first event}; {relation}; {second event});".','({head}; {type}; {tail}); '),
19 | ('Please tell me what is the relationship between the two events? ','first event is "{head}", second event is "{tail}", the relation is "{type}". '),
20 | ('Please give the answer in the tuple form "[Answer]: (first event: {head}; relation: {type}; second event: {tail}); ".','(first event: {head}; relation: {type}; second event: {tail}); '),
21 | ('Please give the answer in the tuple form "[Answer]: (head: {head}; relation: {type}; tail: {tail}); ".','(head: {head}; relation: {type}; tail: {tail}); '),
22 | ('Please give the answer in natural language.','the head event is "{head}", the tail event is "{tail}", and the relation between them is "{type}".'),
23 | ('Identify the relationship between the following events and present your findings in the format "[Answer]: {head} is related to {tail} by {type}."', '{head} is related to {tail} by {type}.'),
24 | ('How do these events connect? Please format your response as "[Answer]: The connection between {head} and {tail} is described as {type}."', 'The connection between {head} and {tail} is described as {type}.'),
25 | ('Can you determine the link between these occurrences? Respond in the structure "[Answer]: Linking {head} to {tail} through {type}."', 'Linking {head} to {tail} through {type}.'),
26 | ('What\'s the relation between the given events? Use the format "[Answer]: Between {head} and {tail}, the relation is {type}." for your answer.', 'Between {head} and {tail}, the relation is {type}.'),
27 | ('What is the connection between these two events? Please format your response as "[Answer]: The bond between {head} and {tail} is {type}."', 'The bond between {head} and {tail} is {type}.'),
28 | ('How do these events relate to each other? Frame your answer like this: "[Answer]: {head} and {tail} are intertwined by {type}."', '{head} and {tail} are intertwined by {type}.'),
29 | ('Explain the relationship between the given events in the format "[Answer]: A {type} relationship exists between {head} and {tail}."', 'A {type} relationship exists between {head} and {tail}.'),
30 | ('Identify the relationship between the following events.', 'Between the events, {head} acts as the initiator and {tail} as the receiver, establishing a {type} relationship.'),
31 | ('What binds these two events together? Elaborate in your response.', 'The bond that unites {head} with {tail} is fundamentally a {type} connection.'),
32 | ('Can you determine the type of interaction between these events?', 'The interaction type between {head} and {tail} is classified as {type}.')
33 | ]
34 | OUTPUT_EXPLAN=[
35 | 'The two events are "{e1}" (classified as "{entitytype1}") and "{e2}" (classified as "{entitytype2}"). ',
36 | 'Two events, "{e1}" (classified as "{entitytype1}") and "{e2}" (classified as "{entitytype2}"), are recognized. ',
37 | 'The identified events are "{e1}" (classified as "{entitytype1}") and "{e2}" (classified as "{entitytype2}"). ',
38 | 'In the context of event relation extraction, the two events are identified as "{e1}" (classified as "{entitytype1}") and "{e2}" (classified as "{entitytype2}"). ',
39 | 'The events "{e1}" (classified as "{entitytype1}") and "{e2}" (classified as "{entitytype2}") have been distinguished. ',
40 | 'Upon analysis, it\'s found that the events "{e1}" (categorized under "{entitytype1}") and "{e2}" (categorized under "{entitytype2}") are interconnected. ',
41 | 'The connection is between "{e1}", which is an instance of "{entitytype1}", and "{e2}", which falls under the category of "{entitytype2}". ',
42 | 'The linkage is discerned between the occurrences "{e1}" (identified as "{entitytype1}") and "{e2}" (identified as "{entitytype2}"). ',
43 | 'The relationship involves "{e1}", a "{entitytype1}" event, and "{e2}", a "{entitytype2}" event. ',
44 | 'The investigation reveals a connection between the events "{e1}" (identified as "{entitytype1}") and "{e2}" (identified as "{entitytype2}"). ',
45 | 'The analysis indicates that the events "{e1}" (belonging to the category "{entitytype1}") and "{e2}" (belonging to the category "{entitytype2}") have a relationship. ',
46 | 'After thorough examination, it is evident that there is a "{type}" relationship between "{e1}", which is a type of "{entitytype1}", and "{e2}", which is a type of "{entitytype2}". ',
47 | 'An analysis indicates a relationship where "{e1}" (categorized under "{entitytype1}") initiates an action affecting "{e2}" (categorized under "{entitytype2}"). ',
48 | 'Upon examination, the bond linking "{e1}" (termed as "{entitytype1}") with "{e2}" (termed as "{entitytype2}") is uncovered. ',
49 | 'The interaction between the events "{e1}" (designated as "{entitytype1}") and "{e2}" (designated as "{entitytype2}") is scrutinized. '
50 | ]
--------------------------------------------------------------------------------
/scripts/utils/gpts/template/template_generate.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 | import os
3 | import json
4 | import time
5 | import math
6 | import requests
7 | from pathlib import Path
8 | from tqdm import tqdm
9 | import argparse
10 | from multiprocessing import Pool, Lock
11 |
12 | # from desc import OUTPUT_BASE,OUTPUT_EXPLAN,PROMPT
13 | from desc4openie import *
14 | import random
15 | import re
16 | import pprint
17 |
18 | api_key_pool = [ # your api pool
19 | ""
20 | # "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"
21 | ]
22 |
23 |
24 | def query_openai_api_per_example(
25 | api_key, args, model, sleep_second, max_tokens, demonstrations=None
26 | ):
27 | # print(api_key)
28 | client = OpenAI(
29 | # This is the default and can be omitted
30 | api_key=api_key,
31 | )
32 |
33 | Prompt = PROMPT
34 | demo_answer = random.sample(OUTPUT_BASE, 3)
35 | demo_explain = random.sample(OUTPUT_EXPLAN, 3)
36 | demo_inst = random.sample(INST["OPENIE"], 3)
37 | demos = []
38 | for vert in demo_answer:
39 | # d={
40 | # "input template":vert[0],
41 | # "answer template":vert[1],
42 | # "explanation template":""
43 | # }
44 | d = {
45 | "instruction": "",
46 | "fail output": "",
47 | "input template": vert[0],
48 | "answer template": vert[1],
49 | "explanation template": "",
50 | }
51 | demos.append(d)
52 |
53 | for i, vert in enumerate(demo_explain):
54 | demos[i]["explanation template"] = vert
55 | for i, vert in enumerate(demo_inst):
56 | demos[i]["instruction"] = vert[0]
57 | demos[i]["fail output"] = vert[1]
58 |
59 | for i, vert in enumerate(demos):
60 | Prompt += f"Template {i}: " + json.dumps(vert) + "\n\n"
61 | Prompt += "Please follow the format given in the example to generate 3 templates."
62 |
63 | print("\033[0;42;40m\tINPUT:\033[0m")
64 | print(Prompt)
65 |
66 | s_time = time.time()
67 | success = False
68 | if model in [
69 | "gpt-4",
70 | "gpt-4-1106",
71 | "gpt-4-1106-preview",
72 | "gpt-3.5-turbo-1106",
73 | "gpt-4-0125-preview",
74 | ]:
75 | messages = [
76 | {
77 | "role": "system",
78 | "content": "You are a helpful, pattern-following assistant.",
79 | }
80 | ]
81 | messages.append({"role": "user", "content": Prompt})
82 | while not success:
83 | try:
84 | chat_completion = client.chat.completions.create(
85 | messages=messages, model=model, max_tokens=max_tokens, temperature=0
86 | )
87 | except Exception as e:
88 | if args.debug:
89 | import pdb
90 |
91 | pdb.set_trace()
92 | print(e)
93 | time.sleep(sleep_second)
94 | else:
95 | success = True
96 | # import pdb; pdb.set_trace()
97 | # result = chat_completion['choices'][0]['message']['content']
98 | result = chat_completion.choices[0].message.content
99 |
100 | print("\033[0;31;40m\tPREDICT:\033[0m")
101 | print(result)
102 |
103 | pattern = r'{"instruction": "(.*?)", "fail output": "(.*?)", "input template": "(.*?)", "answer template": "(.*?)", "explanation template"'
104 | # pattern = r'{"input template": "(.*?)", "answer template": "(.*?)", "explanation template": "(.*?)"}'
105 | # pattern = r'{"input template": "(.*?)", "answer template": "(.*?)", "explanation template"'
106 | matches = re.findall(pattern, result, re.DOTALL)
107 | for match in matches:
108 | inst = match[0]
109 | fo = match[1]
110 | input = match[2]
111 | answer = match[3]
112 | print("---")
113 | print("inst:", inst)
114 | print("fo:", fo)
115 | INST["OPENIE"].append([inst, fo])
116 |
117 | print("input template:", input)
118 | print("answer template:", answer)
119 |
120 | OUTPUT_BASE.append((input, answer))
121 | # explain = match[2]
122 | # print("explanation templat:", explain)
123 | # OUTPUT_EXPLAN.append(explain)
124 | print("=============OUTPUT INST=============", len(INST["OPENIE"]))
125 | pprint.pprint(INST["OPENIE"], width=400, indent=4)
126 | print("=============OUTPUT_BASE=============", len(OUTPUT_BASE))
127 | pprint.pprint(OUTPUT_BASE, width=400, indent=4)
128 | print("=============OUTPUT_EXPLAN=============", len(OUTPUT_EXPLAN))
129 | pprint.pprint(OUTPUT_EXPLAN, width=400, indent=4)
130 |
131 |
132 | def main(args):
133 | # args.n_threads = min(len(api_key_pool), args.n_threads)
134 | # demonstrations
135 | for i in range(0, 1):
136 | query_openai_api_per_example(
137 | api_key_pool[0], args, args.model, args.sleep_second, args.max_tokens
138 | )
139 |
140 |
141 | if __name__ == "__main__":
142 | parser = argparse.ArgumentParser(description="Query OpenAI")
143 | # multiple processing
144 | parser.add_argument("--n_threads", type=int, default=16)
145 | # I/O
146 | # parser.add_argument("--input_dir", type=str, default="prompts")
147 | parser.add_argument("--test_file", type=str, default=None)
148 |
149 | # model & parameters
150 | parser.add_argument("--model", type=str, default="gpt-4-0125-preview")
151 | parser.add_argument("--max_tokens", type=int, default=1024)
152 | parser.add_argument("--sleep_second", type=float, default=10.0)
153 | parser.add_argument("--debug", action="store_true")
154 |
155 | args = parser.parse_args()
156 | main(args)
157 |
--------------------------------------------------------------------------------
/scripts/utils/gpts/test.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pprint
3 | import json
4 |
5 | result = """
6 | ```json
7 | {
8 | "Step-by-Step Explanation": [
9 | "Step 1: Identify the 'head event' and 'tail event' from the question. The 'head event' is the first event or 'Timex' mentioned, which is '1952'. The 'tail event' is the second event or 'Timex' mentioned, which is 'promising'.",
10 | "Step 2: Understand the context of both events within the text. '1952' refers to the year Dwight D. Eisenhower was elected U.S. President, and 'promising' refers to Eisenhower's campaign promise to take a harder line against communism.",
11 | "Step 3: Determine the relationship between the two events based on their context in the text. Since Eisenhower's campaign promise occurred within the year 1952, the 'head event' (1952) contains the 'tail event' (promising).",
12 | "Step 4: Choose the correct relation from the given options [before, overlap, contains, simultaneous, begins-on, ends-on] that best describes the temporal relationship between the 'head event' and 'tail event'. The correct relation is 'contains' because the event of Eisenhower promising to take a harder line against communism is encompassed within the year 1952.",
13 | "Step 5: Formulate the answer in the required format, confirming that the 'head event' ('1952') contains the 'tail event' ('promising'), consistent with the context provided in the text."
14 | ]
15 | }
16 | ```
17 | """
18 |
19 | pattern = r"{[^{}]*}"
20 | matches = re.findall(pattern, result)
21 | d = eval(matches[0])
22 | print(type(d))
23 | print(d)
24 | # str= result.split(matches[0])[1]
25 | # print(str)
26 | # json_string = result.replace('\"', '\\"')
27 | # print(json_string)
28 | # d=json.loads(json_string)
29 | # print(d)
30 | # for match in matches:
31 | # input = match[0]
32 | # answer = match[1]
33 | # print("---")
34 | # print("input template:", input)
35 | # print("answer templat:", answer)
36 | # # print("explanation templat:", explain)
37 | # OUTPUT_BASE.append((input,answer))
38 | # # OUTPUT_EXPLAN.append(explain)
39 |
40 |
41 | # print("=============OUTPUT_BASE=============")
42 | # pprint.pprint(OUTPUT_BASE,width=400)
43 | # print("=============OUTPUT_EXPLAN=============")
44 | # pprint.pprint(OUTPUT_EXPLAN,width=400)
45 |
--------------------------------------------------------------------------------
/scripts/utils/reformat_tuluv2.py:
--------------------------------------------------------------------------------
1 | import pyarrow as pa
2 | import pyarrow.parquet as pq
3 | import os
4 | import json
5 | import random
6 | from pathlib import Path
7 |
8 | def construct_response(input_folder,output_folder,split):
9 | path_list=os.listdir(input_folder)
10 | df=[]
11 | for filename in path_list:
12 | df.append(pq.read_table(os.path.join(input_folder,filename)).to_pandas())
13 |
14 | out_file = open(os.path.join(output_folder,split+".jsonl"), "w")
15 | count=0
16 | for data in df:
17 | l=len(data)
18 | for i in range(0,l):
19 | count+=1
20 | lst=data.iloc[i]["messages"]
21 | # unified_instance = {
22 | # "instruction":lst[0]["content"],
23 | # "input":"",
24 | # "output":lst[1]["content"],
25 | # "history":[]
26 | # }
27 | unified_instance = {
28 | "dataset": data.iloc[i]["dataset"],
29 | "id": data.iloc[i]["id"],
30 | "messages": list(data.iloc[i]["messages"])
31 | }
32 | # print(unified_instance)
33 | out_file.write(json.dumps(unified_instance) + "\n")
34 | print("total num of instance:",count)
35 | out_file.close()
36 |
37 |
38 | if __name__ == "__main__":
39 | input_folder=Path("../data/General/tuluv2")
40 | output_folder = Path("../unified_data/tuluv2")
41 | output_folder.mkdir(exist_ok=True, parents=True)
42 | construct_response(input_folder, output_folder, "train")
--------------------------------------------------------------------------------
/train4llama/ds_configs/stage2.conf:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
--------------------------------------------------------------------------------
/train4llama/ds_configs/stage3_no_offloading.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "overlap_comm": true,
26 | "contiguous_gradients": true,
27 | "sub_group_size": 1e9,
28 | "reduce_bucket_size": "auto",
29 | "stage3_prefetch_bucket_size": "auto",
30 | "stage3_param_persistence_threshold": "auto",
31 | "stage3_max_live_parameters": 1e9,
32 | "stage3_max_reuse_distance": 1e9,
33 | "stage3_gather_16bit_weights_on_model_save": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 1e5,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/train4llama/ds_configs/stage3_no_offloading_accelerate.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "zero_optimization": {
6 | "stage": 3,
7 | "overlap_comm": true,
8 | "contiguous_gradients": true,
9 | "sub_group_size": 1e9,
10 | "reduce_bucket_size": "auto",
11 | "stage3_prefetch_bucket_size": "auto",
12 | "stage3_param_persistence_threshold": "auto",
13 | "stage3_max_live_parameters": 1e9,
14 | "stage3_max_reuse_distance": 1e9,
15 | "stage3_gather_16bit_weights_on_model_save": true
16 | },
17 | "gradient_accumulation_steps": "auto",
18 | "gradient_clipping": "auto",
19 | "steps_per_print": 1e5,
20 | "train_batch_size": "auto",
21 | "train_micro_batch_size_per_gpu": "auto",
22 | "wall_clock_breakdown": false
23 | }
--------------------------------------------------------------------------------
/train4llama/ds_configs/stage3_offloading.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "offload_optimizer": {
26 | "device": "cpu",
27 | "pin_memory": true
28 | },
29 | "offload_param": {
30 | "device": "cpu",
31 | "pin_memory": true
32 | },
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "sub_group_size": 1e9,
36 | "reduce_bucket_size": "auto",
37 | "stage3_prefetch_bucket_size": "auto",
38 | "stage3_param_persistence_threshold": "auto",
39 | "stage3_max_live_parameters": 1e9,
40 | "stage3_max_reuse_distance": 1e9,
41 | "stage3_gather_16bit_weights_on_model_save": true
42 | },
43 | "gradient_accumulation_steps": "auto",
44 | "gradient_clipping": "auto",
45 | "steps_per_print": 1e5,
46 | "train_batch_size": "auto",
47 | "train_micro_batch_size_per_gpu": "auto",
48 | "wall_clock_breakdown": false
49 | }
--------------------------------------------------------------------------------
/train4llama/ds_configs/stage3_offloading_accelerate.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "zero_optimization": {
6 | "stage": 3,
7 | "offload_optimizer": {
8 | "device": "cpu",
9 | "pin_memory": true
10 | },
11 | "offload_param": {
12 | "device": "cpu",
13 | "pin_memory": true
14 | },
15 | "overlap_comm": true,
16 | "contiguous_gradients": true,
17 | "sub_group_size": 1e9,
18 | "reduce_bucket_size": "auto",
19 | "stage3_prefetch_bucket_size": "auto",
20 | "stage3_param_persistence_threshold": "auto",
21 | "stage3_max_live_parameters": 1e9,
22 | "stage3_max_reuse_distance": 1e9,
23 | "stage3_gather_16bit_weights_on_model_save": true
24 | },
25 | "gradient_accumulation_steps": "auto",
26 | "gradient_clipping": "auto",
27 | "steps_per_print": 1e5,
28 | "train_batch_size": "auto",
29 | "train_micro_batch_size_per_gpu": "auto",
30 | "wall_clock_breakdown": false
31 | }
--------------------------------------------------------------------------------
/train4llama/eval/__pycache__/dispatch_openai_requests.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/train4llama/eval/__pycache__/dispatch_openai_requests.cpython-310.pyc
--------------------------------------------------------------------------------
/train4llama/eval/__pycache__/templates.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/train4llama/eval/__pycache__/templates.cpython-310.pyc
--------------------------------------------------------------------------------
/train4llama/eval/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/train4llama/eval/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/train4llama/eval/dispatch_openai_requests.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is copied and modified from https://gist.github.com/neubig/80de662fb3e225c18172ec218be4917a.
3 | Thanks to Graham Neubig for sharing the original code.
4 | '''
5 |
6 | import openai
7 | import asyncio
8 | from typing import Any, List, Dict
9 |
10 | async def dispatch_openai_chat_requesets(
11 | messages_list: List[List[Dict[str,Any]]],
12 | model: str,
13 | **completion_kwargs: Any,
14 | ) -> List[str]:
15 | """Dispatches requests to OpenAI chat completion API asynchronously.
16 |
17 | Args:
18 | messages_list: List of messages to be sent to OpenAI chat completion API.
19 | model: OpenAI model to use.
20 | completion_kwargs: Keyword arguments to be passed to OpenAI ChatCompletion API. See https://platform.openai.com/docs/api-reference/chat for details.
21 | Returns:
22 | List of responses from OpenAI API.
23 | """
24 | async_responses = [
25 | openai.ChatCompletion.acreate(
26 | model=model,
27 | messages=x,
28 | **completion_kwargs,
29 | )
30 | for x in messages_list
31 | ]
32 | return await asyncio.gather(*async_responses)
33 |
34 |
35 | async def dispatch_openai_prompt_requesets(
36 | prompt_list: List[str],
37 | model: str,
38 | **completion_kwargs: Any,
39 | ) -> List[str]:
40 | """Dispatches requests to OpenAI text completion API asynchronously.
41 |
42 | Args:
43 | prompt_list: List of prompts to be sent to OpenAI text completion API.
44 | model: OpenAI model to use.
45 | completion_kwargs: Keyword arguments to be passed to OpenAI text completion API. See https://platform.openai.com/docs/api-reference/completions for details.
46 | Returns:
47 | List of responses from OpenAI API.
48 | """
49 | async_responses = [
50 | openai.Completion.acreate(
51 | model=model,
52 | prompt=x,
53 | **completion_kwargs,
54 | )
55 | for x in prompt_list
56 | ]
57 | return await asyncio.gather(*async_responses)
58 |
59 |
60 | if __name__ == "__main__":
61 | chat_completion_responses = asyncio.run(
62 | dispatch_openai_chat_requesets(
63 | messages_list=[
64 | [{"role": "user", "content": "Write a poem about asynchronous execution."}],
65 | [{"role": "user", "content": "Write a poem about asynchronous pirates."}],
66 | ],
67 | model="gpt-3.5-turbo",
68 | temperature=0.3,
69 | max_tokens=200,
70 | top_p=1.0,
71 |
72 | )
73 | )
74 |
75 | for i, x in enumerate(chat_completion_responses):
76 | print(f"Chat completion response {i}:\n{x['choices'][0]['message']['content']}\n\n")
77 |
78 |
79 | prompt_completion_responses = asyncio.run(
80 | dispatch_openai_prompt_requesets(
81 | prompt_list=[
82 | "Write a poem about asynchronous execution.\n",
83 | "Write a poem about asynchronous pirates.\n",
84 | ],
85 | model="text-davinci-003",
86 | temperature=0.3,
87 | max_tokens=200,
88 | top_p=1.0,
89 | )
90 | )
91 |
92 | for i, x in enumerate(prompt_completion_responses):
93 | print(f"Prompt completion response {i}:\n{x['choices'][0]['text']}\n\n")
--------------------------------------------------------------------------------
/train4llama/open_instruct/__pycache__/dpo_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ADELIE/70f01ff02fa394b8250e5d47ba9572bace27ae4a/train4llama/open_instruct/__pycache__/dpo_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/train4llama/open_instruct/gradio_demo.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import torch
3 | import sys
4 | from transformers import AutoTokenizer, AutoModelForCausalLM
5 |
6 | if len(sys.argv) > 1:
7 | model_name_or_path = sys.argv[1]
8 | else:
9 | raise ValueError("Please provide a model name or path as the first argument")
10 |
11 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
12 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
13 |
14 | model.half().cuda()
15 |
16 | def instruct(instruction):
17 | with torch.inference_mode():
18 | input_text = instruction
19 | input_ids = tokenizer.encode(input_text, return_tensors='pt').cuda()
20 | output_ids = model.generate(input_ids, max_length=1024)[0]
21 | output_str = tokenizer.decode(output_ids[input_ids.shape[-1]:])
22 | return output_str.strip()
23 |
24 | demo = gr.Interface(
25 | fn=instruct,
26 | inputs=gr.Textbox(lines=10, placeholder="Enter your instruction here..."),
27 | outputs="text",
28 | title="Demo for Open-Instruct",
29 | description="Model name or path: " + model_name_or_path
30 | )
31 |
32 | demo.launch(share=True, server_port=7860)
--------------------------------------------------------------------------------
/train4llama/open_instruct/gradio_demo_chat.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import torch
3 | import sys
4 | import html
5 | from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6 | from threading import Thread
7 |
8 | if len(sys.argv) > 1:
9 | model_name_or_path = sys.argv[1]
10 | else:
11 | raise ValueError("Please provide a model name or path as the first argument")
12 |
13 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
14 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
15 |
16 | model.half().cuda()
17 |
18 | def convert_message(message):
19 | message_text = ""
20 | if message["content"] is None and message["role"] == "assistant":
21 | message_text += "<|assistant|>\n" # final msg
22 | elif message["role"] == "system":
23 | message_text += "<|system|>\n" + message["content"].strip() + "\n"
24 | elif message["role"] == "user":
25 | message_text += "<|user|>\n" + message["content"].strip() + "\n"
26 | elif message["role"] == "assistant":
27 | message_text += "<|assistant|>\n" + message["content"].strip() + "\n"
28 | else:
29 | raise ValueError("Invalid role: {}".format(message["role"]))
30 | # gradio cleaning - it converts stuff to html entities
31 | # we would need special handling for where we want to keep the html...
32 | message_text = html.unescape(message_text)
33 | # it also converts newlines to
, undo this.
34 | message_text = message_text.replace("
", "\n")
35 | return message_text
36 |
37 | def convert_history(chat_history, max_input_length=1024):
38 | history_text = ""
39 | idx = len(chat_history) - 1
40 | # add messages in reverse order until we hit max_input_length
41 | while len(tokenizer(history_text).input_ids) < max_input_length and idx >= 0:
42 | user_message, chatbot_message = chat_history[idx]
43 | user_message = convert_message({"role": "user", "content": user_message})
44 | chatbot_message = convert_message({"role": "assistant", "content": chatbot_message})
45 | history_text = user_message + chatbot_message + history_text
46 | idx = idx - 1
47 | # if nothing was added, add <|assistant|> to start generation.
48 | if history_text == "":
49 | history_text = "<|assistant|>\n"
50 | return history_text
51 |
52 | @torch.inference_mode()
53 | def instruct(instruction, max_token_output=1024):
54 | input_text = instruction
55 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
56 | input_ids = tokenizer(input_text, return_tensors='pt', truncation=False)
57 | input_ids["input_ids"] = input_ids["input_ids"].cuda()
58 | input_ids["attention_mask"] = input_ids["attention_mask"].cuda()
59 | generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=max_token_output)
60 | thread = Thread(target=model.generate, kwargs=generation_kwargs)
61 | thread.start()
62 | return streamer
63 |
64 |
65 | with gr.Blocks() as demo:
66 | # recreating the original qa demo in blocks
67 | with gr.Tab("QA Demo"):
68 | with gr.Row():
69 | instruction = gr.Textbox(label="Input")
70 | output = gr.Textbox(label="Output")
71 | greet_btn = gr.Button("Submit")
72 | def yield_instruct(instruction):
73 | # quick prompt hack:
74 | instruction = "<|user|>\n" + instruction + "\n<|assistant|>\n"
75 | output = ""
76 | for token in instruct(instruction):
77 | output += token
78 | yield output
79 | greet_btn.click(fn=yield_instruct, inputs=[instruction], outputs=output, api_name="greet")
80 | # chatbot-style model
81 | with gr.Tab("Chatbot"):
82 | chatbot = gr.Chatbot([], elem_id="chatbot")
83 | msg = gr.Textbox()
84 | clear = gr.Button("Clear")
85 | # fn to add user message to history
86 | def user(user_message, history):
87 | return "", history + [[user_message, None]]
88 |
89 | def bot(history):
90 | prompt = convert_history(history)
91 | streaming_out = instruct(prompt)
92 | history[-1][1] = ""
93 | for new_token in streaming_out:
94 | history[-1][1] += new_token
95 | yield history
96 |
97 | msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
98 | bot, chatbot, chatbot
99 | )
100 |
101 | clear.click(lambda: None, None, chatbot, queue=False)
102 |
103 | if __name__ == "__main__":
104 | demo.queue().launch(share=True)
105 |
--------------------------------------------------------------------------------
/train4llama/open_instruct/instruction_encode_templates.py:
--------------------------------------------------------------------------------
1 |
2 | import random
3 |
4 | encoding_templates_w_input = [
5 | # input encoding template, output encoding template, weight
6 | ("{instruction}\n\n{input}\n\n", "{output}", 0.2),
7 | ("{instruction}\n{input}\n\n", "{output}", 0.1),
8 | ("{instruction}\n{input}\n", "{output}", 0.1),
9 | ("{instruction}\n\nInput: {input}\n\nOutput:", "{output}", 0.05),
10 | ("{instruction}\nInput: {input}\nOutput:", "{output}", 0.05),
11 | ("{instruction}\n{input}\n\nResponse:", "{output}", 0.05),
12 | ("{instruction}\n\nAdditional Context:\n{input}\n\nAnswer:", "{output}", 0.05),
13 | ("Task: {instruction}\nInput: {input}\nOutput:", "{output}", 0.05),
14 | ("Task: {instruction}\n\n{input}\n\n", "{output}", 0.05),
15 | ("Task: {instruction}\n\n{input}\n\nAnswer:", "{output}", 0.05),
16 | ("You need to complete the following task:\n\n{instruction}\n\n{input}\n\nAnswer:", "{output}", 0.05),
17 | ("{instruction}\n\nNow complete the following instance -\nInput: {input}\nOutput:", "{output}", 0.05),
18 | ("Instruction:{instruction}\n\nInput: {input}\n\n", "{output}", 0.05),
19 | ("Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
20 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", "{output}", 0.1), # alpaca template
21 | ]
22 |
23 | encoding_templates_wo_input = [
24 | ("{instruction}\n\n", "{output}", 0.2),
25 | ("{instruction}\n", "{output}", 0.1),
26 | ("{instruction}", "\n{output}", 0.1),
27 | ("{instruction} Output:", "{output}", 0.05),
28 | ("{instruction}\nResponse:", "{output}", 0.05),
29 | ("{instruction}\n\nAnswer:", "{output}", 0.05),
30 | ("Task: {instruction}\n\n", "{output}", 0.05),
31 | ("Instruction: {instruction}\n", "{output}", 0.05),
32 | ("Instruction: {instruction}\nOutput:", "{output}", 0.05),
33 | ("You need to complete the following task:\n\n{instruction}\n\n", "{output}", 0.05),
34 | ("Can you help with this?\n\n{instruction}\n", "{output}", 0.05),
35 | ("Plase answer the following request: {instruction}\nAnswer:", "{output}", 0.05),
36 | ("Tell me how would you respond to the following request.\n{instruction}\n", "{output}", 0.05),
37 | ("Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:", "{output}", 0.1), # alpaca template
38 | ]
39 |
40 |
41 | def encode_instruction_example(instruction, input, output, random_template=True, eos_token=None):
42 | if random_template:
43 | if input is not None and input.strip() != "":
44 | # randomly choose a template with input
45 | prompt_template, completion_template, _ = random.choices(
46 | encoding_templates_w_input, weights=[w for _, _, w in encoding_templates_w_input]
47 | )[0]
48 | prompt = prompt_template.format(instruction=instruction.strip(), input=input.strip())
49 | completion = completion_template.format(output=output.strip())
50 | else:
51 | # randomly choose a template without input
52 | prompt_template, completion_template, _ = random.choices(
53 | encoding_templates_wo_input, weights=[w for _, _, w in encoding_templates_wo_input]
54 | )[0]
55 | prompt = prompt_template.format(instruction=instruction.strip())
56 | completion = completion_template.format(output=output.strip())
57 | else:
58 | if input is not None and input.strip() != "":
59 | prompt = instruction.strip() + "\n\n" + input.strip() + "\n\n"
60 | completion = output.strip()
61 | else:
62 | prompt = instruction.strip() + "\n\n"
63 | completion = output.strip()
64 |
65 | data = {
66 | "prompt": prompt,
67 | "completion": completion + eos_token if eos_token else completion,
68 | }
69 | return data
70 |
71 |
72 | def encode_few_shot_example(instruction, examplars, input, output, eos_token=None):
73 | prompt = instruction.strip() + "\n\n"
74 | for examplar in examplars:
75 | prompt += "Input:\n" + examplar["input"].strip() + "\n"
76 | prompt += "Output:\n" + examplar["output"].strip() + "\n\n"
77 |
78 | prompt += "Input:\n" + input.strip() + "\n"
79 | prompt += "Output:\n"
80 |
81 | data = {
82 | "prompt": prompt,
83 | "completion": output.strip() + eos_token if eos_token else output.strip(),
84 | }
85 | return data
86 |
87 |
--------------------------------------------------------------------------------
/train4llama/open_instruct/merge_lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from peft import PeftConfig, PeftModel
4 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5 | import bitsandbytes as bnb
6 | import os
7 | import copy
8 | from bitsandbytes.functional import dequantize_4bit
9 | from peft.utils import _get_submodules
10 |
11 |
12 | def dequantize_model(model, dtype=torch.bfloat16, device="cuda"):
13 | """
14 | 'model': the peftmodel you loaded with qlora.
15 | 'dtype': dtype that the model was trained using
16 | 'device': device to load the model to
17 | """
18 | cls = bnb.nn.Linear4bit
19 | with torch.no_grad():
20 | for name, module in model.named_modules():
21 | if isinstance(module, cls):
22 | print(f"Dequantizing `{name}`...")
23 | quant_state = copy.deepcopy(module.weight.quant_state)
24 |
25 | # quant_state changed from a list in newer version of bitsandbytes (0.41.3 onwards)
26 | if isinstance(quant_state, list):
27 | quant_state[2] = dtype
28 | else:
29 | quant_state.dtype = dtype
30 |
31 | weights = dequantize_4bit(module.weight.data, quant_state=quant_state, quant_type="nf4").to(dtype)
32 |
33 | new_module = torch.nn.Linear(module.in_features, module.out_features, bias=None, dtype=dtype)
34 | new_module.weight = torch.nn.Parameter(weights)
35 | new_module.to(device=device, dtype=dtype)
36 |
37 | parent, target, target_name = _get_submodules(model, name)
38 | setattr(parent, target_name, new_module)
39 | # to save model, you have to unset this attribute
40 | model.is_loaded_in_4bit = False
41 |
42 | return model
43 |
44 | def parse_args():
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--lora_model_name_or_path", type=str, required=True)
47 | parser.add_argument("--base_model_name_or_path", type=str, required=False)
48 | parser.add_argument("--tokenizer_name_or_path", type=str, required=False)
49 | parser.add_argument("--output_dir", type=str, required=False)
50 | parser.add_argument("--qlora", action="store_true") # qlora requires special treatment.
51 | parser.add_argument("--save_tokenizer", action="store_true")
52 | parser.add_argument("--use_fast_tokenizer", action="store_true")
53 | return parser.parse_args()
54 |
55 |
56 | if __name__ == "__main__":
57 | args = parse_args()
58 | peft_config = PeftConfig.from_pretrained(args.lora_model_name_or_path)
59 | print("Loading the base model...")
60 | if args.qlora:
61 | quantization_config=BitsAndBytesConfig(
62 | load_in_4bit=True,
63 | bnb_4bit_compute_dtype=torch.bfloat16,
64 | bnb_4bit_use_double_quant=True,
65 | bnb_4bit_quant_type="nf4",
66 | )
67 | base_model = AutoModelForCausalLM.from_pretrained(
68 | args.base_model_name_or_path if args.base_model_name_or_path else peft_config.base_model_name_or_path,
69 | load_in_4bit=True,
70 | torch_dtype=torch.bfloat16,
71 | quantization_config=quantization_config,
72 | device_map={"": 0} if torch.cuda.is_available() else None,
73 | )
74 | # base_model = dequantize_model(base_model, device=base_model.device)
75 | base_model = dequantize_model(base_model, device="cpu")
76 | else:
77 | base_model = AutoModelForCausalLM.from_pretrained(
78 | args.base_model_name_or_path if args.base_model_name_or_path else peft_config.base_model_name_or_path,
79 | )
80 | print("Loading the lora model...")
81 | lora_model = PeftModel.from_pretrained(base_model, args.lora_model_name_or_path)
82 | print("Merging the lora modules...")
83 | merged_model = lora_model.merge_and_unload()
84 |
85 | output_dir = args.output_dir if args.output_dir else args.lora_model_name_or_path
86 | os.makedirs(output_dir, exist_ok=True)
87 |
88 | # If tokenizer is specified, use it. Otherwise, use the tokenizer in the lora model folder or the base model folder.
89 | if args.tokenizer_name_or_path:
90 | print(f"Loading the tokenizer from {args.tokenizer_name_or_path}...")
91 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path, use_fast=args.use_fast_tokenizer)
92 | else:
93 | try:
94 | print("Trying to load the tokenizer in the lora model folder...")
95 | tokenizer = AutoTokenizer.from_pretrained(args.lora_model_name_or_path, use_fast=args.use_fast_tokenizer)
96 | except:
97 | print("No tokenizer found in the lora model folder. Using the tokenizer in the base model folder...")
98 | tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path, use_fast=args.use_fast_tokenizer)
99 |
100 | embedding_size = merged_model.get_input_embeddings().weight.shape[0]
101 | if len(tokenizer) > embedding_size:
102 | print(f"The vocabulary the tokenizer contains {len(tokenizer)-embedding_size} more tokens than the base model.")
103 | print("Resizing the token embeddings of the merged model...")
104 | merged_model.resize_token_embeddings(len(tokenizer))
105 |
106 | print(f"Saving merged model to {output_dir}...")
107 | merged_model.save_pretrained(output_dir)
108 |
109 | if args.save_tokenizer:
110 | print(f"Saving the tokenizer to {output_dir}...")
111 | tokenizer.save_pretrained(output_dir)
--------------------------------------------------------------------------------
/train4llama/open_instruct/safe_save_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from packaging import version
4 | from transformers import Trainer, is_torch_tpu_available
5 | from transformers.deepspeed import is_deepspeed_zero3_enabled
6 | from transformers.utils import is_sagemaker_mp_enabled, WEIGHTS_NAME, logging
7 | from transformers.trainer_utils import ShardedDDPOption
8 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
9 | from typing import Optional
10 |
11 | if is_sagemaker_mp_enabled():
12 | import smdistributed.modelparallel.torch as smp
13 | from smdistributed.modelparallel import __version__ as SMP_VERSION
14 |
15 | IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
16 |
17 | from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
18 | else:
19 | IS_SAGEMAKER_MP_POST_1_10 = False
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | class SafeSaveTrainer(Trainer):
24 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
25 | """
26 | Will save the model, so you can reload it using `from_pretrained()`.
27 | Will only save from the main process.
28 | """
29 |
30 | if output_dir is None:
31 | output_dir = self.args.output_dir
32 |
33 | if is_torch_tpu_available():
34 | self._save_tpu(output_dir)
35 | elif is_sagemaker_mp_enabled():
36 | # Calling the state_dict needs to be done on the wrapped model and on all processes.
37 | os.makedirs(output_dir, exist_ok=True)
38 | state_dict = self.model_wrapped.state_dict()
39 | if self.args.should_save:
40 | self._save(output_dir, state_dict=state_dict)
41 | if IS_SAGEMAKER_MP_POST_1_10:
42 | # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
43 | Path(os.path.join(output_dir, "user_content.pt")).touch()
44 | elif (
45 | ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
46 | or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
47 | or self.fsdp is not None
48 | ):
49 | full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
50 | with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
51 | state_dict = self.model.state_dict()
52 |
53 | if self.args.should_save:
54 | self._save(output_dir, state_dict=state_dict)
55 | elif self.deepspeed:
56 | # this takes care of everything as long as we aren't under zero3
57 | if self.args.should_save:
58 | self._save(output_dir)
59 |
60 | if is_deepspeed_zero3_enabled():
61 | # It's too complicated to try to override different places where the weights dump gets
62 | # saved, so since under zero3 the file is bogus, simply delete it. The user should
63 | # either user deepspeed checkpoint to resume or to recover full weights use
64 | # zero_to_fp32.py stored in the checkpoint.
65 | if self.args.should_save:
66 | file = os.path.join(output_dir, WEIGHTS_NAME)
67 | if os.path.isfile(file):
68 | # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
69 | os.remove(file)
70 |
71 | # now save the real model if stage3_gather_16bit_weights_on_model_save=True
72 | # if false it will not be saved.
73 | # This must be called on all ranks
74 | if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
75 | logger.warning(
76 | "deepspeed.save_16bit_model didn't save the model, since"
77 | " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
78 | " zero_to_fp32.py to recover weights"
79 | )
80 | self.deepspeed.save_checkpoint(output_dir)
81 |
82 | elif self.args.should_save:
83 | self._save(output_dir)
84 |
85 | # Push to the Hub when `save_model` is called by the user.
86 | if self.args.push_to_hub and not _internal_call:
87 | self.push_to_hub(commit_message="Model save")
--------------------------------------------------------------------------------
/train4llama/requirements.txt:
--------------------------------------------------------------------------------
1 | torch<=2.0.1
2 | scipy
3 | packaging
4 | sentencepiece
5 | datasets
6 | deepspeed>=0.10.0
7 | accelerate>=0.21.0,<0.23.0 # 0.23.0 will cause an incorrect learning rate schedule when using deepspeed, which is likely caused by https://github.com/huggingface/accelerate/commit/727d624322c67db66a43c559d8c86414d5ffb537
8 | peft>=0.4.0
9 | bitsandbytes>=0.41.1
10 | evaluate>=0.4.0
11 | tokenizers>=0.13.3
12 | protobuf
13 | # Transformers library (v4.34.0) still has a bug for left padding,
14 | # and significantly affect the inference and thus our evaluation performance (e.g., MMLU and TruthfulQA).
15 | # Follwing PR is a temporary fix for it but has not been merged yet.
16 | # See https://github.com/huggingface/transformers/pull/25284
17 | # But this PR is not compatible with the latest version of Transformers library (v4.34.0).
18 | # To incorporate it, we forked the Transformers library and made some changes to make it compatible with the latest version.
19 | git+https://github.com/yizhongw/transformers.git@left_padding
20 | openai>=1.5.0
21 | tiktoken
22 | rouge_score
23 | tensorboard
24 | wandb
25 | gradio==3.50.2
26 | termcolor
27 | jsonlines
28 | unidic-lite
29 | einops
30 | flash-attn==2.2.2
31 | auto-gptq
32 | fire
33 | alpaca-eval==0.5.3
34 | # for human eval web app
35 | flask
36 | vllm
37 | openpyxl
--------------------------------------------------------------------------------
/train4llama/scripts/dpo_train_with_accelerate.sh:
--------------------------------------------------------------------------------
1 | # you need 8 GPUs for full finetuning
2 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
3 |
4 | NUM_GPUS=8
5 | BATCH_SIZE_PER_GPU=1
6 | TOTAL_BATCH_SIZE=32
7 | GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU))
8 | echo "Training model using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps"
9 |
10 | accelerate launch \
11 | --mixed_precision bf16 \
12 | --num_machines 1 \
13 | --num_processes $NUM_GPUS \
14 | --use_deepspeed \
15 | --deepspeed_config_file ds_configs/stage2.conf \
16 | --main_process_port=8888 \
17 | open_instruct/dpo_tune.py \
18 | --model_name_or_path ../../models/ADELIE-SFT \
19 | --use_flash_attn \
20 | --gradient_checkpointing \
21 | --tokenizer_name ../../models/ADELIE-SFT \
22 | --use_slow_tokenizer \
23 | --train_file ../../unified_data/train_mixture/IEFeedback.json \
24 | --max_seq_length 2048 \
25 | --preprocessing_num_workers 16 \
26 | --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
27 | --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
28 | --learning_rate 5e-7 \
29 | --lr_scheduler_type linear \
30 | --warmup_ratio 0.1 \
31 | --weight_decay 0. \
32 | --num_train_epochs 3 \
33 | --output_dir ../../models/ADELIE-DPO \
34 | --with_tracking \
35 | --report_to tensorboard \
36 | --logging_steps 1
--------------------------------------------------------------------------------
/train4llama/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | # Model_Path="../models/ADELIE-SFT"
2 | Model_Path="/data1/qyj/ADELIE-DPO"
3 | SCRIPTS_Path="/data1/qyj/ADELIE/scripts"
4 | Output_DIR="saves/llama_v2_7B/full/ADELIE-DPO"
5 | TEST_DIR="../unified_data/test_format/"
6 | CUDA=7
7 |
8 |
9 | CUDA_VISIBLE_DEVICES=${CUDA} python eval/predict.py \
10 | --model_name_or_path ${Model_Path} \
11 | --input_files ${TEST_DIR}/fewshot_test_history/few-nerd-supervised.jsonl \
12 | --output_file ${Output_DIR}/predict/fewshot/few-nerd-supervised.jsonl \
13 | --batch_size 4 \
14 | --use_chat_format \
15 | --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \
16 | --temperature 0.01 \
17 | --do_sample
18 |
19 | python ${SCRIPTS_Path}/tasks/fewnerd/open_evaluate.py \
20 | --input_dir ${Output_DIR}/predict/fewshot/few-nerd-supervised.jsonl \
21 | > ${Output_DIR}/predict/few-nerd-supervised_result.txt
22 |
23 |
24 | CUDA_VISIBLE_DEVICES=${CUDA} python eval/predict.py \
25 | --model_name_or_path ${Model_Path} \
26 | --input_files ${TEST_DIR}/fewshot_test_history/semeval.jsonl \
27 | --output_file ${Output_DIR}/predict/fewshot/semeval.jsonl \
28 | --batch_size 4 \
29 | --use_chat_format \
30 | --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \
31 | --temperature 0.01 \
32 | --do_sample
33 |
34 | python ${SCRIPTS_Path}/tasks/semeval/open_evaluate.py \
35 | --input_dir ${Output_DIR}/predict/fewshot/semeval.jsonl \
36 | > ${Output_DIR}/predict/semeval_result.txt
37 |
38 |
39 | CUDA_VISIBLE_DEVICES=${CUDA} python eval/predict.py \
40 | --model_name_or_path ${Model_Path} \
41 | --input_files ${TEST_DIR}/fewshot_test_history/RichERE-eae.jsonl \
42 | --output_file ${Output_DIR}/predict/fewshot/RichERE-eae.jsonl \
43 | --batch_size 4 \
44 | --use_chat_format \
45 | --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \
46 | --temperature 0.01 \
47 | --do_sample
48 |
49 | python ${SCRIPTS_Path}/tasks/richere-eae/open_evaluation.py \
50 | --input_dir ${Output_DIR}/predict/fewshot/RichERE-eae.jsonl \
51 | > ${Output_DIR}/predict/RichERE-eae_result.txt
52 |
53 | CUDA_VISIBLE_DEVICES=${CUDA} python eval/predict.py \
54 | --model_name_or_path ${Model_Path} \
55 | --input_files ${TEST_DIR}/fewshot_test_history/RichERE-ed.jsonl \
56 | --output_file ${Output_DIR}/predict/fewshot/RichERE-ed.jsonl \
57 | --batch_size 4 \
58 | --use_chat_format \
59 | --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \
60 | --temperature 0.01 \
61 | --do_sample
62 |
63 | python ${SCRIPTS_Path}/tasks/richere-ed/open_eval.py \
64 | --input_dir ${Output_DIR}/predict/fewshot/RichERE-ed.jsonl \
65 | > ${Output_DIR}/predict/RichERE-ed_result.txt
66 |
67 |
68 | CUDA_VISIBLE_DEVICES=${CUDA} python eval/predict.py \
69 | --model_name_or_path ${Model_Path} \
70 | --input_files ${TEST_DIR}/fewshot_test_history/MATRES.jsonl \
71 | --output_file ${Output_DIR}/predict/fewshot/MATRES.jsonl \
72 | --batch_size 4 \
73 | --use_chat_format \
74 | --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \
75 | --temperature 0.01 \
76 | --do_sample \
77 |
78 | python ${SCRIPTS_Path}/tasks/matres/open_eval.py \
79 | --output_file ${Output_DIR}/predict/fewshot/MATRES.jsonl \
80 | > ${Output_DIR}/predict/MATRES_result.txt
81 |
82 | #================= OPEN IE =====================
83 |
84 | CUDA_VISIBLE_DEVICES=${CUDA} python eval/predict.py \
85 | --model_name_or_path ${Model_Path} \
86 | --input_files ${TEST_DIR}/fewshot_test_history/ROBUST.jsonl \
87 | --output_file ${Output_DIR}/predict/fewshot/ROBUST.jsonl \
88 | --batch_size 4 \
89 | --use_chat_format \
90 | --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \
91 | --temperature 0.01 \
92 | --do_sample
93 |
94 | python ${SCRIPTS_Path}/tasks/ROBUST/reformat_results_open.py \
95 | --input_dir ${Output_DIR}/predict/fewshot/ROBUST.jsonl
96 |
97 | python ${SCRIPTS_Path}/tasks/ROBUST/src/robust_scorer.py \
98 | --pred_file ${SCRIPTS_Path}/tasks/ROBUST/result.json \
99 | > ${Output_DIR}/predict/ROBUST_result.txt
100 |
101 |
102 | #================= ondemand IE =====================
103 | CUDA_VISIBLE_DEVICES=${CUDA} python eval/predict.py \
104 | --model_name_or_path ${Model_Path} \
105 | --input_files ${TEST_DIR}/zeroshot/ondemand.jsonl \
106 | --output_file ${Output_DIR}/predict/zeroshot/ondemand.jsonl \
107 | --batch_size 4 \
108 | --use_chat_format \
109 | --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \
110 | --temperature 0.01 \
111 | --do_sample \
112 |
113 | python ${SCRIPTS_Path}/tasks/ondemandie/o_generate_evaluatefile.py \
114 | --input_dir ${Output_DIR}/predict/zeroshot/ondemand.jsonl \
115 |
116 | python ${SCRIPTS_Path}/tasks/ondemandie/evaluation/rougel_for_content.py > ${Output_DIR}/predict/ondemand_content.txt
117 |
118 | python ${SCRIPTS_Path}/tasks/ondemandie/evaluation/sim_for_header.py > ${Output_DIR}/predict/ondemand_header.txt
--------------------------------------------------------------------------------
/train4llama/scripts/finetune_with_accelerate.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=4,5,6,7
2 |
3 | MODEL_SIZE=7B
4 | NUM_GPUS=4
5 | BATCH_SIZE_PER_GPU=2
6 | TOTAL_BATCH_SIZE=128
7 | GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU))
8 | echo "Training llama model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps"
9 |
10 |
11 | accelerate launch \
12 | --mixed_precision bf16 \
13 | --num_machines 1 \
14 | --num_processes $NUM_GPUS \
15 | --use_deepspeed \
16 | --deepspeed_config_file ds_configs/stage2.conf \
17 | open_instruct/finetune.py \
18 | --model_name_or_path /data2/MODELS/llama-2-7b \
19 | --use_flash_attn \
20 | --tokenizer_name /data2/MODELS/llama-2-7b \
21 | --use_slow_tokenizer \
22 | --train_file ../../unified_data/train_mixture/IEInstruct.jsonl \
23 | --max_seq_length 2048 \
24 | --preprocessing_num_workers 16 \
25 | --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
26 | --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
27 | --learning_rate 2e-5 \
28 | --lr_scheduler_type cosine \
29 | --warmup_ratio 0.03 \
30 | --weight_decay 0. \
31 | --num_train_epochs 2 \
32 | --output_dir ../../models/ADELIE-SFT \
33 | --with_tracking \
34 | --report_to wandb \
35 | --logging_steps 10
--------------------------------------------------------------------------------
/train4llama/scripts/predict.sh:
--------------------------------------------------------------------------------
1 | for p in {1..5}
2 | do
3 | CUDA_VISIBLE_DEVICES=0 python ../eval/predict.py \
4 | --model_name_or_path ../../models/ADELIE-SFT \
5 | --input_files ../../unified_data/train_mixture/mix_vDPO.jsonl \
6 | --output_file ../../unified_data/train_mixture/sample4dpo_results/ADELIE-SFT/T_1.0_$p/mix_vDPO.jsonl \
7 | --batch_size 4 \
8 | --use_vllm \
9 | --use_chat_format \
10 | --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \
11 | --temperature 1.0 \
12 | --do_sample
13 | done
--------------------------------------------------------------------------------
/train4llama/weight-diff-requirements.txt:
--------------------------------------------------------------------------------
1 | fire
2 | torch
3 | tqdm
4 | transformers
5 | accelerate
6 | sentencepiece
7 | protobuf==3.20.0
8 |
--------------------------------------------------------------------------------