├── CoT ├── README.md ├── cot_io.py ├── prompt_list.py └── utils.py ├── Freebase └── README.md ├── README.md ├── ToG ├── README.md ├── client.py ├── freebase_func.py ├── main_freebase.py ├── main_wiki.py ├── prompt_list.py ├── server_urls.txt ├── utils.py └── wiki_func.py ├── Wikidata ├── .gitignore ├── README.md ├── requirements.txt ├── scripts │ ├── build_index.sh │ └── start_server.sh └── simple_wikidata_db │ ├── db_deploy │ ├── __init__.py │ ├── build_index.py │ ├── client.py │ ├── server.py │ └── utils.py │ ├── preprocess_dump.py │ ├── preprocess_utils │ ├── reader_process.py │ ├── worker_process.py │ └── writer_process.py │ └── utils.py ├── assets ├── application.png ├── demo.png ├── experiments.png └── methods.png ├── data ├── README.md ├── SimpleQA.json ├── T-REX.json ├── WebQSP.json ├── WebQuestions.json ├── Zero_Shot_RE.json ├── creak.json ├── cwq.json ├── graliqa.json └── qald_10-en.json ├── eval ├── README.md ├── eval.py └── utils.py ├── requirements.txt └── tools ├── README.md ├── de_duplicate.py ├── jsonl2json.py └── split_dataset.py /CoT/README.md: -------------------------------------------------------------------------------- 1 | # CoT 2 | 3 | In this folder are the experiments that correspond to the CoT and IO prompt in the main experiment table. 4 | 5 | Make sure you have installed all the requirements: 6 | ```sh 7 | tqdm 8 | openai 9 | ``` 10 | > 11 | If you want to use a non-openai model like LLAMA, make sure to download [vllm](https://github.com/vllm-project/vllm) and turn on the api service with the following command: 12 | 13 | ```sh 14 | python -m vllm.entrypoints.openai.api_server \ 15 | --model meta-llama/Llama-2-70b-chat-hf \ 16 | --tensor-parallel-size 8 \ 17 | --max-num-batched-tokens 4096 18 | ``` 19 | 20 | For the `Llama-2-70b-chat-hf`, it is recommended to running with 8 A100-40Gs. 21 | 22 | ### How to run 23 | If you have already configured all the requirements, you can just execute the following command: 24 | ```sh 25 | python cot_io.py \ 26 | --dataset cwq \ # dataset your wanna test, see ToG/data/README.md 27 | --prompt_methods cot \ # cot or io prompt 28 | --max_length 256 \ 29 | --temperature 0 \ # We recommend the temperature setting of 0 for reproducible results. 30 | --LLM_type gpt-3.5-turbo \ # the LLM you choose 31 | --opeani_api_keys sk-xxxx \ # your own api keys, if LLM_type == llama, this parameter would be rendered ineffective. 32 | ``` 33 | 34 | ### How to eval 35 | After finish ToG and generating the result file (such as `CoT_cwq.jsonl`), proceed to the "eval" directory `README.md`. -------------------------------------------------------------------------------- /CoT/cot_io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from tqdm import tqdm 4 | from utils import * 5 | from prompt_list import * 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--dataset", type=str, 10 | default="cwq", help="choose the dataset.") 11 | parser.add_argument("--prompt_methods", type=str, 12 | default="cot", help="cot or io.") 13 | parser.add_argument("--max_length", type=int, 14 | default=256, help="the max length of LLMs output.") 15 | parser.add_argument("--temperature", type=int, 16 | default=0, help="the temperature") 17 | parser.add_argument("--LLM_type", type=int, 18 | default="gpt-3.5-turbo", help="base LLM model.") 19 | parser.add_argument("--opeani_api_keys", type=int, 20 | default="", help="if the LLM_type is gpt-3.5-turbo or gpt-4, you need add your own openai api keys.") 21 | args = parser.parse_args() 22 | 23 | with open("cot_{}.jsonl".format(args.dataset), 'a+', encoding="UTF-8") as out: 24 | datas, question_string = prepare_dataset(args.dataset) 25 | for i in tqdm(datas, total=len(datas)): 26 | if args.prompt_methods == "cot": 27 | prompt = cot_prompt + "\n\nQ: " + i[question_string] + "\nA: " 28 | else: 29 | prompt = io_prompt + "\n\nQ: " + i[question_string] + "\nA: " 30 | results = run_llm(prompt, args.temperature, args.max_length, args.opeani_api_keys, args.LLM_type) 31 | out.write(json.dumps({"question": i[question_string], "{}_result".format(args.prompt_methods): results})+'\n') 32 | -------------------------------------------------------------------------------- /CoT/prompt_list.py: -------------------------------------------------------------------------------- 1 | cot_prompt = """Q: What state is home to the university that is represented in sports by George Washington Colonials men's basketball? 2 | A: First, the education institution has a sports team named George Washington Colonials men's basketball in is George Washington University , Second, George Washington University is in Washington D.C. The answer is {Washington, D.C.}. 3 | 4 | Q: Who lists Pramatha Chaudhuri as an influence and wrote Jana Gana Mana? 5 | A: First, Bharoto Bhagyo Bidhata wrote Jana Gana Mana. Second, Bharoto Bhagyo Bidhata lists Pramatha Chaudhuri as an influence. The answer is {Bharoto Bhagyo Bidhata}. 6 | 7 | Q: Who was the artist nominated for an award for You Drive Me Crazy? 8 | A: First, the artist nominated for an award for You Drive Me Crazy is Britney Spears. The answer is {Jason Allen Alexander}. 9 | 10 | Q: What person born in Siegen influenced the work of Vincent Van Gogh? 11 | A: First, Peter Paul Rubens, Claude Monet and etc. influenced the work of Vincent Van Gogh. Second, Peter Paul Rubens born in Siegen. The answer is {Peter Paul Rubens}. 12 | 13 | Q: What is the country close to Russia where Mikheil Saakashvii holds a government position? 14 | A: First, China, Norway, Finland, Estonia and Georgia is close to Russia. Second, Mikheil Saakashvii holds a government position at Georgia. The answer is {Georgia}. 15 | 16 | Q: What drug did the actor who portrayed the character Urethane Wheels Guy overdosed on? 17 | A: First, Mitchell Lee Hedberg portrayed character Urethane Wheels Guy. Second, Mitchell Lee Hedberg overdose Heroin. The answer is {Heroin}.""" 18 | 19 | io_prompt = """Q: What state is home to the university that is represented in sports by George Washington Colonials men's basketball? 20 | A: {Washington, D.C.}. 21 | 22 | Q: Who lists Pramatha Chaudhuri as an influence and wrote Jana Gana Mana? 23 | A: {Bharoto Bhagyo Bidhata}. 24 | 25 | Q: Who was the artist nominated for an award for You Drive Me Crazy? 26 | A: {Jason Allen Alexander}. 27 | 28 | Q: What person born in Siegen influenced the work of Vincent Van Gogh? 29 | A: {Peter Paul Rubens}. 30 | 31 | Q: What is the country close to Russia where Mikheil Saakashvii holds a government position? 32 | A: {Georgia}. 33 | 34 | Q: What drug did the actor who portrayed the character Urethane Wheels Guy overdosed on? 35 | A: {Heroin}.""" -------------------------------------------------------------------------------- /CoT/utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import time 3 | import json 4 | 5 | def run_llm(prompt, temperature, max_tokens, opeani_api_keys, engine="gpt-3.5-turbo"): 6 | if "llama" not in engine.lower(): 7 | openai.api_key = "EMPTY" 8 | openai.api_base = "http://localhost:8000/v1" # your local llama server port 9 | engine = openai.Model.list()["data"][0]["id"] 10 | else: 11 | openai.api_key = opeani_api_keys 12 | 13 | messages = [{"role":"system","content":"You are an AI assistant that helps people find information."}] 14 | message_prompt = {"role":"user","content":prompt} 15 | messages.append(message_prompt) 16 | print("start openai") 17 | while(f == 0): 18 | try: 19 | response = openai.ChatCompletion.create( 20 | model=engine, 21 | messages = messages, 22 | temperature=temperature, 23 | max_tokens=max_tokens, 24 | frequency_penalty=0, 25 | presence_penalty=0) 26 | result = response["choices"][0]['message']['content'] 27 | f = 1 28 | except: 29 | print("openai error, retry") 30 | time.sleep(2) 31 | print("end openai") 32 | return result 33 | 34 | def prepare_dataset(dataset_name): 35 | if dataset_name == 'cwq': 36 | with open('../data/cwq.json',encoding='utf-8') as f: 37 | datas = json.load(f) 38 | question_string = 'question' 39 | elif dataset_name == 'webqsp': 40 | with open('../data/WebQSP.json',encoding='utf-8') as f: 41 | datas = json.load(f) 42 | question_string = 'RawQuestion' 43 | elif dataset_name == 'grailqa': 44 | with open('../data/grailqa.json',encoding='utf-8') as f: 45 | datas = json.load(f) 46 | question_string = 'question' 47 | elif dataset_name == 'simpleqa': 48 | with open('../data/SimpleQA.json',encoding='utf-8') as f: 49 | datas = json.load(f) 50 | question_string = 'question' 51 | elif dataset_name == 'qald': 52 | with open('../data/qald_10-en.json',encoding='utf-8') as f: 53 | datas = json.load(f) 54 | question_string = 'question' 55 | elif dataset_name == 'webquestions': 56 | with open('../data/WebQuestions.json',encoding='utf-8') as f: 57 | datas = json.load(f) 58 | question_string = 'question' 59 | elif dataset_name == 'trex': 60 | with open('../data/T-REX.json',encoding='utf-8') as f: 61 | datas = json.load(f) 62 | question_string = 'input' 63 | elif dataset_name == 'zeroshotre': 64 | with open('../data/Zero_Shot_RE.json',encoding='utf-8') as f: 65 | datas = json.load(f) 66 | question_string = 'input' 67 | elif dataset_name == 'creak': 68 | with open('../data/creak.json',encoding='utf-8') as f: 69 | datas = json.load(f) 70 | question_string = 'sentence' 71 | else: 72 | print("dataset not found") 73 | exit(-1) 74 | return datas, question_string -------------------------------------------------------------------------------- /Freebase/README.md: -------------------------------------------------------------------------------- 1 | # Freebase Setup 2 | 3 | ## Requirements 4 | 5 | - OpenLink Virtuoso 7.2.5 (download from this public [link](https://sourceforge.net/projects/virtuoso/files/virtuoso/)) 6 | - Python 3 7 | - Freebase dump from this public [link](https://developers.google.com/freebase?hl=en) 8 | 9 | ## Setup 10 | 11 | ### Data Preprocessing 12 | 13 | We use this py script (public [link)](https://github.com/lanyunshi/Multi-hopComplexKBQA/blob/master/code/FreebaseTool/FilterEnglishTriplets.py), to clean the data and remove non-English or non-digital triplets: 14 | 15 | ```shell 16 | gunzip -c freebase-rdf-latest.gz > freebase # data size: 400G 17 | nohup python -u FilterEnglishTriplets.py 0FilterFreebase 2>log_err & # data size: 125G 18 | ``` 19 | 20 | ## Import data 21 | 22 | we import the cleaned data to virtuoso, 23 | 24 | ```shell 25 | tar xvpfz virtuoso-opensource.x86_64-generic_glibc25-linux-gnu.tar.gz 26 | cd virtuoso-opensource/database/ 27 | mv virtuoso.ini.sample virtuoso.ini 28 | 29 | # ../bin/virtuoso-t -df # start the service in the shell 30 | ../bin/virtuoso-t # start the service in the backend. 31 | ../bin/isql 1111 dba dba # run the database 32 | 33 | # 1、unzip the data and use rdf_loader to import 34 | SQL> 35 | ld_dir('.', 'FilterFreebase', 'http://freebase.com'); 36 | rdf_loader_run(); 37 | ``` 38 | 39 | Wait for a long time and then ready to use. 40 | 41 | ## Mapping data to Wikidata 42 | 43 | Due to the partial incompleteness of the data present in the freebase dump, we need to map some of the entities with missing partial relationships to wikidata. We download these rdf data via this public [link](https://developers.google.com/freebase?hl=en#freebase-wikidata-mappings) 44 | 45 | we can use the above method to add it into virtuoso. 46 | 47 | ## Test example 48 | 49 | ```python 50 | import json 51 | from SPARQLWrapper import SPARQLWrapper, JSON 52 | 53 | SPARQLPATH = "http://localhost:8890/sparql" 54 | 55 | def test(): 56 | try: 57 | sparql = SPARQLWrapper(SPARQLPATH) 58 | sparql_txt = """PREFIX ns: 59 | SELECT distinct ?name3 60 | WHERE { 61 | ns:m.0k2kfpc ns:award.award_nominated_work.award_nominations ?e1. 62 | ?e1 ns:award.award_nomination.award_nominee ns:m.02pbp9. 63 | ns:m.02pbp9 ns:people.person.spouse_s ?e2. 64 | ?e2 ns:people.marriage.spouse ?e3. 65 | ?e2 ns:people.marriage.from ?e4. 66 | ?e3 ns:type.object.name ?name3 67 | MINUS{?e2 ns:type.object.name ?name2} 68 | } 69 | """ 70 | #print(sparql_txt) 71 | sparql.setQuery(sparql_txt) 72 | sparql.setReturnFormat(JSON) 73 | results = sparql.query().convert() 74 | print(results) 75 | except: 76 | print('Your database is not installed properly !!!') 77 | 78 | test() 79 | 80 | ``` 81 | 82 | ## 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ToG 2 | 3 | ## News! 4 | Our paper is accepted by ICLR 2024 ! 🥳🥳🥳 5 | 6 | ## ToG is moved to a new repo [ToG](https://github.com/IDEA-FinAI/ToG). 7 | 8 | 9 | The code for paper: ["Think-on-Graph: Deep and Responsible Reasoning of Large Language Model on Knowledge Graph"](https://arxiv.org/pdf/2307.07697.pdf). 10 | 11 | 12 | 13 | ## Here is the illustration of ToG: 14 | 15 | ![image](https://github.com/GasolSun36/ToG/blob/main/assets/demo.png) 16 | 17 | ## The pipeline of ToG: 18 | 19 | ![image](https://github.com/GasolSun36/ToG/blob/main/assets/methods.png) 20 | 21 | ## Project Structure 22 | - `requirements.txt`: Pip environment file. 23 | - `data/`: Evaluation datasets. See `data/README.md` for details. 24 | - `CoT/`: CoT methods. See `CoT/README.md` for details. 25 | - `eval/`: Evaluation script. See `eval/README.md` for details. 26 | - `Freebase/`: Freebase environment setting. See `Freebase/README.md` for details. 27 | - `Wikidata/`: Wikidata environment setting. See `Wikidata/README.md` for details. 28 | - `tools/`: Common tools used in ToG. See `tools/README.md` for details. 29 | - `ToG/`: Source codes. 30 | - `client.py`: Pre-defined Wikidata APIs, copy from `Wikidata/`. 31 | - `server_urls.txt`: Wikidata server urls, copy from `Wikidata/`. 32 | - `main_freebase.py`: The main file of ToG where Freebase as KG source. See `README.md` for details. 33 | - `main_wiki.py`: Same as above but using Wikidata as KG source. See `README.md` for details. 34 | - `prompt_list.py`: The prompts for the ToG to pruning, reasoning and generating. 35 | - `freebase_func.py`: All the functions used in `main_freebase.py`. 36 | - `wiki_func.py`: All the functions used in `main_wiki.py`. 37 | - `utils.py`: All the functions used in ToG. 38 | 39 | ## Get started 40 | Before running ToG, please ensure that you have successfully installed either **Freebase** or **Wikidata** on your local machine. The comprehensive installation instructions and necessary configuration details can be found in the `README.md` file located within the respective folder. 41 | 42 | The required libraries for running ToG can be found in `requirements.txt`. 43 | 44 | When using the Wikidata service, copy the `client.py` and `server_urls.txt` files from the `Wikidata` directory into the `ToG` folder. 45 | 46 | 47 | # How to run 48 | See `ToG/` README.md 49 | 50 | # How to eval 51 | Upon obtaining the result file, such as `ToG_cwq.jsonl`, you should using the `jsonl2json.py` script from the `tools` directory to convert the `ToG_cwq.jsonl` to `ToG_cwq.json`. Then, evaluate using the script in the `eval` folder (see `README.md` in `eval` folder). 52 | 53 | 54 | # How to cite 55 | If you interested or inspired by this work, you can cite us by: 56 | ```sh 57 | @misc{sun2023thinkongraph, 58 | title={Think-on-Graph: Deep and Responsible Reasoning of Large Language Model with Knowledge Graph}, 59 | author={Jiashuo Sun and Chengjin Xu and Lumingyuan Tang and Saizhuo Wang and Chen Lin and Yeyun Gong and Heung-Yeung Shum and Jian Guo}, 60 | year={2023}, 61 | eprint={2307.07697}, 62 | archivePrefix={arXiv}, 63 | primaryClass={cs.CL} 64 | } 65 | ``` 66 | 67 | # Experiment: 68 | 69 | ![image](https://github.com/GasolSun36/ToG/blob/main/assets/experiments.png) 70 | 71 | 72 | # Application: 73 | 74 | ![image](https://github.com/GasolSun36/ToG/blob/main/assets/application.png) 75 | 76 | 77 | 78 | # Claims 79 | This project uses the Apache 2.0 protocol. The project assumes no legal responsibility for any of the model's output and will not be held liable for any damages that may result from the use of the resources and output. 80 | -------------------------------------------------------------------------------- /ToG/README.md: -------------------------------------------------------------------------------- 1 | # ToG 2 | 3 | Upon successfully installing all the necessary configurations, you can proceed to execute ToG directly by employing the following command: 4 | 5 | ```sh 6 | python main_freebase.py \ # if you wanna use Wikidata as KG source, run main_wiki.py 7 | --dataset cwq \ # dataset your wanna test, see ToG/data/README.md 8 | --max_length 256 \ 9 | --temperature_exploration 0.4 \ # the temperature in exploration stage. 10 | --temperature_exploration 0 \ # the temperature in reasoning stage. 11 | --width 3 \ # choose the search width of ToG, 3 is the default setting. 12 | --depth 3 \ # choose the search depth of ToG, 3 is the default setting. 13 | --remove_unnecessary_rel True \ # whether removing unnecessary relations. 14 | --LLM_type gpt-3.5-turbo \ # the LLM you choose 15 | --opeani_api_keys sk-xxxx \ # your own api keys, if LLM_type == llama, this parameter would be rendered ineffective. 16 | --num_retain_entity 5 \ # Number of entities retained during entities search. 17 | --prune_tools llm \ # prune tools for ToG, can be llm (same as LLM_type), bm25 or sentencebert. 18 | ``` 19 | 20 | All the pruning and reasoning prompts utilized in the experiment are in the `prompt_list.py` file. 21 | 22 | For eval, please see `eval/README.md` file. -------------------------------------------------------------------------------- /ToG/client.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import xmlrpc.client 3 | import typing as tp 4 | from concurrent.futures import ThreadPoolExecutor 5 | 6 | class WikidataQueryClient: 7 | def __init__(self, url: str): 8 | self.url = url 9 | self.server = xmlrpc.client.ServerProxy(url) 10 | 11 | def label2qid(self, label: str) -> str: 12 | return self.server.label2qid(label) 13 | 14 | def label2pid(self, label: str) -> str: 15 | return self.server.label2pid(label) 16 | 17 | def pid2label(self, pid: str) -> str: 18 | return self.server.pid2label(pid) 19 | 20 | def qid2label(self, qid: str) -> str: 21 | return self.server.qid2label(qid) 22 | 23 | def get_all_relations_of_an_entity( 24 | self, entity_qid: str 25 | ) -> tp.Dict[str, tp.List]: 26 | return self.server.get_all_relations_of_an_entity(entity_qid) 27 | 28 | def get_tail_entities_given_head_and_relation( 29 | self, head_qid: str, relation_pid: str 30 | ) -> tp.Dict[str, tp.List]: 31 | return self.server.get_tail_entities_given_head_and_relation( 32 | head_qid, relation_pid 33 | ) 34 | 35 | def get_tail_values_given_head_and_relation( 36 | self, head_qid: str, relation_pid: str 37 | ) -> tp.List[str]: 38 | return self.server.get_tail_values_given_head_and_relation( 39 | head_qid, relation_pid 40 | ) 41 | 42 | def get_external_id_given_head_and_relation( 43 | self, head_qid: str, relation_pid: str 44 | ) -> tp.List[str]: 45 | return self.server.get_external_id_given_head_and_relation( 46 | head_qid, relation_pid 47 | ) 48 | 49 | def mid2qid(self, mid: str) -> str: 50 | return self.server.mid2qid(mid) 51 | 52 | 53 | import time 54 | import typing as tp 55 | from concurrent.futures import ThreadPoolExecutor 56 | 57 | 58 | class MultiServerWikidataQueryClient: 59 | def __init__(self, urls: tp.List[str]): 60 | self.clients = [WikidataQueryClient(url) for url in urls] 61 | self.executor = ThreadPoolExecutor(max_workers=len(urls)) 62 | # test connections 63 | start_time = time.perf_counter() 64 | self.test_connections() 65 | end_time = time.perf_counter() 66 | print(f"Connection testing took {end_time - start_time} seconds") 67 | 68 | def test_connections(self): 69 | def test_url(client): 70 | try: 71 | # Check if server provides the system.listMethods function. 72 | client.server.system.listMethods() 73 | return True 74 | except Exception as e: 75 | print(f"Failed to connect to {client.url}. Error: {str(e)}") 76 | return False 77 | 78 | start_time = time.perf_counter() 79 | futures = [ 80 | self.executor.submit(test_url, client) for client in self.clients 81 | ] 82 | results = [f.result() for f in futures] 83 | end_time = time.perf_counter() 84 | # print(f"Testing connections took {end_time - start_time} seconds") 85 | # Remove clients that failed to connect 86 | self.clients = [ 87 | client for client, result in zip(self.clients, results) if result 88 | ] 89 | if not self.clients: 90 | raise Exception("Failed to connect to all URLs") 91 | 92 | def query_all(self, method, *args): 93 | start_time = time.perf_counter() 94 | futures = [ 95 | self.executor.submit(getattr(client, method), *args) 96 | for client in self.clients 97 | ] 98 | # Retrieve results and filter out 'Not Found!' 99 | is_dict_return = method in [ 100 | "get_all_relations_of_an_entity", 101 | "get_tail_entities_given_head_and_relation", 102 | ] 103 | results = [f.result() for f in futures] 104 | end_time = time.perf_counter() 105 | # print(f"HTTP Queries took {end_time - start_time} seconds") 106 | 107 | start_time = time.perf_counter() 108 | real_results = set() if not is_dict_return else {"head": [], "tail": []} 109 | for res in results: 110 | if isinstance(res, str) and res == "Not Found!": 111 | continue 112 | elif isinstance(res, tp.List): 113 | if len(res) == 0: 114 | continue 115 | if isinstance(res[0], tp.List): 116 | res_flattened = itertools.chain(*res) 117 | real_results.update(res_flattened) 118 | continue 119 | real_results.update(res) 120 | elif is_dict_return: 121 | real_results["head"].extend(res["head"]) 122 | real_results["tail"].extend(res["tail"]) 123 | else: 124 | real_results.add(res) 125 | end_time = time.perf_counter() 126 | # print(f"Querying all took {end_time - start_time} seconds") 127 | 128 | return real_results if len(real_results) > 0 else "Not Found!" 129 | 130 | 131 | if __name__ == "__main__": 132 | import argparse 133 | 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument( 136 | "--addr_list", 137 | type=str, 138 | required=True, 139 | help="path to server address list", 140 | ) 141 | args = parser.parse_args() 142 | 143 | with open(args.addr_list, "r") as f: 144 | server_addrs = f.readlines() 145 | server_addrs = [addr.strip() for addr in server_addrs] 146 | print(f"Server addresses: {server_addrs}") 147 | client = MultiServerWikidataQueryClient(server_addrs) 148 | print( 149 | f'MSFT\'s ticker code is {client.query_all("get_tail_values_given_head_and_relation","Q2283","P249",)}' 150 | ) -------------------------------------------------------------------------------- /ToG/freebase_func.py: -------------------------------------------------------------------------------- 1 | from SPARQLWrapper import SPARQLWrapper, JSON 2 | SPARQLPATH = "http://xxx.xxx.xxx.xxx/sparql" # depend on your own internal address and port, shown in Freebase folder's readme.md 3 | 4 | # pre-defined sparqls 5 | sparql_head_relations = """\nPREFIX ns: \nSELECT ?relation\nWHERE {\n ns:%s ?relation ?x .\n}""" 6 | sparql_tail_relations = """\nPREFIX ns: \nSELECT ?relation\nWHERE {\n ?x ?relation ns:%s .\n}""" 7 | sparql_tail_entities_extract = """PREFIX ns: \nSELECT ?tailEntity\nWHERE {\nns:%s ns:%s ?tailEntity .\n}""" 8 | sparql_head_entities_extract = """PREFIX ns: \nSELECT ?tailEntity\nWHERE {\n?tailEntity ns:%s ns:%s .\n}""" 9 | sparql_id = """PREFIX ns: \nSELECT DISTINCT ?tailEntity\nWHERE {\n {\n ?entity ns:type.object.name ?tailEntity .\n FILTER(?entity = ns:%s)\n }\n UNION\n {\n ?entity ?tailEntity .\n FILTER(?entity = ns:%s)\n }\n}""" 10 | 11 | def check_end_word(s): 12 | words = [" ID", " code", " number", "instance of", "website", "URL", "inception", "image", " rate", " count"] 13 | return any(s.endswith(word) for word in words) 14 | 15 | def abandon_rels(relation): 16 | if relation == "type.object.type" or relation == "type.object.name" or relation.startswith("common.") or relation.startswith("freebase.") or "sameAs" in relation: 17 | return True 18 | 19 | 20 | def execurte_sparql(sparql_txt): 21 | sparql = SPARQLWrapper(SPARQLPATH) 22 | sparql.setQuery(sparql_txt) 23 | sparql.setReturnFormat(JSON) 24 | results = sparql.query().convert() 25 | return results["results"]["bindings"] 26 | 27 | 28 | def replace_relation_prefix(relations): 29 | return [relation['relation']['value'].replace("http://rdf.freebase.com/ns/","") for relation in relations] 30 | 31 | def replace_entities_prefix(entities): 32 | return [entity['tailEntity']['value'].replace("http://rdf.freebase.com/ns/","") for entity in entities] 33 | 34 | 35 | def id2entity_name_or_type(entity_id): 36 | sparql = sparql_id % (entity_id, entity_id) 37 | sparql = SPARQLWrapper(SPARQLPATH) 38 | sparql.setQuery(sparql) 39 | sparql.setReturnFormat(JSON) 40 | results = sparql.query().convert() 41 | if len(results["results"]["bindings"])==0: 42 | return "UnName_Entity" 43 | else: 44 | return results["results"]["bindings"][0]['tailEntity']['value'] 45 | -------------------------------------------------------------------------------- /ToG/main_freebase.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import argparse 3 | from utils import * 4 | import random 5 | from client import * 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--dataset", type=str, 11 | default="cwq", help="choose the dataset.") 12 | parser.add_argument("--max_length", type=int, 13 | default=256, help="the max length of LLMs output.") 14 | parser.add_argument("--temperature_exploration", type=float, 15 | default=0.4, help="the temperature in exploration stage.") 16 | parser.add_argument("--temperature_reasoning", type=float, 17 | default=0, help="the temperature in reasoning stage.") 18 | parser.add_argument("--width", type=int, 19 | default=3, help="choose the search width of ToG.") 20 | parser.add_argument("--depth", type=int, 21 | default=3, help="choose the search depth of ToG.") 22 | parser.add_argument("--remove_unnecessary_rel", type=bool, 23 | default=True, help="whether removing unnecessary relations.") 24 | parser.add_argument("--LLM_type", type=str, 25 | default="gpt-3.5-turbo", help="base LLM model.") 26 | parser.add_argument("--opeani_api_keys", type=str, 27 | default="", help="if the LLM_type is gpt-3.5-turbo or gpt-4, you need add your own openai api keys.") 28 | parser.add_argument("--num_retain_entity", type=int, 29 | default=5, help="Number of entities retained during entities search.") 30 | parser.add_argument("--prune_tools", type=str, 31 | default="llm", help="prune tools for ToG, can be llm (same as LLM_type), bm25 or sentencebert.") 32 | args = parser.parse_args() 33 | 34 | datas, question_string = prepare_dataset(args.dataset) 35 | 36 | for data in tqdm(datas): 37 | question = data[question_string] 38 | topic_entity = data['topic_entity'] 39 | cluster_chain_of_entities = [] 40 | pre_relations = [], 41 | pre_heads= [-1] * len(topic_entity) 42 | flag_printed = False 43 | for depth in range(1, args.depth+1): 44 | current_entity_relations_list = [] 45 | i=0 46 | for entity in topic_entity: 47 | if entity!="[FINISH_ID]": 48 | retrieve_relations_with_scores = relation_search_prune(entity, topic_entity[entity], pre_relations, pre_heads[i], question, args) # best entity triplet, entitiy_id 49 | current_entity_relations_list.extend(retrieve_relations_with_scores) 50 | i+=1 51 | total_candidates = [] 52 | total_scores = [] 53 | total_relations = [] 54 | total_entities_id = [] 55 | total_topic_entities = [] 56 | total_head = [] 57 | 58 | for entity in current_entity_relations_list: 59 | if entity['head']: 60 | entity_candidates_id = entity_search(entity['entity'], entity['relation'], True) 61 | else: 62 | entity_candidates_id = entity_search(entity['entity'], entity['relation'], False) 63 | 64 | if len(entity_candidates_id) >=20: 65 | entity_candidates_id = random.sample(entity_candidates_id, args.num_retain_entity) 66 | 67 | if len(entity_candidates_id) ==0: 68 | continue 69 | 70 | scores, entity_candidates, entity_candidates_id = entity_score(question, entity_candidates_id, entity['score'], entity['relation'], args) 71 | 72 | total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head = update_history(entity_candidates, entity, scores, entity_candidates_id, total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head) 73 | 74 | if len(total_candidates) ==0: 75 | half_stop(question, cluster_chain_of_entities, args) 76 | break 77 | 78 | flag, chain_of_entities, entities_id, pre_relations, pre_heads = entity_prune(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores, args) 79 | cluster_chain_of_entities.append(chain_of_entities) 80 | if flag: 81 | stop, results = reasoning(question, cluster_chain_of_entities, args) 82 | if stop: 83 | print("ToG stoped at depth %d." % depth) 84 | save_2_jsonl(question, results, cluster_chain_of_entities, file_name=args.dataset) 85 | flag_printed = True 86 | else: 87 | print("depth %d still not find the answer." % depth) 88 | topic_entity = {entity: id2entity_name_or_type[entity] for entity in entities_id} 89 | continue 90 | else: 91 | half_stop(question, cluster_chain_of_entities, args) 92 | 93 | if not flag_printed: 94 | results = generate_without_explored_paths(question, args) 95 | save_2_jsonl(question, results, [], file_name=args.dataset) 96 | -------------------------------------------------------------------------------- /ToG/main_wiki.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import argparse 3 | import random 4 | from wiki_func import * 5 | from client import * 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--dataset", type=str, 11 | default="cwq", help="choose the dataset.") 12 | parser.add_argument("--max_length", type=int, 13 | default=256, help="the max length of LLMs output.") 14 | parser.add_argument("--temperature_exploration", type=float, 15 | default=0.4, help="the temperature in exploration stage.") 16 | parser.add_argument("--temperature_reasoning", type=float, 17 | default=0, help="the temperature in reasoning stage.") 18 | parser.add_argument("--width", type=int, 19 | default=3, help="choose the search width of ToG.") 20 | parser.add_argument("--depth", type=int, 21 | default=3, help="choose the search depth of ToG.") 22 | parser.add_argument("--remove_unnecessary_rel", type=bool, 23 | default=True, help="whether removing unnecessary relations.") 24 | parser.add_argument("--LLM_type", type=str, 25 | default="gpt-3.5-turbo", help="base LLM model.") 26 | parser.add_argument("--opeani_api_keys", type=str, 27 | default="", help="if the LLM_type is gpt-3.5-turbo or gpt-4, you need add your own openai api keys.") 28 | parser.add_argument("--num_retain_entity", type=int, 29 | default=5, help="Number of entities retained during entities search.") 30 | parser.add_argument("--prune_tools", type=str, 31 | default="llm", help="prune tools for ToG, can be llm (same as LLM_type), bm25 or sentencebert.") 32 | parser.add_argument("--addr_list", type=int, 33 | default="server_urls.txt", help="The address of the Wikidata service.") 34 | args = parser.parse_args() 35 | 36 | datas, question_string = prepare_dataset(args.dataset) 37 | 38 | for data in tqdm(datas): 39 | question = data[question_string] 40 | topic_entity = data['topic_entity'] 41 | cluster_chain_of_entities = [] 42 | pre_relations = [], 43 | pre_heads= [-1] * len(topic_entity) 44 | flag_printed = False 45 | addr_list = 'ToG/ToG-E/server_urls.txt' 46 | with open(addr_list, "r") as f: 47 | server_addrs = f.readlines() 48 | server_addrs = [addr.strip() for addr in server_addrs] 49 | print(f"Server addresses: {server_addrs}") 50 | wiki_client = MultiServerWikidataQueryClient(server_addrs) 51 | for depth in range(1, args.depth+1): 52 | current_entity_relations_list = [] 53 | i=0 54 | for entity in topic_entity: 55 | if entity!="[FINISH_ID]": 56 | retrieve_relations_with_scores = relation_search_prune(entity, topic_entity[entity], pre_relations, pre_heads[i], question, args, wiki_client) # best entity triplet, entitiy_id 57 | current_entity_relations_list.extend(retrieve_relations_with_scores) 58 | i+=1 59 | total_candidates = [] 60 | total_scores = [] 61 | total_relations = [] 62 | total_entities_id = [] 63 | total_topic_entities = [] 64 | total_head = [] 65 | 66 | for entity in current_entity_relations_list: 67 | value_flag=False 68 | if entity['head']: 69 | entity_candidates_id, entity_candidates_name = entity_search(entity['entity'], entity['relation'], True) 70 | else: 71 | entity_candidates_id, entity_candidates_name = entity_search(entity['entity'], entity['relation'], False) 72 | 73 | if len(entity_candidates_id) ==0: # values 74 | value_flag=True 75 | if len(entity_candidates_name) >=20: 76 | entity_candidates_name = random.sample(entity_candidates_name, 10) 77 | entity_candidates_id = ["[FINISH_ID]"] * len(entity_candidates_name) 78 | else: # ids 79 | entity_candidates_id, entity_candidates_name = del_all_unknown_entity(entity_candidates_id, entity_candidates_name) 80 | if len(entity_candidates_id) >=20: 81 | indices = random.sample(range(len(entity_candidates_name)), 10) 82 | entity_candidates_id = [entity_candidates_id[i] for i in indices] 83 | entity_candidates_name = [entity_candidates_name[i] for i in indices] 84 | 85 | if len(entity_candidates_id) ==0: 86 | continue 87 | 88 | scores, entity_candidates, entity_candidates_id = entity_score(question, entity_candidates_id, entity_candidates_name, entity['score'], entity['relation'], args) 89 | 90 | total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head = update_history(entity_candidates, entity, scores, entity_candidates_id, total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head, value_flag) 91 | 92 | if len(total_candidates) ==0: 93 | half_stop(question, cluster_chain_of_entities, args) 94 | break 95 | 96 | flag, chain_of_entities, entities_id, pre_relations, pre_heads = entity_prune(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores, args, wiki_client) 97 | cluster_chain_of_entities.append(chain_of_entities) 98 | if flag: 99 | stop, results = reasoning(question, cluster_chain_of_entities, args) 100 | if stop: 101 | print("ToG stoped at depth %d." % depth) 102 | save_2_jsonl(question, results, cluster_chain_of_entities, file_name=args.dataset) 103 | flag_printed = True 104 | else: 105 | print("depth %d still not find the answer." % depth) 106 | topic_entity = {entity: wiki_client.query_all("qid2label", entity) for entity in entities_id} 107 | continue 108 | else: 109 | half_stop(question, cluster_chain_of_entities, args) 110 | 111 | if not flag_printed: 112 | results = generate_without_explored_paths(question, args) 113 | save_2_jsonl(question, results, [], file_name=args.dataset) 114 | -------------------------------------------------------------------------------- /ToG/prompt_list.py: -------------------------------------------------------------------------------- 1 | extract_relation_prompt = """Please retrieve %s relations (separated by semicolon) that contribute to the question and rate their contribution on a scale from 0 to 1 (the sum of the scores of %s relations is 1). 2 | Q: Name the president of the country whose main spoken language was Brahui in 1980? 3 | Topic Entity: Brahui Language 4 | Relations: language.human_language.main_country; language.human_language.language_family; language.human_language.iso_639_3_code; base.rosetta.languoid.parent; language.human_language.writing_system; base.rosetta.languoid.languoid_class; language.human_language.countries_spoken_in; kg.object_profile.prominent_type; base.rosetta.languoid.document; base.ontologies.ontology_instance.equivalent_instances; base.rosetta.languoid.local_name; language.human_language.region 5 | A: 1. {language.human_language.main_country (Score: 0.4))}: This relation is highly relevant as it directly relates to the country whose president is being asked for, and the main country where Brahui language is spoken in 1980. 6 | 2. {language.human_language.countries_spoken_in (Score: 0.3)}: This relation is also relevant as it provides information on the countries where Brahui language is spoken, which could help narrow down the search for the president. 7 | 3. {base.rosetta.languoid.parent (Score: 0.2)}: This relation is less relevant but still provides some context on the language family to which Brahui belongs, which could be useful in understanding the linguistic and cultural background of the country in question. 8 | 9 | Q: """ 10 | 11 | score_entity_candidates_prompt = """Please score the entities' contribution to the question on a scale from 0 to 1 (the sum of the scores of all entities is 1). 12 | Q: The movie featured Miley Cyrus and was produced by Tobin Armbrust? 13 | Relation: film.producer.film 14 | Entites: The Resident; So Undercover; Let Me In; Begin Again; The Quiet Ones; A Walk Among the Tombstones 15 | Score: 0.0, 1.0, 0.0, 0.0, 0.0, 0.0 16 | The movie that matches the given criteria is "So Undercover" with Miley Cyrus and produced by Tobin Armbrust. Therefore, the score for "So Undercover" would be 1, and the scores for all other entities would be 0. 17 | 18 | Q: {} 19 | Relation: {} 20 | Entites: """ 21 | 22 | answer_prompt = """Given a question and the associated retrieved knowledge graph triplets (entity, relation, entity), you are asked to answer the question with these triplets and your knowledge. 23 | Q: Find the person who said \"Taste cannot be controlled by law\", what did this person die from? 24 | Knowledge Triplets: Taste cannot be controlled by law., media_common.quotation.author, Thomas Jefferson 25 | A: Based on the given knowledge triplets, it's not sufficient to answer the entire question. The triplets only provide information about the person who said "Taste cannot be controlled by law," which is Thomas Jefferson. To answer the second part of the question, it's necessary to have additional knowledge about where Thomas Jefferson's dead. 26 | 27 | Q: The artist nominated for The Long Winter lived where? 28 | Knowledge Triplets: The Long Winter, book.written_work.author, Laura Ingalls Wilder 29 | Laura Ingalls Wilder, people.person.places_lived, Unknown-Entity 30 | Unknown-Entity, people.place_lived.location, De Smet 31 | A: Based on the given knowledge triplets, the author of The Long Winter, Laura Ingalls Wilder, lived in De Smet. Therefore, the answer to the question is {De Smet}. 32 | 33 | Q: Who is the coach of the team owned by Steve Bisciotti? 34 | Knowledge Triplets: Steve Bisciotti, sports.professional_sports_team.owner_s, Baltimore Ravens 35 | Steve Bisciotti, sports.sports_team_owner.teams_owned, Baltimore Ravens 36 | Steve Bisciotti, organization.organization_founder.organizations_founded, Allegis Group 37 | A: Based on the given knowledge triplets, the coach of the team owned by Steve Bisciotti is not explicitly mentioned. However, it can be inferred that the team owned by Steve Bisciotti is the Baltimore Ravens, a professional sports team. Therefore, additional knowledge about the current coach of the Baltimore Ravens can be used to answer the question. 38 | 39 | Q: Rift Valley Province is located in a nation that uses which form of currency? 40 | Knowledge Triplets: Rift Valley Province, location.administrative_division.country, Kenya 41 | Rift Valley Province, location.location.geolocation, UnName_Entity 42 | Rift Valley Province, location.mailing_address.state_province_region, UnName_Entity 43 | Kenya, location.country.currency_used, Kenyan shilling 44 | A: Based on the given knowledge triplets, Rift Valley Province is located in Kenya, which uses the Kenyan shilling as its currency. Therefore, the answer to the question is {Kenyan shilling}. 45 | 46 | Q: The country with the National Anthem of Bolivia borders which nations? 47 | Knowledge Triplets: National Anthem of Bolivia, government.national_anthem_of_a_country.anthem, UnName_Entity 48 | National Anthem of Bolivia, music.composition.composer, Leopoldo Benedetto Vincenti 49 | National Anthem of Bolivia, music.composition.lyricist, José Ignacio de Sanjinés 50 | UnName_Entity, government.national_anthem_of_a_country.country, Bolivia 51 | Bolivia, location.country.national_anthem, UnName_Entity 52 | A: Based on the given knowledge triplets, we can infer that the National Anthem of Bolivia is the anthem of Bolivia. Therefore, the country with the National Anthem of Bolivia is Bolivia itself. However, the given knowledge triplets do not provide information about which nations border Bolivia. To answer this question, we need additional knowledge about the geography of Bolivia and its neighboring countries. 53 | 54 | Q: {} 55 | """ 56 | 57 | prompt_evaluate="""Given a question and the associated retrieved knowledge graph triplets (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triplets and your knowledge (Yes or No). 58 | Q: Find the person who said \"Taste cannot be controlled by law\", what did this person die from? 59 | Knowledge Triplets: Taste cannot be controlled by law., media_common.quotation.author, Thomas Jefferson 60 | A: {No}. Based on the given knowledge triplets, it's not sufficient to answer the entire question. The triplets only provide information about the person who said "Taste cannot be controlled by law," which is Thomas Jefferson. To answer the second part of the question, it's necessary to have additional knowledge about where Thomas Jefferson's dead. 61 | 62 | Q: The artist nominated for The Long Winter lived where? 63 | Knowledge Triplets: The Long Winter, book.written_work.author, Laura Ingalls Wilder 64 | Laura Ingalls Wilder, people.person.places_lived, Unknown-Entity 65 | Unknown-Entity, people.place_lived.location, De Smet 66 | A: {Yes}. Based on the given knowledge triplets, the author of The Long Winter, Laura Ingalls Wilder, lived in De Smet. Therefore, the answer to the question is {De Smet}. 67 | 68 | Q: Who is the coach of the team owned by Steve Bisciotti? 69 | Knowledge Triplets: Steve Bisciotti, sports.professional_sports_team.owner_s, Baltimore Ravens 70 | Steve Bisciotti, sports.sports_team_owner.teams_owned, Baltimore Ravens 71 | Steve Bisciotti, organization.organization_founder.organizations_founded, Allegis Group 72 | A: {No}. Based on the given knowledge triplets, the coach of the team owned by Steve Bisciotti is not explicitly mentioned. However, it can be inferred that the team owned by Steve Bisciotti is the Baltimore Ravens, a professional sports team. Therefore, additional knowledge about the current coach of the Baltimore Ravens can be used to answer the question. 73 | 74 | Q: Rift Valley Province is located in a nation that uses which form of currency? 75 | Knowledge Triplets: Rift Valley Province, location.administrative_division.country, Kenya 76 | Rift Valley Province, location.location.geolocation, UnName_Entity 77 | Rift Valley Province, location.mailing_address.state_province_region, UnName_Entity 78 | Kenya, location.country.currency_used, Kenyan shilling 79 | A: {Yes}. Based on the given knowledge triplets, Rift Valley Province is located in Kenya, which uses the Kenyan shilling as its currency. Therefore, the answer to the question is {Kenyan shilling}. 80 | 81 | Q: The country with the National Anthem of Bolivia borders which nations? 82 | Knowledge Triplets: National Anthem of Bolivia, government.national_anthem_of_a_country.anthem, UnName_Entity 83 | National Anthem of Bolivia, music.composition.composer, Leopoldo Benedetto Vincenti 84 | National Anthem of Bolivia, music.composition.lyricist, José Ignacio de Sanjinés 85 | UnName_Entity, government.national_anthem_of_a_country.country, Bolivia 86 | Bolivia, location.country.national_anthem, UnName_Entity 87 | A: {No}. Based on the given knowledge triplets, we can infer that the National Anthem of Bolivia is the anthem of Bolivia. Therefore, the country with the National Anthem of Bolivia is Bolivia itself. However, the given knowledge triplets do not provide information about which nations border Bolivia. To answer this question, we need additional knowledge about the geography of Bolivia and its neighboring countries. 88 | 89 | """ 90 | 91 | generate_directly = """Q: What state is home to the university that is represented in sports by George Washington Colonials men's basketball? 92 | A: First, the education institution has a sports team named George Washington Colonials men's basketball in is George Washington University , Second, George Washington University is in Washington D.C. The answer is {Washington, D.C.}. 93 | 94 | Q: Who lists Pramatha Chaudhuri as an influence and wrote Jana Gana Mana? 95 | A: First, Bharoto Bhagyo Bidhata wrote Jana Gana Mana. Second, Bharoto Bhagyo Bidhata lists Pramatha Chaudhuri as an influence. The answer is {Bharoto Bhagyo Bidhata}. 96 | 97 | Q: Who was the artist nominated for an award for You Drive Me Crazy? 98 | A: First, the artist nominated for an award for You Drive Me Crazy is Britney Spears. The answer is {Jason Allen Alexander}. 99 | 100 | Q: What person born in Siegen influenced the work of Vincent Van Gogh? 101 | A: First, Peter Paul Rubens, Claude Monet and etc. influenced the work of Vincent Van Gogh. Second, Peter Paul Rubens born in Siegen. The answer is {Peter Paul Rubens}. 102 | 103 | Q: What is the country close to Russia where Mikheil Saakashvii holds a government position? 104 | A: First, China, Norway, Finland, Estonia and Georgia is close to Russia. Second, Mikheil Saakashvii holds a government position at Georgia. The answer is {Georgia}. 105 | 106 | Q: What drug did the actor who portrayed the character Urethane Wheels Guy overdosed on? 107 | A: First, Mitchell Lee Hedberg portrayed character Urethane Wheels Guy. Second, Mitchell Lee Hedberg overdose Heroin. The answer is {Heroin}.""" 108 | 109 | score_entity_candidates_prompt_wiki = """Please score the entities' contribution to the question on a scale from 0 to 1 (the sum of the scores of all entities is 1). 110 | Q: Staten Island Summer, starred what actress who was a cast member of "Saturday Night Live"? 111 | Relation: cast member 112 | Entites: Ashley Greene; Bobby Moynihan; Camille Saviola; Cecily Strong; Colin Jost; Fred Armisen; Gina Gershon; Graham Phillips; Hassan Johnson; Jackson Nicoll; Jim Gaffigan; John DeLuca; Kate Walsh; Mary Birdsong 113 | Score: 0.0, 0.0, 0.0, 0.4, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4, 0.0 114 | To score the entities\' contribution to the question, we need to determine which entities are relevant to the question and have a higher likelihood of being the correct answer. 115 | In this case, we are looking for an actress who was a cast member of "Saturday Night Live" and starred in the movie "Staten Island Summer." Based on this information, we can eliminate entities that are not actresses or were not cast members of "Saturday Night Live." 116 | The relevant entities that meet these criteria are:\n- Ashley Greene\n- Cecily Strong\n- Fred Armisen\n- Gina Gershon\n- Kate Walsh\n\nTo distribute the scores, we can assign a higher score to entities that are more likely to be the correct answer. In this case, the most likely answer would be an actress who was a cast member of "Saturday Night Live" around the time the movie was released. 117 | Based on this reasoning, the scores could be assigned as follows:\n- Ashley Greene: 0\n- Cecily Strong: 0.4\n- Fred Armisen: 0.2\n- Gina Gershon: 0\n- Kate Walsh: 0.4 118 | 119 | Q: {} 120 | Relation: {} 121 | Entites: """ 122 | 123 | prompt_evaluate_wiki="""Given a question and the associated retrieved knowledge graph triplets (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triplets and your knowledge (Yes or No). 124 | Q: Viscount Yamaji Motoharu was a general in the early Imperial Japanese Army which belonged to which Empire? 125 | Knowledge Triplets: Imperial Japanese Army, allegiance, Emperor of Japan 126 | Yamaji Motoharu, allegiance, Emperor of Japan 127 | Yamaji Motoharu, military rank, general 128 | A: {Yes}. Based on the given knowledge triplets and my knowledge, Viscount Yamaji Motoharu, who was a general in the early Imperial Japanese Army, belonged to the Empire of Japan. Therefore, the answer to the question is {Empire of Japan}. 129 | 130 | Q: Who is the coach of the team owned by Steve Bisciotti? 131 | Knowledge Triplets: psilocybin, described by source, Opium Law, 132 | psilocybin, found in taxon, Gymnopilus purpuratus, 133 | psilocybin, found in taxon, Gymnopilus spectabilis, 134 | Opium Law, part of, norcodeine (stereochemistry defined), 135 | Gymnopilus purpuratus, edibility, psychoactive mushroom, 136 | Gymnopilus spectabilis, parent taxon, Gymnopilus 137 | A: {No}. Based on the given knowledge triplets and my knowledge, the specific psychedelic compound found in the Psilocybin genus mushroom that is converted to psilocin by the body is not explicitly mentioned. Therefore, additional knowledge about the specific compounds and their conversion to psilocin is required to answer the question. 138 | 139 | Q: Which tennis player is younger, John Newcombe or Květa Peschke? 140 | Knowledge Triplets: Květa Peschke, date of birth, +1975-07-09T00:00:00Z, 141 | John Newcombe, date of birth, +1944-05-23T00:00:00Z, 142 | John Newcombe, country of citizenship, Australia 143 | A: {Yes}. Based on the given knowledge triplets and my knowledge, John Newcombe was born on May 23, 1944, and Květa Peschke was born on July 9, 1975. Therefore, {Květa Peschke} is younger than John Newcombe. 144 | 145 | Q: At what stadium did Mychal George Thompson play home games with the San Antonio Spurs? 146 | Knowledge Triplets: San Antonio Spurs, home venue, AT&T Center 147 | San Antonio Spurs, home venue, Alamodome 148 | San Antonio Spurs, home venue, Fort Worth Convention Center 149 | AT&T Center, occupant, San Antonio Spurs 150 | Fort Worth Convention Center, located in the administrative territorial entity, Texas 151 | Fort Worth Convention Center, occupant, San Antonio Spurs 152 | A: {Yes}. Based on the given knowledge triplets and my knowledge, Mychal George Thompson played home games with the San Antonio Spurs at the AT&T Center. Therefore, the answer to the question is {AT&T Center}. 153 | 154 | """ 155 | extract_relation_prompt_wiki = """Please retrieve %s relations (separated by semicolon) that contribute to the question and rate their contribution on a scale from 0 to 1 (the sum of the scores of %s relations is 1). 156 | Q: Mesih Pasha's uncle became emperor in what year? 157 | Topic Entity: Mesih Pasha 158 | Relations: 159 | 1. wiki.relation.child 160 | 2. wiki.relation.country_of_citizenship 161 | 3. wiki.relation.date_of_birth 162 | 4. wiki.relation.family 163 | 5. wiki.relation.father 164 | 6. wiki.relation.languages_spoken, written_or_signed 165 | 7. wiki.relation.military_rank 166 | 8. wiki.relation.occupation 167 | 9. wiki.relation.place_of_death 168 | 10. wiki.relation.position_held 169 | 11. wiki.relation.religion_or_worldview 170 | 12. wiki.relation.sex_or_gender 171 | 13. wiki.relation.sibling 172 | 14. wiki.relation.significant_event 173 | A: 1. {wiki.relation.family (Score: 0.5)}: This relation is highly relevant as it can provide information about the family background of Mesih Pasha, including his uncle who became emperor. 174 | 2. {wiki.relation.father (Score: 0.4)}: Uncle is father's brother, so father might provide some information as well. 175 | 3. {wiki.relation.position held (Score: 0.1)}: This relation is moderately relevant as it can provide information about any significant positions held by Mesih Pasha or his uncle that could be related to becoming an emperor. 176 | 177 | Q: Van Andel Institute was founded in part by what American businessman, who was best known as co-founder of the Amway Corporation? 178 | Topic Entity: Van Andel Institute 179 | Relations: 180 | 1. wiki.relation.affiliation 181 | 2. wiki.relation.country 182 | 3. wiki.relation.donations 183 | 4. wiki.relation.educated_at 184 | 5. wiki.relation.employer 185 | 6. wiki.relation.headquarters_location 186 | 7. wiki.relation.legal_form 187 | 8. wiki.relation.located_in_the_administrative_territorial_entity 188 | 9. wiki.relation.total_revenue 189 | A: 1. {wiki.relation.affiliation (Score: 0.4)}: This relation is relevant because it can provide information about the individuals or organizations associated with the Van Andel Institute, including the American businessman who co-founded the Amway Corporation. 190 | 2. {wiki.relation.donations (Score: 0.3)}: This relation is relevant because it can provide information about the financial contributions made to the Van Andel Institute, which may include donations from the American businessman in question. 191 | 3. {wiki.relation.educated_at (Score: 0.3)}: This relation is relevant because it can provide information about the educational background of the American businessman, which may have influenced his involvement in founding the Van Andel Institute. 192 | 193 | Q: """ 194 | 195 | answer_prompt_wiki = """Given a question and the associated retrieved knowledge graph triplets (entity, relation, entity), you are asked to answer the question with these triplets and your own knowledge. 196 | Q: Viscount Yamaji Motoharu was a general in the early Imperial Japanese Army which belonged to which Empire? 197 | Knowledge Triplets: Imperial Japanese Army, allegiance, Emperor of Japan 198 | Yamaji Motoharu, allegiance, Emperor of Japan 199 | Yamaji Motoharu, military rank, general 200 | A: Based on the given knowledge triplets and my knowledge, Viscount Yamaji Motoharu, who was a general in the early Imperial Japanese Army, belonged to the Empire of Japan. Therefore, the answer to the question is {Empire of Japan}. 201 | 202 | Q: Who is the coach of the team owned by Steve Bisciotti? 203 | Knowledge Triplets: psilocybin, described by source, Opium Law, 204 | psilocybin, found in taxon, Gymnopilus purpuratus, 205 | psilocybin, found in taxon, Gymnopilus spectabilis, 206 | Opium Law, part of, norcodeine (stereochemistry defined), 207 | Gymnopilus purpuratus, edibility, psychoactive mushroom, 208 | Gymnopilus spectabilis, parent taxon, Gymnopilus 209 | A: Based on the given knowledge triplets and my knowledge, the specific psychedelic compound found in the Psilocybin genus mushroom that is converted to psilocin by the body is not explicitly mentioned. Therefore, additional knowledge about the specific compounds and their conversion to psilocin is required to answer the question. 210 | 211 | Q: Which tennis player is younger, John Newcombe or Květa Peschke? 212 | Knowledge Triplets: Květa Peschke, date of birth, +1975-07-09T00:00:00Z, 213 | John Newcombe, date of birth, +1944-05-23T00:00:00Z, 214 | John Newcombe, country of citizenship, Australia 215 | A: Based on the given knowledge triplets and my knowledge, John Newcombe was born on May 23, 1944, and Květa Peschke was born on July 9, 1975. Therefore, {Květa Peschke} is younger than John Newcombe. 216 | 217 | Q: At what stadium did Mychal George Thompson play home games with the San Antonio Spurs? 218 | Knowledge Triplets: San Antonio Spurs, home venue, AT&T Center 219 | San Antonio Spurs, home venue, Alamodome 220 | San Antonio Spurs, home venue, Fort Worth Convention Center 221 | AT&T Center, occupant, San Antonio Spurs 222 | Fort Worth Convention Center, located in the administrative territorial entity, Texas 223 | Fort Worth Convention Center, occupant, San Antonio Spurs 224 | A: Based on the given knowledge triplets and my knowledge, Mychal George Thompson played home games with the San Antonio Spurs at the AT&T Center. Therefore, the answer to the question is {AT&T Center}. 225 | 226 | Q: {} 227 | """ -------------------------------------------------------------------------------- /ToG/server_urls.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasolSun36/ToG/934064c4a8a391c7f351339676f29ed764e40054/ToG/server_urls.txt -------------------------------------------------------------------------------- /ToG/utils.py: -------------------------------------------------------------------------------- 1 | from freebase_func import * 2 | from prompt_list import * 3 | import json 4 | import re 5 | import time 6 | import openai 7 | from rank_bm25 import BM25Okapi 8 | from sentence_transformers import util 9 | from sentence_transformers import SentenceTransformer 10 | 11 | def retrieve_top_docs(query, docs, model, width=3): 12 | """ 13 | Retrieve the topn most relevant documents for the given query. 14 | 15 | Parameters: 16 | - query (str): The input query. 17 | - docs (list of str): The list of documents to search from. 18 | - model_name (str): The name of the SentenceTransformer model to use. 19 | - width (int): The number of top documents to return. 20 | 21 | Returns: 22 | - list of float: A list of scores for the topn documents. 23 | - list of str: A list of the topn documents. 24 | """ 25 | 26 | query_emb = model.encode(query) 27 | doc_emb = model.encode(docs) 28 | 29 | scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist() 30 | 31 | doc_score_pairs = sorted(list(zip(docs, scores)), key=lambda x: x[1], reverse=True) 32 | 33 | top_docs = [pair[0] for pair in doc_score_pairs[:width]] 34 | top_scores = [pair[1] for pair in doc_score_pairs[:width]] 35 | 36 | return top_docs, top_scores 37 | 38 | 39 | def compute_bm25_similarity(query, corpus, width=3): 40 | """ 41 | Computes the BM25 similarity between a question and a list of relations, 42 | and returns the topn relations with the highest similarity along with their scores. 43 | 44 | Args: 45 | - question (str): Input question. 46 | - relations_list (list): List of relations. 47 | - width (int): Number of top relations to return. 48 | 49 | Returns: 50 | - list, list: topn relations with the highest similarity and their respective scores. 51 | """ 52 | 53 | tokenized_corpus = [doc.split(" ") for doc in corpus] 54 | bm25 = BM25Okapi(tokenized_corpus) 55 | tokenized_query = query.split(" ") 56 | 57 | doc_scores = bm25.get_scores(tokenized_query) 58 | 59 | relations = bm25.get_top_n(tokenized_query, corpus, n=width) 60 | doc_scores = sorted(doc_scores, reverse=True)[:width] 61 | 62 | return relations, doc_scores 63 | 64 | 65 | def clean_relations(string, entity_id, head_relations): 66 | pattern = r"{\s*(?P[^()]+)\s+\(Score:\s+(?P[0-9.]+)\)}" 67 | relations=[] 68 | for match in re.finditer(pattern, string): 69 | relation = match.group("relation").strip() 70 | if ';' in relation: 71 | continue 72 | score = match.group("score") 73 | if not relation or not score: 74 | return False, "output uncompleted.." 75 | try: 76 | score = float(score) 77 | except ValueError: 78 | return False, "Invalid score" 79 | if relation in head_relations: 80 | relations.append({"entity": entity_id, "relation": relation, "score": score, "head": True}) 81 | else: 82 | relations.append({"entity": entity_id, "relation": relation, "score": score, "head": False}) 83 | if not relations: 84 | return False, "No relations found" 85 | return True, relations 86 | 87 | 88 | def if_all_zero(topn_scores): 89 | return all(score == 0 for score in topn_scores) 90 | 91 | 92 | def clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations): 93 | relations = [] 94 | if if_all_zero(topn_scores): 95 | topn_scores = [float(1/len(topn_scores))] * len(topn_scores) 96 | for relation in topn_relations: 97 | if relation in head_relations: 98 | relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": True}) 99 | else: 100 | relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": False}) 101 | return True, relations 102 | 103 | 104 | def run_llm(prompt, temperature, max_tokens, opeani_api_keys, engine="gpt-3.5-turbo"): 105 | if "llama" not in engine.lower(): 106 | openai.api_key = "EMPTY" 107 | openai.api_base = "http://localhost:8000/v1" # your local llama server port 108 | engine = openai.Model.list()["data"][0]["id"] 109 | else: 110 | openai.api_key = opeani_api_keys 111 | 112 | messages = [{"role":"system","content":"You are an AI assistant that helps people find information."}] 113 | message_prompt = {"role":"user","content":prompt} 114 | messages.append(message_prompt) 115 | print("start openai") 116 | while(f == 0): 117 | try: 118 | response = openai.ChatCompletion.create( 119 | model=engine, 120 | messages = messages, 121 | temperature=temperature, 122 | max_tokens=max_tokens, 123 | frequency_penalty=0, 124 | presence_penalty=0) 125 | result = response["choices"][0]['message']['content'] 126 | f = 1 127 | except: 128 | print("openai error, retry") 129 | time.sleep(2) 130 | print("end openai") 131 | return result 132 | 133 | def construct_relation_prune_prompt(question, entity_name, total_relations, args): 134 | return extract_relation_prompt % (args.width, args.width) + question + '\nTopic Entity: ' + entity_name + '\nRelations: '+ '; '.join(total_relations) + "\nA: " 135 | 136 | 137 | def construct_entity_score_prompt(question, relation, entity_candidates): 138 | return score_entity_candidates_prompt.format(question, relation) + "; ".join(entity_candidates) + '\nScore: ' 139 | 140 | def relation_search_prune(entity_id, entity_name, pre_relations, pre_head, question, args): 141 | sparql_relations_extract_head = sparql_head_relations % (entity_id) 142 | head_relations = execurte_sparql(sparql_relations_extract_head) 143 | head_relations = replace_relation_prefix(head_relations) 144 | 145 | sparql_relations_extract_tail= sparql_tail_relations % (entity_id) 146 | tail_relations = execurte_sparql(sparql_relations_extract_tail) 147 | tail_relations = replace_relation_prefix(tail_relations) 148 | 149 | if args.remove_unnecessary_rel: 150 | head_relations = [relation for relation in head_relations if not abandon_rels(relation)] 151 | tail_relations = [relation for relation in tail_relations if not abandon_rels(relation)] 152 | 153 | 154 | if len(pre_relations) != 0 and pre_head !=-1: 155 | tail_relations = [rel for rel in tail_relations if not pre head or rel not in pre_relations] 156 | head_relations = [rel for rel in head_relations if pre_head or rel not in pre_relations] 157 | 158 | head_relations = list(set(head_relations)) 159 | tail_relations = list(set(tail_relations)) 160 | total_relations = head_relations+tail_relations 161 | total_relations.sort() # make sure the order in prompt is always equal 162 | 163 | if args.prune_tools == "llm": 164 | prompt = construct_relation_prune_prompt(question, entity_name, total_relations, args) 165 | 166 | result = run_llm(prompt, args.temperature_exploration, args.max_length, args.opeani_api_keys, args.LLM_type) 167 | flag, retrieve_relations_with_scores = clean_relations(result, entity_id, head_relations) 168 | 169 | elif args.prune_tools == "bm25": 170 | topn_relations, topn_scores = compute_bm25_similarity(question, total_relations, args.width) 171 | flag, retrieve_relations_with_scores = clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations) 172 | else: 173 | model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b') 174 | topn_relations, topn_scores = retrieve_top_docs(question, total_relations, model, args.width) 175 | flag, retrieve_relations_with_scores = clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations) 176 | 177 | if flag: 178 | return retrieve_relations_with_scores 179 | else: 180 | return [] # format error or too small max_length 181 | 182 | 183 | def entity_search(entity, relation, head=True): 184 | if head: 185 | tail_entities_extract = sparql_tail_entities_extract% (entity, relation) 186 | entities = execurte_sparql(tail_entities_extract) 187 | else: 188 | head_entities_extract = sparql_head_entities_extract% (entity, relation) 189 | entities = execurte_sparql(head_entities_extract) 190 | 191 | 192 | entity_ids = replace_entities_prefix(entities) 193 | new_entity = [entity for entity in entity_ids if entity.startswith("m.")] 194 | return new_entity 195 | 196 | 197 | def entity_score(question, entity_candidates_id, score, relation, args): 198 | entity_candidates = [id2entity_name_or_type(entity_id) for entity_id in entity_candidates_id] 199 | if all_unknown_entity(entity_candidates): 200 | return [1/len(entity_candidates) * score] * len(entity_candidates), entity_candidates, entity_candidates_id 201 | entity_candidates = del_unknown_entity(entity_candidates) 202 | if len(entity_candidates) == 1: 203 | return [score], entity_candidates, entity_candidates_id 204 | if len(entity_candidates) == 0: 205 | return [0.0], entity_candidates, entity_candidates_id 206 | 207 | # make sure the id and entity are in the same order 208 | zipped_lists = sorted(zip(entity_candidates, entity_candidates_id)) 209 | entity_candidates, entity_candidates_id = zip(*zipped_lists) 210 | entity_candidates = list(entity_candidates) 211 | entity_candidates_id = list(entity_candidates_id) 212 | if args.prune_tools == "llm": 213 | prompt = construct_entity_score_prompt(question, relation, entity_candidates, score) 214 | 215 | result = run_llm(prompt, args.temperature_exploration, args.max_length, args.opeani_api_keys, args.LLM_type) 216 | return [float(x) * score for x in clean_scores(result, entity_candidates)], entity_candidates, entity_candidates_id 217 | 218 | elif args.prune_tools == "bm25": 219 | topn_entities, topn_scores = compute_bm25_similarity(question, entity_candidates, args.width) 220 | else: 221 | model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b') 222 | topn_entities, topn_scores = retrieve_top_docs(question, entity_candidates, model, args.width) 223 | if if_all_zero(topn_scores): 224 | topn_scores = [float(1/len(topn_scores))] * len(topn_scores) 225 | return [float(x) * score for x in topn_scores], topn_entities, entity_candidates_id 226 | 227 | 228 | def all_unknown_entity(entity_candidates): 229 | return all(candidate == "UnName_Entity" for candidate in entity_candidates) 230 | 231 | def del_unknown_entity(entity_candidates): 232 | if len(entity_candidates)==1 and entity_candidates[0]=="UnName_Entity": 233 | return entity_candidates 234 | entity_candidates = [candidate for candidate in entity_candidates if candidate != "UnName_Entity"] 235 | return entity_candidates 236 | 237 | def clean_scores(string, entity_candidates): 238 | scores = re.findall(r'\d+\.\d+', string) 239 | scores = [float(number) for number in scores] 240 | if len(scores) == len(entity_candidates): 241 | return scores 242 | else: 243 | print("All entities are created equal.") 244 | return [1/len(entity_candidates)] * len(entity_candidates) 245 | 246 | def update_history(entity_candidates, entity, scores, entity_candidates_id, total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head): 247 | if len(entity_candidates) == 0: 248 | entity_candidates.append("[FINISH]") 249 | entity_candidates_id = ["[FINISH_ID]"] 250 | candidates_relation = [entity['relation']] * len(entity_candidates) 251 | topic_entities = [entity['entity']] * len(entity_candidates) 252 | head_num = [entity['head']] * len(entity_candidates) 253 | total_candidates.extend(entity_candidates) 254 | total_scores.extend(scores) 255 | total_relations.extend(candidates_relation) 256 | total_entities_id.extend(entity_candidates_id) 257 | total_topic_entities.extend(topic_entities) 258 | total_head.extend(head_num) 259 | return total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head 260 | 261 | 262 | def generate_answer(question, cluster_chain_of_entities, args): 263 | prompt = answer_prompt + question + '\n' 264 | chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist]) 265 | prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: ' 266 | result = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type) 267 | return result 268 | 269 | 270 | def save_2_jsonl(question, answer, cluster_chain_of_entities, file_name): 271 | dict = {"question":question, "results": answer, "reasoning_chains": cluster_chain_of_entities} 272 | with open("ToG_{}.jsonl".format(file_name), "a") as outfile: 273 | json_str = json.dumps(dict) 274 | outfile.write(json_str + "\n") 275 | 276 | 277 | def entity_prune(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores, args): 278 | zipped = list(zip(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores)) 279 | sorted_zipped = sorted(zipped, key=lambda x: x[5], reverse=True) 280 | sorted_entities_id, sorted_relations, sorted_candidates, sorted_topic_entities, sorted_head, sorted_scores = [x[0] for x in sorted_zipped], [x[1] for x in sorted_zipped], [x[2] for x in sorted_zipped], [x[3] for x in sorted_zipped], [x[4] for x in sorted_zipped], [x[5] for x in sorted_zipped] 281 | 282 | entities_id, relations, candidates, topics, heads, scores = sorted_entities_id[:args.width], sorted_relations[:args.width], sorted_candidates[:args.width], sorted_topic_entities[:args.width], sorted_head[:args.width], sorted_scores[:args.width] 283 | merged_list = list(zip(entities_id, relations, candidates, topics, heads, scores)) 284 | filtered_list = [(id, rel, ent, top, hea, score) for id, rel, ent, top, hea, score in merged_list if score != 0] 285 | if len(filtered_list) ==0: 286 | return False, [], [], [], [] 287 | entities_id, relations, candidates, tops, heads, scores = map(list, zip(*filtered_list)) 288 | 289 | tops = [id2entity_name_or_type(entity_id) for entity_id in tops] 290 | cluster_chain_of_entities = [[(tops[i], relations[i], candidates[i]) for i in range(len(candidates))]] 291 | return True, cluster_chain_of_entities, entities_id, relations, heads 292 | 293 | 294 | def reasoning(question, cluster_chain_of_entities, args): 295 | prompt = prompt_evaluate + question 296 | chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist]) 297 | prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: ' 298 | 299 | response = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type) 300 | 301 | result = extract_answer(response) 302 | if if_true(result): 303 | return True, response 304 | else: 305 | return False, response 306 | 307 | def extract_answer(text): 308 | start_index = text.find("{") 309 | end_index = text.find("}") 310 | if start_index != -1 and end_index != -1: 311 | return text[start_index+1:end_index].strip() 312 | else: 313 | return "" 314 | 315 | def if_true(prompt): 316 | if prompt.lower().strip().replace(" ","")=="yes": 317 | return True 318 | return False 319 | 320 | def half_stop(question, cluster_chain_of_entities, args): 321 | print("No new knowledge added during search depth %d, stop searching." % args.depth) 322 | answer = generate_answer(question, cluster_chain_of_entities, args) 323 | save_2_jsonl(question, answer, cluster_chain_of_entities, file_name=args.dataset) 324 | 325 | 326 | def generate_without_explored_paths(question, args): 327 | prompt = generate_directly + "\n\nQ: " + question + "\nA:" 328 | response = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type) 329 | return response 330 | 331 | def prepare_dataset(dataset_name): 332 | if dataset_name == 'cwq': 333 | with open('../data/cwq.json',encoding='utf-8') as f: 334 | datas = json.load(f) 335 | question_string = 'question' 336 | elif dataset_name == 'webqsp': 337 | with open('../data/WebQSP.json',encoding='utf-8') as f: 338 | datas = json.load(f) 339 | question_string = 'RawQuestion' 340 | elif dataset_name == 'grailqa': 341 | with open('../data/grailqa.json',encoding='utf-8') as f: 342 | datas = json.load(f) 343 | question_string = 'question' 344 | elif dataset_name == 'simpleqa': 345 | with open('../data/SimpleQA.json',encoding='utf-8') as f: 346 | datas = json.load(f) 347 | question_string = 'question' 348 | elif dataset_name == 'qald': 349 | with open('../data/qald_10-en.json',encoding='utf-8') as f: 350 | datas = json.load(f) 351 | question_string = 'question' 352 | elif dataset_name == 'webquestions': 353 | with open('../data/WebQuestions.json',encoding='utf-8') as f: 354 | datas = json.load(f) 355 | question_string = 'question' 356 | elif dataset_name == 'trex': 357 | with open('../data/T-REX.json',encoding='utf-8') as f: 358 | datas = json.load(f) 359 | question_string = 'input' 360 | elif dataset_name == 'zeroshotre': 361 | with open('../data/Zero_Shot_RE.json',encoding='utf-8') as f: 362 | datas = json.load(f) 363 | question_string = 'input' 364 | elif dataset_name == 'creak': 365 | with open('../data/creak.json',encoding='utf-8') as f: 366 | datas = json.load(f) 367 | question_string = 'sentence' 368 | else: 369 | print("dataset not found, you should pick from {cwq, webqsp, grailqa, simpleqa, qald, webquestions, trex, zeroshotre, creak}.") 370 | exit(-1) 371 | return datas, question_string -------------------------------------------------------------------------------- /ToG/wiki_func.py: -------------------------------------------------------------------------------- 1 | from prompt_list import * 2 | import json 3 | import openai 4 | import re 5 | import time 6 | 7 | def clean_relations(string, entity_id, head_relations): 8 | pattern = r"{\s*(?P[^()]+)\s+\(Score:\s+(?P[0-9.]+)\)}" 9 | relations=[] 10 | for match in re.finditer(pattern, string): 11 | relation = match.group("relation").strip() 12 | if ';' in relation: 13 | continue 14 | score = match.group("score") 15 | if not relation or not score: 16 | return False, "output uncompleted.." 17 | try: 18 | score = float(score) 19 | except ValueError: 20 | return False, "Invalid score" 21 | if relation in head_relations: 22 | relations.append({"entity": entity_id, "relation": relation, "score": score, "head": True}) 23 | else: 24 | relations.append({"entity": entity_id, "relation": relation, "score": score, "head": False}) 25 | if not relations: 26 | return False, "No relations found" 27 | return True, relations 28 | 29 | 30 | 31 | def run_llm(prompt, temperature, max_tokens, opeani_api_keys, engine="gpt-3.5-turbo"): 32 | if "llama" not in engine.lower(): 33 | openai.api_key = "EMPTY" 34 | openai.api_base = "http://localhost:8000/v1" # your local llama server port 35 | engine = openai.Model.list()["data"][0]["id"] 36 | else: 37 | openai.api_key = opeani_api_keys 38 | 39 | messages = [{"role":"system","content":"You are an AI assistant that helps people find information."}] 40 | message_prompt = {"role":"user","content":prompt} 41 | messages.append(message_prompt) 42 | print("start openai") 43 | while(f == 0): 44 | try: 45 | response = openai.ChatCompletion.create( 46 | model=engine, 47 | messages = messages, 48 | temperature=temperature, 49 | max_tokens=max_tokens, 50 | frequency_penalty=0, 51 | presence_penalty=0) 52 | result = response["choices"][0]['message']['content'] 53 | f = 1 54 | except: 55 | print("openai error, retry") 56 | time.sleep(2) 57 | print("end openai") 58 | return result 59 | 60 | def construct_relation_prune_prompt(question, entity_name, total_relations, args): 61 | return extract_relation_prompt_wiki % (args.width, args.width)+question+'\nTopic Entity: '+entity_name+ '\nRelations:\n'+'\n'.join([f"{i}. {item}" for i, item in enumerate(total_relations, start=1)])+'A:' 62 | 63 | 64 | def check_end_word(s): 65 | words = [" ID", " code", " number", "instance of", "website", "URL", "inception", "image", " rate", " count"] 66 | return any(s.endswith(word) for word in words) 67 | 68 | def abandon_rels(relation): 69 | useless_relation_list = ["category's main topic", "topic\'s main category", "stack exchange site", 'main subject', 'country of citizenship', "commons category", "commons gallery", "country of origin", "country", "nationality"] 70 | if check_end_word(relation) or 'wikidata' in relation.lower() or 'wikimedia' in relation.lower() or relation.lower() in useless_relation_list: 71 | return True 72 | return False 73 | 74 | def construct_entity_score_prompt(question, relation, entity_candidates): 75 | return score_entity_candidates_prompt_wiki.format(question, relation) + "; ".join(entity_candidates) + '\nScore: ' 76 | 77 | def relation_search_prune(entity_id, entity_name, pre_relations, pre_head, question, args, wiki_client): 78 | relations = wiki_client.query_all("get_all_relations_of_an_entity", entity_id) 79 | head_relations = relations['head'] 80 | tail_relations = relations['tail'] 81 | 82 | if args.remove_unnecessary_rel: 83 | head_relations = [relation for relation in head_relations if not abandon_rels(relation)] 84 | tail_relations = [relation for relation in tail_relations if not abandon_rels(relation)] 85 | 86 | if len(pre_relations)!=0 and pre_head !=-1: 87 | tail_relations = [rel for rel in pre_relations if pre_head and rel not in tail_relations] 88 | head_relations = [rel for rel in pre_relations if not pre_head and rel not in head_relations] 89 | 90 | head_relations = list(set(head_relations)) 91 | tail_relations = list(set(tail_relations)) 92 | total_relations = head_relations+tail_relations 93 | total_relations.sort() # make sure the order in prompt is always equal 94 | 95 | prompt = construct_relation_prune_prompt(question, entity_name, total_relations, args) 96 | 97 | result = run_llm(prompt, args.temperature_exploration, args.max_length, args.opeani_api_keys, args.LLM_type) 98 | flag, retrieve_relations_with_scores = clean_relations(result, entity_id, head_relations) 99 | 100 | if flag: 101 | return retrieve_relations_with_scores 102 | else: 103 | return [] # format error or too small max_length 104 | 105 | def del_all_unknown_entity(entity_candidates_id, entity_candidates_name): 106 | if len(entity_candidates_name) == 1 and entity_candidates_name[0] == "N/A": 107 | return entity_candidates_id, entity_candidates_name 108 | 109 | new_candidates_id = [] 110 | new_candidates_name = [] 111 | for i, candidate in enumerate(entity_candidates_name): 112 | if candidate != "N/A": 113 | new_candidates_id.append(entity_candidates_id[i]) 114 | new_candidates_name.append(candidate) 115 | 116 | return new_candidates_id, new_candidates_name 117 | 118 | def all_zero(topn_scores): 119 | return all(score == 0 for score in topn_scores) 120 | 121 | def entity_search(entity, relation, wiki_client, head): 122 | 123 | rid = wiki_client.query_all("label2pid", relation) 124 | if not rid or rid == "Not Found!": 125 | return [], [] 126 | 127 | rid_str = rid.pop() 128 | 129 | entities = wiki_client.query_all("get_tail_entities_given_head_and_relation", entity, rid_str) 130 | 131 | if head: 132 | entities_set = entities['tail'] 133 | else: 134 | entities_set = entities['head'] 135 | 136 | if not entities_set: 137 | values = wiki_client.query_all("get_tail_values_given_head_and_relation", entity, rid_str) 138 | return [], list(values) 139 | 140 | id_list = [item['qid'] for item in entities_set] 141 | name_list = [item['label'] if item['label'] != "N/A" else "Unname_Entity" for item in entities_set] 142 | 143 | return id_list, name_list 144 | 145 | def clean_scores(string, entity_candidates): 146 | scores = re.findall(r'\d+\.\d+', string) 147 | scores = [float(number) for number in scores] 148 | if len(scores) == len(entity_candidates): 149 | return scores 150 | else: 151 | print("All entities are created equal.") 152 | return [1/len(entity_candidates)] * len(entity_candidates) 153 | 154 | def entity_score(question, entity_candidates_id, entity_candidates, score, relation, args): 155 | if len(entity_candidates) == 1: 156 | return [score], entity_candidates, entity_candidates_id 157 | if len(entity_candidates) == 0: 158 | return [0.0], entity_candidates, entity_candidates_id 159 | 160 | # make sure the id and entity are in the same order 161 | zipped_lists = sorted(zip(entity_candidates, entity_candidates_id)) 162 | entity_candidates, entity_candidates_id = zip(*zipped_lists) 163 | entity_candidates = list(entity_candidates) 164 | entity_candidates_id = list(entity_candidates_id) 165 | 166 | prompt = construct_entity_score_prompt(question, relation, entity_candidates, score) 167 | 168 | result = run_llm(prompt, args.temperature_exploration, args.max_length, args.opeani_api_keys, args.LLM_type) 169 | entity_scores = clean_scores(result, entity_candidates) 170 | if all_zero(entity_scores): 171 | return [1/len(entity_candidates) * score] * len(entity_candidates), entity_candidates, entity_candidates_id 172 | else: 173 | return [float(x) * score for x in entity_scores], entity_candidates, entity_candidates_id 174 | 175 | 176 | def all_unknown_entity(entity_candidates): 177 | return all(candidate == "UnName_Entity" for candidate in entity_candidates) 178 | 179 | def del_unknown_entity(entity_candidates): 180 | if len(entity_candidates)==1 and entity_candidates[0]=="UnName_Entity": 181 | return entity_candidates 182 | entity_candidates = [candidate for candidate in entity_candidates if candidate != "UnName_Entity"] 183 | return entity_candidates 184 | 185 | 186 | def update_history(entity_candidates, entity, scores, entity_candidates_id, total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head, value_flag): 187 | if value_flag: 188 | scores = [1/len(entity_candidates) * entity['score']] 189 | candidates_relation = [entity['relation']] * len(entity_candidates) 190 | topic_entities = [entity['entity']] * len(entity_candidates) 191 | head_num = [entity['head']] * len(entity_candidates) 192 | total_candidates.extend(entity_candidates) 193 | total_scores.extend(scores) 194 | total_relations.extend(candidates_relation) 195 | total_entities_id.extend(entity_candidates_id) 196 | total_topic_entities.extend(topic_entities) 197 | total_head.extend(head_num) 198 | 199 | 200 | return total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head 201 | 202 | 203 | def generate_answer(question, cluster_chain_of_entities, args): 204 | prompt = answer_prompt_wiki + question + '\n' 205 | chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist]) 206 | prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: ' 207 | result = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type) 208 | return result 209 | 210 | 211 | def save_2_jsonl(question, answer, cluster_chain_of_entities, file_name): 212 | dict = {"question":question, "turbo_results": answer, "chains": cluster_chain_of_entities} 213 | with open("ToG_{}.jsonl".format(file_name), "a") as outfile: 214 | json_str = json.dumps(dict) 215 | outfile.write(json_str + "\n") 216 | 217 | 218 | def entity_prune(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores, args, wiki_client): 219 | zipped = list(zip(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores)) 220 | sorted_zipped = sorted(zipped, key=lambda x: x[5], reverse=True) 221 | sorted_entities_id, sorted_relations, sorted_candidates, sorted_topic_entities, sorted_head, sorted_scores = [x[0] for x in sorted_zipped], [x[1] for x in sorted_zipped], [x[2] for x in sorted_zipped], [x[3] for x in sorted_zipped], [x[4] for x in sorted_zipped], [x[5] for x in sorted_zipped] 222 | 223 | entities_id, relations, candidates, topics, heads, scores = sorted_entities_id[:args.width], sorted_relations[:args.width], sorted_candidates[:args.width], sorted_topic_entities[:args.width], sorted_head[:args.width], sorted_scores[:args.width] 224 | merged_list = list(zip(entities_id, relations, candidates, topics, heads, scores)) 225 | filtered_list = [(id, rel, ent, top, hea, score) for id, rel, ent, top, hea, score in merged_list if score != 0] 226 | if len(filtered_list) ==0: 227 | return False, [], [], [], [] 228 | entities_id, relations, candidates, tops, heads, scores = map(list, zip(*filtered_list)) 229 | tops = [wiki_client.query_all("qid2label", entity_id).pop() if (entity_name := wiki_client.query_all("qid2label", entity_id)) != "Not Found!" else "Unname_Entity" for entity_id in tops] 230 | cluster_chain_of_entities = [[(tops[i], relations[i], candidates[i]) for i in range(len(candidates))]] 231 | return True, cluster_chain_of_entities, entities_id, relations, heads 232 | 233 | def reasoning(question, cluster_chain_of_entities, args): 234 | prompt = prompt_evaluate_wiki + question 235 | chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist]) 236 | prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: ' 237 | 238 | response = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type) 239 | 240 | result = extract_answer(response) 241 | if if_true(result): 242 | return True, response 243 | else: 244 | return False, response 245 | 246 | def extract_answer(text): 247 | start_index = text.find("{") 248 | end_index = text.find("}") 249 | if start_index != -1 and end_index != -1: 250 | return text[start_index+1:end_index].strip() 251 | else: 252 | return "" 253 | 254 | def if_true(prompt): 255 | if prompt.lower().strip().replace(" ","")=="yes": 256 | return True 257 | return False 258 | 259 | def half_stop(question, cluster_chain_of_entities, args): 260 | print("No new knowledge added during search depth %d, stop searching." % args.depth) 261 | answer = generate_answer(question, cluster_chain_of_entities, args) 262 | save_2_jsonl(question, answer, cluster_chain_of_entities, file_name=args.dataset) 263 | 264 | 265 | def generate_without_explored_paths(question, args): 266 | prompt = generate_directly + "\n\nQ: " + question + "\nA:" 267 | response = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type) 268 | return response 269 | 270 | def prepare_dataset(dataset_name): 271 | if dataset_name == 'cwq': 272 | with open('../data/cwq.json',encoding='utf-8') as f: 273 | datas = json.load(f) 274 | question_string = 'question' 275 | elif dataset_name == 'webqsp': 276 | with open('../data/WebQSP.json',encoding='utf-8') as f: 277 | datas = json.load(f) 278 | question_string = 'RawQuestion' 279 | elif dataset_name == 'grailqa': 280 | with open('../data/grailqa.json',encoding='utf-8') as f: 281 | datas = json.load(f) 282 | question_string = 'question' 283 | elif dataset_name == 'simpleqa': 284 | with open('../data/SimpleQA.json',encoding='utf-8') as f: 285 | datas = json.load(f) 286 | question_string = 'question' 287 | elif dataset_name == 'qald': 288 | with open('../data/qald_10-en.json',encoding='utf-8') as f: 289 | datas = json.load(f) 290 | question_string = 'question' 291 | elif dataset_name == 'webquestions': 292 | with open('../data/WebQuestions.json',encoding='utf-8') as f: 293 | datas = json.load(f) 294 | question_string = 'question' 295 | elif dataset_name == 'trex': 296 | with open('../data/T-REX.json',encoding='utf-8') as f: 297 | datas = json.load(f) 298 | question_string = 'input' 299 | elif dataset_name == 'zeroshotre': 300 | with open('../data/Zero_Shot_RE.json',encoding='utf-8') as f: 301 | datas = json.load(f) 302 | question_string = 'input' 303 | elif dataset_name == 'creak': 304 | with open('../data/creak.json',encoding='utf-8') as f: 305 | datas = json.load(f) 306 | question_string = 'sentence' 307 | else: 308 | print("dataset not found") 309 | exit(-1) 310 | return datas, question_string -------------------------------------------------------------------------------- /Wikidata/.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | .vscode/ 4 | .pytest_cache/ 5 | *.egg-info/ 6 | *.out 7 | 8 | server_log** 9 | **.log 10 | 11 | mlflow/ -------------------------------------------------------------------------------- /Wikidata/README.md: -------------------------------------------------------------------------------- 1 | # simple-wikidata-db 2 | 3 | This library provides a set of scripts to download the Wikidata dump, sort it into staging files, and query the data in these staged files in a distributed manner. The staging is optimized for (1) querying time, and (2) simplicity. 4 | 5 | This library is helpful if you'd like to issue queries like: 6 | 7 | - Fetch all QIDs which are related to [Q38257](https://www.wikidata.org/wiki/Q38257) 8 | - Fetch all triples corresponding to the relation (e.g. [P35](https://www.wikidata.org/wiki/Property:P35)) 9 | - Fetch all aliases for a QID 10 | 11 | ## Downloading the dump 12 | 13 | A full list of available dumps is available [here](https://dumps.wikimedia.org/wikidatawiki/entities/). To fetch the most recent dump, run: 14 | 15 | ``` 16 | wget https://dumps.wikimedia.org/wikidatawiki/entities/latest-all.json.gz 17 | ``` 18 | 19 | or, if aria2c is installed, run: 20 | 21 | ``` 22 | aria2c --max-connection-per-server 16 https://dumps.wikimedia.org/wikidatawiki/entities/latest-all.json.gz 23 | ``` 24 | 25 | Downloading takes about 2-5 hours (depending on bandwidth). 26 | 27 | ## Processing the dump 28 | 29 | The original downloaded wikidata dump is a single file and combines different types of information (alias names, properties, relations, etc). We preprocess the dump by iterating over the compressed file, and saving information to different subdirectories. For more information, see the [Data Format](#data-format). To preprocess the dump, run: 30 | 31 | ```bash 32 | python3 preprocess_dump.py \ 33 | --input_file $PATH_TO_COMPRESSED_WIKI_JSON \ 34 | --out_dir $DIR_TO_SAVE_DATA_TO \ 35 | --batch_size $BATCH_SIZE \ 36 | --language_id $LANG 37 | ``` 38 | 39 | These arguments are: 40 | 41 | - `input_file`: path to the compressed JSON Wikidata dump json file 42 | - `out_dir`: path to directory where tables will be written. Subdirectories will be created under this directory for each table. 43 | - 'num_lines_read': number of lines to read. Useful for debuggin. 44 | - `num_lines_in_dump`: specifies the total number of lines in the uncompressed json file. This is used by a tqdm bar to track progress. As of January 2022, there are 95,980,335 lines in latest-all.json. It takes about ~21 minutes to run `wc -l latest-all.json`. 45 | - `batch_size`: The number of triples to write into each batch file that is saved under a table directory. 46 | - `language_id`: The language to use when extracting entity labels, aliases, descriptions, and wikipedia links 47 | 48 | Additionally, running with the flag `--test` will terminate after processing an initial chunk, allowing you to verify results. 49 | 50 | It takes ~5 hours to process the dump when running with 90 processes on a 1024GB machine with 56 cores. A tqdm progress bar should provide a more accurate estimate while data is being processed. 51 | 52 | ## Data Format 53 | 54 | The Wikidata dump is made available as a single, unweildy JSON file. To make querying/filtering easier, we split the information contained in this JSON file into multiple **tables**, where each table contains a certain type of information. The tables we create are described below: 55 | 56 | | Table name | Table description | Table schema| 57 | | --------------- |:--------------------| :-----| 58 | | labels | Holds the labels for different entities | qid: the QID of the entity
label: the entity's label ('name') | 59 | | descriptions | Holds the descriptions for different entities | qid: the QID of the entity
description: the entity's description (short summary at the top of the page) | 60 | | aliases | Holds the aliases for different entities | qid: the QID of the entity
alias: an alias for the entity | 61 | | entity_rels | Holds statements where the value of the statement is another wikidata entity | claim_id: the ID for the statement
qid: the ID for wikidata entity
property_id: the ID for the property
value: the qid for the value wikidata entity | 62 | | external_ids | Holds statements where the value of the statement is an identifier to an external database (e.g. Musicbrainz, Freebase, etc) | claim_id: the ID for the statement
qid: the ID for wikidata entity
property_id: the ID for the property
value: the identifier for the external ID | 63 | | entity_values | Holds statements where the value of the statement is a string/quantity | claim_id: the ID for the statement
qid: the ID for wikidata entity
property_id: the ID for the property
value: the value for this property | 64 | | qualifiers | Holds qualifiers for statements | qualifier_id: the ID for the qualifier
claim_id: the ID for the claim being qualified
property_id: the ID for the property
value: the value of the qualifier | 65 | | wikipedia_links | Holds links to Wikipedia items | qid: the QID of the entity
wiki_title: link to corresponding wikipedia entity | 66 | | plabels | Holds PIDs and their corresponding labels | pid: the PID of the property
label: the label for the property | 67 | ---- 68 | 69 |

70 | Each table is stored in a directory, where the content of the table is written to multiple jsonl files stored inside the directory (each file contains a subset of the rows in the table). Each line in the file corresponds to a different triple. Partitioning the table's contents into multiple files improves querying speed--we can process each file in parallel. 71 | 72 | ## Querying scripts 73 | 74 | Two scripts are provided as examples of how to write parallelized queries over the data once it's been preprocessed: 75 | 76 | - `fatching/fetch_with_name.py`: fetches all QIDs which are associated with a particular name. For example: all entities associated with the name 'Victoria', which would inclue entities like Victoria Beckham, or Victoria (Australia). 77 | - `fatching/fetch_with_rel_and_value.py`: fetches all QIDs which have a relationship with a specific value. For example: all triples where the relation is P413 and the object of the relation is Q622747. 78 | 79 | # Instructions for deploying a query service locally 80 | 81 | ## Making index 82 | 83 | Use `simple_wikidata_db/db_deploy/build_index` to build a dict index for in-memory fast query: 84 | 85 | ```bash 86 | python simple_wikidata_db/db_deploy/build_index.py \ 87 | --input_dir $PREPROCESS_DATA_DIR \ 88 | --output_dir $INDEX_FILE_DIR \ 89 | --num_chunks $NUM_CHUNKS \ 90 | --num_workers $NUM_WORKERS \ 91 | --chunk_idx $CHUNK_IDX 92 | ``` 93 | 94 | - `input_dir`: The preprocessed wikidata dump dir. It should be the output dir of the preprocessing job described above. 95 | - `output_dir`: The dir where the generated index is stored. it is usually a subfolder of `input_dir`, in this case it is `input_dir`/indices 96 | - `num_chunks`: The number of chunks to split the data into. This is used to split the data into multiple files, which can be queried in parallel. 97 | - `num_workers`: number of subprocesses in this job. 98 | - `chunk_idx`: Which chunk of the whole index to build. By default it's -1, where all chunks are built sequentially. If you want to build a specific chunk, set it to the index of the chunk. 99 | 100 | Note that index is deeply coupled with query interfaces. So if you have any new requirements for querying the data, you may need to modify the index building script `build_index.py` by yourself. Construction of index chunks can be parallized or distributed. 101 | 102 | Please also note that index building is a memory-intensive task. A chunk of 1/10 the total size of the data requires ~200GB of memory. So you may need to adjust the chunk size according to your machine's memory. For a 1/10 chunk index, its construction takes ~30mins for worker=400. 103 | 104 | ## Deploying the database 105 | 106 | Use `simple_wikidata_db/db_deploy/server` to start a server with a chunk of data and listening on a port: 107 | 108 | ```bash 109 | python simple_wikidata_db/db_deploy/server.py \ 110 | --data_dir $INDEX_FILE_DIR \ 111 | --chunk_number $CHUNK_NUMBER 112 | ``` 113 | 114 | - `data_dir`: The dir of the processed data. Its `indices` subfolder should contain the index files. Usually this should be the same as `input_dir` in the index building step. 115 | - `chunk_number`: The chunk number of the data to be served. This should be the same as the `chunk_idx` in the index building step. A single process can only serve one chunk of data. If you want to serve multiple chunks, you need to start multiple processes. 116 | 117 | The service is implemented via XML-RPC. A server process will listen on port 23546 (this is hardcoded in `server.py`). And clients can connect to the server via `http://[server_ip]:23546`. All queries are implemented via python's builtin support for `xmlrpc`, and code is written with the help of ChatGPT. 118 | 119 | Similar to index construction, this service is deployed in a distributed manner. Specifically, each server process reads 1 chunk of data, which takes ~200GB of memory for a chunk of 1/10 the total size. So you may need to adjust the chunk size according to your machine's memory. Reading index is also very time-consuming. For a 1/10 chunk index, it takes ~20mins to load the index into memory. 120 | 121 | ## Querying the database 122 | 123 | An example client is provided in `db_deploy/client.py`. It can be used to query the database: 124 | 125 | ```bash 126 | python simple_wikidata_db/db_deploy/client.py --addr_list server_urls.txt 127 | ``` 128 | 129 | For a single query, the client sends the query to all server nodes, get results, and aggregate locally. 130 | -------------------------------------------------------------------------------- /Wikidata/requirements.txt: -------------------------------------------------------------------------------- 1 | ujson==5.1.0 2 | pathlib==1.0.1 3 | -------------------------------------------------------------------------------- /Wikidata/scripts/build_index.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {0..9}; do 4 | python -u simple_wikidata_db/db_deploy/build_index.py --input_dir /dev/shm/wikidump_inmem/wikidump_20230116 --num_chunks 10 --chunk_idx $i --output_dir /dev/shm/wikidump_inmem/wikidump_20230116/indices > logs/build_index_${i}.log 2>&1 & 5 | done 6 | 7 | wait 8 | -------------------------------------------------------------------------------- /Wikidata/scripts/start_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm server_urls.txt 4 | 5 | for i in {0..9}; do 6 | python -u simple_wikidata_db/db_deploy/server.py --data_dir /dev/shm/wikidump_inmem/wikidump_20230116 --chunk_number $i --port 2315$i > logs/server_log_$i.log 2>&1 & 7 | done 8 | 9 | wait 10 | -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/db_deploy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasolSun36/ToG/934064c4a8a391c7f351339676f29ed764e40054/Wikidata/simple_wikidata_db/db_deploy/__init__.py -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/db_deploy/build_index.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | import pickle 4 | from collections import defaultdict 5 | from multiprocessing import Pool 6 | from numpy import require 7 | from tqdm import tqdm 8 | import math 9 | from dataclasses import dataclass 10 | import ujson as json 11 | from simple_wikidata_db.db_deploy.utils import ( 12 | a_factory, 13 | Entity, 14 | Relation, 15 | get_batch_files, 16 | jsonl_generator, 17 | read_relation_label, 18 | read_entity_label, 19 | ) 20 | import typing as tp 21 | 22 | 23 | def read_relation_entities(filename): 24 | relation_entities = [] 25 | for item in jsonl_generator(filename): 26 | relation_entities.append( 27 | { 28 | "head_qid": item["qid"], 29 | "pid": item["property_id"], 30 | "tail_qid": item["value"], 31 | } 32 | ) 33 | return relation_entities 34 | 35 | 36 | def read_tail_values(filename): 37 | relation_entities = [] 38 | for item in jsonl_generator(filename): 39 | relation_entities.append( 40 | { 41 | "head_qid": item["qid"], 42 | "pid": item["property_id"], 43 | "tail_value": item["value"], 44 | } 45 | ) 46 | return relation_entities 47 | 48 | 49 | def read_external_ids(filename): 50 | relation_entities = [] 51 | for item in jsonl_generator(filename): 52 | relation_entities.append( 53 | { 54 | "qid": item["qid"], 55 | "pid": item["property_id"], 56 | "value": item["value"], 57 | } 58 | ) 59 | return relation_entities 60 | 61 | 62 | from collections import defaultdict 63 | from typing import DefaultDict 64 | 65 | 66 | def merge_defaultdicts( 67 | dd1: DefaultDict[str, list], dd2: DefaultDict[str, list] 68 | ) -> DefaultDict[str, list]: 69 | # Create a new defaultdict to hold the merged results 70 | merged_dict = defaultdict(list, dd1) 71 | 72 | # Merge dd1 and dd2 73 | for key, val in dd2.items(): 74 | merged_dict[key].extend(val) 75 | 76 | return merged_dict 77 | 78 | 79 | def filter_value( 80 | dict_list: tp.List[tp.Dict], 81 | key: str, 82 | ) -> tp.List[tp.Dict]: 83 | ret_list = [] 84 | for dict_item in tqdm(dict_list, desc='filter_value iter over dict_list'): 85 | if key in dict_item: 86 | ret_list.append(dict_item[key]) 87 | # Flatten the list 88 | ret_list = [item for sublist in ret_list for item in sublist] 89 | return key, ret_list 90 | 91 | 92 | def main(args): 93 | os.makedirs(args.output_dir, exist_ok=True) 94 | data_dir = args.input_dir 95 | num_chunks = args.num_chunks # adjust as needed 96 | pool = Pool(processes=args.num_workers) # adjust as needed 97 | 98 | files_index = { 99 | "labels": get_batch_files(os.path.join(data_dir, "labels")), 100 | "descriptions": get_batch_files(os.path.join(data_dir, "descriptions")), 101 | "aliases": get_batch_files(os.path.join(data_dir, "aliases")), 102 | "entity_rels": get_batch_files(os.path.join(data_dir, "entity_rels")), 103 | "external_ids": get_batch_files(os.path.join(data_dir, "external_ids")), 104 | "entity_values": get_batch_files( 105 | os.path.join(data_dir, "entity_values") 106 | ), 107 | "qualifiers": get_batch_files(os.path.join(data_dir, "qualifiers")), 108 | "wikipedia_links": get_batch_files( 109 | os.path.join(data_dir, "wikipedia_links") 110 | ), 111 | "plabels": get_batch_files(os.path.join(data_dir, "plabels")), 112 | } 113 | chunk_size_entity_rels = math.ceil( 114 | len(files_index["entity_rels"]) / num_chunks 115 | ) 116 | chunk_size_entity_values = math.ceil( 117 | len(files_index["entity_values"]) / num_chunks 118 | ) 119 | chunk_size_external_ids = math.ceil( 120 | len(files_index["external_ids"]) / num_chunks 121 | ) 122 | 123 | # QID/PID <=> Name mapping 124 | qid_to_name = {} 125 | name_to_qid = {} 126 | name_to_qid_list = [] 127 | pid_to_name = {} 128 | name_to_pid = {} 129 | name_to_pid_list = [] 130 | print(f"args.chunk_idx: {args.chunk_idx}") 131 | 132 | # Step 1: Read Entity label <=> QID mapping 133 | print("Reading entity labels ...") 134 | for output in tqdm( 135 | pool.imap_unordered( 136 | read_entity_label, files_index["labels"], chunksize=1 137 | ), 138 | ): 139 | qid_to_name.update(output[0]) 140 | # name_to_qid_list.append(output[1]) 141 | 142 | # all_entity_names = set() 143 | # for d in name_to_qid_list: 144 | # all_entity_names.update(d.keys()) 145 | # counter = 0 146 | # for name, qids in tqdm( 147 | # pool.imap_unordered( 148 | # partial(filter_value, dict_list=name_to_qid_list), 149 | # all_entity_names, 150 | # chunksize=1, 151 | # ), 152 | # ): 153 | # name_to_qid[name] = qids 154 | # if counter < 5: 155 | # print(f"{name}: {qids}") 156 | # counter += 1 157 | 158 | # Step 2: Read Relation label <=> PID mapping 159 | print("Reading relation labels ...") 160 | for output in tqdm( 161 | pool.imap_unordered( 162 | read_relation_label, files_index["plabels"], chunksize=1 163 | ), 164 | ): 165 | pid_to_name.update(output[0]) 166 | # name_to_pid_list.append(output[1]) 167 | 168 | # all_relation_names = set.intersection(*map(set, name_to_pid_list)) 169 | # counter = 0 170 | # for name, pids in tqdm( 171 | # pool.imap_unordered( 172 | # partial(filter_value, dict_list=name_to_pid_list), 173 | # all_relation_names, 174 | # chunksize=1, 175 | # ), 176 | # ): 177 | # name_to_pid[name] = pids 178 | # if counter < 5: 179 | # print(f"{name}: {pids}") 180 | # counter += 1 181 | 182 | # missing_qids = [] 183 | # missing_pids = [] 184 | 185 | # Step 3: Read entity_rels, entity_values, and external_ids 186 | for i in range(num_chunks): 187 | if args.chunk_idx != -1 and i != args.chunk_idx: 188 | continue 189 | start = i * chunk_size_entity_rels 190 | end = start + chunk_size_entity_rels 191 | chunk_files = files_index["entity_rels"][start:end] 192 | 193 | relations_linked_to_entities = defaultdict(a_factory) 194 | entities_related_to_relent_pair = defaultdict(a_factory) 195 | tail_values = defaultdict(list) 196 | 197 | print(f"Processing `entity_rels` of chunk {i+1} ...") 198 | for output in tqdm( 199 | pool.imap_unordered( 200 | read_relation_entities, 201 | chunk_files, 202 | chunksize=1, 203 | ) 204 | ): 205 | for item in output: 206 | # if item["pid"] not in pid_to_name: 207 | # missing_pids.append(item["pid"]) 208 | # if item["tail_qid"] not in qid_to_name: 209 | # missing_qids.append(item["tail_qid"]) 210 | rel = Relation( 211 | pid=item["pid"], 212 | label=pid_to_name.get(item["pid"], "N/A"), 213 | ) 214 | relations_linked_to_entities[item["head_qid"]]["head"].append( 215 | rel 216 | ) 217 | relations_linked_to_entities[item["tail_qid"]]["tail"].append( 218 | rel 219 | ) 220 | 221 | entities_related_to_relent_pair[ 222 | f'{item["head_qid"]}@{item["pid"]}' 223 | ]["tail"].append( 224 | Entity( 225 | qid=item["tail_qid"], 226 | label=qid_to_name.get(item["tail_qid"], "N/A"), 227 | ) 228 | ) 229 | entities_related_to_relent_pair[ 230 | f'{item["tail_qid"]}@{item["pid"]}' 231 | ]["head"].append( 232 | Entity( 233 | qid=item["head_qid"], 234 | label=qid_to_name.get(item["head_qid"], "N/A"), 235 | ) 236 | ) 237 | 238 | print(f"Processing `entity_values` of chunk {i+1} ...") 239 | start = i * chunk_size_entity_values 240 | end = start + chunk_size_entity_values 241 | chunk_files = files_index["entity_values"][start:end] 242 | for output in tqdm( 243 | pool.imap_unordered( 244 | read_tail_values, 245 | chunk_files, 246 | chunksize=1, 247 | ) 248 | ): 249 | for item in output: 250 | # if item["pid"] not in pid_to_name: 251 | # missing_pids.append(item["pid"]) 252 | relations_linked_to_entities[item["head_qid"]]["head"].append( 253 | Relation( 254 | pid=item["pid"], 255 | label=pid_to_name.get(item["pid"], "N/A"), 256 | ) 257 | ) 258 | tail_values[f'{item["head_qid"]}@{item["pid"]}'].append( 259 | item["tail_value"] 260 | ) 261 | 262 | external_ids = defaultdict(list) 263 | mid_to_qid = defaultdict(list) 264 | print(f"Processing `external_ids` of chunk {i+1} ...") 265 | start = i * chunk_size_external_ids 266 | end = start + chunk_size_external_ids 267 | chunk_files = files_index["external_ids"][start:end] 268 | for output in tqdm( 269 | pool.imap_unordered(read_external_ids, chunk_files, chunksize=1) 270 | ): 271 | for item in output: 272 | external_ids[f'{item["qid"]}@{item["pid"]}'].append( 273 | item["value"] 274 | ) 275 | mid_to_qid[f'{item["value"]}'].append(item["qid"]) 276 | 277 | # Dump 3 index files 278 | with open( 279 | f"{args.output_dir}/relation_entities_chunk_{i+1}.pickle", "wb" 280 | ) as handle: 281 | pickle.dump( 282 | relations_linked_to_entities, 283 | handle, 284 | protocol=pickle.HIGHEST_PROTOCOL, 285 | ) 286 | with open( 287 | f"{args.output_dir}/tail_entities_chunk_{i+1}.pickle", "wb" 288 | ) as handle: 289 | pickle.dump( 290 | entities_related_to_relent_pair, 291 | handle, 292 | protocol=pickle.HIGHEST_PROTOCOL, 293 | ) 294 | 295 | with open( 296 | f"{args.output_dir}/tail_values_chunk_{i+1}.pickle", "wb" 297 | ) as handle: 298 | pickle.dump(tail_values, handle, protocol=pickle.HIGHEST_PROTOCOL) 299 | 300 | with open( 301 | f"{args.output_dir}/external_ids_chunk_{i+1}.pickle", "wb" 302 | ) as handle: 303 | pickle.dump(external_ids, handle, protocol=pickle.HIGHEST_PROTOCOL) 304 | 305 | with open( 306 | f"{args.output_dir}/mid_to_qid_chunk_{i+1}.pickle", "wb" 307 | ) as handle: 308 | pickle.dump(mid_to_qid, handle, protocol=pickle.HIGHEST_PROTOCOL) 309 | 310 | # print( 311 | # f"Missing QIDs: {len(missing_qids)}, total: {len(qid_to_name)}, ratio: {len(missing_qids)/len(qid_to_name)}" 312 | # ) 313 | # print( 314 | # f"Missing PIDs: {len(missing_pids)}, total: {len(pid_to_name)}, ratio: {len(missing_pids)/len(pid_to_name)}" 315 | # ) 316 | 317 | 318 | if __name__ == "__main__": 319 | import argparse 320 | 321 | parser = argparse.ArgumentParser() 322 | parser.add_argument( 323 | "--input_dir", 324 | type=str, 325 | required=True, 326 | help="Preprocessed Wikidata dumpfile directory", 327 | ) 328 | parser.add_argument( 329 | "--output_dir", 330 | type=str, 331 | required=True, 332 | help="Output directory", 333 | ) 334 | parser.add_argument("--num_chunks", type=int, default=5) 335 | parser.add_argument("--num_workers", type=int, default=400) 336 | parser.add_argument("--chunk_idx", type=int, default=-1) 337 | 338 | args = parser.parse_args() 339 | main(args) 340 | -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/db_deploy/client.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import xmlrpc.client 3 | import typing as tp 4 | from dataclasses import dataclass 5 | from concurrent.futures import ThreadPoolExecutor 6 | from tqdm import tqdm 7 | from simple_wikidata_db.db_deploy.utils import Entity, Relation, a_factory 8 | import requests 9 | 10 | 11 | class WikidataQueryClient: 12 | def __init__(self, url: str): 13 | self.url = url 14 | self.server = xmlrpc.client.ServerProxy(url) 15 | 16 | def label2qid(self, label: str) -> str: 17 | return self.server.label2qid(label) 18 | 19 | def label2pid(self, label: str) -> str: 20 | return self.server.label2pid(label) 21 | 22 | def pid2label(self, pid: str) -> str: 23 | return self.server.pid2label(pid) 24 | 25 | def qid2label(self, qid: str) -> str: 26 | return self.server.qid2label(qid) 27 | 28 | def get_all_relations_of_an_entity( 29 | self, entity_qid: str 30 | ) -> tp.Dict[str, tp.List]: 31 | return self.server.get_all_relations_of_an_entity(entity_qid) 32 | 33 | def get_tail_entities_given_head_and_relation( 34 | self, head_qid: str, relation_pid: str 35 | ) -> tp.Dict[str, tp.List]: 36 | return self.server.get_tail_entities_given_head_and_relation( 37 | head_qid, relation_pid 38 | ) 39 | 40 | def get_tail_values_given_head_and_relation( 41 | self, head_qid: str, relation_pid: str 42 | ) -> tp.List[str]: 43 | return self.server.get_tail_values_given_head_and_relation( 44 | head_qid, relation_pid 45 | ) 46 | 47 | def get_external_id_given_head_and_relation( 48 | self, head_qid: str, relation_pid: str 49 | ) -> tp.List[str]: 50 | return self.server.get_external_id_given_head_and_relation( 51 | head_qid, relation_pid 52 | ) 53 | 54 | def mid2qid(self, mid: str) -> str: 55 | return self.server.mid2qid(mid) 56 | 57 | 58 | import time 59 | import typing as tp 60 | from concurrent.futures import ThreadPoolExecutor 61 | 62 | 63 | class MultiServerWikidataQueryClient: 64 | def __init__(self, urls: tp.List[str]): 65 | self.clients = [WikidataQueryClient(url) for url in urls] 66 | self.executor = ThreadPoolExecutor(max_workers=len(urls)) 67 | # test connections 68 | start_time = time.perf_counter() 69 | self.test_connections() 70 | end_time = time.perf_counter() 71 | print(f"Connection testing took {end_time - start_time} seconds") 72 | 73 | def test_connections(self): 74 | def test_url(client): 75 | try: 76 | # Check if server provides the system.listMethods function. 77 | client.server.system.listMethods() 78 | return True 79 | except Exception as e: 80 | print(f"Failed to connect to {client.url}. Error: {str(e)}") 81 | return False 82 | 83 | start_time = time.perf_counter() 84 | futures = [ 85 | self.executor.submit(test_url, client) for client in self.clients 86 | ] 87 | results = [f.result() for f in futures] 88 | end_time = time.perf_counter() 89 | # print(f"Testing connections took {end_time - start_time} seconds") 90 | # Remove clients that failed to connect 91 | self.clients = [ 92 | client for client, result in zip(self.clients, results) if result 93 | ] 94 | if not self.clients: 95 | raise Exception("Failed to connect to all URLs") 96 | 97 | def query_all(self, method, *args): 98 | start_time = time.perf_counter() 99 | futures = [ 100 | self.executor.submit(getattr(client, method), *args) 101 | for client in self.clients 102 | ] 103 | # Retrieve results and filter out 'Not Found!' 104 | is_dict_return = method in [ 105 | "get_all_relations_of_an_entity", 106 | "get_tail_entities_given_head_and_relation", 107 | ] 108 | results = [f.result() for f in futures] 109 | end_time = time.perf_counter() 110 | # print(f"HTTP Queries took {end_time - start_time} seconds") 111 | 112 | start_time = time.perf_counter() 113 | real_results = set() if not is_dict_return else {"head": [], "tail": []} 114 | for res in results: 115 | if isinstance(res, str) and res == "Not Found!": 116 | continue 117 | elif isinstance(res, tp.List): 118 | if len(res) == 0: 119 | continue 120 | if isinstance(res[0], tp.List): 121 | res_flattened = itertools.chain(*res) 122 | real_results.update(res_flattened) 123 | continue 124 | real_results.update(res) 125 | elif is_dict_return: 126 | real_results["head"].extend(res["head"]) 127 | real_results["tail"].extend(res["tail"]) 128 | else: 129 | real_results.add(res) 130 | end_time = time.perf_counter() 131 | # print(f"Querying all took {end_time - start_time} seconds") 132 | 133 | return real_results if len(real_results) > 0 else "Not Found!" 134 | 135 | 136 | if __name__ == "__main__": 137 | import argparse 138 | 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument( 141 | "--addr_list", 142 | type=str, 143 | required=True, 144 | help="path to server address list", 145 | ) 146 | args = parser.parse_args() 147 | 148 | with open(args.addr_list, "r") as f: 149 | server_addrs = f.readlines() 150 | server_addrs = [addr.strip() for addr in server_addrs] 151 | print(f"Server addresses: {server_addrs}") 152 | client = MultiServerWikidataQueryClient(server_addrs) 153 | print( 154 | f'MSFT\'s ticker code is {client.query_all("get_tail_values_given_head_and_relation","Q2283","P249",)}' 155 | ) -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/db_deploy/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import typing as tp 4 | from collections import defaultdict 5 | from dataclasses import dataclass 6 | from functools import partial 7 | from multiprocessing import Pool 8 | from xmlrpc.server import SimpleXMLRPCRequestHandler, SimpleXMLRPCServer 9 | from numpy import require 10 | from sqlalchemy import true 11 | from simple_wikidata_db.db_deploy.utils import ( 12 | Entity, 13 | Relation, 14 | a_factory, 15 | jsonl_generator, 16 | get_batch_files, 17 | read_entity_label, 18 | read_relation_label, 19 | ) 20 | import ujson as json 21 | from tqdm import tqdm 22 | import itertools 23 | 24 | 25 | def merge_list_of_list(dd1, dd2): 26 | """ 27 | Optimized function to merge two defaultdict(list) instances. 28 | For common keys, lists will be concatenated. 29 | """ 30 | merged_dd = dd1 31 | 32 | # Using dictionary comprehension to merge 33 | for key in dd2.keys(): 34 | merged_dd[key].append(dd2[key]) 35 | 36 | return merged_dd 37 | 38 | 39 | class WikidataQueryServer: 40 | def __init__( 41 | self, 42 | chunk_number: int, 43 | data_dir: str, 44 | num_workers: int = 400, 45 | ): 46 | self.num_workers = num_workers 47 | self.pool = Pool(processes=self.num_workers) 48 | 49 | self.files_index = { 50 | "labels": get_batch_files(os.path.join(data_dir, "labels")), 51 | "plabels": get_batch_files(os.path.join(data_dir, "plabels")), 52 | } 53 | 54 | self.qid_to_name = {} 55 | self.name_to_qid = defaultdict(list) 56 | self.pid_to_name = {} 57 | self.name_to_pid = defaultdict(list) 58 | print("Reading relation labels ...") 59 | for output in tqdm( 60 | self.pool.imap_unordered( 61 | read_relation_label, self.files_index["plabels"], chunksize=1 62 | ) 63 | ): 64 | self.pid_to_name.update(output[0]) 65 | self.name_to_pid = merge_list_of_list(self.name_to_pid, output[1]) 66 | for k, v in self.name_to_pid.items(): 67 | self.name_to_pid[k] = list(itertools.chain(*v)) 68 | 69 | print("Reading entity labels ...") 70 | for output in tqdm( 71 | self.pool.imap_unordered( 72 | read_entity_label, self.files_index["labels"], chunksize=1 73 | ) 74 | ): 75 | self.qid_to_name.update(output[0]) 76 | self.name_to_qid = merge_list_of_list(self.name_to_qid, output[1]) 77 | 78 | for k, v in self.name_to_qid.items(): 79 | self.name_to_qid[k] = list(itertools.chain(*v)) 80 | 81 | print("Reading links ...") 82 | chunk_number = chunk_number + 1 83 | print( 84 | f"Reading {args.data_dir}/indices/relation_entities_chunk_{chunk_number}.pickle" 85 | ) 86 | with open( 87 | f"{args.data_dir}/indices/relation_entities_chunk_{chunk_number}.pickle", 88 | "rb", 89 | ) as handle: 90 | self.relation_entities = pickle.load(handle) 91 | print( 92 | f"Reading {args.data_dir}/indices/tail_entities_chunk_{chunk_number}.pickle" 93 | ) 94 | with open( 95 | f"{args.data_dir}/indices/tail_entities_chunk_{chunk_number}.pickle", 96 | "rb", 97 | ) as handle: 98 | self.tail_entities = pickle.load(handle) 99 | print( 100 | f"Reading {args.data_dir}/indices/tail_values_chunk_{chunk_number}.pickle" 101 | ) 102 | with open( 103 | f"{args.data_dir}/indices/tail_values_chunk_{chunk_number}.pickle", 104 | "rb", 105 | ) as handle: 106 | self.tail_values = pickle.load(handle) 107 | print( 108 | f"Reading {args.data_dir}/indices/external_ids_chunk_{chunk_number}.pickle" 109 | ) 110 | with open( 111 | f"{args.data_dir}/indices/external_ids_chunk_{chunk_number}.pickle", 112 | "rb", 113 | ) as handle: 114 | self.external_ids = pickle.load(handle) 115 | with open( 116 | f"{args.data_dir}/indices/mid_to_qid_chunk_{chunk_number}.pickle", 117 | "rb", 118 | ) as handle: 119 | self.mid_to_qid = pickle.load(handle) 120 | 121 | # See the number of conflict names by making differences in length 122 | dup_entity_names = len(self.qid_to_name) - len(self.name_to_qid) 123 | print( 124 | f"Total entities = {len(self.qid_to_name)}, duplicate names = {dup_entity_names}" 125 | ) 126 | 127 | def label2qid(self, label: str) -> tp.List[Entity]: 128 | return self.name_to_qid.get(label, "Not Found!") 129 | 130 | def label2pid(self, label: str) -> tp.List[Relation]: 131 | return self.name_to_pid.get(label, "Not Found!") 132 | 133 | def qid2label(self, qid: str) -> tp.List[Entity]: 134 | return self.qid_to_name.get(qid, "Not Found!") 135 | 136 | def pid2label(self, pid: str) -> tp.List[Relation]: 137 | return self.pid_to_name.get(pid, "Not Found!") 138 | 139 | def mid2qid(self, mid: str) -> tp.List[str]: 140 | return self.mid_to_qid.get(mid, "Not Found!") 141 | 142 | def get_all_relations_of_an_entity( 143 | self, entity_qid: str 144 | ) -> tp.Dict[str, tp.List[Relation]]: 145 | try: 146 | return self.relation_entities[entity_qid] 147 | except KeyError: 148 | return "Not Found!" 149 | 150 | def get_tail_entities_given_head_and_relation( 151 | self, head_qid: str, relation_pid: str 152 | ) -> tp.Dict[str, tp.List[Entity]]: 153 | try: 154 | return self.tail_entities[f"{head_qid}@{relation_pid}"] 155 | except KeyError: 156 | return "Not Found!" 157 | 158 | def get_tail_values_given_head_and_relation( 159 | self, head_qid: str, relation_pid: str 160 | ) -> tp.List[str]: 161 | try: 162 | return self.tail_values[f"{head_qid}@{relation_pid}"] 163 | except KeyError: 164 | return "Not Found!" 165 | 166 | def get_external_id_given_head_and_relation( 167 | self, head_qid: str, relation_pid: str 168 | ) -> tp.List[str]: 169 | try: 170 | return self.external_ids[f"{head_qid}@{relation_pid}"] 171 | except KeyError: 172 | return "Not Found!" 173 | 174 | 175 | class RequestHandler(SimpleXMLRPCRequestHandler): 176 | rpc_paths = ("/RPC2",) 177 | 178 | 179 | class XMLRPCWikidataQueryServer(WikidataQueryServer): 180 | def __init__(self, addr, server_args, requestHandler=RequestHandler): 181 | super().__init__( 182 | chunk_number=server_args.chunk_number, data_dir=server_args.data_dir 183 | ) 184 | self.server = SimpleXMLRPCServer(addr, requestHandler=requestHandler) 185 | self.server.register_introspection_functions() 186 | self.server.register_function(self.get_all_relations_of_an_entity) 187 | self.server.register_function( 188 | self.get_tail_entities_given_head_and_relation 189 | ) 190 | self.server.register_function(self.label2pid) 191 | self.server.register_function(self.label2qid) 192 | self.server.register_function(self.pid2label) 193 | self.server.register_function(self.qid2label) 194 | self.server.register_function( 195 | self.get_tail_values_given_head_and_relation 196 | ) 197 | self.server.register_function( 198 | self.get_external_id_given_head_and_relation 199 | ) 200 | self.server.register_function(self.mid2qid) 201 | 202 | def serve_forever(self): 203 | self.server.serve_forever() 204 | 205 | 206 | if __name__ == "__main__": 207 | import argparse 208 | 209 | parser = argparse.ArgumentParser() 210 | parser.add_argument( 211 | "--data_dir", type=str, required=True, help="Path to the data directory" 212 | ) 213 | parser.add_argument( 214 | "--chunk_number", type=int, required=True, help="Chunk number" 215 | ) 216 | parser.add_argument("--port", type=int, default=23546, help="Port number") 217 | parser.add_argument("--host_ip", type=str, required=True, help="Host IP") 218 | args = parser.parse_args() 219 | print("Start with my program now!!!") 220 | server = XMLRPCWikidataQueryServer( 221 | addr=("0.0.0.0", args.port), server_args=args 222 | ) 223 | with open("server_urls_new.txt", "a") as f: 224 | f.write(f"http://{args.host_ip}:{args.port}\n") 225 | print(f"XMLRPC WDQS server ready and listening on 0.0.0.0:{args.port}") 226 | server.serve_forever() 227 | -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/db_deploy/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from traitlets import default 4 | import ujson as json 5 | import os 6 | 7 | 8 | @dataclass 9 | class Entity: 10 | qid: str 11 | label: str 12 | 13 | 14 | @dataclass 15 | class Relation: 16 | pid: str 17 | label: str 18 | 19 | 20 | def a_factory(): 21 | return {"head": [], "tail": []} 22 | 23 | 24 | def jsonl_generator(fname): 25 | """Returns generator for jsonl file.""" 26 | for line in open(fname, "r"): 27 | line = line.strip() 28 | if len(line) < 3: 29 | d = {} 30 | elif line[len(line) - 1] == ",": 31 | d = json.loads(line[: len(line) - 1]) 32 | else: 33 | d = json.loads(line) 34 | yield d 35 | 36 | 37 | def get_batch_files(fdir): 38 | """Returns paths to files in fdir.""" 39 | filenames = os.listdir(fdir) 40 | filenames = [os.path.join(fdir, f) for f in filenames] 41 | print(f"Fetched {len(filenames)} files from {fdir}") 42 | return filenames 43 | 44 | 45 | # Build these 4 dictionaries 46 | def read_entity_label(filename): 47 | qid_to_name = {} 48 | name_to_qid = defaultdict(list) 49 | for item in jsonl_generator(filename): 50 | qid_to_name[item["qid"]] = item["label"] 51 | name_to_qid[item["label"]].append(item["qid"]) 52 | return qid_to_name, name_to_qid 53 | 54 | 55 | def read_relation_label(filename): 56 | pid_to_name = {} 57 | name_to_pid = defaultdict(list) 58 | for item in jsonl_generator(filename): 59 | pid_to_name[item["pid"]] = item["label"] 60 | name_to_pid[item["label"]].append(item["pid"]) 61 | return pid_to_name, name_to_pid 62 | -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/preprocess_dump.py: -------------------------------------------------------------------------------- 1 | """ Wikidata Dump Processor 2 | 3 | This script preprocesses the raw Wikidata dump (in JSON format) and sorts triples into 8 "tables": labels, descriptions, aliases, entity_rels, external_ids, entity_values, qualifiers, and wikipedia_links. See the README for more information on each table. 4 | 5 | Example command: 6 | 7 | python3 preprocess_dump.py \ 8 | --input_file latest-all.json.gz \ 9 | --out_dir data/processed 10 | 11 | """ 12 | import argparse 13 | import multiprocessing 14 | from multiprocessing import Queue, Process 15 | from pathlib import Path 16 | import time 17 | 18 | from simple_wikidata_db.preprocess_utils.reader_process import count_lines, read_data 19 | from simple_wikidata_db.preprocess_utils.worker_process import process_data 20 | from simple_wikidata_db.preprocess_utils.writer_process import write_data 21 | 22 | 23 | def get_arg_parser(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--input_file', type=str, required=True, help='path to gz wikidata json dump') 26 | parser.add_argument('--out_dir', type=str, required=True, help='path to output directory') 27 | parser.add_argument('--language_id', type=str, default='en', help='language identifier') 28 | parser.add_argument('--processes', type=int, default=90, help="number of concurrent processes to spin off. ") 29 | parser.add_argument('--batch_size', type=int, default=10000) 30 | parser.add_argument('--num_lines_read', type=int, default=-1, 31 | help='Terminate after num_lines_read lines are read. Useful for debugging.') 32 | parser.add_argument('--num_lines_in_dump', type=int, default=-1, help='Number of lines in dump. If -1, we will count the number of lines.') 33 | return parser 34 | 35 | 36 | def main(): 37 | start = time.time() 38 | args = get_arg_parser().parse_args() 39 | print(f"ARGS: {args}") 40 | 41 | out_dir = Path(args.out_dir) 42 | out_dir.mkdir(exist_ok=True, parents=True) 43 | 44 | input_file = Path(args.input_file) 45 | assert input_file.exists(), f"Input file {input_file} does not exist" 46 | 47 | 48 | max_lines_to_read = args.num_lines_read 49 | 50 | print("Starting processes") 51 | maxsize = 10 * args.processes 52 | 53 | # Queues for inputs/outputs 54 | output_queue = Queue(maxsize=maxsize) 55 | work_queue = Queue(maxsize=maxsize) 56 | 57 | # Processes for reading/processing/writing 58 | num_lines_read = multiprocessing.Value("i", 0) 59 | read_process = Process( 60 | target=read_data, 61 | args=(input_file, num_lines_read, max_lines_to_read, work_queue) 62 | ) 63 | 64 | read_process.start() 65 | 66 | write_process = Process( 67 | target=write_data, 68 | args=(out_dir, args.batch_size, output_queue) 69 | ) 70 | write_process.start() 71 | 72 | work_processes = [] 73 | for _ in range(max(1, args.processes-2)): 74 | work_process = Process( 75 | target=process_data, 76 | args=(args.language_id, work_queue, output_queue) 77 | ) 78 | work_process.daemon = True 79 | work_process.start() 80 | work_processes.append(work_process) 81 | 82 | read_process.join() 83 | print(f"Done! Read {num_lines_read.value} lines") 84 | # Cause all worker process to quit 85 | for work_process in work_processes: 86 | work_queue.put(None) 87 | # Now join the work processes 88 | for work_process in work_processes: 89 | work_process.join() 90 | output_queue.put(None) 91 | write_process.join() 92 | 93 | print(f"Finished processing {num_lines_read.value} in {time.time() - start}s") 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/preprocess_utils/reader_process.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Queue, Value 2 | from pathlib import Path 3 | import gzip 4 | from tqdm import tqdm 5 | 6 | def count_lines(input_file: Path, max_lines_to_read: int): 7 | cnt = 0 8 | with gzip.open(input_file, 'rb') as f: 9 | for _ in tqdm(f): 10 | cnt += 1 11 | if max_lines_to_read > 0 and cnt >= max_lines_to_read: 12 | break 13 | return cnt 14 | 15 | def read_data(input_file: Path, num_lines_read: Value, max_lines_to_read: int, work_queue: Queue): 16 | """ 17 | Reads the data from the input file and pushes it to the output queue. 18 | :param input_file: Path to the input file. 19 | :param num_lines_read: Value to store the number of lines in the input file. 20 | :param max_lines_to_read: Maximum number of lines to read from the input file (for testing). 21 | :param work_queue: Queue to push the data to. 22 | """ 23 | with gzip.GzipFile(input_file, "r") as f: 24 | num_lines = 0 25 | for ln in f: 26 | if ln == b"[\n" or ln == b"]\n": 27 | continue 28 | if ln.endswith(b",\n"): # all but the last element 29 | obj = ln[:-2] 30 | else: 31 | obj = ln 32 | num_lines += 1 33 | work_queue.put(obj) 34 | if 0 < max_lines_to_read <= num_lines: 35 | break 36 | num_lines_read.value = num_lines 37 | return 38 | -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/preprocess_utils/worker_process.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from multiprocessing import Queue 3 | 4 | # properties which encode some alias/name 5 | import ujson 6 | 7 | ALIAS_PROPERTIES = { 8 | "P138", 9 | "P734", 10 | "P735", 11 | "P742", 12 | "P1448", 13 | "P1449", 14 | "P1477", 15 | "P1533", 16 | "P1549", 17 | "P1559", 18 | "P1560", 19 | "P1635", 20 | "P1705", 21 | "P1782", 22 | "P1785", 23 | "P1786", 24 | "P1787", 25 | "P1810", 26 | "P1813", 27 | "P1814", 28 | "P1888", 29 | "P1950", 30 | "P2358", 31 | "P2359", 32 | "PP2365", 33 | "P2366", 34 | "P2521", 35 | "P2562", 36 | "P2976", 37 | "PP3321", 38 | "P4239", 39 | "P4284", 40 | "P4970", 41 | "P5056", 42 | "P5278", 43 | "PP6978", 44 | "P7383", 45 | } 46 | 47 | # data types in wikidata dump which we ignore 48 | IGNORE = { 49 | "wikibase-lexeme", 50 | "musical-notation", 51 | "globe-coordinate", 52 | "commonsMedia", 53 | "geo-shape", 54 | "wikibase-sense", 55 | "wikibase-property", 56 | "math", 57 | "tabular-data", 58 | } 59 | 60 | 61 | def process_mainsnak(data, language_id): 62 | datatype = data["datatype"] 63 | if datatype == "string": 64 | return data["datavalue"]["value"] 65 | elif datatype == "monolingualtext": 66 | if data["datavalue"]["value"]["language"] == language_id: 67 | return data["datavalue"]["value"]["text"] 68 | elif datatype == "quantity": 69 | return data["datavalue"]["value"]["amount"] 70 | elif datatype == "time": 71 | return data["datavalue"]["value"]["time"] 72 | elif datatype == "wikibase-item": 73 | return data["datavalue"]["value"]["id"] 74 | elif datatype == "external-id": 75 | return data["datavalue"]["value"] 76 | elif datatype == "url": 77 | return data["datavalue"]["value"] 78 | 79 | # Ignore all other triples 80 | elif datatype in IGNORE: 81 | return None 82 | else: 83 | return None 84 | return None 85 | 86 | 87 | def process_json(obj, language_id="en"): 88 | out_data = defaultdict(list) 89 | id = obj["id"] # The canonical ID of the entity. 90 | # skip properties 91 | if obj["type"] == "property": 92 | out_data["plabels"].append( 93 | {"pid": id, "label": obj["labels"][language_id]["value"]} 94 | ) 95 | return dict(out_data) 96 | # extract labels 97 | if language_id in obj["labels"]: 98 | label = obj["labels"][language_id]["value"] 99 | out_data["labels"].append({"qid": id, "label": label}) 100 | out_data["aliases"].append({"qid": id, "alias": label}) 101 | 102 | # extract description 103 | if language_id in obj["descriptions"]: 104 | description = obj["descriptions"][language_id]["value"] 105 | out_data["descriptions"].append( 106 | { 107 | "qid": id, 108 | "description": description, 109 | } 110 | ) 111 | 112 | # extract aliases 113 | if language_id in obj["aliases"]: 114 | for alias in obj["aliases"][language_id]: 115 | out_data["aliases"].append( 116 | { 117 | "qid": id, 118 | "alias": alias["value"], 119 | } 120 | ) 121 | 122 | # extract english wikipedia sitelink -- we just add this to the external links table 123 | if f"{language_id}wiki" in obj["sitelinks"]: 124 | sitelink = obj["sitelinks"][f"{language_id}wiki"]["title"] 125 | out_data["wikipedia_links"].append({"qid": id, "wiki_title": sitelink}) 126 | 127 | # extract claims and qualifiers 128 | for property_id in obj["claims"]: 129 | for claim in obj["claims"][property_id]: 130 | if not claim["mainsnak"]["snaktype"] == "value": 131 | continue 132 | claim_id = claim["id"] 133 | datatype = claim["mainsnak"]["datatype"] 134 | value = process_mainsnak(claim["mainsnak"], language_id) 135 | 136 | if value is None: 137 | continue 138 | 139 | if datatype == "wikibase-item": 140 | out_data["entity_rels"].append( 141 | { 142 | "claim_id": claim_id, 143 | "qid": id, 144 | "property_id": property_id, 145 | "value": value, 146 | } 147 | ) 148 | elif datatype == "external-id": 149 | out_data["external_ids"].append( 150 | { 151 | "claim_id": claim_id, 152 | "qid": id, 153 | "property_id": property_id, 154 | "value": value, 155 | } 156 | ) 157 | else: 158 | out_data["entity_values"].append( 159 | { 160 | "claim_id": claim_id, 161 | "qid": id, 162 | "property_id": property_id, 163 | "value": value, 164 | } 165 | ) 166 | if property_id in ALIAS_PROPERTIES: 167 | out_data["aliases"].append( 168 | { 169 | "qid": id, 170 | "alias": value, 171 | } 172 | ) 173 | 174 | # get qualifiers 175 | if "qualifiers" in claim: 176 | for qualifier_property in claim["qualifiers"]: 177 | for qualifier in claim["qualifiers"][qualifier_property]: 178 | if not qualifier["snaktype"] == "value": 179 | continue 180 | qualifier_id = qualifier["hash"] 181 | value = process_mainsnak(qualifier, language_id) 182 | if value is None: 183 | continue 184 | out_data["qualifiers"].append( 185 | { 186 | "qualifier_id": qualifier_id, 187 | "claim_id": claim_id, 188 | "property_id": qualifier_property, 189 | "value": value, 190 | } 191 | ) 192 | 193 | return dict(out_data) 194 | 195 | 196 | def process_data(language_id: str, work_queue: Queue, out_queue: Queue): 197 | while True: 198 | json_obj = work_queue.get() 199 | if json_obj is None: 200 | break 201 | if len(json_obj) == 0: 202 | continue 203 | out_queue.put(process_json(ujson.loads(json_obj), language_id)) 204 | return 205 | -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/preprocess_utils/writer_process.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from multiprocessing import Queue 3 | from pathlib import Path 4 | from typing import Dict, Any, List 5 | import time 6 | import ujson 7 | 8 | TABLE_NAMES = [ 9 | "labels", 10 | "descriptions", 11 | "aliases", 12 | "external_ids", 13 | "entity_values", 14 | "qualifiers", 15 | "wikipedia_links", 16 | "entity_rels", 17 | "ticker_symbols", 18 | 'plabels', 19 | ] 20 | 21 | 22 | class Table: 23 | def __init__(self, path: Path, batch_size: int, table_name: str): 24 | self.table_dir = path / table_name 25 | if self.table_dir.exists(): 26 | shutil.rmtree(self.table_dir) 27 | self.table_dir.mkdir(parents=True, exist_ok=False) 28 | 29 | self.index = 0 30 | self.cur_num_lines = 0 31 | self.batch_size = batch_size 32 | self.cur_file = self.table_dir / f"{self.index:d}.jsonl" 33 | self.cur_file_writer = None 34 | 35 | def write(self, json_value: List[Dict[str, Any]]): 36 | if self.cur_file_writer is None: 37 | self.cur_file_writer = open(self.cur_file, "w") 38 | for json_obj in json_value: 39 | self.cur_file_writer.write( 40 | ujson.dumps(json_obj, ensure_ascii=False) + "\n" 41 | ) 42 | self.cur_num_lines += 1 43 | if self.cur_num_lines >= self.batch_size: 44 | self.cur_file_writer.close() 45 | self.cur_num_lines = 0 46 | self.index += 1 47 | self.cur_file = self.table_dir / f"{self.index:d}.jsonl" 48 | self.cur_file_writer = None 49 | 50 | def close(self): 51 | self.cur_file_writer.close() 52 | 53 | 54 | class Writer: 55 | def __init__(self, path: Path, batch_size: int): 56 | self.cur_num_lines = 0 57 | # self.total_num_lines = total_num_lines 58 | self.start_time = time.time() 59 | self.output_tables = { 60 | table_name: Table(path, batch_size, table_name) 61 | for table_name in TABLE_NAMES 62 | } 63 | 64 | def write(self, json_object: Dict[str, Any]): 65 | self.cur_num_lines += 1 66 | for key, value in json_object.items(): 67 | if len(value) > 0: 68 | self.output_tables[key].write(value) 69 | if self.cur_num_lines % 200000 == 0: 70 | time_elapsed = time.time() - self.start_time 71 | # estimated_time = time_elapsed * (self.total_num_lines - self.cur_num_lines) / (200000*3600) 72 | print( 73 | f"{self.cur_num_lines} lines written in {time_elapsed:.2f}s. " 74 | ) 75 | self.start_time = time.time() 76 | 77 | def close(self): 78 | for v in self.output_tables.values(): 79 | v.close() 80 | 81 | 82 | def write_data(path: Path, batch_size: int, outout_queue: Queue): 83 | writer = Writer(path, batch_size) 84 | while True: 85 | json_object = outout_queue.get() 86 | if json_object is None: 87 | break 88 | writer.write(json_object) 89 | writer.close() 90 | -------------------------------------------------------------------------------- /Wikidata/simple_wikidata_db/utils.py: -------------------------------------------------------------------------------- 1 | """Assortment of useful utility functions 2 | """ 3 | 4 | import os 5 | import ujson as json 6 | import multiprocessing as mp 7 | 8 | def jsonl_generator(fname): 9 | """ Returns generator for jsonl file """ 10 | for line in open(fname, 'r'): 11 | line = line.strip() 12 | if len(line) < 3: 13 | d = {} 14 | elif line[len(line)-1] == ',': 15 | d= json.loads(line[:len(line)-1]) 16 | else: 17 | d= json.loads(line) 18 | yield d 19 | 20 | def batch_line_generator(fname, batch_size): 21 | """ Returns generator for jsonl file with batched lines """ 22 | res = [] 23 | batch_id = 0 24 | for line in open(fname, 'r'): 25 | line = line.strip() 26 | if len(line) < 3: 27 | d = '' 28 | elif line[len(line) - 1] == ',': 29 | d = line[:len(line) - 1] 30 | else: 31 | d = line 32 | res.append(d) 33 | if len(res) >= batch_size: 34 | yield batch_id, res 35 | batch_id += 1 36 | res = [] 37 | yield batch_id, res 38 | 39 | def append_to_jsonl_file(data, file): 40 | """ Appends json dictionary as new line to file """ 41 | with open(file, 'a+') as out_file: 42 | for x in data: 43 | out_file.write(json.dumps(x, ensure_ascii=False)+"\n") 44 | 45 | 46 | def get_batch_files(fdir): 47 | """ Returns paths to files in fdir """ 48 | filenames = os.listdir(fdir) 49 | filenames = [os.path.join(fdir, f) for f in filenames] 50 | print(f"Fetched {len(filenames)} files from {fdir}") 51 | return filenames 52 | 53 | def create_dir(out_dir): 54 | """ Creates new directory if it doesn't already exist """ 55 | if not os.path.exists(out_dir): 56 | print(f"Creating {out_dir}") 57 | os.makedirs(out_dir) 58 | -------------------------------------------------------------------------------- /assets/application.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasolSun36/ToG/934064c4a8a391c7f351339676f29ed764e40054/assets/application.png -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasolSun36/ToG/934064c4a8a391c7f351339676f29ed764e40054/assets/demo.png -------------------------------------------------------------------------------- /assets/experiments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasolSun36/ToG/934064c4a8a391c7f351339676f29ed764e40054/assets/experiments.png -------------------------------------------------------------------------------- /assets/methods.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasolSun36/ToG/934064c4a8a391c7f351339676f29ed764e40054/assets/methods.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | The current folder holds all the datasets we used, the statistics of the datasets used in the paper are shown in table below: 4 | 5 | | Dataset | Answer Format | Train | Test | Licence | 6 | |------------------------|---------------|-------------|-------|------------| 7 | | ComplexWebQuestions | Entity | 27,734 | 3,531 | - | 8 | | WebQSP | Number | 3,098 | 1,639 | CC License | 9 | | GrailQA* | Entity/Number | 44,337 | 1,000 | - | 10 | | QALD-10 | Entity/Number | - | 333 | MIT License| 11 | | Simple Question* | Number | 14,894 | 1,000 | CC License | 12 | | WebQuestions | Entity/Number | 3,778 | 2,032 | - | 13 | | T-REx | Entity | 2,284,168 | 5,000 | MIT License| 14 | | Zero-Shot RE | Entity | 147,909 | 3,724 | MIT License| 15 | | Creak | Bool | 10,176 | 1,371 | MIT License| 16 | 17 | where * denotes we randomly selected 1,000 samples from the GrailQA and Simple Questions test set to constitute the testing set owing to the abundance of test samples. 18 | 19 | If the user wants to search with a different KG source, check out the `mid2qid` and `qid2mid` APIs of the simple-wikidata-db folder. -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Eval 2 | 3 | We use **Exact Match** as our evaluation metric. 4 | 5 | After getting the final result file, use the following command to evaluate the results: 6 | 7 | ```sh 8 | python eval.py \ # if you wanna use Wikidata as KG source, run main_wiki.py 9 | --dataset cwq \ # dataset your wanna test, see ToG/data/README.md 10 | --output_file ToG_cwq.json \ 11 | --constraints_refuse True 12 | ``` 13 | 14 | After that, you will get a result json file that contains: 15 | 16 | ```sh 17 | { 18 | 'dataset': 19 | 'method': 20 | 'Exact Match': 21 | 'Right Samples': 22 | 'Error Sampels': 23 | } 24 | ``` -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import * 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--dataset", type=str, 7 | default="cwq", help="choose the dataset.") 8 | parser.add_argument("--output_file", type=str, 9 | default="ToG_cwq.json", help="the output file name.") 10 | parser.add_argument("--constraints_refuse", type=bool, 11 | default=True, help="LLM may have refuse erorr, enable this option to skip current sample.") 12 | args = parser.parse_args() 13 | 14 | ground_truth_datas, question_string, output_datas = prepare_dataset_for_eval(args.dataset, args.output_file) 15 | 16 | num_right = 0 17 | num_error = 0 18 | for data in output_datas: 19 | answers = align(args.dataset, question_string, data, ground_truth_datas) 20 | results = data['results'] 21 | if check_string(results): 22 | response = clean_results(results) 23 | if response=="NULL": 24 | response = results 25 | else: 26 | if exact_match(response, answers): 27 | num_right+=1 28 | else: 29 | num_error+=1 30 | else: 31 | response = results 32 | if args.constraints_refuse and check_string(response): 33 | continue 34 | if exact_match(response, answers): 35 | num_right+=1 36 | else: 37 | num_error+=1 38 | 39 | print("Exact Match: {}".format(float(num_right/len(output_datas)))) 40 | print("right: {}, error: {}".format(num_right, num_error)) 41 | 42 | save_result2json(args.dataset, num_right, num_error, len(output_datas)) 43 | -------------------------------------------------------------------------------- /eval/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | 5 | def prepare_dataset_for_eval(dataset_name, output_file): 6 | if dataset_name == 'cwq': 7 | with open('../data/cwq.json',encoding='utf-8') as f: 8 | datas = json.load(f) 9 | question_string = 'question' 10 | elif dataset_name == 'webqsp': 11 | with open('../data/WebQSP.json',encoding='utf-8') as f: 12 | datas = json.load(f) 13 | question_string = 'RawQuestion' 14 | elif dataset_name == 'grailqa': 15 | with open('../data/grailqa.json',encoding='utf-8') as f: 16 | datas = json.load(f) 17 | question_string = 'question' 18 | elif dataset_name == 'simpleqa': 19 | with open('../data/SimpleQA.json',encoding='utf-8') as f: 20 | datas = json.load(f) 21 | question_string = 'question' 22 | elif dataset_name == 'qald': 23 | with open('../data/qald_10-en.json',encoding='utf-8') as f: 24 | datas = json.load(f) 25 | question_string = 'question' 26 | elif dataset_name == 'webquestions': 27 | with open('../data/WebQuestions.json',encoding='utf-8') as f: 28 | datas = json.load(f) 29 | question_string = 'question' 30 | elif dataset_name == 'trex': 31 | with open('../data/T-REX.json',encoding='utf-8') as f: 32 | datas = json.load(f) 33 | question_string = 'input' 34 | elif dataset_name == 'zeroshotre': 35 | with open('../data/Zero_Shot_RE.json',encoding='utf-8') as f: 36 | datas = json.load(f) 37 | question_string = 'input' 38 | elif dataset_name == 'creak': 39 | with open('../data/creak.json',encoding='utf-8') as f: 40 | datas = json.load(f) 41 | question_string = 'sentence' 42 | else: 43 | print("dataset not found, you should pick from {cwq, webqsp, grailqa, simpleqa, qald, webquestions, trex, zeroshotre, creak}.") 44 | exit(-1) 45 | with open(output_file, encoding='utf-8') as f: 46 | output_datas = json.load(f) 47 | return datas, question_string, output_datas 48 | 49 | 50 | def align(dataset_name, question_string, data, ground_truth_datas): 51 | answer_list= [] 52 | origin_data = [j for j in ground_truth_datas if j[question_string] == data[question_string]][0] 53 | if dataset_name == 'cwq': 54 | if 'answers' in origin_data: 55 | answers = origin_data["answers"] 56 | else: 57 | answers = origin_data["answer"] 58 | for answer in answers: 59 | alias = answer['aliases'] 60 | ans = answer['answer'] 61 | alias.append(ans) 62 | answer_list.extend(alias) 63 | 64 | elif dataset_name == 'webqsp': 65 | answers = origin_data["Parses"] 66 | for answer in answers: 67 | for name in answer['Answers']: 68 | if name['EntityName'] == None: 69 | answer_list.append(name['AnswerArgument']) 70 | else: 71 | answer_list.append(name['EntityName']) 72 | 73 | elif dataset_name == 'grailqa': 74 | answers = origin_data["answer"] 75 | for answer in answers: 76 | if "entity_name" in answer: 77 | answer_list.append(answer['entity_name']) 78 | else: 79 | answer_list.append(answer['answer_argument']) 80 | 81 | elif dataset_name == 'simpleqa': 82 | answers = origin_data["answer"] 83 | answer_list.append(answers) 84 | 85 | elif dataset_name == 'qald': 86 | answers = origin_data["answer"] 87 | for answer in answers: 88 | answer_list.append(answers[answer]) 89 | 90 | elif dataset_name == 'webquestions': 91 | answer_list = origin_data["answers"] 92 | 93 | elif dataset_name == 'trex' or dataset_name == 'zeroshotre': 94 | answers = origin_data["answer"] 95 | answer_list.append(answers) 96 | 97 | elif dataset_name == 'creak': 98 | answer = origin_data['label'] 99 | answer_list.append(answer) 100 | 101 | return list(set(answer_list)) 102 | 103 | def check_string(string): 104 | return "{" in string 105 | 106 | def clean_results(string): 107 | if "{" in string: 108 | start = string.find("{") + 1 109 | end = string.find("}") 110 | content = string[start:end] 111 | return content 112 | else: 113 | return "NULL" 114 | 115 | 116 | def check_refuse(string): 117 | refuse_words = ["however", "sorry"] 118 | return any(word in string.lower() for word in refuse_words) 119 | 120 | 121 | def exact_match(response, answers): 122 | clean_result = response.strip().replace(" ","").lower() 123 | for answer in answers: 124 | clean_answer = answer.strip().replace(" ","").lower() 125 | if clean_result == clean_answer or clean_result in clean_answer or clean_answer in clean_result: 126 | return True 127 | return False 128 | 129 | def save_result2json(dataset_name, num_right, num_error, total_nums, method): 130 | results_data = { 131 | 'dataset': dataset_name, 132 | 'method': method, 133 | 'Exact Match': float(num_right/total_nums), 134 | 'Right Samples': num_right, 135 | 'Error Sampels': num_error 136 | } 137 | with open('ToG_{}_results.json'.format(dataset_name), 'w', encoding='utf-8') as f: 138 | json.dump(results_data, f, ensure_ascii=False, indent=4) 139 | 140 | def extract_content(s): 141 | matches = re.findall(r'\{(.*?)\}', s) 142 | if len(matches) >= 2 and matches[0].lower() == 'yes': 143 | return matches[1] 144 | elif len(matches) >= 1: 145 | return matches[0] 146 | else: 147 | return 'NULL' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines 2 | openai 3 | SPARQLWrapper 4 | tqdm 5 | argparse 6 | 7 | # if need to use BM25, SentenceBERT as pruning tools. 8 | rank_bm25 9 | sentence_transformers 10 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # ToG-Tools 2 | 3 | ## Some of the tool functions that may be used are kept here, including: 4 | 5 | 1. Convert jsonl to json files in `jsonl2json.py` 6 | 2. Remove duplicate elements from json file based on 'question' key in `de_duplicate.py` 7 | 3. Random sampling n datasets from json file and save to the new json file in `split_dataset.py` 8 | 9 | 10 | This folder will be updated frequently, and users can define new functions according to their needs. -------------------------------------------------------------------------------- /tools/de_duplicate.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | 4 | with open("A.json", "r", encoding="utf-8") as file: 5 | data = json.load(file) 6 | 7 | result = list(OrderedDict((item['question'], item) for item in data).values()) 8 | 9 | with open("B.json", "w", encoding="utf-8") as file: 10 | json.dump(result, file, ensure_ascii=False) 11 | -------------------------------------------------------------------------------- /tools/jsonl2json.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def jsonl_to_json(jsonl_file, json_file): 4 | with open(jsonl_file, 'r') as infile: 5 | with open(json_file, 'w') as outfile: 6 | json_lines = infile.readlines() 7 | json_list = [json.loads(line) for line in json_lines] 8 | json.dump(json_list, outfile, indent=4) 9 | 10 | # 用法示例 11 | jsonl_to_json('ToG_cwq.jsonl', 'ToG_cwq.json') -------------------------------------------------------------------------------- /tools/split_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | with open('A.json', 'r', encoding='utf-8') as f: 5 | original_data = json.load(f) 6 | 7 | random.shuffle(original_data) 8 | new_data = original_data[:1000] 9 | 10 | with open('B.json', 'w', encoding='utf-8') as f: 11 | json.dump(new_data, f) --------------------------------------------------------------------------------