├── factkg ├── openai_api_key.txt ├── rewrite.py ├── data │ └── preprocess.py ├── pretrain_LM_encoder.py ├── test.py └── make_training_set.py ├── metaQA ├── openai_api_key.txt ├── rewrite.py ├── meta_1hop_prompts │ ├── verify_claim_no_evidence.txt │ └── verify_claim_with_evidence.txt ├── meta_2hop_prompts │ ├── verify_claim_no_evidence.txt │ └── verify_claim_with_evidence.txt ├── data │ └── preprocess.py ├── pretrain_LM_encoder.py ├── make_training_set.py ├── test-1hop.py └── test-2hop.py ├── KELP.png └── README.md /factkg/openai_api_key.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metaQA/openai_api_key.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /KELP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaochenLiu2000/KELP/HEAD/KELP.png -------------------------------------------------------------------------------- /factkg/rewrite.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def clean_json_lines(file_path): 4 | with open(file_path, 'r', encoding='utf-8') as file: 5 | lines = file.readlines() 6 | 7 | cleaned_lines = [] 8 | current_json = "" 9 | 10 | for line in lines: 11 | line = line.strip() 12 | 13 | if line.startswith('{') and current_json == "": 14 | current_json = line 15 | elif line.endswith('}') and current_json != "": 16 | current_json += line 17 | cleaned_lines.append(current_json) 18 | current_json = "" 19 | elif current_json != "": 20 | current_json += line 21 | else: 22 | cleaned_lines.append(line) 23 | 24 | return cleaned_lines 25 | 26 | def write_jsonl(file_path, data): 27 | with open(file_path, 'w', encoding='utf-8') as file: 28 | for obj in data: 29 | json_line = json.dumps(obj, ensure_ascii=False) 30 | file.write(json_line + '\n') 31 | 32 | input_file_path = 'output.txt' 33 | output_file_path = 'output.jsonl' 34 | 35 | cleaned_lines = clean_json_lines(input_file_path) 36 | 37 | with open(output_file_path, 'w', encoding='utf-8') as output_file: 38 | output_file.write('\n'.join(cleaned_lines)) 39 | 40 | 41 | -------------------------------------------------------------------------------- /metaQA/rewrite.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def clean_json_lines(file_path): 4 | with open(file_path, 'r', encoding='utf-8') as file: 5 | lines = file.readlines() 6 | 7 | cleaned_lines = [] 8 | current_json = "" 9 | 10 | for line in lines: 11 | line = line.strip() 12 | 13 | if line.startswith('{') and current_json == "": 14 | current_json = line 15 | elif line.endswith('}') and current_json != "": 16 | current_json += line 17 | cleaned_lines.append(current_json) 18 | current_json = "" 19 | elif current_json != "": 20 | current_json += line 21 | else: 22 | cleaned_lines.append(line) 23 | 24 | return cleaned_lines 25 | 26 | def write_jsonl(file_path, data): 27 | with open(file_path, 'w', encoding='utf-8') as file: 28 | for obj in data: 29 | json_line = json.dumps(obj, ensure_ascii=False) 30 | file.write(json_line + '\n') 31 | 32 | input_file_path = 'one-hop.txt' 33 | output_file_path = 'one-hop.jsonl' 34 | 35 | cleaned_lines = clean_json_lines(input_file_path) 36 | 37 | 38 | with open(output_file_path, 'w', encoding='utf-8') as output_file: 39 | output_file.write('\n'.join(cleaned_lines)) 40 | 41 | 42 | -------------------------------------------------------------------------------- /metaQA/meta_1hop_prompts/verify_claim_no_evidence.txt: -------------------------------------------------------------------------------- 1 | Answer the questions based on evidence. 2 | Each evidence is in the form of [head, relation, tail] and it means "head's relation is tail.". 3 | If you think a question can have multiple answers, you must choose one and answer it. 4 | 5 | Examples) 6 | 7 | Claim A: what words describe [Coming Home]? 8 | Answer: 'hal ashby' 9 | 10 | 11 | Claim B: what films does [Faye Wong] appear in? 12 | Answer: 'Chungking Express' 13 | 14 | 15 | Claim C: what films are about [haneke]? 16 | Answer: 'The Piano Teacher' 17 | 18 | 19 | Claim D: who acted in the movie [Inescapable]? 20 | Answer: 'Marisa Tomei' 21 | 22 | 23 | Claim E: can you name a film directed by [William Cameron Menzies]? 24 | Answer: 'Things to Come' 25 | 26 | 27 | Claim F: what sort of movie is [Witness for the Prosecution]? 28 | Answer: 'Drama' 29 | 30 | 31 | Claim G: what type of film is [The Mouse That Roared]? 32 | Answer: 'Comedy' 33 | 34 | 35 | Claim H: what is the primary language in the film [Blackboards]? 36 | Answer: 'Kurdish' 37 | 38 | 39 | Claim I: who is the creator of the film script for [The Truth of Lie]? 40 | Answer: 'Roland Reber' 41 | 42 | 43 | Claim J: what was the release year of the film [The Return of Doctor X]? 44 | Answer: '1939' 45 | 46 | 47 | Claim K: which topics is movie [Topper] about? 48 | Answer: 'ghosts' 49 | 50 | 51 | Claim L: describe the movie [The Mouse on the Moon] in a few words? 52 | Answer: 'bd-r' 53 | 54 | 55 | Now let's verify the Claim based on the Evidence set. Please do not say there is no evdience, you must say the most related one entity from the evidence set. 56 | Claim: <<<>>> 57 | Answer: -------------------------------------------------------------------------------- /metaQA/meta_2hop_prompts/verify_claim_no_evidence.txt: -------------------------------------------------------------------------------- 1 | Answer the questions. 2 | If you think a question can have multiple answers, you must choose one and answer it. 3 | 4 | Examples) 5 | 6 | Claim A: which person wrote the films directed by [Yuriy Norshteyn]? 7 | Answer: 'Sergei Kozlov' 8 | 9 | 10 | Claim B: who are the writers of the movies directed by [Kresten Vestbjerg Andersen]? 11 | Answer: 'Philip Einstein Lipski' 12 | 13 | 14 | Claim C: the movies directed by [David Atkins] were in which genres? 15 | Answer: 'Comedy' 16 | 17 | 18 | Claim D: the films written by [Scott Lobdell] were released in which years? 19 | Answer: '1995' 20 | 21 | 22 | Claim E: what are the languages spoken in the films starred by [Terence Hill]? 23 | Answer: 'Italian' 24 | 25 | 26 | Claim F: who is listed as director of [Oliver Cooper] acted films? 27 | Answer: 'Nima Nourizadeh' 28 | 29 | 30 | Claim G: who co-starred with [Stephen Furst]? 31 | Answer: 'Peter Boyle' 32 | 33 | 34 | Claim H: what types are the films written by [Polaris Banks]? 35 | Answer: 'Short' 36 | 37 | 38 | Claim I: when were the movies written by [Phillip Borsos] released? 39 | Answer: '1985' 40 | 41 | 42 | Claim J: which movies share the screenwriter with [Jesus Henry Christ]? 43 | Answer: 'Fireflies in the Garden' 44 | 45 | 46 | Claim K: which directors co-directed movies with [Ridley Scott]? 47 | Answer: 'Rowdy Herrington' 48 | 49 | 50 | Claim L: the scriptwriter of [First Monday in October] also wrote movies? 51 | Answer: 'Inherit the Wind' 52 | 53 | 54 | Now let's verify the Claim based on the Evidence set. Please do not say there is no evdience, you must say the most related one entity from the evidence set. 55 | Claim: <<<>>> 56 | Answer: -------------------------------------------------------------------------------- /factkg/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import jsonlines 3 | import pickle 4 | import argparse 5 | 6 | if __name__ == "__main__": 7 | 8 | parser = argparse.ArgumentParser(description="Parsing input arguments.") 9 | parser.add_argument('--factkg_train', type=str, required=True, help='Path for factkg train set.') 10 | parser.add_argument('--factkg_dev', type=str, required=True, help='Path for factkg dev set.') 11 | parser.add_argument('--factkg_test', type=str, required=True, help='Path for factkg test set.') 12 | 13 | args = parser.parse_args() 14 | 15 | train_set_path = args.factkg_train 16 | dev_set_path = args.factkg_dev 17 | test_set_path = args.factkg_test 18 | 19 | with open(train_set_path, 'rb') as f: 20 | train_set = pickle.load(f) 21 | claims_train = list(train_set) 22 | 23 | with open(dev_set_path, 'rb') as f: 24 | dev_set = pickle.load(f) 25 | claims_dev = list(dev_set) 26 | 27 | with open(test_set_path, 'rb') as f: 28 | test_set = pickle.load(f) 29 | claims_test = list(test_set) 30 | 31 | with jsonlines.open(f'./extracted_train_set.jsonl', mode='w') as w: 32 | for i, sample in enumerate(claims_train): 33 | new_sample = {} 34 | new_sample["question_id"] = i+1 35 | new_sample["question"] = sample 36 | new_sample["types"] = test_set[sample]["types"] 37 | new_sample["entity_set"] = test_set[sample]["Entity_set"] 38 | new_sample["Label"] = test_set[sample]["Label"] 39 | w.write(new_sample) 40 | 41 | with jsonlines.open(f'./extracted_dev_set.jsonl', mode='w') as w: 42 | for i, sample in enumerate(claims_dev): 43 | new_sample = {} 44 | new_sample["question_id"] = i+1 45 | new_sample["question"] = sample 46 | new_sample["types"] = test_set[sample]["types"] 47 | new_sample["entity_set"] = test_set[sample]["Entity_set"] 48 | new_sample["Label"] = test_set[sample]["Label"] 49 | w.write(new_sample) 50 | 51 | with jsonlines.open(f'./extracted_test_set.jsonl', mode='w') as w: 52 | for i, sample in enumerate(claims_test): 53 | new_sample = {} 54 | new_sample["question_id"] = i+1 55 | new_sample["question"] = sample 56 | new_sample["types"] = test_set[sample]["types"] 57 | new_sample["entity_set"] = test_set[sample]["Entity_set"] 58 | new_sample["Label"] = test_set[sample]["Label"] 59 | w.write(new_sample) -------------------------------------------------------------------------------- /metaQA/meta_1hop_prompts/verify_claim_with_evidence.txt: -------------------------------------------------------------------------------- 1 | Answer the questions based on evidence. 2 | Each evidence is in the form of [head, relation, tail] and it means "head's relation is tail.". 3 | If you think a question can have multiple answers, you must choose one and answer it. 4 | 5 | Examples) 6 | 7 | Claim A: what words describe [Coming Home]? 8 | Evidence set: [['Coming Home', 'has_genre', 'Drama'], ['Coming Home', 'has_genre', 'War'], ['Coming Home', 'has_tags', 'vietnam'], ['Coming Home', 'has_tags', 'hal ashby']] 9 | Answer: 'hal ashby' 10 | 11 | 12 | Claim B: what films does [Faye Wong] appear in? 13 | Evidence set: [['Chungking Express', 'starred_actors', 'Faye Wong'], ['Chinese Odyssey 2002', 'starred_actors', 'Faye Wong']] 14 | Answer: 'Chungking Express' 15 | 16 | 17 | Claim C: what films are about [haneke]? 18 | Evidence set: [['Code Unknown', 'has_tag', 'haneke'], ['The Piano Teacher', 'has_tag', 'haneke'], ['Funny Games', 'has_tag', 'haneke'], ['Time of the Wolf', 'has_tag', 'haneke']] 19 | Answer: 'The Piano Teacher' 20 | 21 | 22 | Claim D: who acted in the movie [Inescapable]? 23 | Evidence set: [['Inescapable', 'directed_by', 'Ruba Nadda'], ['Inescapable', 'written_by', 'Ruba Nadda'], ['Inescapable', 'starred_actors', 'Marisa Tomei'], ['Inescapable', 'starred_actors', 'Joshua Jackson'], ['Inescapable', 'starred_actors', 'Alexander Siddig']] 24 | Answer: 'Marisa Tomei' 25 | 26 | 27 | Claim E: can you name a film directed by [William Cameron Menzies]? 28 | Evidence set: [['Things to Come', 'directed_by', 'William Cameron Menzies']] 29 | Answer: 'Things to Come' 30 | 31 | 32 | Claim F: what sort of movie is [Witness for the Prosecution]? 33 | Evidence set: [['Witness for the Prosecution', 'has_tags', 'bd-r'], ['Witness for the Prosecution', 'has_genre', 'Drama'], ['Witness for the Prosecution', 'has_tags', 'courtroom']] 34 | Answer: 'Drama' 35 | 36 | 37 | Claim G: what type of film is [The Mouse That Roared]? 38 | Evidence set: [['The Mouse That Roared', 'has_genre', 'Comedy'], ['The Mouse That Roared', 'has_tags', 'satirical'], ['The Mouse That Roared', 'has_tags', 'peter sellers']] 39 | Answer: 'Comedy' 40 | 41 | 42 | Claim H: what is the primary language in the film [Blackboards]? 43 | Evidence set: [['Blackboards', 'in_language', 'Kurdish'], ['Blackboards', 'has_genre', 'War'], ['Blackboards', 'has_tags', 'samira makhmalbaf']] 44 | Answer: 'Kurdish' 45 | 46 | 47 | Claim I: who is the creator of the film script for [The Truth of Lie]? 48 | Evidence set: [['The Truth of Lie', 'written_by', 'Roland Reber'], ['The Truth of Lie', 'directed_by', 'Roland Reber'], ['The Truth of Lie, 'has_genre', 'Thriller']] 49 | Answer: 'Roland Reber' 50 | 51 | 52 | Claim J: what was the release year of the film [The Return of Doctor X]? 53 | Evidence set: [['The Return of Doctor X', 'written_by', 'William J. Makin'], ['The Return of Doctor X', 'release_year', '1939'], ['The Return of Doctor X', 'has_tags', 'humphrey bogart']] 54 | Answer: '1939' 55 | 56 | 57 | Claim K: which topics is movie [Topper] about? 58 | Evidence set: [['Topper', 'has_tags', 'ghosts'], ['Topper', 'has_tags', 'norman z. mcleod'], ['Topper', 'has_genre', 'Comedy']] 59 | Answer: 'ghosts' 60 | 61 | 62 | Claim L: describe the movie [The Mouse on the Moon] in a few words? 63 | Evidence set: [['The Mouse on the Moon', 'has_tags', 'bd-r'], ['The Mouse on the Moon', 'has_genre', 'Comedy'], ['The Mouse on the Moon', 'written_by', 'Leonard Wibberley']] 64 | Answer: 'bd-r' 65 | 66 | 67 | Now let's verify the Claim based on the Evidence set. Please do not say there is no evdience, you must say the most related one entity from the evidence set. 68 | Claim: <<<>>> 69 | Evidence set: <<<>>> 70 | Answer: -------------------------------------------------------------------------------- /metaQA/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import jsonlines 4 | import argparse 5 | import re 6 | 7 | if __name__ == "__main__": 8 | 9 | parser = argparse.ArgumentParser(description="Parsing input arguments.") 10 | parser.add_argument('--setting', type=str, required=True) 11 | parser.add_argument('--kb', type=str, required=True, help='Path for metaqa kb.') 12 | 13 | 14 | 15 | args = parser.parse_args() 16 | 17 | setting=args.setting 18 | if setting=='train': 19 | test_1_hop='1-hop/vanilla/qa_train.txt' 20 | test_2_hop='2-hop/vanilla/qa_train.txt' 21 | elif setting=='dev': 22 | test_1_hop='1-hop/vanilla/qa_dev.txt' 23 | test_2_hop='2-hop/vanilla/qa_dev.txt' 24 | elif setting=='test': 25 | test_1_hop='1-hop/vanilla/qa_test.txt' 26 | test_2_hop='2-hop/vanilla/qa_test.txt' 27 | 28 | 29 | 30 | test_1_hop = args.test_1_hop 31 | test_2_hop = args.test_2_hop 32 | test_3_hop = args.test_3_hop 33 | kb = args.kb 34 | 35 | KG_construct = {} 36 | with open(kb, 'r') as f: 37 | for line in f: 38 | head = line.strip().split('|')[0] 39 | relation = line.strip().split('|')[1] 40 | tail = line.strip().split('|')[2] 41 | 42 | try: 43 | KG_construct[head][relation].append(tail) 44 | except: 45 | try: 46 | KG_construct[head][relation] = [tail] 47 | except: 48 | KG_construct[head] = {} 49 | KG_construct[head][relation] = [tail] 50 | 51 | try: 52 | KG_construct[tail]['~'+relation].append(head) 53 | except: 54 | try: 55 | KG_construct[tail]['~'+relation] = [head] 56 | except: 57 | KG_construct[tail] = {} 58 | KG_construct[tail]['~'+relation] = [head] 59 | 60 | #with open('data/metaqa_kg.pickle', 'wb') as f: 61 | # pickle.dump(KG_construct, f) 62 | 63 | 64 | onehop = {} 65 | with open(test_1_hop, 'r') as f: 66 | for line in f: 67 | seperated = line.strip().split('\t') 68 | entities = re.findall(r'\[(.*?)\]', seperated[0]) 69 | labels = seperated[1] 70 | labels = labels.split('|') 71 | onehop[seperated[0]+'?'] = {'entity_set': [entities[0]], 'Label': labels} 72 | 73 | with jsonlines.open(f'data/onehop_{setting}_set.jsonl', mode='w') as w: 74 | total = 0 75 | for i, sample in enumerate(list(onehop)): 76 | new_sample = {} 77 | new_sample["question_id"] = i+1 78 | new_sample["question"] = sample 79 | new_sample["entity_set"] = onehop[sample]["entity_set"] 80 | new_sample["Label"] = onehop[sample]["Label"] 81 | w.write(new_sample) 82 | 83 | 84 | twohop = {} 85 | with open(test_2_hop, 'r') as f: 86 | for line in f: 87 | seperated = line.strip().split('\t') 88 | entities = re.findall(r'\[(.*?)\]', seperated[0]) 89 | labels = seperated[1] 90 | labels = labels.split('|') 91 | twohop[seperated[0]+'?'] = {'entity_set': [entities[0]], 'Label': labels} 92 | 93 | with jsonlines.open(f'data/twohop_{setting}_set.jsonl', mode='w') as w: 94 | total = 0 95 | for i, sample in enumerate(list(twohop)): 96 | new_sample = {} 97 | new_sample["question_id"] = i+1 98 | new_sample["question"] = sample 99 | new_sample["entity_set"] = twohop[sample]["entity_set"] 100 | new_sample["Label"] = twohop[sample]["Label"] 101 | w.write(new_sample) 102 | 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Graph-Enhanced Large Language Models via Path Selection 2 | 3 | The codes are associated with the following paper: 4 | 5 | >**Knowledge Graph-Enhanced Large Language Models via Path Selection,**[PDF](https://arxiv.org/pdf/2406.13862) 6 | 7 | >Haochen Liu, Song Wang, Yaochen Zhu, Yushun Dong, Jundong Li, 8 | >Annual Meeting of the Association for Computational Linguistics (ACL), 2024. 9 | 10 |

11 | Overview of KELP. 12 |

13 | 14 | ## 1. Datasets 15 | 16 | The dataset, requirements, and data preparation follow the setting of [KG-GPT](https://github.com/jiho283/KG-GPT/). 17 | 18 | Download [FactKG](https://github.com/jiho283/FactKG) and [MetaQA](https://github.com/yuyuz/MetaQA) here. 19 | 20 | Place the files `dbpedia_2015_undirected_light.pickle`, `factkg_test.pickle`, `factkg_test.pickle`, `factkg_test.pickle` under `./factkg`. 21 | 22 | Place the files or folders `kb.txt`, `1-hop/vanilla`, `2-hop/vanilla` under `./metaQA`. 23 | 24 | For data preprocessing, run: 25 | 26 | cd factkg 27 | python data/preprocess.py --factkg_train factkg_train.pickle --factkg_dev factkg_dev.pickle --factkg_test factkg_test.pickle 28 | cd .. 29 | 30 | cd metaQA 31 | python data/preprocess.py --setting train --kb kb.txt 32 | python data/preprocess.py --setting dev --kb kb.txt 33 | python data/preprocess.py --setting test --kb kb.txt 34 | cd .. 35 | 36 | ## 2. Openai Key 37 | 38 | Write your own OpenAI API key in factkg/openai_api_key.txt and metaqa/openai_api_key.txt and save them. 39 | 40 | ## 3. Building of the Training Data 41 | 42 | To build the specific training data from the original datasets: 43 | 44 | Run 45 | 46 | cd factkg 47 | python make_training_set.py --setting train 48 | python make_training_set.py --setting dev 49 | cd .. 50 | 51 | cd metaQA 52 | python make_training_set.py --setting train --hop 1 53 | python make_training_set.py --setting dev --hop 1 54 | python make_training_set.py --setting train --hop 2 55 | python make_training_set.py --setting dev --hop 2 56 | cd .. 57 | 58 | 59 | ## 4. Training 60 | 61 | To train our model on dataset: 62 | 63 | Run 64 | cd factkg 65 | python pretrain_LM_encoder.py 66 | cd .. 67 | 68 | cd metaQA 69 | python pretrain_LM_encoder.py --hop 1 70 | python pretrain_LM_encoder.py --hop 2 71 | cd .. 72 | 73 | ## 4. Evaluation 74 | 75 | To test the trained model: 76 | 77 | Run 78 | 79 | cd factkg 80 | python test.py --question_model --question_model_relation_only 81 | cd .. 82 | 83 | cd metaQA 84 | python test-1hop.py --question_model 85 | python test-2hop.py --question_model 86 | cd .. 87 | 88 | ## 5. Acknowledgment 89 | 90 | The dataset, requirements, and data preparation follow the setting of [KG-GPT](https://github.com/jiho283/KG-GPT/). 91 | 92 | Thanks to the authors and developers! 93 | 94 | ## 6. Citation 95 | If you find this work is helpful to your research, please consider citing our paper: 96 | ``` 97 | @inproceedings{liu-etal-2024-knowledge-graph, 98 | title = "Knowledge Graph-Enhanced Large Language Models via Path Selection", 99 | author = "Liu, Haochen and Wang, Song and Zhu, Yaochen and Dong, Yushun and Li, Jundong", 100 | editor = "Ku, Lun-Wei and Martins, Andre and Srikumar, Vivek", 101 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2024", 102 | month = aug, 103 | year = "2024", 104 | address = "Bangkok, Thailand", 105 | publisher = "Association for Computational Linguistics", 106 | pages = "6311--6321", 107 | } 108 | ``` 109 | **Thanks for your interest in our work!** -------------------------------------------------------------------------------- /metaQA/meta_2hop_prompts/verify_claim_with_evidence.txt: -------------------------------------------------------------------------------- 1 | Answer the questions based on evidence. 2 | Each evidence is in the form of [head, relation, tail] and it means "head's relation is tail.". 3 | If you think a question can have multiple answers, you must choose one and answer it. 4 | 5 | Examples) 6 | 7 | Claim A: which person wrote the films directed by [Yuriy Norshteyn]? 8 | Evidence set: [['Tale of Tales', 'directed_by', 'Yuriy Norshteyn'], ['Tale of Tales', 'written_by', 'Yuriy Norshteyn'], ['Hedgehog in the Fog', 'written_by', 'Sergei Kozlov'], ['Hedgehog in the Fog', 'directed_by', 'Yuriy Norshteyn']] 9 | Answer: 'Sergei Kozlov' 10 | 11 | 12 | Claim B: who are the writers of the movies directed by [Kresten Vestbjerg Andersen]? 13 | Evidence set: [['Ronal the Barbarian', 'written_by', 'Philip Einstein Lipski'], ['Ronal the Barbarian', 'directed_by', 'Kresten Vestbjerg Andersen'], ['Ronal the Barbarian', 'written_by', 'Kresten Vestbjerg Andersen'], ['Ronal the Barbarian', 'written_by', 'Thorbjørn Christoffersen']] 14 | Answer: 'Philip Einstein Lipski' 15 | 16 | 17 | Claim C: the movies directed by [David Atkins] were in which genres? 18 | Evidence set: [['Novocaine', 'has_genre', 'Comedy'], ['Novocaine', 'written_by', 'David Atkins'], ['Novocaine', 'directed_by', 'David Atkins']] 19 | Answer: 'Comedy' 20 | 21 | 22 | Claim D: the films written by [Scott Lobdell] were released in which years? 23 | Evidence set: [['Man of the House', 'release_year', '2005'], ['Man of the House', 'written_by', 'Scott Lobdell'], ['Man of the House', 'release_year', '1995'], ['Man of the House', 'has_genre', 'Comedy']] 24 | Answer: '1995' 25 | 26 | 27 | Claim E: what are the languages spoken in the films starred by [Terence Hill]? 28 | Evidence set: [['Terence Hill', 'starred_actors', 'They Call Me Trinity'], ['Terence Hill', 'starred_actors', 'They Call Me Renegade'], ['Terence Hill', 'starred_actors', 'Go for It'], ['They Call Me Renegade', 'in_language', 'Italian']] 29 | Answer: 'Italian' 30 | 31 | 32 | Claim F: who is listed as director of [Oliver Cooper] acted films? 33 | Evidence set: [['Project X', 'directed_by', 'Jonathan Kaplan'], ['Project X', 'starred_actors', 'Jonathan Daniel Brown'], ['Project X', 'starred_actors', 'Oliver Cooper'], ['Project X', 'directed_by', 'Nima Nourizadeh'], ['Project X', 'written_by', 'Matt Drake']] 34 | Answer: 'Nima Nourizadeh' 35 | 36 | 37 | Claim G: who co-starred with [Stephen Furst]? 38 | Evidence set: [['The Dream Team', 'starred_actors', 'Michael Keaton'], ['The Dream Team', 'starred_actors', 'Peter Boyle'], ['The Dream Team', 'starred_actors', 'Christopher Lloyd'], ['The Dream Team', 'directed_by', 'Howard Zieff'], ['The Dream Team', 'written_by', 'David Loucka'], ['The Dream Team', 'starred_actors', 'Stephen Furst']] 39 | Answer: 'Peter Boyle' 40 | 41 | 42 | Claim H: what types are the films written by [Polaris Banks]? 43 | Evidence set: [['Casey Jones', 'written_by', 'Polaris Banks'], ['Casey Jones', 'directed_by', 'Polaris Banks'], ['Casey Jones', 'has_genre', 'Short'], ['Casey Jones', 'release_year', '2011']] 44 | Answer: 'Short' 45 | 46 | 47 | Claim I: when were the movies written by [Phillip Borsos] released? 48 | Evidence set: [['The Grey Fox', 'directed_by', 'Phillip Borsos'], ['The Grey Fox', 'release_year', '1982'], ['One Magic Christmas', 'written_by', 'Phillip Borsos'], ['One Magic Christmas', 'release_year', '1985']] 49 | Answer: '1985' 50 | 51 | 52 | Claim J: which movies share the screenwriter with [Jesus Henry Christ]? 53 | Evidence set: [['Jesus Henry Christ', 'directed_by', 'Dennis Lee'], ['Jesus Henry Christ', 'has_genre', 'Comedy'], ['Jesus Henry Christ', 'written_by', 'Dennis Lee'], ['Fireflies in the Garden', 'written_by', 'Dennis Lee']] 54 | Answer: 'Fireflies in the Garden' 55 | 56 | 57 | Claim K: which directors co-directed movies with [Ridley Scott]? 58 | Evidence set: [['The Counselor', 'directed_by', 'Ridley Scott'], ['Legend', 'directed_by', 'Ridley Scott'], ['Body of Lies', 'directed_by', 'Ridley Scott'], ['Blade Runner', 'directed_by', 'Ridley Scott'], ['Someone to Watch Over Me', 'directed_by', 'Ridley Scott'], ['Gladiator', 'directed_by', 'Ridley Scott'], ['Black Hawk Down', 'directed_by', 'Ridley Scott'], ['Black Rain', 'directed_by', 'Ridley Scott'], ['Robin Hood', 'directed_by', 'Ridley Scott'], ['Gladiator', 'directed_by', 'Rowdy Herrington']] 59 | Answer: 'Rowdy Herrington' 60 | 61 | 62 | Claim L: the scriptwriter of [First Monday in October] also wrote movies? 63 | Evidence set: [['First Monday in October', 'written_by', 'Robert E. Lee'], ['First Monday in October', 'written_by', 'Jerome Lawrence'], ['Inherit the Wind', 'written_by', 'Jerome Lawrence']] 64 | Answer: 'Inherit the Wind' 65 | 66 | 67 | Now let's verify the Claim based on the Evidence set. Please do not say there is no evdience, you must say the most related one entity from the evidence set. 68 | Claim: <<<>>> 69 | Evidence set: <<<>>> 70 | Answer: -------------------------------------------------------------------------------- /metaQA/pretrain_LM_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader, Dataset 4 | from sentence_transformers import SentenceTransformer, losses 5 | from transformers import DistilBertModel, DistilBertTokenizer 6 | from tqdm import tqdm 7 | import json 8 | import jsonlines 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | import os 11 | import argparse 12 | parser = argparse.ArgumentParser(description="Parsing input arguments.") 13 | parser.add_argument('--hop', type=int, required=True) 14 | args = parser.parse_args() 15 | hop=args.hop 16 | 17 | 18 | model_name = "distilbert-base-uncased" 19 | tokenizer = DistilBertTokenizer.from_pretrained(model_name) 20 | question_model = DistilBertModel.from_pretrained(model_name) 21 | 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | question_model.to(device) 24 | 25 | class CustomDataset(Dataset): 26 | def __init__(self, data_path): 27 | self.data = self.load_data(data_path) 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | def __getitem__(self, idx): 33 | sample = self.data[idx] 34 | if len(sample['pos_triplet'])==3: 35 | if sample['pos_triplet'][1][0]=='~': 36 | pos=sample['pos_triplet'][2]+' '+sample['pos_triplet'][1][1:]+' '+sample['pos_triplet'][0]+'.' 37 | else: 38 | pos=sample['pos_triplet'][0]+' '+sample['pos_triplet'][1]+' '+sample['pos_triplet'][2]+'.' 39 | elif len(sample['pos_triplet'])==5: 40 | if sample['pos_triplet'][1][0]=='~': 41 | pos=sample['pos_triplet'][2]+' '+sample['pos_triplet'][1][1:]+' '+sample['pos_triplet'][0]+', ' 42 | else: 43 | pos=sample['pos_triplet'][0]+' '+sample['pos_triplet'][1]+' '+sample['pos_triplet'][2]+', ' 44 | if sample['pos_triplet'][3][0]=='~': 45 | pos+=sample['pos_triplet'][4]+' '+sample['pos_triplet'][3][1:]+' '+sample['pos_triplet'][2]+'.' 46 | else: 47 | pos+=sample['pos_triplet'][2]+' '+sample['pos_triplet'][3]+' '+sample['pos_triplet'][4]+'.' 48 | 49 | 50 | if len(sample['neg_triplet'])==3: 51 | if sample['neg_triplet'][1][0]=='~': 52 | neg=sample['neg_triplet'][2]+' '+sample['neg_triplet'][1][1:]+' '+sample['neg_triplet'][0]+'.' 53 | else: 54 | neg=sample['neg_triplet'][0]+' '+sample['neg_triplet'][1]+' '+sample['neg_triplet'][2]+'.' 55 | elif len(sample['neg_triplet'])==5: 56 | if sample['neg_triplet'][1][0]=='~': 57 | neg=sample['neg_triplet'][2]+' '+sample['neg_triplet'][1][1:]+' '+sample['neg_triplet'][0]+', ' 58 | else: 59 | neg=sample['neg_triplet'][0]+' '+sample['neg_triplet'][1]+' '+sample['neg_triplet'][2]+', ' 60 | if sample['neg_triplet'][3][0]=='~': 61 | neg+=sample['neg_triplet'][4]+' '+sample['neg_triplet'][3][1:]+' '+sample['neg_triplet'][2]+'.' 62 | else: 63 | neg+=sample['neg_triplet'][2]+' '+sample['neg_triplet'][3]+' '+sample['neg_triplet'][4]+'.' 64 | 65 | return { 66 | "question": sample["question"], 67 | "positive_question": pos, 68 | "negative_question": neg 69 | } 70 | 71 | def load_data(self, data_path): 72 | data = [] 73 | with jsonlines.open(data_path, 'r') as reader: 74 | for obj in reader: 75 | data.append(obj) 76 | return data 77 | 78 | 79 | if hop==1: 80 | data_path = "onehop.jsonl" 81 | data_path_dev = "onehop-dev.jsonl" 82 | if hop==2: 83 | data_path = "twohop.jsonl" 84 | data_path_dev = "twohop-dev.jsonl" 85 | 86 | dataset = CustomDataset(data_path) 87 | dataloader = DataLoader(dataset, batch_size=20, shuffle=True) 88 | optimizer_question = torch.optim.AdamW(question_model.parameters(), lr=2e-6) 89 | 90 | 91 | 92 | 93 | dataset_dev = CustomDataset(data_path_dev) 94 | dataloader_dev=DataLoader(dataset_dev, batch_size=5, shuffle=True) 95 | criterion=nn.CosineSimilarity() 96 | margin = 0.5 97 | 98 | num_epochs = 60 99 | for epoch in range(num_epochs): 100 | question_model.train() 101 | total_loss = 0 102 | total_loss_dev = 0 103 | for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"): 104 | 105 | question_input = tokenizer(batch["question"], return_tensors="pt", padding=True, truncation=True).to(device) 106 | positive_input = tokenizer(batch["positive_question"], return_tensors="pt", padding=True, truncation=True).to(device) 107 | negative_input = tokenizer(batch["negative_question"], return_tensors="pt", padding=True, truncation=True).to(device) 108 | 109 | question_embedding = question_model(**question_input).last_hidden_state.mean(dim=1) 110 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 111 | negative_embedding = question_model(**negative_input).last_hidden_state.mean(dim=1) 112 | similarity_scores_pos=criterion(question_embedding, positive_embedding).mean() 113 | similarity_scores_neg=criterion(question_embedding, negative_embedding).mean() 114 | 115 | loss = torch.mean(torch.relu(1 - similarity_scores_pos) + torch.relu(1 + similarity_scores_neg)) 116 | optimizer_question.zero_grad() 117 | 118 | loss.backward() 119 | optimizer_question.step() 120 | total_loss += loss.item() 121 | 122 | for batch in tqdm(dataloader_dev, desc=f"Epoch {epoch + 1}/{num_epochs}"): 123 | 124 | question_input = tokenizer(batch["question"], return_tensors="pt", padding=True, truncation=True).to(device) 125 | positive_input = tokenizer(batch["positive_question"], return_tensors="pt", padding=True, truncation=True).to(device) 126 | negative_input = tokenizer(batch["negative_question"], return_tensors="pt", padding=True, truncation=True).to(device) 127 | 128 | question_embedding = question_model(**question_input).last_hidden_state.mean(dim=1) 129 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 130 | negative_embedding = question_model(**negative_input).last_hidden_state.mean(dim=1) 131 | 132 | 133 | similarity_scores_pos=criterion(question_embedding, positive_embedding).mean() 134 | similarity_scores_neg=criterion(question_embedding, negative_embedding).mean() 135 | 136 | loss = torch.mean(torch.relu(1 - similarity_scores_pos) + torch.relu(1 + similarity_scores_neg)) 137 | total_loss_dev += loss.item() 138 | 139 | print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}, Loss_dev: {total_loss_dev / len(dataloader_dev)}") 140 | os.makedirs('./model', exist_ok=True) 141 | torch.save(question_model.state_dict(), 'model/question_model_epoch'+str(epoch)+'.pth') 142 | 143 | tokenizer.save_vocabulary('distilbert_tokenizer') 144 | -------------------------------------------------------------------------------- /factkg/pretrain_LM_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader, Dataset 4 | from sentence_transformers import SentenceTransformer, losses 5 | from transformers import DistilBertModel, DistilBertTokenizer 6 | from tqdm import tqdm 7 | import json 8 | import jsonlines 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | import os 11 | model_name = "distilbert-base-uncased" 12 | tokenizer = DistilBertTokenizer.from_pretrained(model_name) 13 | question_model = DistilBertModel.from_pretrained(model_name) 14 | question_model2 = DistilBertModel.from_pretrained(model_name) 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | question_model.to(device) 18 | question_model2.to(device) 19 | 20 | class CustomDataset(Dataset): 21 | def __init__(self, data_path): 22 | self.data = self.load_data(data_path) 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, idx): 28 | sample = self.data[idx] 29 | if len(sample['pos_triplet'])==3: 30 | pos=sample['pos_triplet'][0]+' '+sample['pos_triplet'][1]+' '+sample['pos_triplet'][2]+'.' 31 | elif len(sample['pos_triplet'])==5: 32 | pos=sample['pos_triplet'][0]+' '+sample['pos_triplet'][1]+' '+sample['pos_triplet'][2]+', ' 33 | pos+=sample['pos_triplet'][2]+' '+sample['pos_triplet'][3]+' '+sample['pos_triplet'][4]+'.' 34 | if len(sample['neg_triplet'])==3: 35 | neg=sample['neg_triplet'][0]+' '+sample['neg_triplet'][1]+' '+sample['neg_triplet'][2]+'.' 36 | elif len(sample['neg_triplet'])==5: 37 | neg=sample['neg_triplet'][0]+' '+sample['neg_triplet'][1]+' '+sample['neg_triplet'][2]+', ' 38 | neg+=sample['neg_triplet'][2]+' '+sample['neg_triplet'][3]+' '+sample['neg_triplet'][4]+'.' 39 | 40 | if len(sample['pos_triplet'])==3: 41 | pos2=sample['pos_triplet'][1]+'.' 42 | elif len(sample['pos_triplet'])==5: 43 | pos2=sample['pos_triplet'][1]+', ' 44 | pos2+=sample['pos_triplet'][3]+'.' 45 | if len(sample['neg_triplet'])==3: 46 | neg2=sample['neg_triplet'][1]+'.' 47 | elif len(sample['neg_triplet'])==5: 48 | neg2=sample['neg_triplet'][1]+', ' 49 | neg2+=sample['neg_triplet'][3]+'.' 50 | 51 | return { 52 | "question": sample["question"], 53 | "positive_question": pos, 54 | "negative_question": neg, 55 | "positive_relation": pos2, 56 | "negative_relation": neg2 57 | } 58 | 59 | def load_data(self, data_path): 60 | data = [] 61 | with jsonlines.open(data_path, 'r') as reader: 62 | for obj in reader: 63 | if ((len(obj['pos_triplet'])==3 or len(obj['pos_triplet'])==5) and (len(obj['neg_triplet'])==3 or len(obj['neg_triplet'])==5)): 64 | data.append(obj) 65 | return data 66 | 67 | data_path = "output.jsonl" 68 | dataset = CustomDataset(data_path) 69 | dataloader = DataLoader(dataset, batch_size=40, shuffle=True) 70 | optimizer_question = torch.optim.AdamW(question_model.parameters(), lr=2e-6) 71 | optimizer_question2= torch.optim.AdamW(question_model.parameters(), lr=2e-6) 72 | 73 | 74 | data_path_dev = "output_dev.jsonl" 75 | dataset_dev = CustomDataset(data_path_dev) 76 | dataloader_dev=DataLoader(dataset_dev, batch_size=5, shuffle=True) 77 | criterion=nn.CosineSimilarity() 78 | margin = 0.5 79 | 80 | num_epochs = 60 81 | for epoch in range(num_epochs): 82 | question_model.train() 83 | total_loss = 0 84 | total_loss_dev = 0 85 | for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"): 86 | 87 | question_input = tokenizer(batch["question"], return_tensors="pt", padding=True, truncation=True).to(device) 88 | positive_input = tokenizer(batch["positive_question"], return_tensors="pt", padding=True, truncation=True).to(device) 89 | negative_input = tokenizer(batch["negative_question"], return_tensors="pt", padding=True, truncation=True).to(device) 90 | positive_input2 = tokenizer(batch["positive_relation"], return_tensors="pt", padding=True, truncation=True).to(device) 91 | negative_input2 = tokenizer(batch["negative_relation"], return_tensors="pt", padding=True, truncation=True).to(device) 92 | 93 | question_embedding = question_model(**question_input).last_hidden_state.mean(dim=1) 94 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 95 | negative_embedding = question_model(**negative_input).last_hidden_state.mean(dim=1) 96 | question_embedding2 = question_model2(**question_input).last_hidden_state.mean(dim=1) 97 | positive_embedding2 = question_model2(**positive_input2).last_hidden_state.mean(dim=1) 98 | negative_embedding2 = question_model2(**negative_input2).last_hidden_state.mean(dim=1) 99 | 100 | similarity_scores_pos=criterion(question_embedding, positive_embedding).mean() 101 | similarity_scores_neg=criterion(question_embedding, negative_embedding).mean() 102 | similarity_scores_pos2=criterion(question_embedding2, positive_embedding2).mean() 103 | similarity_scores_neg2=criterion(question_embedding2, negative_embedding2).mean() 104 | 105 | loss = torch.mean(torch.relu(1 - similarity_scores_pos) + torch.relu(1 + similarity_scores_neg)+ torch.relu(1 - similarity_scores_pos2) + torch.relu(1 + similarity_scores_neg2)) 106 | optimizer_question.zero_grad() 107 | optimizer_question2.zero_grad() 108 | loss.backward() 109 | optimizer_question.step() 110 | optimizer_question2.step() 111 | total_loss += loss.item() 112 | 113 | for batch in tqdm(dataloader_dev, desc=f"Epoch {epoch + 1}/{num_epochs}"): 114 | 115 | question_input = tokenizer(batch["question"], return_tensors="pt", padding=True, truncation=True).to(device) 116 | positive_input = tokenizer(batch["positive_question"], return_tensors="pt", padding=True, truncation=True).to(device) 117 | negative_input = tokenizer(batch["negative_question"], return_tensors="pt", padding=True, truncation=True).to(device) 118 | positive_input2 = tokenizer(batch["positive_relation"], return_tensors="pt", padding=True, truncation=True).to(device) 119 | negative_input2 = tokenizer(batch["negative_relation"], return_tensors="pt", padding=True, truncation=True).to(device) 120 | 121 | question_embedding = question_model(**question_input).last_hidden_state.mean(dim=1) 122 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 123 | negative_embedding = question_model(**negative_input).last_hidden_state.mean(dim=1) 124 | question_embedding2 = question_model2(**question_input).last_hidden_state.mean(dim=1) 125 | positive_embedding2 = question_model2(**positive_input2).last_hidden_state.mean(dim=1) 126 | negative_embedding2 = question_model2(**negative_input2).last_hidden_state.mean(dim=1) 127 | 128 | similarity_scores_pos=criterion(question_embedding, positive_embedding).mean() 129 | similarity_scores_neg=criterion(question_embedding, negative_embedding).mean() 130 | similarity_scores_pos2=criterion(question_embedding2, positive_embedding2).mean() 131 | similarity_scores_neg2=criterion(question_embedding2, negative_embedding2).mean() 132 | 133 | loss = torch.mean(torch.relu(1 - similarity_scores_pos) + torch.relu(1 + similarity_scores_neg)+ torch.relu(1 - similarity_scores_pos2) + torch.relu(1 + similarity_scores_neg2)) 134 | total_loss_dev += loss.item() 135 | 136 | print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}, Loss_dev: {total_loss_dev / len(dataloader_dev)}") 137 | os.makedirs('./model', exist_ok=True) 138 | torch.save(question_model.state_dict(), 'model/question_model_epoch'+str(epoch)+'.pth') 139 | torch.save(question_model2.state_dict(), 'model/question_model2_epoch'+str(epoch)+'.pth') 140 | 141 | tokenizer.save_vocabulary('distilbert_tokenizer') 142 | -------------------------------------------------------------------------------- /metaQA/make_training_set.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import random 4 | import openai 5 | import os 6 | from tqdm import tqdm 7 | import time 8 | import argparse 9 | parser = argparse.ArgumentParser(description="Parsing input arguments.") 10 | parser.add_argument('--setting', type=str, required=True) 11 | parser.add_argument('--hop', type=int, required=True) 12 | args = parser.parse_args() 13 | setting = args.setting 14 | hop=args.hop 15 | 16 | if setting=='train': 17 | if hop==1: 18 | question_data=f'data/onehop_{setting}_set.jsonl' 19 | save_file='onehop.jsonl' 20 | if hop==2: 21 | question_data=f'data/twohop_{setting}_set.jsonl' 22 | save_file='twohop.jsonl' 23 | elif setting=='dev': 24 | if hop==1: 25 | question_data=f'data/onehop_{setting}_set.jsonl' 26 | save_file='onehop-dev.jsonl' 27 | if hop==2: 28 | question_data=f'data/twohop_{setting}_set.jsonl' 29 | save_file='twohop-dev.jsonl' 30 | 31 | openai.api_key = os.environ["OPENAI_API_KEY"] 32 | with open('data/metaqa_kg.pickle', 'rb') as f: 33 | kg = pickle.load(f) 34 | questions_dict = {} 35 | entity_set_dict = {} 36 | label_set_dict = {} 37 | 38 | 39 | 40 | with open(question_data, 'r') as f: 41 | for line in f: 42 | if not line: 43 | continue 44 | dataset = json.loads(line) 45 | questions_dict[dataset["question_id"]] = dataset["question"] 46 | entity_set_dict[dataset["question_id"]] = dataset["entity_set"] 47 | label_set_dict[dataset["question_id"]] = dataset["Label"] 48 | 49 | model_name = 'gpt-3.5-turbo-0613' 50 | max_tokens = 400 51 | temperature = 0.2 52 | top_p = 0.1 53 | 54 | 55 | def llm(prompt,max_tokens=max_tokens): 56 | for _ in range(3): 57 | try: 58 | response = openai.ChatCompletion.create( 59 | model=model_name, 60 | messages=[ 61 | {"role": "system", "content": "You are a helpful assistant."}, 62 | { 63 | "role": "user", 64 | "content": prompt, 65 | }, 66 | ], 67 | max_tokens=max_tokens, 68 | temperature=0.2, 69 | top_p = 0.1, 70 | timeout=30 71 | ) 72 | generated_text = response["choices"][0]["message"]["content"] 73 | return generated_text 74 | except Exception as e: 75 | if _==2: 76 | print("[ERROR]", e) 77 | time.sleep(5) 78 | 79 | 80 | def build_subgraph_rels_ranking(question, entity_set, knowledge_graph): 81 | one_hop_neighbors = get_one_hop_neighbors_rels_ranking(question, entity_set, knowledge_graph) 82 | if one_hop_neighbors=="No": 83 | return "No" 84 | two_hop_neighbors = get_two_hop_neighbors_rels_ranking(question, one_hop_neighbors, knowledge_graph) 85 | if two_hop_neighbors=="No": 86 | return "No" 87 | if hop==1: 88 | return one_hop_neighbors#+two_hop_neighbors 89 | #if hop==2: 90 | # #return one_hop_neighbors+two_hop_neighbors 91 | # return two_hop_neighbors 92 | 93 | def get_one_hop_neighbors_rels_ranking(question, entity_set, knowledge_graph): 94 | neighbors_rel = set() 95 | for entity in entity_set: 96 | if entity in knowledge_graph: 97 | for relation, object in knowledge_graph[entity].items(): 98 | if (str(entity), str(relation)) not in neighbors_rel and (',' not in str(entity) and (',' not in str(relation))): 99 | neighbors_rel.add((str(entity), str(relation))) 100 | top_neighbors_rel=get_top_related_triplets_rels_ranking(question,neighbors_rel,1) 101 | if top_neighbors_rel=="No": 102 | return "No" 103 | neighbors_ent=[set() for _ in range(len(top_neighbors_rel))] 104 | num_ent=0 105 | for i in range(len(top_neighbors_rel)): 106 | if len(top_neighbors_rel[i])!=2: 107 | continue 108 | if top_neighbors_rel[i][0] in knowledge_graph and top_neighbors_rel[i][1] in knowledge_graph[top_neighbors_rel[i][0]]: 109 | for obj in knowledge_graph[top_neighbors_rel[i][0]][top_neighbors_rel[i][1]]: 110 | if (str(top_neighbors_rel[i][0]), str(top_neighbors_rel[i][1]), str(obj)) not in neighbors_ent[i] and (',' not in str(entity) and (',' not in str(relation))) and (',' not in str(obj)): 111 | neighbors_ent[i].add((str(top_neighbors_rel[i][0]), str(top_neighbors_rel[i][1]), str(obj))) 112 | num_ent+=1 113 | top_neighbors_ent=get_top_related_triplets_ents_ranking(question,neighbors_ent,1,num_ent) 114 | if top_neighbors_ent=="No": 115 | return "No" 116 | if not top_neighbors_ent: 117 | return "No" 118 | 119 | #if top_neighbors_ent: 120 | # a=random.choice(list(top_neighbors_ent)) 121 | # if len(a)==3: 122 | # h0,r0,h_n=a 123 | # if h_n in knowledge_graph: 124 | # r_n=random.choice(list(knowledge_graph[h_n])) 125 | # t_n=random.choice(list(knowledge_graph[h_n][r_n])) 126 | # top_neighbors_ent.append((str(h_n), str(r_n), str(t_n))) 127 | 128 | 129 | return top_neighbors_ent 130 | 131 | def get_two_hop_neighbors_rels_ranking(question, top_1hop_triplets, knowledge_graph): 132 | neighbors_rel = set() 133 | for triplet in top_1hop_triplets: 134 | try: 135 | h, r, t = triplet 136 | except Exception as e: 137 | return "No" 138 | if t in knowledge_graph: 139 | for relation, object in knowledge_graph[t].items(): 140 | if relation!='~'+r and r!='~'+relation: 141 | if (str(t), str(relation)) not in neighbors_rel and (',' not in str(t)) and (',' not in str(relation)): 142 | neighbors_rel.add((str(t), str(relation))) 143 | top_neighbors_rel=get_top_related_triplets_rels_ranking(question,neighbors_rel,2) 144 | if top_neighbors_rel=="No": 145 | return "No" 146 | neighbors_ent=[set() for _ in range(len(top_neighbors_rel))] 147 | num_ent=0 148 | for i in range(len(top_neighbors_rel)): 149 | if top_neighbors_rel[i][0] in knowledge_graph and top_neighbors_rel[i][1] in knowledge_graph[top_neighbors_rel[i][0]]: 150 | for obj in knowledge_graph[top_neighbors_rel[i][0]][top_neighbors_rel[i][1]]: 151 | if (str(top_neighbors_rel[i][0]), str(top_neighbors_rel[i][1]), str(obj)) not in neighbors_ent[i] and (',' not in str(top_neighbors_rel[i][0])) and (',' not in str(top_neighbors_rel[i][1])) and (',' not in str(obj)): 152 | neighbors_ent[i].add((str(top_neighbors_rel[i][0]), str(top_neighbors_rel[i][1]), str(obj))) 153 | num_ent+=1 154 | top_neighbors_ent=get_top_related_triplets_ents_ranking(question,neighbors_ent,1,num_ent) 155 | if top_neighbors_ent=="No": 156 | return "No" 157 | for i in range(len(top_neighbors_ent)): 158 | for triplet in top_1hop_triplets: 159 | h, r, t = triplet 160 | if t==top_neighbors_ent[i][0]: 161 | top_neighbors_ent[i]=(str(h),str(r),str(top_neighbors_ent[i][0]),str(top_neighbors_ent[i][1]),str(top_neighbors_ent[i][2])) 162 | break 163 | if top_1hop_triplets: 164 | h0,r0,h_n=random.choice(list(top_1hop_triplets)) 165 | if h_n in knowledge_graph: 166 | r_n=random.choice(list(knowledge_graph[h_n])) 167 | t_n=random.choice(list(knowledge_graph[h_n][r_n])) 168 | top_neighbors_ent.append((str(h0),str(r0),str(h_n), str(r_n), str(t_n))) 169 | 170 | return top_neighbors_ent 171 | 172 | def get_top_related_triplets_rels_ranking(question, triplets, hop): 173 | if len(triplets)<=5: 174 | ranked_triplets=[triplet for triplet in triplets] 175 | return ranked_triplets 176 | 177 | prompt = f"Each of these word sets shows an entity and one of its corresponding relation. Select the 5-top word sets which are most semantically related to a given question. You should list the selected word sets from rank 1 to rank 5. Your answer should be in the form of '(XXX,XXX);(XXX,XXX);(XXX,XXX);(XXX,XXX);(XXX,XXX)'. Question: {question}\nWord sets: " 178 | for triplet in triplets: 179 | prompt += f"{triplet};" 180 | 181 | response = llm(prompt,len(triplets)*50) 182 | if response is None: 183 | return "No" 184 | try: 185 | ranked_triplets = [tuple(word.strip('\'" ') for word in triplet.split('(')[1].split(')')[0].split(',')) for triplet in response.split(';') if triplet] 186 | except Exception as e: 187 | return "No" 188 | return ranked_triplets 189 | 190 | def get_top_related_triplets_ents_ranking(question, neighbors_ent, hop,num_ent): 191 | ranked_triplets=[] 192 | if num_ent<=5: 193 | for triplets in neighbors_ent: 194 | ranked_triplets+=[triplet for triplet in triplets] 195 | return ranked_triplets 196 | num_sel=0 197 | i=0 198 | while(num_sel<5 and i1: 204 | ranked_triplets+=[triplet for triplet in neighbors_ent[i]] 205 | num_sel+=2 206 | i+=1 207 | else: 208 | if 5-num_sel>1: 209 | prompt = f"These word sets shows the relations of some entities. Select the 2-top word sets which are most semantically related to a given question. You should list the selected word sets from rank 1 to rank 2. Your answer should be in the form of '(XXX,XXX,XXX);(XXX,XXX,XXX)'. Question: {question}\nWord sets: " 210 | else: 211 | prompt = f"These word sets shows the relations of some entities. Select the best word sets which are most semantically related to a given question. Your answer should be in the form of '(XXX,XXX,XXX)'. Question: {question}\nWord sets: " 212 | 213 | for triplet in neighbors_ent[i]: 214 | prompt += f"{triplet};" 215 | 216 | response = llm(prompt,30+len(neighbors_ent[i])*50) 217 | if response is None: 218 | return "No" 219 | try: 220 | add_info=[tuple(word.strip('\'" ') for word in triplet.split('(')[1].split(')')[0].split(',')) for triplet in response.split(';') if triplet] 221 | except Exception as e: 222 | return "No" 223 | ranked_triplets += add_info 224 | num_sel+=len(add_info) 225 | i+=1 226 | return ranked_triplets 227 | 228 | 229 | def open_file(filepath): 230 | with open(filepath, 'r', encoding='utf-8') as infile: 231 | return infile.read() 232 | 233 | original_prompt=open_file('./meta_'+str(hop)+'hop_prompts/verify_claim_no_evidence.txt') 234 | context_prompt=open_file('./meta_'+str(hop)+'hop_prompts/verify_claim_with_evidence.txt') 235 | 236 | def original_query(question,ground_truth): 237 | prompt = original_prompt.replace('<<<>>>', question) 238 | 239 | result_original = llm(prompt) 240 | original_correct=False 241 | 242 | for lab in ground_truth: 243 | if lab.lower() in result_original.lower(): 244 | original_correct=True 245 | break 246 | 247 | return original_correct 248 | 249 | def context_query(question,ground_truth,triplet,already_pos,already_neg, pos_sam, neg_sam): 250 | context='[' 251 | if len(triplet)==3: 252 | if triplet[1][0]=='~': 253 | context+='['+triplet[2]+', '+triplet[1][1:]+', '+triplet[0]+']' 254 | else: 255 | context+='['+triplet[0]+', '+triplet[1]+', '+triplet[2]+']' 256 | elif len(triplet)==5: 257 | if triplet[1][0]=='~': 258 | context+='['+triplet[2]+', '+triplet[1][1:]+', '+triplet[0]+'], ' 259 | else: 260 | context+='['+triplet[0]+', '+triplet[1]+', '+triplet[2]+'], ' 261 | if triplet[3][0]=='~': 262 | context+='['+triplet[4]+', '+triplet[3][1:]+', '+triplet[2]+']' 263 | else: 264 | context+='['+triplet[2]+', '+triplet[3]+', '+triplet[4]+']' 265 | else: 266 | return 'triplet error', already_pos, already_neg, pos_sam, neg_sam 267 | context+=']' 268 | prompt = context_prompt.replace('<<<>>>', question).replace('<<<>>>', context) 269 | result = llm(prompt) 270 | context_answer="No correct answer" 271 | context_correct=False 272 | 273 | for lab in ground_truth: 274 | if lab.lower() in result.lower(): 275 | context_answer = lab.lower() 276 | context_correct=True 277 | break 278 | 279 | if already_pos==False and context_correct==True: 280 | already_pos=True 281 | pos_sam=triplet 282 | if already_neg==False and context_correct==False: 283 | already_neg=True 284 | neg_sam=triplet 285 | 286 | return context_correct, already_pos, already_neg, pos_sam, neg_sam 287 | 288 | 289 | 290 | dataset_len=len(questions_dict) 291 | a=1 292 | data_num=range(a,dataset_len+1,1) 293 | 294 | 295 | for ii in tqdm(data_num): 296 | if ii%5!=0: 297 | continue 298 | question = questions_dict[ii] 299 | entity_set = entity_set_dict[ii] 300 | ground_truth =label_set_dict[ii] 301 | 302 | original_correct=original_query(question,ground_truth) 303 | if original_correct==True: 304 | continue 305 | 306 | subgraph = build_subgraph_rels_ranking(question, entity_set, kg) 307 | if subgraph=="No": 308 | continue 309 | 310 | num_triplets_to_test = len(subgraph) 311 | if num_triplets_to_test==0: 312 | continue 313 | 314 | 315 | 316 | 317 | already_pos,already_neg=False, False 318 | pos_sam,neg_sam=False, False 319 | tl=list(range(num_triplets_to_test)) 320 | random.shuffle(tl) 321 | for i in tl: 322 | context_correct, already_pos, already_neg, pos_sam, neg_sam=context_query(question,ground_truth,subgraph[i],already_pos,already_neg, pos_sam, neg_sam) 323 | if context_correct=='triplet error': 324 | continue 325 | if pos_sam==True and neg_sam==True: 326 | break 327 | if already_pos == True and already_neg == True: 328 | result_dict = { 329 | 'question_id': ii, 330 | 'question': question, 331 | 'entity_set': entity_set, 332 | 'ground_truth': ground_truth, 333 | 'pos_triplet': pos_sam, 334 | 'neg_triplet': neg_sam 335 | } 336 | with open(save_file, 'a') as f: 337 | json_line = json.dumps(result_dict) 338 | f.write(json_line + '\n') -------------------------------------------------------------------------------- /metaQA/test-1hop.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import random 4 | import openai 5 | import os 6 | from tqdm import tqdm 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader, Dataset 11 | from sentence_transformers import SentenceTransformer, losses 12 | from transformers import DistilBertModel, DistilBertTokenizer 13 | from tqdm import tqdm 14 | import json 15 | import jsonlines 16 | import torch 17 | import argparse 18 | parser = argparse.ArgumentParser(description="Parsing input arguments.") 19 | parser.add_argument('--question_model', type=str, required=True) 20 | args = parser.parse_args() 21 | question_model_path = args.question_model 22 | 23 | question_model = DistilBertModel.from_pretrained('distilbert-base-uncased') 24 | question_model.load_state_dict(torch.load(question_model_path)) 25 | tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | question_model.to(device) 29 | 30 | def open_file(filepath): 31 | with open(filepath, 'r', encoding='utf-8') as infile: 32 | return infile.read() 33 | 34 | openai.api_key = open_file('./openai_api_key.txt') 35 | 36 | with open('data/metaqa_kg.pickle', 'rb') as f: 37 | kg = pickle.load(f) 38 | hop=1 39 | 40 | questions_dict = {} 41 | entity_set_dict = {} 42 | label_set_dict = {} 43 | if hop==1: 44 | question_data='data/onehop_test_set.jsonl' 45 | note='note.txt' 46 | error_file='one-hop-errors.json' 47 | if hop==2: 48 | question_data='data/twohop_test_set.jsonl' 49 | note='note.txt' 50 | error_file='two-hop-errors.json' 51 | with open(question_data, 'r') as f: 52 | for line in f: 53 | if not line: 54 | continue 55 | dataset = json.loads(line) 56 | questions_dict[dataset["question_id"]] = dataset["question"] 57 | entity_set_dict[dataset["question_id"]] = dataset["entity_set"] 58 | label_set_dict[dataset["question_id"]] = dataset["Label"] 59 | 60 | 61 | model_name = 'gpt-3.5-turbo-0613' 62 | max_tokens = 400 63 | temperature = 0.2 64 | top_p = 0.1 65 | 66 | 67 | def llm(prompt,max_tokens=max_tokens): 68 | for _ in range(3): 69 | try: 70 | response = openai.ChatCompletion.create( 71 | model=model_name, 72 | messages=[ 73 | {"role": "system", "content": "You are a helpful assistant."}, 74 | { 75 | "role": "user", 76 | "content": prompt, 77 | }, 78 | ], 79 | max_tokens=max_tokens, 80 | temperature=0.2, 81 | top_p = 0.1, 82 | timeout=30 83 | ) 84 | generated_text = response["choices"][0]["message"]["content"] 85 | return generated_text 86 | except Exception as e: 87 | if _==2: 88 | print("[ERROR]", e) 89 | time.sleep(5) 90 | 91 | 92 | def build_subgraph(entity_set, knowledge_graph): 93 | subgraph = set() 94 | for entity in entity_set: 95 | if entity in knowledge_graph: 96 | for relation, object in knowledge_graph[entity].items(): 97 | for obj in knowledge_graph[entity][relation]: 98 | subgraph.add((str(entity), str(relation), str(obj))) 99 | #for relation2, object2 in knowledge_graph[obj].items(): 100 | # for obj2 in knowledge_graph[obj][relation2]: 101 | # subgraph.add((str(entity), str(relation), str(obj), str(relation2), str(obj2))) 102 | 103 | return subgraph 104 | 105 | def open_file(filepath): 106 | with open(filepath, 'r', encoding='utf-8') as infile: 107 | return infile.read() 108 | 109 | def context_query(question_list,ground_truth_list,context_texts): 110 | 111 | for _ in range(3): 112 | prompt="""Answer the following questions.\n The context is the evidence of triplets may help your verifying.\n 113 | Each context contains triplets in the form of [head, relation, tail] and it means "head's relation is tail.". 114 | If you think a question can have multiple answers, you must choose one and answer it. Enter when you start answering the next question. Examples:\n 115 | """ 116 | prompt+=""" 117 | Context 1: [['Coming Home', 'has_genre', 'Drama'], ['Coming Home', 'has_genre', 'War'], ['Coming Home', 'has_tags', 'vietnam'], ['Coming Home', 'has_tags', 'hal ashby'], ] Question 1: what words describe [Coming Home]? 118 | Answer 1: 'hal ashby' 119 | Context 2: [['Chungking Express', 'starred_actors', 'Faye Wong'], ['Chinese Odyssey 2002', 'starred_actors', 'Faye Wong'], ] Question 2: what films does [Faye Wong] appear in? 120 | Answer 2: 'Chungking Express' 121 | Context 3: [['Code Unknown', 'has_tag', 'haneke'], ['The Piano Teacher', 'has_tag', 'haneke'], ['Funny Games', 'has_tag', 'haneke'], ['Time of the Wolf', 'has_tag', 'haneke'], ] Question 3: what films are about [haneke]? 122 | Answer 3: 'The Piano Teacher' 123 | Context 4: [['Inescapable', 'directed_by', 'Ruba Nadda'], ['Inescapable', 'written_by', 'Ruba Nadda'], ['Inescapable', 'starred_actors', 'Marisa Tomei'], ['Inescapable', 'starred_actors', 'Joshua Jackson'], ['Inescapable', 'starred_actors', 'Alexander Siddig'], ] Question 4: who acted in the movie [Inescapable]? 124 | Answer 4: 'Marisa Tomei' 125 | Context 5: [['Things to Come', 'directed_by', 'William Cameron Menzies'], ] Question 5: can you name a film directed by [William Cameron Menzies]? 126 | Answer 5: 'Things to Come' 127 | Context 6: [['Witness for the Prosecution', 'has_tags', 'bd-r'], ['Witness for the Prosecution', 'has_genre', 'Drama'], ['Witness for the Prosecution', 'has_tags', 'courtroom'], ] Question 6: what sort of movie is [Witness for the Prosecution]? 128 | Answer 6: 'Drama' 129 | Context 7: [['The Mouse That Roared', 'has_genre', 'Comedy'], ['The Mouse That Roared', 'has_tags', 'satirical'], ['The Mouse That Roared', 'has_tags', 'peter sellers'], ] Question 7: what type of film is [The Mouse That Roared]? 130 | Answer 7: 'Comedy' 131 | Context 8: [['Blackboards', 'in_language', 'Kurdish'], ['Blackboards', 'has_genre', 'War'], ['Blackboards', 'has_tags', 'samira makhmalbaf'], ] Question 8: what is the primary language in the film [Blackboards]? 132 | Answer 8: 'Kurdish' 133 | Context 9: [['The Truth of Lie', 'written_by', 'Roland Reber'], ['The Truth of Lie', 'directed_by', 'Roland Reber'], ['The Truth of Lie, 'has_genre', 'Thriller'], ] Question 9: who is the creator of the film script for [The Truth of Lie]? 134 | Answer 9: 'Roland Reber' 135 | Context 10: [['The Return of Doctor X', 'written_by', 'William J. Makin'], ['The Return of Doctor X', 'release_year', '1939'], ['The Return of Doctor X', 'has_tags', 'humphrey bogart'], ] Question 10: what was the release year of the film [The Return of Doctor X]? 136 | Answer 10: '1939' 137 | Context 11: [['Topper', 'has_tags', 'ghosts'], ['Topper', 'has_tags', 'norman z. mcleod'], ['Topper', 'has_genre', 'Comedy'], ] Question 11: which topics is movie [Topper] about? 138 | Answer 11: 'ghosts' 139 | Context 12: [['The Mouse on the Moon', 'has_tags', 'bd-r'], ['The Mouse on the Moon', 'has_genre', 'Comedy'], ['The Mouse on the Moon', 'written_by', 'Leonard Wibberley'], ] Question 12: describe the movie [The Mouse on the Moon] in a few words? 140 | Answer 12: 'bd-r'\n""" 141 | 142 | 143 | 144 | prompt+='Now answer the following '+str(len(question_list))+' questions in the same way of these examples.\n' 145 | j=0 146 | for question in question_list: 147 | j+=1 148 | prompt+='Context '+str(j)+f': {context_texts[0]}'+f' Question '+str(j)+f': {question}'+'\n' 149 | prompt+='Answer '+str(j)+': ' 150 | result = llm(prompt) 151 | context_answer_list=len(question_list)*["No correct answer"] 152 | context_correct_list=len(question_list)*[False] 153 | answer_list=result.split('\n') 154 | answer_list = [item for item in answer_list if item != ""] 155 | if len(answer_list)==len(question_list): 156 | break 157 | if len(answer_list)!=len(question_list): 158 | return "answer length error" 159 | for j in range(len(question_list)): 160 | for lab in ground_truth_list[j]: 161 | if lab.lower() in answer_list[j].lower(): 162 | context_answer_list[j] = lab.lower() 163 | context_correct_list[j]=True 164 | break 165 | return context_correct_list 166 | 167 | 168 | def find_top_k_elements(lst, k): 169 | indexed_lst = list(enumerate(lst)) 170 | sorted_lst = sorted(indexed_lst, key=lambda x: x[1], reverse=True) 171 | top_k_elements = sorted_lst[:k] 172 | top_k_values = [value for index, value in top_k_elements] 173 | top_k_indices = [index for index, value in top_k_elements] 174 | return top_k_values, top_k_indices 175 | 176 | 177 | 178 | 179 | 180 | 181 | criterion=nn.CosineSimilarity() 182 | 183 | 184 | dataset_len=len(questions_dict) 185 | a=1 186 | k1=5 187 | #k2=4 188 | data_num=range(a,dataset_len+1,1) 189 | total_correct=0 190 | 191 | question_id_list=[] 192 | question_list=[] 193 | entity_set_list=[] 194 | ground_truth_list=[] 195 | contexts_list=[] 196 | 197 | for ii in tqdm(data_num): 198 | question = questions_dict[ii] 199 | entity_set = entity_set_dict[ii] 200 | ground_truth =label_set_dict[ii] 201 | 202 | subgraph=list(build_subgraph(entity_set,kg)) 203 | triplets=[] 204 | cossim=[] 205 | i=0 206 | pos_list=[] 207 | question_input = tokenizer([question], return_tensors="pt", padding=True, truncation=True).to(device) 208 | question_embedding = question_model(**question_input).last_hidden_state.mean(dim=1) 209 | for triplet in subgraph: 210 | if len(triplet)==3: 211 | if triplet[1][0]=='~': 212 | pos=triplet[2]+' '+triplet[1][1:]+' '+triplet[0]+'.' 213 | else: 214 | pos=triplet[0]+' '+triplet[1]+' '+triplet[2]+'.' 215 | elif len(triplet)==5: 216 | if triplet[1][0]=='~': 217 | pos=triplet[2]+' '+triplet[1][1:]+' '+triplet[0]+', ' 218 | else: 219 | pos=triplet[0]+' '+triplet[1]+' '+triplet[2]+', ' 220 | if triplet[3][0]=='~': 221 | pos+=triplet[4]+' '+triplet[3][1:]+' '+triplet[2]+'.' 222 | else: 223 | pos+=triplet[2]+' '+triplet[3]+' '+triplet[4]+'.' 224 | pos_list.append(pos) 225 | 226 | i+=1 227 | if i>=300: 228 | 229 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 230 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 231 | similarity_scores_pos=criterion(question_embedding, positive_embedding).tolist() 232 | 233 | cossim+=similarity_scores_pos 234 | 235 | pos_list=[] 236 | i=0 237 | if len(pos_list)>0: 238 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 239 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 240 | similarity_scores_pos=criterion(question_embedding, positive_embedding).tolist() 241 | 242 | cossim+=similarity_scores_pos 243 | indexed_lst = list(enumerate(cossim)) 244 | sorted_lst = sorted(indexed_lst, key=lambda x: x[1], reverse=True) 245 | already_got_list=[] 246 | values=[] 247 | 248 | context_texts="[" 249 | for index, value in sorted_lst: 250 | already_got_list.append((subgraph[index][1],subgraph[index][2])) 251 | triplets.append(subgraph[index]) 252 | values.append(value) 253 | if len(subgraph[index])==3: 254 | if subgraph[index][1][0]=='~': 255 | context_texts+='['+subgraph[index][2]+', '+subgraph[index][1][1:]+', '+subgraph[index][0]+'], ' 256 | else: 257 | context_texts+='['+subgraph[index][0]+', '+subgraph[index][1]+', '+subgraph[index][2]+'], ' 258 | elif len(subgraph[index])==5: 259 | if subgraph[index][1][0]=='~': 260 | context_texts+='['+subgraph[index][2]+', '+subgraph[index][1][1:]+', '+subgraph[index][0]+'], ' 261 | else: 262 | context_texts+='['+subgraph[index][0]+', '+subgraph[index][1]+', '+subgraph[index][2]+'], ' 263 | if subgraph[index][3][0]=='~': 264 | context_texts+='['+subgraph[index][4]+', '+subgraph[index][3][1:]+', '+subgraph[index][2]+'], ' 265 | else: 266 | context_texts+='['+subgraph[index][2]+', '+subgraph[index][3]+', '+subgraph[index][4]+'], ' 267 | 268 | if len(already_got_list)>=k1: 269 | break 270 | context_texts+="]" 271 | question_id_list.append(ii) 272 | question_list.append(question) 273 | entity_set_list.append(entity_set) 274 | ground_truth_list.append(ground_truth) 275 | contexts_list.append(context_texts) 276 | if len(question_id_list)>=1: 277 | cor_list=context_query(question_list,ground_truth_list,contexts_list) 278 | if cor_list=="answer length error": 279 | for i in range(len(question_list)): 280 | result_dict={} 281 | result_dict['question_id']=question_id_list[i] 282 | result_dict['question']=question_list[i] 283 | result_dict['entity_set']=entity_set_list[i] 284 | result_dict['ground_truth']=ground_truth_list[i] 285 | with open(error_file, 'a') as f: 286 | json.dump(result_dict, f, indent=3) 287 | f.write('\n') 288 | question_id_list=[] 289 | question_list=[] 290 | entity_set_list=[] 291 | ground_truth_list=[] 292 | contexts_list=[] 293 | else: 294 | count_true = cor_list.count(True) 295 | total_correct+=count_true 296 | with open(note, 'w') as file: 297 | file.write(str(total_correct)+'/'+str(ii)+' ') 298 | file.write(str(total_correct/ii)) 299 | file.write('\n') 300 | question_id_list=[] 301 | question_list=[] 302 | entity_set_list=[] 303 | ground_truth_list=[] 304 | contexts_list=[] 305 | 306 | if len(question_id_list)>=1: 307 | cor_list=context_query(question_list,ground_truth_list,contexts_list) 308 | if cor_list=="answer length error": 309 | for i in range(len(question_list)): 310 | result_dict={} 311 | result_dict['question_id']=question_id_list[i] 312 | result_dict['question']=question_list[i] 313 | result_dict['entity_set']=entity_set_list[i] 314 | result_dict['ground_truth']=ground_truth_list[i] 315 | with open(error_file, 'a') as f: 316 | json.dump(result_dict, f, indent=3) 317 | f.write('\n') 318 | question_id_list=[] 319 | question_list=[] 320 | entity_set_list=[] 321 | ground_truth_list=[] 322 | contexts_list=[] 323 | else: 324 | count_true = cor_list.count(True) 325 | total_correct+=count_true 326 | with open(note, 'w') as file: 327 | file.write(str(total_correct)+'/'+str(ii)+' ') 328 | file.write(str(total_correct/ii)) 329 | file.write('\n') 330 | question_id_list=[] 331 | question_list=[] 332 | entity_set_list=[] 333 | ground_truth_list=[] 334 | contexts_list=[] 335 | 336 | print('Acc: ',total_correct,'/',ii,'=',total_correct/ii) -------------------------------------------------------------------------------- /metaQA/test-2hop.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import random 4 | import openai 5 | import os 6 | from tqdm import tqdm 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader, Dataset 11 | from sentence_transformers import SentenceTransformer, losses 12 | from transformers import DistilBertModel, DistilBertTokenizer 13 | from tqdm import tqdm 14 | import json 15 | import jsonlines 16 | import torch 17 | import argparse 18 | parser = argparse.ArgumentParser(description="Parsing input arguments.") 19 | parser.add_argument('--question_model', type=str, required=True) 20 | args = parser.parse_args() 21 | question_model_path = args.question_model 22 | 23 | question_model = DistilBertModel.from_pretrained('distilbert-base-uncased') 24 | question_model.load_state_dict(torch.load(question_model_path)) 25 | tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | question_model.to(device) 29 | def open_file(filepath): 30 | with open(filepath, 'r', encoding='utf-8') as infile: 31 | return infile.read() 32 | 33 | openai.api_key = open_file('./openai_api_key.txt') 34 | with open('data/metaqa_kg.pickle', 'rb') as f: 35 | kg = pickle.load(f) 36 | hop=2 37 | 38 | questions_dict = {} 39 | entity_set_dict = {} 40 | label_set_dict = {} 41 | if hop==1: 42 | question_data='data/onehop_test_set.jsonl' 43 | note='note.txt' 44 | error_file='one-hop-errors.json' 45 | if hop==2: 46 | question_data='data/twohop_test_set.jsonl' 47 | note='note.txt' 48 | error_file='two-hop-errors.json' 49 | with open(question_data, 'r') as f: 50 | for line in f: 51 | if not line: 52 | continue 53 | dataset = json.loads(line) 54 | questions_dict[dataset["question_id"]] = dataset["question"] 55 | entity_set_dict[dataset["question_id"]] = dataset["entity_set"] 56 | label_set_dict[dataset["question_id"]] = dataset["Label"] 57 | 58 | 59 | model_name = 'gpt-3.5-turbo-0613' 60 | max_tokens = 400 61 | temperature = 0.2 62 | top_p = 0.1 63 | 64 | 65 | def llm(prompt,max_tokens=max_tokens): 66 | for _ in range(3): 67 | try: 68 | response = openai.ChatCompletion.create( 69 | model=model_name, 70 | messages=[ 71 | {"role": "system", "content": "You are a helpful assistant."}, 72 | { 73 | "role": "user", 74 | "content": prompt, 75 | }, 76 | ], 77 | max_tokens=max_tokens, 78 | temperature=0.2, 79 | top_p = 0.1, 80 | timeout=30 81 | ) 82 | generated_text = response["choices"][0]["message"]["content"] 83 | return generated_text 84 | except Exception as e: 85 | if _==2: 86 | print("[ERROR]", e) 87 | time.sleep(5) 88 | 89 | def build_subgraph(entity_set, knowledge_graph): 90 | subgraph = set() 91 | for entity in entity_set: 92 | if entity in knowledge_graph: 93 | for relation, object in knowledge_graph[entity].items(): 94 | for obj in knowledge_graph[entity][relation]: 95 | subgraph.add((str(entity), str(relation), str(obj))) 96 | for relation2, object2 in knowledge_graph[obj].items(): 97 | for obj2 in knowledge_graph[obj][relation2]: 98 | subgraph.add((str(entity), str(relation), str(obj), str(relation2), str(obj2))) 99 | 100 | return subgraph 101 | 102 | def open_file(filepath): 103 | with open(filepath, 'r', encoding='utf-8') as infile: 104 | return infile.read() 105 | 106 | def context_query(question_list,ground_truth_list,context_texts): 107 | #print(question_list) 108 | #print(ground_truth_list) 109 | #print(context_texts) 110 | 111 | #prompt = open_file('./meta_2hop_prompts/verify_claim_with_evidence.txt').replace('<<<>>>', question_list[0]).replace('<<<>>>', context_texts[0]) 112 | 113 | #print(prompt) 114 | for _ in range(3): 115 | prompt="""Answer the following questions.\n The context is the evidence of triplets may help your verifying.\n 116 | Each context contains triplets in the form of [head, relation, tail] and it means "head's relation is tail.". 117 | If you think a question can have multiple answers, you must choose one and answer it. Enter when you start answering the next question. Examples:\n 118 | """ 119 | prompt+=""" 120 | Context 1: [['The Grey Fox', 'directed_by', 'Phillip Borsos'], ['The Grey Fox', 'release_year', '1982'], ['One Magic Christmas', 'written_by', 'Phillip Borsos'], ['One Magic Christmas', 'release_year', '1985'], ] Question 1: when were the movies written by [Phillip Borsos] released? 121 | Answer 1: '1985' 122 | Context 2: [['Jesus Henry Christ', 'directed_by', 'Dennis Lee'], ['Jesus Henry Christ', 'has_genre', 'Comedy'], ['Jesus Henry Christ', 'written_by', 'Dennis Lee'], ['Fireflies in the Garden', 'written_by', 'Dennis Lee'], ] Question 2: which movies share the screenwriter with [Jesus Henry Christ]? 123 | Answer 2: 'Fireflies in the Garden' 124 | Context 3: [['The Counselor', 'directed_by', 'Ridley Scott'], ['Legend', 'directed_by', 'Ridley Scott'], ['Body of Lies', 'directed_by', 'Ridley Scott'], ['Blade Runner', 'directed_by', 'Ridley Scott'], ['Someone to Watch Over Me', 'directed_by', 'Ridley Scott'], ['Gladiator', 'directed_by', 'Ridley Scott'], ['Black Hawk Down', 'directed_by', 'Ridley Scott'], ['Black Rain', 'directed_by', 'Ridley Scott'], ['Robin Hood', 'directed_by', 'Ridley Scott'], ['Gladiator', 'directed_by', 'Rowdy Herrington'], ] Question 3: which directors co-directed movies with [Ridley Scott]? 125 | Answer 3: 'Rowdy Herrington' 126 | Context 4: [['First Monday in October', 'written_by', 'Robert E. Lee'], ['First Monday in October', 'written_by', 'Jerome Lawrence'], ['Inherit the Wind', 'written_by', 'Jerome Lawrence'], ] Question 4: the scriptwriter of [First Monday in October] also wrote movies? 127 | Answer 4: 'Inherit the Wind' 128 | Context 5: [['Tale of Tales', 'directed_by', 'Yuriy Norshteyn'], ['Tale of Tales', 'written_by', 'Yuriy Norshteyn'], ['Hedgehog in the Fog', 'written_by', 'Sergei Kozlov'], ['Hedgehog in the Fog', 'directed_by', 'Yuriy Norshteyn'], ] Question 5: which person wrote the films directed by [Yuriy Norshteyn]? 129 | Answer 5: 'Sergei Kozlov' 130 | Context 6: [['Ronal the Barbarian', 'written_by', 'Philip Einstein Lipski'], ['Ronal the Barbarian', 'directed_by', 'Kresten Vestbjerg Andersen'], ['Ronal the Barbarian', 'written_by', 'Kresten Vestbjerg Andersen'], ['Ronal the Barbarian', 'written_by', 'Thorbjørn Christoffersen'], ] Question 6: who are the writers of the movies directed by [Kresten Vestbjerg Andersen]? 131 | Answer 6: 'Philip Einstein Lipski' 132 | Context 7: [['Novocaine', 'has_genre', 'Comedy'], ['Novocaine', 'written_by', 'David Atkins'], ['Novocaine', 'directed_by', 'David Atkins'], ] Question 7: the movies directed by [David Atkins] were in which genres? 133 | Answer 7: 'Comedy' 134 | Context 8: [['Man of the House', 'release_year', '2005'], ['Man of the House', 'written_by', 'Scott Lobdell'], ['Man of the House', 'release_year', '1995'], ['Man of the House', 'has_genre', 'Comedy'], ] Question 8: the films written by [Scott Lobdell] were released in which years? 135 | Answer 8: '1995' 136 | Context 9: [['Terence Hill', 'starred_actors', 'They Call Me Trinity'], ['Terence Hill', 'starred_actors', 'They Call Me Renegade'], ['Terence Hill', 'starred_actors', 'Go for It'], ['They Call Me Renegade', 'in_language', 'Italian'], ] Question 9: what are the languages spoken in the films starred by [Terence Hill]? 137 | Answer 9: 'Italian' 138 | Context 10: [['Project X', 'directed_by', 'Jonathan Kaplan'], ['Project X', 'starred_actors', 'Jonathan Daniel Brown'], ['Project X', 'starred_actors', 'Oliver Cooper'], ['Project X', 'directed_by', 'Nima Nourizadeh'], ['Project X', 'written_by', 'Matt Drake'], ] Question 10: who is listed as director of [Oliver Cooper] acted films? 139 | Answer 10: 'Nima Nourizadeh' 140 | Context 11: [['The Dream Team', 'starred_actors', 'Michael Keaton'], ['The Dream Team', 'starred_actors', 'Peter Boyle'], ['The Dream Team', 'starred_actors', 'Christopher Lloyd'], ['The Dream Team', 'directed_by', 'Howard Zieff'], ['The Dream Team', 'written_by', 'David Loucka'], ['The Dream Team', 'starred_actors', 'Stephen Furst'], ] Question 11: who co-starred with [Stephen Furst]? 141 | Answer 11: 'Peter Boyle' 142 | Context 12: [['Casey Jones', 'written_by', 'Polaris Banks'], ['Casey Jones', 'directed_by', 'Polaris Banks'], ['Casey Jones', 'has_genre', 'Short'], ['Casey Jones', 'release_year', '2011'], ] Question 12: what types are the films written by [Polaris Banks]? 143 | Answer 12: 'Short'\n""" 144 | 145 | 146 | 147 | prompt+='Now answer the following '+str(len(question_list))+' questions in the same way of these examples.\n' 148 | j=0 149 | for question in question_list: 150 | j+=1 151 | prompt+='Context '+str(j)+f': {context_texts[0]}'+f' Question '+str(j)+f': {question}'+'\n' 152 | prompt+='Answer '+str(j)+': ' 153 | result = llm(prompt) 154 | context_answer_list=len(question_list)*["No correct answer"] 155 | context_correct_list=len(question_list)*[False] 156 | answer_list=result.split('\n') 157 | answer_list = [item for item in answer_list if item != ""] 158 | if len(answer_list)==len(question_list): 159 | break 160 | if len(answer_list)!=len(question_list): 161 | return "answer length error" 162 | for j in range(len(question_list)): 163 | for lab in ground_truth_list[j]: 164 | if lab.lower() in answer_list[j].lower(): 165 | context_answer_list[j] = lab.lower() 166 | context_correct_list[j]=True 167 | break 168 | return context_correct_list 169 | 170 | 171 | def find_top_k_elements(lst, k): 172 | indexed_lst = list(enumerate(lst)) 173 | sorted_lst = sorted(indexed_lst, key=lambda x: x[1], reverse=True) 174 | top_k_elements = sorted_lst[:k] 175 | top_k_values = [value for index, value in top_k_elements] 176 | top_k_indices = [index for index, value in top_k_elements] 177 | return top_k_values, top_k_indices 178 | 179 | 180 | 181 | 182 | 183 | 184 | criterion=nn.CosineSimilarity() 185 | 186 | 187 | dataset_len=len(questions_dict) 188 | a=1 189 | k1=4 190 | k2=4 191 | data_num=range(a,dataset_len+1,1) 192 | total_correct=0 193 | 194 | question_id_list=[] 195 | question_list=[] 196 | entity_set_list=[] 197 | ground_truth_list=[] 198 | contexts_list=[] 199 | 200 | for ii in tqdm(data_num): 201 | question = questions_dict[ii] 202 | entity_set = entity_set_dict[ii] 203 | ground_truth =label_set_dict[ii] 204 | 205 | subgraph=list(build_subgraph(entity_set,kg)) 206 | 207 | triplets=[] 208 | cossim=[] 209 | i=0 210 | pos_list=[] 211 | question_input = tokenizer([question], return_tensors="pt", padding=True, truncation=True).to(device) 212 | question_embedding = question_model(**question_input).last_hidden_state.mean(dim=1) 213 | for triplet in subgraph: 214 | if len(triplet)==3: 215 | if triplet[1][0]=='~': 216 | pos=triplet[2]+' '+triplet[1][1:]+' '+triplet[0]+'.' 217 | else: 218 | pos=triplet[0]+' '+triplet[1]+' '+triplet[2]+'.' 219 | elif len(triplet)==5: 220 | if triplet[1][0]=='~': 221 | pos=triplet[2]+' '+triplet[1][1:]+' '+triplet[0]+', ' 222 | else: 223 | pos=triplet[0]+' '+triplet[1]+' '+triplet[2]+', ' 224 | if triplet[3][0]=='~': 225 | pos+=triplet[4]+' '+triplet[3][1:]+' '+triplet[2]+'.' 226 | else: 227 | pos+=triplet[2]+' '+triplet[3]+' '+triplet[4]+'.' 228 | pos_list.append(pos) 229 | 230 | i+=1 231 | if i>=300: 232 | 233 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 234 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 235 | similarity_scores_pos=criterion(question_embedding, positive_embedding).tolist() 236 | 237 | cossim+=similarity_scores_pos 238 | 239 | pos_list=[] 240 | i=0 241 | if len(pos_list)>0: 242 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 243 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 244 | similarity_scores_pos=criterion(question_embedding, positive_embedding).tolist() 245 | 246 | cossim+=similarity_scores_pos 247 | indexed_lst = list(enumerate(cossim)) 248 | sorted_lst = sorted(indexed_lst, key=lambda x: x[1], reverse=True) 249 | already_got=[0]*int(k1-1) 250 | already_got_list=[] 251 | values=[] 252 | 253 | context_texts="[" 254 | for index, value in sorted_lst: 255 | if (subgraph[index][1],subgraph[index][2]) not in already_got_list: 256 | already_got_list.append((subgraph[index][1],subgraph[index][2])) 257 | triplets.append(subgraph[index]) 258 | values.append(value) 259 | if len(subgraph[index])==3: 260 | if subgraph[index][1][0]=='~': 261 | context_texts+='['+subgraph[index][2]+', '+subgraph[index][1][1:]+', '+subgraph[index][0]+'], ' 262 | else: 263 | context_texts+='['+subgraph[index][0]+', '+subgraph[index][1]+', '+subgraph[index][2]+'], ' 264 | elif len(subgraph[index])==5: 265 | if subgraph[index][1][0]=='~': 266 | context_texts+='['+subgraph[index][2]+', '+subgraph[index][1][1:]+', '+subgraph[index][0]+'], ' 267 | else: 268 | context_texts+='['+subgraph[index][0]+', '+subgraph[index][1]+', '+subgraph[index][2]+'], ' 269 | if subgraph[index][3][0]=='~': 270 | context_texts+='['+subgraph[index][4]+', '+subgraph[index][3][1:]+', '+subgraph[index][2]+'], ' 271 | else: 272 | context_texts+='['+subgraph[index][2]+', '+subgraph[index][3]+', '+subgraph[index][4]+'], ' 273 | else: 274 | for i in range(len(already_got_list)): 275 | if (len(subgraph[index])==5) and (already_got_list[i]==(subgraph[index][1],subgraph[index][2])) and (already_got[i]=k1: 284 | break 285 | context_texts+="]" 286 | question_id_list.append(ii) 287 | question_list.append(question) 288 | entity_set_list.append(entity_set) 289 | ground_truth_list.append(ground_truth) 290 | contexts_list.append(context_texts) 291 | if len(question_id_list)>=1: 292 | cor_list=context_query(question_list,ground_truth_list,contexts_list) 293 | if cor_list=="answer length error": 294 | for i in range(len(question_list)): 295 | result_dict={} 296 | result_dict['question_id']=question_id_list[i] 297 | result_dict['question']=question_list[i] 298 | result_dict['entity_set']=entity_set_list[i] 299 | result_dict['ground_truth']=ground_truth_list[i] 300 | with open(error_file, 'a') as f: 301 | json.dump(result_dict, f, indent=3) 302 | f.write('\n') 303 | question_id_list=[] 304 | question_list=[] 305 | entity_set_list=[] 306 | ground_truth_list=[] 307 | contexts_list=[] 308 | else: 309 | count_true = cor_list.count(True) 310 | total_correct+=count_true 311 | with open(note, 'w') as file: 312 | file.write(str(total_correct)+'/'+str(ii)+' ') 313 | file.write(str(total_correct/ii)) 314 | file.write('\n') 315 | question_id_list=[] 316 | question_list=[] 317 | entity_set_list=[] 318 | ground_truth_list=[] 319 | contexts_list=[] 320 | 321 | if len(question_id_list)>=1: 322 | cor_list=context_query(question_list,ground_truth_list,contexts_list) 323 | if cor_list=="answer length error": 324 | for i in range(len(question_list)): 325 | result_dict={} 326 | result_dict['question_id']=question_id_list[i] 327 | result_dict['question']=question_list[i] 328 | result_dict['entity_set']=entity_set_list[i] 329 | result_dict['ground_truth']=ground_truth_list[i] 330 | with open(error_file, 'a') as f: 331 | json.dump(result_dict, f, indent=3) 332 | f.write('\n') 333 | question_id_list=[] 334 | question_list=[] 335 | entity_set_list=[] 336 | ground_truth_list=[] 337 | contexts_list=[] 338 | else: 339 | count_true = cor_list.count(True) 340 | total_correct+=count_true 341 | with open(note, 'w') as file: 342 | file.write(str(total_correct)+'/'+str(ii)+' ') 343 | file.write(str(total_correct/ii)) 344 | file.write('\n') 345 | question_id_list=[] 346 | question_list=[] 347 | entity_set_list=[] 348 | ground_truth_list=[] 349 | contexts_list=[] 350 | 351 | print('Acc: ',total_correct,'/',ii,'=',total_correct/ii) -------------------------------------------------------------------------------- /factkg/test.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import random 4 | import openai 5 | import os 6 | from tqdm import tqdm 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader, Dataset 11 | from sentence_transformers import SentenceTransformer, losses 12 | from transformers import DistilBertModel, DistilBertTokenizer 13 | from tqdm import tqdm 14 | import json 15 | import jsonlines 16 | import torch 17 | import argparse 18 | parser = argparse.ArgumentParser(description="Parsing input arguments.") 19 | parser.add_argument('--question_model', type=str, required=True) 20 | parser.add_argument('--question_model_relation_only', type=str, required=True) 21 | args = parser.parse_args() 22 | question_model_path = args.question_model 23 | question_model_path_relation_only = args.question_model_relation_only 24 | 25 | tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | question_model = DistilBertModel.from_pretrained('distilbert-base-uncased') 28 | question_model.load_state_dict(torch.load(question_model_path)) 29 | question_model.to(device) 30 | question_model2 = DistilBertModel.from_pretrained('distilbert-base-uncased') 31 | question_model2.load_state_dict(torch.load(question_model_path_relation_only)) 32 | question_model2.to(device) 33 | def open_file(filepath): 34 | with open(filepath, 'r', encoding='utf-8') as infile: 35 | return infile.read() 36 | 37 | openai.api_key = open_file('./openai_api_key.txt') 38 | 39 | with open('dbpedia_2015_undirected_light.pickle', 'rb') as f: 40 | kg = pickle.load(f) 41 | 42 | 43 | questions_dict = {} 44 | entity_set_dict = {} 45 | label_set_dict = {} 46 | 47 | question_data='extracted_test_set.jsonl' 48 | note='note.txt' 49 | error_file='error.json' 50 | with open(question_data, 'r') as f: 51 | for line in f: 52 | if not line: 53 | continue 54 | dataset = json.loads(line) 55 | questions_dict[dataset["question_id"]] = dataset["question"] 56 | entity_set_dict[dataset["question_id"]] = dataset["entity_set"] 57 | label_set_dict[dataset["question_id"]] = dataset["Label"] 58 | 59 | 60 | 61 | model_name = 'gpt-3.5-turbo-0613' 62 | max_tokens = 400 63 | temperature = 0.2 64 | top_p = 0.1 65 | 66 | 67 | def llm(prompt,max_tokens=max_tokens): 68 | for _ in range(3): 69 | try: 70 | response = openai.ChatCompletion.create( 71 | model=model_name, 72 | messages=[ 73 | {"role": "system", "content": "You are a helpful assistant."}, 74 | { 75 | "role": "user", 76 | "content": prompt, 77 | }, 78 | ], 79 | max_tokens=max_tokens, 80 | temperature=0.2, 81 | top_p = 0.1, 82 | timeout=30 83 | ) 84 | generated_text = response["choices"][0]["message"]["content"] 85 | return generated_text 86 | except Exception as e: 87 | if _==2: 88 | print("[ERROR]", e) 89 | time.sleep(5) 90 | 91 | 92 | 93 | def open_file(filepath): 94 | with open(filepath, 'r', encoding='utf-8') as infile: 95 | return infile.read() 96 | 97 | def context_query(question_list,ground_truth_list,context_texts): 98 | #print(question_list) 99 | #print(ground_truth_list) 100 | #print(context_texts) 101 | 102 | 103 | 104 | prompt="""Verify the following claims.\n The context is the evidence of triplets may help your verifying.\n 105 | Each context contains triplets in the form of [head, relation, tail] and it means "head's relation is tail.". 106 | Choose one of {True, False}, and give me the one-sentence evidence. Examples:\n 107 | """ 108 | 109 | prompt+=""" 110 | Context 1: [['Ahamad_Kadhim', 'clubs', "Al-Zawra'a SC"], ] Claim 1: Ahmad Kadhim Assad's club is Al-Zawra'a SC. 111 | Answer 1: True, based on the evidence set, Ahmad Kadhim Assad's club is Al-Zawra'a SC. 112 | Context 2: [['Bananaman', 'firstAired', '"1983-10-03"'], ['Bananaman', 'starring', 'Tim_Brooke-Taylor'], ] Claim 2: Yeah! I know that a TV show, which starred Tim Brooke-Taylor, first aired on 3rd October 1983! 113 | Answer 2: True, the claim is supported by the evidence since Bananaman refers to the TV show. 114 | Context 3: [['Jamie_Lawrence', 'composer', 'Death_on_a_Factory_Farm'], ['Death_on_a_Factory_Farm', 'director', 'Sarah_Teale'], ] Claim 3: Really? Jamie Lawrence is the music composer of the 83 minute 'Death on a Factory Farm' film, directed by Sarah Teale! 115 | Answer 3: False, there is no evidence for the 83 minute length. 116 | Context 4: [[], ] Claim 4: Do you know Milan Hodža? he had a religion. 117 | Answer 4: False, there is no evidence that Milan had a religion. 118 | Context 5: [[], ] Claim 5: No, but the leader of the United States is not Olena Serdiuk. 119 | Answer 5: True, based on the evidence set, there is no information that the leader of the United States is Olena Serdiuk. 120 | Context 6: [['Brandon_Carter', 'almaMater', 'University_of_Cambridge'], ['Brandon_Carter', 'birthPlace', 'England'], ['University_of_Cambridge', 'viceChancellor', 'Leszek_Borysiewicz'], ] Claim 6: Brandon Carter was born in England and graduated from the University of Cambridge where the current Chancellor is Leszek Borysiewicz. 121 | Answer 6: True, everything of the claim is supported by the evidence set. 122 | Context 7: [['Unpublished_Story', 'director', 'Harold_French'], ['Unpublished_Story', 'cinematography', 'Bernard_Knowles'], ] Claim 7: 'A film' was produced by Anatole de Grunwald, directed by Harold French, with cinematography done by Bernard Knowles. 123 | Answer 7: False, there is no information about the producer of 'Unpublished_Story'. 124 | Context 8: [['200_Public_Square', 'location', 'Cleveland'], ['200_Public_Square', 'floorCount', '"45"'], ['Cleveland', 'country', 'United_States'], ] Claim 8: Yes, with a floor count of 45, 200 Public Square is located in Cleveland in the United States. 125 | Answer 8: True, everything of the claim is supported by the evidence set.\n""" 126 | #Context 9: [['Bananaman', 'starring', 'Bill_Oddie'], ['Bananaman', 'network', 'Broadcasting_House'], ['Bananaman', 'locationCity', 'Broadcasting_House'], ] Claim 9: Bananaman the TV series starred by a person was shown on the company and the company headquarters is called Broadcasting House. 127 | #Answer 9: True, everything of the claim is supported by the evidence set. 128 | #Context 10: [['Azerbaijan', 'leaderName', 'Artur_Rasizade'], ["Baku_Turkish_Martyrs'_Memorial", 'designer', '"Hüseyin Bütüner and Hilmi Güner"'], ["Baku_Turkish_Martyrs'_Memorial", 'location', 'Azerbaijan'], ] Claim 10: The place, designed by Huseyin Butuner and Hilmi Guner, is located in a country, where the leader is Artur Rasizade. 129 | #Answer 10: True, everything of the claim is supported by the evidence set. 130 | #Context 11: [['AIDAstella', 'shipBuilder', 'Meyer_Werft'], ['AIDAstella', 'shipOperator', 'AIDA_Cruises'], ] Claim 11: AIDA Cruise line operated the ship which was built by Meyer Werft in Townsend, Poulshot, Wiltshire. 131 | #Answer 11: False, there is no evidence for Townsend, Poulshot, Wiltshire. 132 | #Context 12: [[], ] Claim 12: An academic journal with code IJPHDE is also Acta Math. Hungar. 133 | #Answer 12: False, there is no evidence that the academic journal is also Acta Math. Hungar. 134 | 135 | 136 | prompt+='Now verify the following '+str(len(question_list))+' claims in the same way of these examples.\n' 137 | j=0 138 | for question in question_list: 139 | j+=1 140 | prompt+='Context '+str(j)+f': {context_texts[0]}'+f' Claim '+str(j)+f': {question}'+'\n' 141 | prompt+='Answer '+str(j)+': ' 142 | result = llm(prompt) 143 | context_correct_list=len(question_list)*[False] 144 | 145 | if 'false' in result.lower(): 146 | prompt="""Verify the claim. Is this claim True? or False? 147 | Choose one of {True, False}. If you are unsure, please choose the option you think is most likely.""" 148 | 149 | prompt+=""" 150 | Claim: Ahmad Kadhim Assad's club is Al-Zawra'a SC. 151 | Answer: True. 152 | Claim: Yeah! I know that a TV show, which starred Tim Brooke-Taylor, first aired on 3rd October 1983! 153 | Answer: True. 154 | Claim: Really? Jamie Lawrence is the music composer of the 83 minute 'Death on a Factory Farm' film, directed by Sarah Teale! 155 | Answer: False. 156 | Claim: Do you know Milan Hodža? he had a religion. 157 | Answer: False. 158 | Claim: No, but the leader of the United States is not Olena Serdiuk. 159 | Answer: True. 160 | Claim: Brandon Carter was born in England and graduated from the University of Cambridge where the current Chancellor is Leszek Borysiewicz. 161 | Answer: True. 162 | Claim: 'A film' was produced by Anatole de Grunwald, directed by Harold French, with cinematography done by Bernard Knowles. 163 | Answer: False. 164 | Claim: Yes, with a floor count of 45, 200 Public Square is located in Cleveland in the United States. 165 | Answer: True. 166 | Claim: Bananaman the TV series starred by a person was shown on the company and the company headquarters is called Broadcasting House. 167 | Answer: True. 168 | Claim: The place, designed by Huseyin Butuner and Hilmi Guner, is located in a country, where the leader is Artur Rasizade. 169 | Answer: True. 170 | Claim: AIDA Cruise line operated the ship which was built by Meyer Werft in Townsend, Poulshot, Wiltshire. 171 | Answer: False. 172 | Claim: An academic journal with code IJPHDE is also Acta Math. Hungar. 173 | Answer: False. 174 | Claim: """ 175 | prompt+=question_list[0]+'\n'+'Answer: ' 176 | result = llm(prompt) 177 | 178 | for j in range(len(question_list)): 179 | for lab in ground_truth_list[j]: 180 | if (lab and ('true' in result.lower())) or ((not lab) and ('false' in result.lower())): 181 | context_correct_list[j]=True 182 | break 183 | return context_correct_list 184 | 185 | 186 | def find_top_k_elements(lst, k): 187 | indexed_lst = list(enumerate(lst)) 188 | sorted_lst = sorted(indexed_lst, key=lambda x: x[1], reverse=True) 189 | top_k_elements = sorted_lst[:k] 190 | top_k_values = [value for index, value in top_k_elements] 191 | top_k_indices = [index for index, value in top_k_elements] 192 | return top_k_values, top_k_indices 193 | 194 | 195 | 196 | 197 | 198 | 199 | criterion=nn.CosineSimilarity() 200 | 201 | 202 | dataset_len=len(questions_dict) 203 | a=1 204 | k1=4 205 | k2=4 206 | data_num=range(a,dataset_len+1,1) 207 | total_correct=0 208 | 209 | question_id_list=[] 210 | question_list=[] 211 | entity_set_list=[] 212 | ground_truth_list=[] 213 | contexts_list=[] 214 | 215 | for ii in tqdm(data_num): 216 | question = questions_dict[ii] 217 | entity_set = entity_set_dict[ii] 218 | ground_truth =label_set_dict[ii] 219 | 220 | triplets=[] 221 | cossim=[] 222 | i=0 223 | pos_list=[] 224 | question_input = tokenizer([question], return_tensors="pt", padding=True, truncation=True).to(device) 225 | question_embedding = question_model(**question_input).last_hidden_state.mean(dim=1) 226 | question_embedding2 = question_model2(**question_input).last_hidden_state.mean(dim=1) 227 | 228 | subgraph = [] 229 | for entity in entity_set: 230 | if entity in kg: 231 | for relation, object in kg[entity].items(): 232 | if [str(relation)] not in subgraph: 233 | subgraph.append([str(relation)]) 234 | m=0 235 | for obj in kg[entity][relation]: 236 | for relation2, object2 in kg[obj].items(): 237 | if [str(relation),str(relation2)] not in subgraph: 238 | subgraph.append([str(relation),str(relation2)]) 239 | m+=1 240 | for triplet in subgraph: 241 | if len(triplet)==1: 242 | pos=triplet[0]+'.' 243 | elif len(triplet)==2: 244 | pos=triplet[0]+', '+triplet[1]+'.' 245 | pos_list.append(pos) 246 | 247 | i+=1 248 | if i>=300: 249 | 250 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 251 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 252 | similarity_scores_pos=criterion(question_embedding, positive_embedding).tolist() 253 | 254 | cossim+=similarity_scores_pos 255 | 256 | pos_list=[] 257 | i=0 258 | if len(pos_list)>0: 259 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 260 | positive_embedding = question_model(**positive_input).last_hidden_state.mean(dim=1) 261 | similarity_scores_pos=criterion(question_embedding, positive_embedding).tolist() 262 | 263 | cossim+=similarity_scores_pos 264 | pos_list=[] 265 | indexed_lst = list(enumerate(cossim)) 266 | context_texts="[" 267 | top_k_values, top_k_indices=find_top_k_elements(indexed_lst,k1) 268 | for i in top_k_indices: 269 | if len(subgraph[i])==1: 270 | for entity in entity_set: 271 | if entity in kg: 272 | if subgraph[i][0] in kg[entity]: 273 | relation=subgraph[i][0] 274 | cossim=[] 275 | objlist=list(kg[entity][relation]) 276 | for obj in objlist: 277 | pos=str(entity)+' '+str(relation)+' '+str(obj)+'.' 278 | pos_list.append(pos) 279 | if len(pos_list)>=300: 280 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 281 | positive_embedding = question_model2(**positive_input).last_hidden_state.mean(dim=1) 282 | similarity_scores_pos=criterion(question_embedding2, positive_embedding).tolist() 283 | cossim+=similarity_scores_pos 284 | pos_list=[] 285 | if len(pos_list)>0: 286 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 287 | positive_embedding = question_model2(**positive_input).last_hidden_state.mean(dim=1) 288 | similarity_scores_pos=criterion(question_embedding2, positive_embedding).tolist() 289 | cossim+=similarity_scores_pos 290 | pos_list=[] 291 | indexed_lst2 = list(enumerate(cossim)) 292 | if len(indexed_lst2)<=k2: 293 | top_k_indices2=list(range(len(indexed_lst2))) 294 | else: 295 | top_k_values2, top_k_indices2=find_top_k_elements(indexed_lst2, k2) 296 | for j in top_k_indices2: 297 | obj=objlist[j] 298 | if str(relation)[0]!='~': 299 | context_texts+='['+str(entity)+', '+str(relation)+', '+str(obj)+'], ' 300 | else: 301 | context_texts+='['+str(obj)+', '+str(relation)[1:]+', '+str(entity)+'], ' 302 | elif len(subgraph[i])==2: 303 | save_triplets=[] 304 | for entity in entity_set: 305 | if entity in kg: 306 | if (subgraph[i][0] in kg[entity]): 307 | relation=subgraph[i][0] 308 | relation2=subgraph[i][1] 309 | cossim=[] 310 | objlist1=list(kg[entity][relation]) 311 | for obj in objlist1: 312 | if relation2 in kg[obj]: 313 | for obj2 in kg[obj][relation2]: 314 | pos=str(entity)+' '+str(relation)+' '+str(obj)+', '+str(obj)+' '+str(relation2)+' '+str(obj2)+'.' 315 | save_triplets+=[(str(entity),str(relation),str(obj),str(relation2),str(obj2))] 316 | pos_list.append(pos) 317 | if len(pos_list)>=300: 318 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 319 | positive_embedding = question_model2(**positive_input).last_hidden_state.mean(dim=1) 320 | similarity_scores_pos=criterion(question_embedding2, positive_embedding).tolist() 321 | cossim+=similarity_scores_pos 322 | pos_list=[] 323 | if len(pos_list)>0: 324 | positive_input = tokenizer(pos_list, return_tensors="pt", padding=True, truncation=True).to(device) 325 | positive_embedding = question_model2(**positive_input).last_hidden_state.mean(dim=1) 326 | similarity_scores_pos=criterion(question_embedding2, positive_embedding).tolist() 327 | cossim+=similarity_scores_pos 328 | pos_list=[] 329 | indexed_lst2 = list(enumerate(cossim)) 330 | if len(indexed_lst2)<=k2: 331 | top_k_indices2=list(range(len(indexed_lst2))) 332 | else: 333 | top_k_values2, top_k_indices2=find_top_k_elements(indexed_lst2, k2) 334 | for j in top_k_indices2: 335 | triplet=save_triplets[j] 336 | if str(triplet[1])[0]!='~': 337 | context_texts+='['+str(triplet[0])+', '+str(triplet[1])+', '+str(triplet[2])+'], ' 338 | else: 339 | context_texts+='['+str(triplet[2])+', '+str(triplet[1])[1:]+', '+str(triplet[0])+'], ' 340 | if str(triplet[3])[0]!='~': 341 | context_texts+='['+str(triplet[2])+', '+str(triplet[3])+', '+str(triplet[4])+'], ' 342 | else: 343 | context_texts+='['+str(triplet[4])+', '+str(triplet[3])[1:]+', '+str(triplet[2])+'], ' 344 | 345 | 346 | 347 | context_texts+="]" 348 | question_id_list.append(ii) 349 | question_list.append(question) 350 | entity_set_list.append(entity_set) 351 | ground_truth_list.append(ground_truth) 352 | contexts_list.append(context_texts) 353 | if len(question_id_list)>=1: 354 | cor_list=context_query(question_list,ground_truth_list,contexts_list) 355 | if cor_list=="answer length error": 356 | for i in range(len(question_list)): 357 | result_dict={} 358 | result_dict['question_id']=question_id_list[i] 359 | result_dict['question']=question_list[i] 360 | result_dict['entity_set']=entity_set_list[i] 361 | result_dict['ground_truth']=ground_truth_list[i] 362 | with open(error_file, 'a') as f: 363 | json.dump(result_dict, f, indent=3) 364 | f.write('\n') 365 | question_id_list=[] 366 | question_list=[] 367 | entity_set_list=[] 368 | ground_truth_list=[] 369 | contexts_list=[] 370 | else: 371 | count_true = cor_list.count(True) 372 | total_correct+=count_true 373 | with open(note, 'w') as file: 374 | file.write(str(total_correct)+'/'+str(ii)+' ') 375 | file.write(str(total_correct/ii)) 376 | file.write('\n') 377 | question_id_list=[] 378 | question_list=[] 379 | entity_set_list=[] 380 | ground_truth_list=[] 381 | contexts_list=[] 382 | 383 | if len(question_id_list)>=1: 384 | cor_list=context_query(question_list,ground_truth_list,contexts_list) 385 | if cor_list=="answer length error": 386 | for i in range(len(question_list)): 387 | result_dict={} 388 | result_dict['question_id']=question_id_list[i] 389 | result_dict['question']=question_list[i] 390 | result_dict['entity_set']=entity_set_list[i] 391 | result_dict['ground_truth']=ground_truth_list[i] 392 | with open(error_file, 'a') as f: 393 | json.dump(result_dict, f, indent=3) 394 | f.write('\n') 395 | question_id_list=[] 396 | question_list=[] 397 | entity_set_list=[] 398 | ground_truth_list=[] 399 | contexts_list=[] 400 | else: 401 | count_true = cor_list.count(True) 402 | total_correct+=count_true 403 | with open(note, 'w') as file: 404 | file.write(str(total_correct)+'/'+str(ii)+' ') 405 | file.write(str(total_correct/ii)) 406 | file.write('\n') 407 | question_id_list=[] 408 | question_list=[] 409 | entity_set_list=[] 410 | ground_truth_list=[] 411 | contexts_list=[] 412 | 413 | print('Acc: ',total_correct,'/',ii,'=',total_correct/ii) -------------------------------------------------------------------------------- /factkg/make_training_set.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import random 4 | import openai 5 | import os 6 | from tqdm import tqdm 7 | import time 8 | import argparse 9 | parser = argparse.ArgumentParser(description="Parsing input arguments.") 10 | parser.add_argument('--setting', type=str, required=True) 11 | args = parser.parse_args() 12 | setting = args.setting 13 | 14 | if setting=='train': 15 | extracted='extracted_train_set.jsonl' 16 | output_file='output.jsonl' 17 | elif setting=='dev': 18 | extracted='extracted_dev_set.jsonl' 19 | output_file='output_dev.jsonl' 20 | 21 | def open_file(filepath): 22 | with open(filepath, 'r', encoding='utf-8') as infile: 23 | return infile.read() 24 | 25 | openai.api_key = open_file('./openai_api_key.txt') 26 | 27 | with open('dbpedia_2015_undirected_light.pickle', 'rb') as f: 28 | kg = pickle.load(f) 29 | 30 | questions_dict = {} 31 | entity_set_dict = {} 32 | label_set_dict = {} 33 | with open(extracted, 'r') as f: 34 | for line in f: 35 | if not line: 36 | continue 37 | dataset = json.loads(line) 38 | questions_dict[dataset["question_id"]] = dataset["question"] 39 | entity_set_dict[dataset["question_id"]] = dataset["entity_set"] 40 | label_set_dict[dataset["question_id"]] = dataset["Label"] 41 | 42 | model_name = 'gpt-3.5-turbo-0613' 43 | max_tokens = 400 44 | temperature = 0.2 45 | top_p = 0.1 46 | 47 | 48 | def llm(prompt,max_tokens=max_tokens): 49 | for _ in range(3): 50 | try: 51 | response = openai.ChatCompletion.create( 52 | model=model_name, 53 | messages=[ 54 | {"role": "system", "content": "You are a helpful assistant."}, 55 | { 56 | "role": "user", 57 | "content": prompt, 58 | }, 59 | ], 60 | max_tokens=max_tokens, 61 | temperature=0.2, 62 | top_p = 0.1, 63 | timeout=30 64 | ) 65 | generated_text = response["choices"][0]["message"]["content"] 66 | return generated_text 67 | except Exception as e: 68 | if _==2: 69 | print("[ERROR]", e) 70 | time.sleep(5) 71 | 72 | 73 | def build_subgraph_rels_ranking(question, entity_set, knowledge_graph): 74 | one_hop_neighbors = get_one_hop_neighbors_rels_ranking(question, entity_set, knowledge_graph) 75 | if one_hop_neighbors=="No": 76 | return "No" 77 | two_hop_neighbors = get_two_hop_neighbors_rels_ranking(question, one_hop_neighbors, knowledge_graph) 78 | if two_hop_neighbors=="No": 79 | return "No" 80 | 81 | return one_hop_neighbors+two_hop_neighbors 82 | 83 | def get_one_hop_neighbors_rels_ranking(question, entity_set, knowledge_graph): 84 | neighbors_rel = set() 85 | for entity in entity_set: 86 | if entity in knowledge_graph: 87 | for relation, object in knowledge_graph[entity].items(): 88 | if (str(entity), str(relation)) not in neighbors_rel and (',' not in str(entity) and (',' not in str(relation))): 89 | neighbors_rel.add((str(entity), str(relation))) 90 | if len(neighbors_rel)>40: 91 | neighbors_rel=random.sample(neighbors_rel, 40) 92 | top_neighbors_rel=get_top_related_triplets_rels_ranking(question,neighbors_rel,1) 93 | if top_neighbors_rel=="No": 94 | return "No" 95 | if not top_neighbors_rel: 96 | return "No" 97 | neighbors_ent=[set() for _ in range(len(top_neighbors_rel))] 98 | num_ent=0 99 | for i in range(len(top_neighbors_rel)): 100 | if len(top_neighbors_rel[i])!=2: 101 | continue 102 | if top_neighbors_rel[i][0] in knowledge_graph and top_neighbors_rel[i][1] in knowledge_graph[top_neighbors_rel[i][0]]: 103 | for obj in knowledge_graph[top_neighbors_rel[i][0]][top_neighbors_rel[i][1]]: 104 | if (str(top_neighbors_rel[i][0]), str(top_neighbors_rel[i][1]), str(obj)) not in neighbors_ent[i] and (',' not in str(entity) and (',' not in str(relation))) and (',' not in str(obj)): 105 | neighbors_ent[i].add((str(top_neighbors_rel[i][0]), str(top_neighbors_rel[i][1]), str(obj))) 106 | num_ent+=1 107 | for i in range(len(top_neighbors_rel)): 108 | if len(neighbors_ent[i])>20: 109 | neighbors_ent[i]=random.sample(neighbors_ent[i],20) 110 | top_neighbors_ent=get_top_related_triplets_ents_ranking(question,neighbors_ent,1,num_ent) 111 | if top_neighbors_ent=="No": 112 | return "No" 113 | if not top_neighbors_ent: 114 | return "No" 115 | return top_neighbors_ent 116 | 117 | def get_two_hop_neighbors_rels_ranking(question, top_1hop_triplets, knowledge_graph): 118 | neighbors_rel = set() 119 | for triplet in top_1hop_triplets: 120 | try: 121 | h, r, t = triplet 122 | except Exception as e: 123 | return "No" 124 | if t in knowledge_graph: 125 | for relation, object in knowledge_graph[t].items(): 126 | if relation!='~'+r and r!='~'+relation: 127 | if (str(t), str(relation)) not in neighbors_rel and (',' not in str(t)) and (',' not in str(relation)): 128 | neighbors_rel.add((str(t), str(relation))) 129 | if len(neighbors_rel)>40: 130 | neighbors_rel=random.sample(neighbors_rel, 40) 131 | top_neighbors_rel=get_top_related_triplets_rels_ranking(question,neighbors_rel,2) 132 | if top_neighbors_rel=="No": 133 | return "No" 134 | neighbors_ent=[set() for _ in range(len(top_neighbors_rel))] 135 | num_ent=0 136 | 137 | for i in range(len(top_neighbors_rel)): 138 | #print(i) 139 | if len(top_neighbors_rel[i])!=2: 140 | continue 141 | if top_neighbors_rel[i][0] in knowledge_graph and top_neighbors_rel[i][1] in knowledge_graph[top_neighbors_rel[i][0]]: 142 | for obj in knowledge_graph[top_neighbors_rel[i][0]][top_neighbors_rel[i][1]]: 143 | if (str(top_neighbors_rel[i][0]), str(top_neighbors_rel[i][1]), str(obj)) not in neighbors_ent[i] and (',' not in str(top_neighbors_rel[i][0])) and (',' not in str(top_neighbors_rel[i][1])) and (',' not in str(obj)): 144 | neighbors_ent[i].add((str(top_neighbors_rel[i][0]), str(top_neighbors_rel[i][1]), str(obj))) 145 | num_ent+=1 146 | for i in range(len(top_neighbors_rel)): 147 | if len(neighbors_ent[i])>20: 148 | neighbors_ent[i]=random.sample(neighbors_ent[i],20) 149 | top_neighbors_ent=get_top_related_triplets_ents_ranking(question,neighbors_ent,1,num_ent) 150 | if top_neighbors_ent=="No": 151 | return "No" 152 | for i in range(len(top_neighbors_ent)): 153 | for triplet in top_1hop_triplets: 154 | h, r, t = triplet 155 | if len(top_neighbors_ent[i])==3 and t==top_neighbors_ent[i][0]: 156 | top_neighbors_ent[i]=(str(h),str(r),str(top_neighbors_ent[i][0]),str(top_neighbors_ent[i][1]),str(top_neighbors_ent[i][2])) 157 | break 158 | if top_1hop_triplets: 159 | h0,r0,h_n=random.choice(list(top_1hop_triplets)) 160 | if h_n in knowledge_graph: 161 | r_n=random.choice(list(knowledge_graph[h_n])) 162 | t_n=random.choice(list(knowledge_graph[h_n][r_n])) 163 | top_neighbors_ent.append((str(h0),str(r0),str(h_n), str(r_n), str(t_n))) 164 | if h0 in knowledge_graph: 165 | r_n1=random.choice(list(knowledge_graph[h0])) 166 | t_n1=random.choice(list(knowledge_graph[h0][r_n1])) 167 | top_neighbors_ent.append((str(h0),str(r_n1),str(t_n1))) 168 | return top_neighbors_ent 169 | 170 | def get_top_related_triplets_rels_ranking(question, triplets, hop): 171 | if len(triplets)<=5: 172 | ranked_triplets=[triplet for triplet in triplets] 173 | return ranked_triplets 174 | 175 | prompt = f"Each of these word sets shows an entity and one of its corresponding relation. Select the 5-top word sets which are most semantically related to the given sentence. You should list the selected word sets from rank 1 to rank 5. Your answer should be in the form of '(XXX,XXX);(XXX,XXX);(XXX,XXX);(XXX,XXX);(XXX,XXX)'. Sentence: {question}\nWord sets: " 176 | triplets_ran=triplets 177 | for _ in range(3): 178 | try: 179 | triplets_ran=random.sample(triplets_ran,len(triplets)-5) 180 | prompt1=prompt 181 | for triplet in triplets_ran: 182 | prompt1 += f"{triplet};" 183 | response = llm(prompt1,len(triplets_ran)*50) 184 | except Exception as e: 185 | continue 186 | if response is None: 187 | return "No" 188 | try: 189 | ranked_triplets = [tuple(word.strip('\'" ') for word in triplet.split('(')[1].split(')')[0].split(',')) for triplet in response.split(';') if triplet] 190 | except Exception as e: 191 | return "No" 192 | return ranked_triplets 193 | 194 | def get_top_related_triplets_ents_ranking(question, neighbors_ent, hop,num_ent): 195 | ranked_triplets=[] 196 | if num_ent<=5: 197 | for triplets in neighbors_ent: 198 | ranked_triplets+=[triplet for triplet in triplets] 199 | return ranked_triplets 200 | num_sel=0 201 | i=0 202 | while(num_sel<5 and i1: 208 | ranked_triplets+=[triplet for triplet in neighbors_ent[i]] 209 | num_sel+=2 210 | i+=1 211 | else: 212 | if 5-num_sel>1: 213 | prompt = f"These word sets shows the relations of some entities. Select the 2-top word sets which are most semantically related to the given sentence. You should list the selected word sets from rank 1 to rank 2. Your answer should be in the form of '(XXX,XXX,XXX);(XXX,XXX,XXX)'. Sentence: {question}\nWord sets: " 214 | else: 215 | prompt = f"These word sets shows the relations of some entities. Select the best word sets which are most semantically related to the given sentence. Your answer should be in the form of '(XXX,XXX,XXX)'. Sentence: {question}\nWord sets: " 216 | 217 | for triplet in neighbors_ent[i]: 218 | prompt += f"{triplet};" 219 | 220 | response = llm(prompt,30+len(neighbors_ent[i])*50) 221 | if response is None: 222 | return "No" 223 | try: 224 | add_info=[tuple(word.strip('\'" ') for word in triplet.split('(')[1].split(')')[0].split(',')) for triplet in response.split(';') if triplet] 225 | except Exception as e: 226 | return "No" 227 | ranked_triplets += add_info 228 | num_sel+=len(add_info) 229 | i+=1 230 | return ranked_triplets 231 | 232 | 233 | 234 | def original_query(question_list,ground_truth_list): 235 | for _ in range(3): 236 | prompt="""Verify these claims. Is this claim True? or False?\n 237 | Choose one of {True, False}. Enter when you start answering the next question. Example:\n 238 | Claim: """ 239 | prompt+=""" 240 | Claim 1: Ahmad Kadhim Assad's club is Al-Zawra'a SC. 241 | Claim 2: Yeah! I know that a TV show, which starred Tim Brooke-Taylor, first aired on 3rd October 1983! 242 | Claim 3: Really? Jamie Lawrence is the music composer of the 83 minute 'Death on a Factory Farm' film, directed by Sarah Teale! 243 | Claim 4: 'A film' was produced by Anatole de Grunwald, directed by Harold French, with cinematography done by Bernard Knowles. 244 | Answer 1: True 245 | Answer 2: True 246 | Answer 3: False 247 | Answer 4: False 248 | Now answer the following questions. Do not repeat the question again and just response like the example. 249 | 250 | """ 251 | 252 | j=0 253 | for question in question_list: 254 | j+=1 255 | prompt+=f'Claim '+str(j)+f': {question}\n' 256 | 257 | result_original = llm(prompt) 258 | original_answer_list=len(question_list)*["No correct answer"] 259 | original_correct_list=len(question_list)*[False] 260 | answer_list=result_original.split('\n') 261 | if len(answer_list)==len(question_list): 262 | break 263 | if len(answer_list)!=len(question_list): 264 | return "answer length error","answer length error" 265 | for i in range(len(question_list)): 266 | for lab in ground_truth_list[i]: 267 | if (lab and ('true' in answer_list[i].lower())) or ((not lab) and ('false' in answer_list[i].lower())): 268 | original_answer_list[i] = lab 269 | original_correct_list[i]=True 270 | break 271 | q_choose_list = [i for i, value in enumerate(original_correct_list) if value==False] 272 | 273 | return q_choose_list, original_correct_list 274 | 275 | def context_query(question_list,ground_truth_list,subgraph_list,i,q_choose_list, already_pos_list, already_neg_list): 276 | question_list2=[question_list[j] for j in q_choose_list] 277 | ground_truth_list2=[ground_truth_list[j] for j in q_choose_list] 278 | subgraph_list2=[subgraph_list[j] for j in q_choose_list] 279 | context_texts=[] 280 | for subgraph in subgraph_list2: 281 | triplet=subgraph[i] 282 | context='[' 283 | if len(triplet)==3: 284 | if str(triplet[1])[0]!='~': 285 | context+='['+str(triplet[0])+', '+str(triplet[1])+', '+str(triplet[2])+'], ' 286 | else: 287 | context+='['+str(triplet[2])+', '+str(triplet[1])[1:]+', '+str(triplet[0])+'], ' 288 | elif len(triplet)==5: 289 | if str(triplet[1])[0]!='~': 290 | context+='['+str(triplet[0])+', '+str(triplet[1])+', '+str(triplet[2])+'], ' 291 | else: 292 | context+='['+str(triplet[2])+', '+str(triplet[1])[1:]+', '+str(triplet[0])+'], ' 293 | if str(triplet[3])[0]!='~': 294 | context+='['+str(triplet[2])+', '+str(triplet[3])+', '+str(triplet[4])+'], ' 295 | else: 296 | context+='['+str(triplet[4])+', '+str(triplet[3])[1:]+', '+str(triplet[2])+'], ' 297 | context+=']' 298 | context_texts.append(context) 299 | for _ in range(3): 300 | 301 | prompt="""Verify the following claims.\n The context is the evidence of triplets may help your verifying.\n 302 | Each context contains triplets in the form of [head, relation, tail] and it means "head's relation is tail.". 303 | Choose one of {True, False}, and give me the one-sentence evidence. Examples:\n 304 | """ 305 | 306 | prompt+=""" 307 | Context 1: [['Ahamad_Kadhim', 'clubs', "Al-Zawra'a SC"], ] Claim 1: Ahmad Kadhim Assad's club is Al-Zawra'a SC. 308 | Answer 1: True, based on the evidence set, Ahmad Kadhim Assad's club is Al-Zawra'a SC. 309 | Context 2: [['Bananaman', 'firstAired', '"1983-10-03"'], ['Bananaman', 'starring', 'Tim_Brooke-Taylor'], ] Claim 2: Yeah! I know that a TV show, which starred Tim Brooke-Taylor, first aired on 3rd October 1983! 310 | Answer 2: True, the claim is supported by the evidence since Bananaman refers to the TV show. 311 | Context 3: [['Jamie_Lawrence', 'composer', 'Death_on_a_Factory_Farm'], ['Death_on_a_Factory_Farm', 'director', 'Sarah_Teale'], ] Claim 3: Really? Jamie Lawrence is the music composer of the 83 minute 'Death on a Factory Farm' film, directed by Sarah Teale! 312 | Answer 3: False, there is no evidence for the 83 minute length. 313 | Context 4: [[], ] Claim 4: Do you know Milan Hodža? he had a religion. 314 | Answer 4: False, there is no evidence that Milan had a religion. 315 | Context 5: [[], ] Claim 5: No, but the leader of the United States is not Olena Serdiuk. 316 | Answer 5: True, based on the evidence set, there is no information that the leader of the United States is Olena Serdiuk. 317 | Context 6: [['Brandon_Carter', 'almaMater', 'University_of_Cambridge'], ['Brandon_Carter', 'birthPlace', 'England'], ['University_of_Cambridge', 'viceChancellor', 'Leszek_Borysiewicz'], ] Claim 6: Brandon Carter was born in England and graduated from the University of Cambridge where the current Chancellor is Leszek Borysiewicz. 318 | Answer 6: True, everything of the claim is supported by the evidence set. 319 | Context 7: [['Unpublished_Story', 'director', 'Harold_French'], ['Unpublished_Story', 'cinematography', 'Bernard_Knowles'], ] Claim 7: 'A film' was produced by Anatole de Grunwald, directed by Harold French, with cinematography done by Bernard Knowles. 320 | Answer 7: False, there is no information about the producer of 'Unpublished_Story'. 321 | Context 8: [['200_Public_Square', 'location', 'Cleveland'], ['200_Public_Square', 'floorCount', '"45"'], ['Cleveland', 'country', 'United_States'], ] Claim 8: Yes, with a floor count of 45, 200 Public Square is located in Cleveland in the United States. 322 | Answer 8: True, everything of the claim is supported by the evidence set. 323 | Context 9: [['Bananaman', 'starring', 'Bill_Oddie'], ['Bananaman', 'network', 'Broadcasting_House'], ['Bananaman', 'locationCity', 'Broadcasting_House'], ] Claim 9: Bananaman the TV series starred by a person was shown on the company and the company headquarters is called Broadcasting House. 324 | Answer 9: True, everything of the claim is supported by the evidence set. 325 | Context 10: [['Azerbaijan', 'leaderName', 'Artur_Rasizade'], ["Baku_Turkish_Martyrs'_Memorial", 'designer', '"Hüseyin Bütüner and Hilmi Güner"'], ["Baku_Turkish_Martyrs'_Memorial", 'location', 'Azerbaijan'], ] Claim 10: The place, designed by Huseyin Butuner and Hilmi Guner, is located in a country, where the leader is Artur Rasizade. 326 | Answer 10: True, everything of the claim is supported by the evidence set. 327 | Context 11: [['AIDAstella', 'shipBuilder', 'Meyer_Werft'], ['AIDAstella', 'shipOperator', 'AIDA_Cruises'], ] Claim 11: AIDA Cruise line operated the ship which was built by Meyer Werft in Townsend, Poulshot, Wiltshire. 328 | Answer 11: False, there is no evidence for Townsend, Poulshot, Wiltshire. 329 | Context 12: [[], ] Claim 12: An academic journal with code IJPHDE is also Acta Math. Hungar. 330 | Answer 12: False, there is no evidence that the academic journal is also Acta Math. Hungar.\n""" 331 | 332 | 333 | prompt+='Now verify the following '+str(len(question_list2))+' claims in the same way of these examples.\n' 334 | j=0 335 | for question in question_list2: 336 | j+=1 337 | prompt+='Context '+str(j)+f': {context_texts[0]}'+f' Claim '+str(j)+f': {question}'+'\n' 338 | prompt+='Answer '+str(j)+': ' 339 | 340 | 341 | result = llm(prompt) 342 | context_answer_list=len(question_list2)*["No correct answer"] 343 | context_correct_list=len(question_list2)*[False] 344 | answer_list=result.split('\n') 345 | if len(answer_list)==len(question_list2): 346 | break 347 | if len(answer_list)!=len(question_list2): 348 | return "answer length error","answer length error", already_pos_list, already_neg_list, pos_sam_list, neg_sam_list 349 | for j in range(len(question_list2)): 350 | for lab in ground_truth_list2[j]: 351 | if (lab and ('true' in answer_list[j].lower())) or ((not lab) and ('false' in answer_list[j].lower())): 352 | context_answer_list[j] = lab 353 | context_correct_list[j]=True 354 | break 355 | for j in range(len(question_list2)): 356 | if already_pos_list[q_choose_list[j]]==False and context_correct_list[j]==True: 357 | already_pos_list[q_choose_list[j]]=True 358 | pos_sam_list[q_choose_list[j]]=[subgraph_list[q_choose_list[j]][i], context_texts[j]] 359 | for j in range(len(question_list2)): 360 | if already_neg_list[q_choose_list[j]]==False and context_correct_list[j]==False: 361 | already_neg_list[q_choose_list[j]]=True 362 | neg_sam_list[q_choose_list[j]]=[subgraph_list[q_choose_list[j]][i], context_texts[j]] 363 | q_choose_list = [v for j, v in enumerate(q_choose_list) if (already_pos_list[v] and already_neg_list[v])==False and i+1=1: 401 | q_choose_list, original_correct_list=original_query(question_list,ground_truth_list) 402 | if q_choose_list=="answer length error": 403 | question_id_list=[] 404 | question_list=[] 405 | entity_set_list=[] 406 | ground_truth_list=[] 407 | subgraph_list=[] 408 | num_triplets_to_test_list=[] 409 | questions_num_in_list=0 410 | continue 411 | 412 | already_pos_list,already_neg_list=[False]*len(question_list), [False]*len(question_list) 413 | pos_sam_list,neg_sam_list=[False]*len(question_list), [False]*len(question_list) 414 | for i in range(max(num_triplets_to_test_list)): 415 | if len(q_choose_list)>0: 416 | q_choose_list, context_correct_list, already_pos_list, already_neg_list, pos_sam_list, neg_sam_list=context_query(question_list,ground_truth_list,subgraph_list,i,q_choose_list,already_pos_list,already_neg_list) 417 | if q_choose_list=="answer length error": 418 | break 419 | with open(output_file, 'a') as f: 420 | for i in range(len(question_list)): 421 | if already_pos_list[i] and already_neg_list[i]: 422 | result_dict = { 423 | 'question_id': question_id_list[i], 424 | 'question': question_list[i], 425 | 'entity_set': entity_set_list[i], 426 | 'ground_truth': ground_truth_list[i], 427 | 'pos_triplet': pos_sam_list[i][0], 428 | 'pos_context': pos_sam_list[i][1], 429 | 'neg_triplet': neg_sam_list[i][0], 430 | 'neg_context': neg_sam_list[i][1] 431 | } 432 | json_line = json.dumps(result_dict) 433 | f.write(json_line + '\n') 434 | 435 | question_id_list=[] 436 | question_list=[] 437 | entity_set_list=[] 438 | ground_truth_list=[] 439 | subgraph_list=[] 440 | num_triplets_to_test_list=[] 441 | questions_num_in_list=0 --------------------------------------------------------------------------------