├── LICENSE
├── README.md
├── download_hotpotqa_wikipedia.sh
├── figures
└── method.png
└── src
├── 2wiki
├── 0_generate_tree.sh
├── 1_conduct_reasoning.sh
├── RoHT
│ ├── 1_build_tree.py
│ ├── 2_run.py
│ ├── 3_get_f1.py
│ ├── aggregate
│ │ └── prompt.txt
│ ├── cb
│ │ └── prompt.txt
│ ├── count.py
│ ├── evaluate.py
│ ├── ob
│ │ ├── multihop_prompt.txt
│ │ └── singlehop_prompt.txt
│ ├── openai_req.py
│ ├── parallel.py
│ ├── question_answering.py
│ └── results
│ │ └── released.json
└── Tree_Generation
│ ├── 0_get_prompt.py
│ ├── 1_query.py
│ ├── 2_postprocess.py
│ ├── 3_postprocess_tree.py
│ ├── combine.py
│ ├── openai_req.py
│ ├── prompt.txt
│ ├── question_decompositions.json
│ └── tree.json
├── hotpotqa
├── 0_generate_tree.sh
├── 1_conduct_reasoning.sh
├── RoHT
│ ├── 1_build_tree.py
│ ├── 2_run.py
│ ├── 3_get_f1.py
│ ├── aggregate
│ │ └── prompt.txt
│ ├── cb
│ │ └── prompt.txt
│ ├── count.py
│ ├── evaluate.py
│ ├── ob
│ │ ├── multihop_prompt.txt
│ │ └── singlehop_prompt.txt
│ ├── openai_req.py
│ ├── parallel.py
│ ├── question_answering.py
│ ├── results
│ │ └── released.json
│ └── search
│ │ ├── __init__.py
│ │ ├── serpapi.py
│ │ └── wikipedia.py
└── Tree_Generation
│ ├── 0_get_prompt.py
│ ├── 1_query.py
│ ├── 2_postprocess.py
│ ├── 3_postprocess_tree.py
│ ├── combine.py
│ ├── openai_req.py
│ ├── prompt.txt
│ ├── question_decompositions.json
│ └── tree.json
├── musique
├── 0_generate_tree.sh
├── 1_conduct_reasoning.sh
├── RoHT
│ ├── 1_build_tree.py
│ ├── 2_run.py
│ ├── 3_get_f1.py
│ ├── aggregate
│ │ └── prompt.txt
│ ├── cb
│ │ └── prompt.txt
│ ├── count.py
│ ├── evaluate.py
│ ├── ob
│ │ ├── get_para.py
│ │ ├── multihop_prompt.txt
│ │ └── singlehop_prompt.txt
│ ├── openai_req.py
│ ├── parallel.py
│ ├── question_answering.py
│ └── results
│ │ └── released.json
└── Tree_Generation
│ ├── 0_get_prompt.py
│ ├── 1_query.py
│ ├── 2_postprocess.py
│ ├── 3_postprocess_tree.py
│ ├── combine.py
│ ├── openai_req.py
│ ├── prompt.txt
│ ├── question_decompositions.json
│ └── tree.json
└── service
├── es
├── index_2wiki_wiki.py
├── index_hotpotqa_wiki.py
├── index_musique_wiki.py
├── run_2wiki_index.py
├── run_hotpotqa_index.py
└── run_musique_indx.py
└── openai
└── openai_service.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Neo-Zhangjiajie
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProbTree
2 | Source code for findings of EMNLP 2023 paper "Probabilistic Tree-of-thought Reasoning for Answering Knowledge-intensive Complex Questions".
3 |
4 | [](https://github.com/Neo-Zhangjiajie/ProbTree/issues)
5 | [](https://opensource.org/licenses/Apache-2.0)
6 | [](https://www.python.org/)
7 | [](https://arxiv.org/pdf/2311.13982.pdf)
8 |
9 | > In this paper, we propose a novel approach: Probabilistic Tree-of-thought Reasoning (ProbTree). First, LLMs translate a complex question into a query tree, in which each non-root node denotes a sub-question of its parent node. Then, probabilistic reasoning is conducted over the tree, by solving questions from leaf to root considering the confidence of both question decomposing and answering. During reasoning, for leaf nodes, LLMs choose a more confident answer from Closed-book QA that employs parametric knowledge and Open-book QA that employs retrieved external knowledge, thus eliminating the negative retrieval problem. For non-leaf nodes, with the hierarchical structure, LLMs have broader sights and are able to globally reason with the information from child nodes, thus recovering from local errors. The experiments on three Complex QA datasets under the open-domain setting show that our approach outperforms SOTA methods significantly, demonstrating the effect of probabilistic tree-of-thought reasoning.
10 |
11 |
12 |

13 |
14 |
15 | ## File Structure
16 | ```
17 | ProbTree/
18 | ├─ data/:
19 | │ ├─ 2wiki: original 2WikiMQA dataset
20 | │ ├─ musique: original MuSiQue dataset
21 | │ └── enwiki-20171001-pages-meta-current-withlinks-abstracts: Wikipedia dump for HotpotQA
22 | ├─ released_data/: released test samples by IRCoT
23 | ├─ src/:
24 | │ ├─ 2wiki: experiment codes for 2WikiMQA
25 | │ ├─ RoHT: code for probablisitc reasoning
26 | │ ├─ Tree_Generation: code for generating quesion decomposition trees
27 | │ ├─ 0_generate_tree.sh: script for probablisitc reasoning
28 | │ └── 1_conduct_reasoning.sh: script for generating quesion decomposition trees
29 | │ ├─ hotpotqa: experiment codes for HotpotQA
30 | │ ├─ musique: experiment codes for MuSiQue
31 | │ └── service:
32 | │ ├─ es: Elasticsearch services
33 | │ └── openai: OpenAI Service
34 | └── download_hotpotqa_wikipedia.sh: scripts for downloading Wikipedia dump for HotpotQA
35 | ```
36 |
37 | ## Download Data
38 | 1. Download the original [2WikiMQA](https://github.com/Alab-NII/2wikimultihop) and [MuSiQue-ans](https://github.com/stonybrooknlp/musique) datasets, then put their train, dev and test sets under `./data/2wiki` and `./data/musique`, respectively.
39 |
40 | 2. Download the Wikipedia dump for [HotpotQA](https://hotpotqa.github.io/) and put it under `./data`:
41 | ```
42 | bash download_hotpotqa_wikipedia.sh
43 | ```
44 |
45 | 3. Download the [released test samples](https://drive.google.com/drive/folders/1UAlz8NIwTSR2CVXlWlKWh-oadjCAtJfA?usp=sharing) by [IRCoT](https://github.com/StonyBrookNLP/ircot/tree/main) and put them under `./released_data`.
46 |
47 | ## Prepare Services
48 |
49 | ### 1. Elasticsearch
50 | Install Elasticsearch following [official document](https://www.elastic.co/guide/en/elasticsearch/reference/current/targz.html) to enable BM25 retrieval. We use the version 8.1.2.
51 |
52 | Run Elasticsearch with a tmux window:
53 | ```
54 | cd elasticsearch-8.1.2/bin # replace this with your installation path
55 | ./elasticsearch
56 | ```
57 |
58 | Index corpus and BM25 retriever for 2WikiMQA, MuSiQue, and HotpotQA with tmux windows:
59 | ```
60 | cd src/service/es
61 |
62 | # 2WikiMQA
63 | python index_2wiki_wiki.py
64 | python run_2wiki_index.py
65 |
66 | # MuSiQue
67 | python index_musique_wiki.py
68 | python run_musique_index.py
69 |
70 | #HotpotQA
71 | python index_hotpotqa_wiki.py
72 | python run_hotpotqa_index.py
73 | ```
74 |
75 |
76 | ### 2. OpenAI Service
77 | Put your OpenAI keys in `src/openai/openai_service.py`, then run OpenAI service with a tmux window:
78 | ```
79 | cd src/service/openai
80 | python openai_service.py
81 | ```
82 |
83 | ### 3. Google Search API
84 | Put your Serp API key in `src/hotpotqa/RoHT/question_answering.py` so that you can use Google Search API.
85 |
86 | ## Run Experirments
87 | First generate question decompostion trees, then conduct probablistic reasoning on these trees.
88 | ```
89 | cd src/{dataset_name} # 2wiki, musique, hotpotqa
90 | bash 0_generate_tree.sh
91 | bash 1_conduct_reasoning.sh
92 | ```
93 |
94 | We have released our generated question decompostion trees so that you can directly run probablistic reasoning.
95 | ```
96 | src/{dataset_name}/Tree_Generation/tree.json # 2wiki, musique, hotpotqa
97 | ```
98 |
99 | We also released our prediction results corresponding to the reported scores in
100 | ```
101 | src/{dataset_name}/RoHT/results/released.json # 2wiki, musique, hotpotqa
102 | ```
103 |
--------------------------------------------------------------------------------
/download_hotpotqa_wikipedia.sh:
--------------------------------------------------------------------------------
1 | # download the wiki dump file
2 | mkdir -p data
3 | wget https://nlp.stanford.edu/projects/hotpotqa/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2 -O data/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2
4 | # verify that we have the whole thing
5 | unameOut="$(uname -s)"
6 | case "${unameOut}" in
7 | Darwin*) MD5SUM="md5 -r";;
8 | *) MD5SUM=md5sum
9 | esac
10 | if [ `$MD5SUM data/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2 | awk '{print $1}'` == "01edf64cd120ecc03a2745352779514c" ]; then
11 | echo "Downloaded the processed Wikipedia dump from the HotpotQA website. Everything's looking good, so let's extract it!"
12 | else
13 | echo "The md5 doesn't seem to match what we expected, try again?"
14 | exit 1
15 | fi
16 | cd data
17 | tar -xjvf enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2
18 | # clean up
19 | rm enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2
20 | echo 'Done!'
--------------------------------------------------------------------------------
/figures/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ProbTree/ce17f5f239c47389ab53920bfef817f8a18e3841/figures/method.png
--------------------------------------------------------------------------------
/src/2wiki/0_generate_tree.sh:
--------------------------------------------------------------------------------
1 | cd ./Tree_Generation
2 | python 0_get_prompt.py
3 | python 1_query.py
4 | python combine.py
5 | python 2_postprocess.py
6 | python 3_posprocess_tree.py
--------------------------------------------------------------------------------
/src/2wiki/1_conduct_reasoning.sh:
--------------------------------------------------------------------------------
1 | cd ./RoHT
2 | python 1_build_tree.py
3 | python 2_run.py
4 | python 3_get_f1.py
--------------------------------------------------------------------------------
/src/2wiki/RoHT/1_build_tree.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 |
4 | raw_data = [json.loads(line.strip()) for line in open('../../../released_data/2wikimultihopqa__v2_test_random_500.jsonl')]
5 | q2sub_q = json.load(open("../Tree_Generation/tree.json"))
6 |
7 | trees = []
8 |
9 | def dfs(q, tree):
10 | sons = []
11 | print(q)
12 | for sub_q in q2sub_q.get(q, [[]])[0]:
13 | son_idx = dfs(sub_q, tree)
14 | sons.append(son_idx)
15 | idx = len(tree)
16 | tree.append({
17 | "idx": idx,
18 | "question_text": q,
19 | "sons": sons,
20 | "qd_logprob": q2sub_q.get(q, [[], None])[1]
21 | })
22 | for son_idx in sons:
23 | tree[son_idx]["fa"] = idx
24 | return idx
25 |
26 | for item in raw_data:
27 | question = item['question_text'].strip()
28 | assert question in q2sub_q
29 | tree = []
30 | dfs(question, tree)
31 | trees.append(tree)
32 |
33 | json.dump(trees, open("trees.json", "w"), indent=2)
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/src/2wiki/RoHT/2_run.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import os
4 | from question_answering import *
5 | from tqdm import tqdm
6 | from parallel import parallel_process_data
7 |
8 | PROC_NUM = 50
9 | cnt = 0
10 |
11 | def solve(tree):
12 | global cnt
13 | cnt += 1
14 | print(cnt)
15 | #print(tree[-1])
16 | try:
17 | for node in tree:
18 | #print(node)
19 | question = node["question_text"].strip()
20 | ref_tokens = re.findall(r"#\d+", question)
21 | topic_entities = []
22 | for ref_token in ref_tokens:
23 | if "fa" in node and int(ref_token[1:]) <= len(tree[node["fa"]]["sons"]):
24 | ref_idx = tree[node["fa"]]["sons"][int(ref_token[1:])-1]
25 | if "answer" in tree[ref_idx]:
26 | question = question.replace(ref_token, tree[ref_idx]["answer"][0])
27 | topic_entities.append(tree[ref_idx]["answer"][0])
28 | node["question"] = question
29 | node["cb_answer"] = get_cb_answer(question)
30 | #print(node["cb_answer"])
31 | if len(node["sons"]) == 0:
32 | node["ob_answer"] = get_singlehop_ob_answer(question, topic_entities)
33 | #print(node["ob_answer"])
34 | node["answer"] = aggregate_singlehop_answer(node["cb_answer"], node["ob_answer"])
35 | else:
36 | node["ob_answer"] = get_multihop_ob_answer(node, tree)
37 | #print(node["ob_answer"])
38 | node["child_answer"], node["answer"] = aggregate_multihop_answer(node, tree)
39 | except Exception as e:
40 | print("ERROR CASE")
41 | print(tree[-1])
42 | raise e
43 |
44 |
45 | #trees = [x for x in json.load(open("trees.json", "r")) if x[-1]["question_text"] == "Who is the paternal grandfather of Princess Yasmin Aga Khan?"]
46 | trees = json.load(open("trees.json", "r"))
47 | print("Total: %d | Start Processing..."%len(trees))
48 | parallel_process_data(trees, solve, PROC_NUM)
49 |
50 |
51 | print("END")
52 | os.makedirs("results", exist_ok=True)
53 | json.dump(trees, open("results/test.json", "w"), indent=2)
--------------------------------------------------------------------------------
/src/2wiki/RoHT/3_get_f1.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from termcolor import colored
4 | from evaluate import update_answer
5 | import math
6 |
7 |
8 | q2a = {}
9 | id2type = {}
10 | for item in json.load(open("../../../data/2wiki/dev.json")):
11 | id2type[item["_id"]] = item["type"]
12 | raw_data = [json.loads(line.strip()) for line in open('../../../released_data/2wikimultihopqa__v2_test_random_500.jsonl')]
13 | q2gold = {}
14 | for item in raw_data:
15 | question = item['question_text'].strip()
16 | gold = item['answers_objects'][0]['spans'][0]
17 | q_type = id2type[item["question_id"]]
18 | q2gold[question] = (gold, q_type)
19 |
20 | trees = json.load(open("./results/test.json", "r"))
21 | metrics = {}
22 | for q_type in ["all", "compositional", "inference", "comparison", "bridge_comparison"]:
23 | metrics[q_type] = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 'N': 0}
24 |
25 | print(len(trees))
26 | for i, tree in enumerate(trees):
27 | node = tree[-1]
28 | question, answer = node["question"], node["answer"][0]
29 | q2a[question] = answer
30 | gold, q_type = q2gold[question]
31 | em, f1, prec, recall = update_answer(metrics["all"], answer, gold)
32 | update_answer(metrics[q_type], answer, gold)
33 |
34 | for q_type in ["all", "compositional", "inference", "comparison", "bridge_comparison"]:
35 | print(q_type)
36 | print(metrics[q_type]['N'])
37 |
38 | for k in metrics[q_type].keys():
39 | metrics[q_type][k] /= metrics[q_type]['N']
40 | print(metrics[q_type])
41 |
42 |
43 | json.dump(q2a, open("q2a.json", "w"), indent=2)
--------------------------------------------------------------------------------
/src/2wiki/RoHT/aggregate/prompt.txt:
--------------------------------------------------------------------------------
1 | Given a qeustion and a context, answer the question and explain why.
2 |
3 | #
4 | Context:
5 | Who is the director of film Hypocrite? Miguel Morayta.
6 | When did Miguel Morayta die? 19 June 2013.
7 |
8 | Question:
9 | When did the director of film Hypocrite (Film) die?
10 |
11 | Answer:
12 | The film Hypocrite was directed by Miguel Morayta. Miguel Morayta died on 19 June 2013. So the answer is: 19 June 2013.
13 | #
14 | Context:
15 | When was the director of film Two Weeks With Pay born? November 28, 1919.
16 | When was the director of film Chhailla Babu born? 24 February 1939.
17 |
18 | Question:
19 | Which film has the director born first, Two Weeks With Pay or Chhailla Babu?
20 |
21 | Answer:
22 | The director of Two Weeks With Pay was born on November 28, 1919. The director of Chhailla Babu With Pay was born on 24 February 1939. Thus, Two Weeks With Pay has the director born first. So the answer is: Two Weeks With Pay.
23 | #
24 | Context:
25 | Where is Phu Luong located? Vietnam.
26 | Which country is Vietnam in? Southeast Asia.
27 |
28 | Question:
29 | Which country contains Phu Luong?
30 |
31 | Answer:
32 | Phu Luong is located in the country Vietnam. So the answer is: Vietnam.
33 | #
34 | Context:
35 | Who is the mother of Abraham Lincoln? Nancy Hanks Lincoln.
36 | Who is the father of Nancy Hanks Lincoln? James Hanks.
37 | Who is the father of James Hanks? Joseph Hanks.
38 |
39 | Question:
40 | Who is the maternal grandfather of Abraham Lincoln?
41 |
42 | Answer:
43 | The mother of Abraham Lincoln is Nancy Hanks Lincoln. The father of Nancy Hanks Lincoln is James Hanks. Thus, the maternal grandfather of Abraham Lincoln is James Hanks. So the answer is: James Hanks.
44 | #
--------------------------------------------------------------------------------
/src/2wiki/RoHT/cb/prompt.txt:
--------------------------------------------------------------------------------
1 | Please answer the question by thinking step-by-step.
2 | Q: When did the director of film Hypocrite (Film) die?
3 | A: The film Hypocrite was directed by Miguel Morayta. Miguel Morayta died on 19 June 2013. So the answer is: 19 June 2013.
4 | Q: Do director of film Coolie No. 1 (1995 Film) and director of film The Sensational Trial have the same nationality?
5 | A: Coolie No. 1 (1995 film) was directed by David Dhawan. The Sensational Trial was directed by Karl Freund. David Dhawan's nationality is India. Karl Freund's nationality is Germany. Thus, they do not have the same nationality. So the answer is: no.
6 | Q: Are both Kurram Garhi and Trojkrsti located in the same country?
7 | A: Kurram Garhi is located in the country of Pakistan. Trojkrsti is located in the country of Republic of Macedonia. Thus, they are not in the same country. So the answer is: no.
8 | Q: Who was born first out of Martin Hodge and Ivania Martinich?
9 | A: Martin Hodge was born on 4 February 1959. Ivania Martinich was born on 25 July 1995. Thus, Martin Hodge was born first. So the answer is: Martin Hodge.
10 | Q: Which film came out first, The Night Of Tricks or The Genealogy?
11 | A: The Night of Tricks was published in the year 1939. The Genealogy was published in the year 1979. Thus, The Night of Tricks came out first. So the answer is: The Night Of Tricks.
12 | Q: When did the director of film Laughter In Hell die?
13 | A: The film Laughter In Hell was directed by Edward L. Cahn. Edward L. Cahn died on August 25, 1963. So the answer is: August 25, 1963.
14 | Q: Which film has the director died later, The Gal Who Took the West or Twenty Plus Two?
15 | A: The film Twenty Plus Two was directed by Joseph M. Newman. The Gal Who Took the West was directed by Frederick de Cordova. Joseph M. Newman died on January 23, 2006. Fred de Cordova died on September 15, 2001. Thus, the person to die later from the two is Twenty Plus Two. So the answer is: Twenty Plus Two.
16 | Q: Who is Boraqchin (Wife Of ÃUgedei)'s father−in−law? ˝
17 | A: Boraqchin is married to ÃUgedei Khan. ÃUgedei Khan's father is Genghis Khan. Thus, Boraqchin's father−in−law is Genghis Khan. So the answer is: Genghis Khan.
18 | Q: What is the cause of death of Grand Duke Alexei Alexandrovich Of Russia's mother?
19 | A: The mother of Grand Duke Alexei Alexandrovich of Russia is Maria Alexandrovna. Maria Alexandrovna died from tuberculosis. So the answer is: tuberculosis.
20 | Q: Which film has the director died earlier, When The Mad Aunts Arrive or The Miracle Worker (1962 Film)?
21 | A: When The Mad Aunts Arrive was directed by Franz Josef Gottlieb. The Miracle Worker (1962 film) was directed by Arthur Penn. Franz Josef Gottlieb died on 23 July 2006. Arthur Penn died on September 28, 2010. Thus, of the two, the director to die earlier is Franz Josef Gottlieb, who directed When The Mad Aunts Arrive. So the answer is: When The Mad Aunts Arrive.
22 | Q: Which album was released earlier, What'S Inside or Cassandra'S Dream (Album)?
23 | A: What's Inside was released in the year 1995. Cassandra's Dream (album) was released in the year 2008. Thus, of the two, the album to release earlier is What's Inside. So the answer is: What's Inside.
24 | Q: Are both mountains, Serre Mourene and Monte Galbiga, located in the same country?
25 | A: Serre Mourene is located in Spain. Monte Galbiga is located in Italy. Thus, the two countries are not located in the same country. So the answer is: no.
26 | Q: What is the date of birth of the director of film Best Friends (1982 Film)?
27 | A: The film Best Friends was directed by Norman Jewison. Norman Jewison was born on July 21, 1926. So the answer is: July 21, 1926.
28 | Q: Which film has the director born first, Two Weeks With Pay or Chhailla Babu?
29 | A: Two Weeks with Pay was directed by Maurice Campbell. Chhailla Babu was directed by Joy Mukherjee. Maurice Campbell was born on November 28, 1919. Joy Mukherjee was born on 24 February 1939. Thus, from the two directors, Chhailla Babu was born first, who directed Two Weeks With Pay. So the answer is: Two Weeks With Pay.
30 | Q: Who is the grandchild of Krishna Shah (Nepalese Royal)?
31 | A: Krishna Shah has a child named Rudra Shah. Rudra Shah has a child named Prithvipati Shah. Thus, Krishna Shah has a grandchild named Prithvipati Shah. So the answer is: Prithvipati Shah.
32 | Q: When was the director of film P.S. Jerusalem born?
33 | A: P.S. Jerusalem was directed by Danae Elon. Danae Elon was born on December 23, 1970. So the answer is: December 23, 1970.
34 | Q: Which album was released more recently, If I Have to Stand Alone or Answering Machine Music?
35 | A: If I Have to Stand Alone was published in the year 1991. Answering Machine Music was released in the year 1999. Thus, of the two, the album to release more recently is Answering Machine Music. So the answer is: Answering Machine Music.
36 | Q: Where did the director of film Maddalena (1954 Film) die?
37 | A: The film Maddalena is directed by Augusto Genina. Augusto Genina died in Rome. So the answer is: Rome.
38 | Q: When did the director of film The Boy And The Fog die?
39 | A: The director of The Boy and the Fog is Roberto Gavaldøsn. Roberto Gavaldøsn died on September 4, 1986. So the answer is: September 4, 1986.
40 | Q: Are the directors of films The Sun of the Sleepless and Nevada (1927 film) both from the same country?
41 | A: The director of Sun of the Sleepless is Temur Babluani. The director of Nevada (1927 film) is John Waters. John Waters is from the country of America. Temur Babluani is from the country of Georgia. Thus, John Walters and Temur Babluani are not from the same country. So the answer is: no.
42 | Q: Who is the director of film Hypocrite (Film)?
43 | A: The film Hypocrite was directed by Miguel Morayta. So the answer is: Miguel Morayta.
44 | Q: When did Franz Josef Gottlieb die?
45 | A: Franz Josef Gottlieb died on 23 July 2006. So the answer is: 23 July 2006.
46 | Q: When was the album What'S Inside released?
47 | A: What's Inside was released in the year 1995. So the answer is: 1995.
48 | Q: Which country was the mountain Serre Mourene located in?
49 | A: Serre Mourene is located in Spain. So the answer is Spain.
--------------------------------------------------------------------------------
/src/2wiki/RoHT/count.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | trees = json.load(open("./results/test.json", "r"))
4 | cnt = defaultdict(int)
5 | total = 0
6 | for tree in trees:
7 | for node in tree:
8 | if "child_answer" in node:
9 | if node["answer"][1] == node["cb_answer"][1]:
10 | cnt["non_leaf_cb"] += 1
11 | elif node["answer"][1] == node["ob_answer"][1]:
12 | cnt["non_leaf_ob"] += 1
13 | else:
14 | cnt["non_leaf_ca"] += 1
15 | else:
16 | if node["answer"][1] == node["cb_answer"][1]:
17 | cnt["leaf_cb"] += 1
18 | else:
19 | cnt["leaf_ob"] += 1
20 | total += 1
21 |
22 | print(cnt)
23 | keys = ["leaf_ob", "leaf_cb"]
24 | print("leaf_cb: ", cnt["leaf_cb"], cnt["leaf_cb"] / (cnt["leaf_ob"] + cnt["leaf_cb"]))
25 | print("leaf_ob: ", cnt["leaf_ob"], cnt["leaf_ob"] / (cnt["leaf_ob"] + cnt["leaf_cb"]))
26 |
27 | print("non_leaf_cb:", cnt["non_leaf_cb"], cnt["non_leaf_cb"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
28 | print("non_leaf_ob:", cnt["non_leaf_ob"], cnt["non_leaf_ob"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
29 | print("non_leaf_ca:", cnt["non_leaf_ca"], cnt["non_leaf_ca"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
30 |
--------------------------------------------------------------------------------
/src/2wiki/RoHT/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import ujson as json
3 | import re
4 | import string
5 | from collections import Counter
6 | import pickle
7 |
8 | def normalize_answer(s):
9 |
10 | def remove_articles(text):
11 | return re.sub(r'\b(a|an|the)\b', ' ', text)
12 |
13 | def white_space_fix(text):
14 | return ' '.join(text.split())
15 |
16 | def remove_punc(text):
17 | exclude = set(string.punctuation)
18 | return ''.join(ch for ch in text if ch not in exclude)
19 |
20 | def lower(text):
21 | return text.lower()
22 |
23 | return white_space_fix(remove_articles(remove_punc(lower(s))))
24 |
25 |
26 | def f1_score(prediction, ground_truth):
27 | normalized_prediction = normalize_answer(prediction)
28 | normalized_ground_truth = normalize_answer(ground_truth)
29 |
30 | ZERO_METRIC = (0, 0, 0)
31 |
32 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
33 | return ZERO_METRIC
34 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
35 | return ZERO_METRIC
36 |
37 | prediction_tokens = normalized_prediction.split()
38 | ground_truth_tokens = normalized_ground_truth.split()
39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
40 | num_same = sum(common.values())
41 | if num_same == 0:
42 | return ZERO_METRIC
43 | precision = 1.0 * num_same / len(prediction_tokens)
44 | recall = 1.0 * num_same / len(ground_truth_tokens)
45 | f1 = (2 * precision * recall) / (precision + recall)
46 | return f1, precision, recall
47 |
48 |
49 | def exact_match_score(prediction, ground_truth):
50 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
51 |
52 | def update_answer(metrics, prediction, gold):
53 | em = exact_match_score(prediction, gold)
54 | f1, prec, recall = f1_score(prediction, gold)
55 | metrics['em'] += float(em)
56 | metrics['f1'] += f1
57 | metrics['prec'] += prec
58 | metrics['recall'] += recall
59 | metrics['N'] += 1
60 | return em, f1, prec, recall
61 |
62 | def eval():
63 | gold = [json.loads(line.strip()) for line in open('/data/csl/exp/LLMReasoning/released_data/hotpotqa__v2_test_random_500.jsonl')]
64 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 'N': 0}
65 | for dp in gold:
66 | print(dp)
67 | print(dp['answers_objects'][0])
68 | answer = dp['answers_objects'][0]['spans'][0]
69 |
70 | # cur_id = dp['_id']
71 | # if cur_id not in prediction['answer']:
72 | # print('missing answer {}'.format(cur_id))
73 | # else:
74 | em, f1, prec, recall = update_answer(
75 | metrics, answer, answer)
76 |
77 | N = len(gold)
78 | for k in metrics.keys():
79 | metrics[k] /= N
80 |
81 | print(metrics)
82 |
83 | if __name__ == '__main__':
84 | eval()
--------------------------------------------------------------------------------
/src/2wiki/RoHT/ob/multihop_prompt.txt:
--------------------------------------------------------------------------------
1 | Given a question and the relevant Wikipedia text, answer the question and explain why. If you are unsure, answer Unknown.
2 |
3 | #1 Wikipedia Title: Hypocrite (film)
4 | Text: Hypocrite (Spanish: Hipócrita..!) is a 1949 Mexican thriller film directed by Miguel Morayta and starring Antonio Badú, Leticia Palma, Carmen Molina and Luis Beristáin. The film included the song "Hipócrita". The film's sets were designed by Francisco Marco Chillet.
5 | #2 Wikipedia Title: When the Legends Die
6 | Text: When The Legends Die is a 1963 novel, by Hal Borland, and a DeLuxe Color film released in 1972 by Twentieth Century-Fox.
7 | #3 Wikipedia Title: Who Is the Guilty?
8 | Text: Who is the Guilty? ( sometimes" Who is to Blame?") is a 1925 Georgian silent film directed by Alexandre Tsutsunava
9 | #4 Wikipedia Title: Miguel Morayta
10 | Text: Miguel Morayta( 15 August 1907 – 19 June 2013) was a Spanish film director and screenwriter. He directed 74 films between 1944 and 1978. At the outbreak of the Spanish Civil War, Morayta was a Spanish artillery officer, who joined the Republican side. After Francisco Franco's victory, he left Spain for France and Africa, finally arriving in Mexico in 1941, where he started his career. He was living in Mexico when he died aged 105.
11 | #5 Wikipedia Title: Joselito vagabundo
12 | Text: Joselito vagabundo(" Joselito Vagabond") is a 1966 Mexican film. It stars Sara García and is directed by Miguel Morayta.
13 | Q: When did the director of film Hypocrite (Film) die?
14 | A: The film Hypocrite was directed by Miguel Morayta. Miguel Morayta died on 19 June 2013. So the answer is: 19 June 2013.
15 |
16 | #1 Wikipedia Title: Kurram Garhi
17 | Text: Kurram Garhi is a small village located near the city of Bannu, which is the part of Khyber Pakhtunkhwa province of Pakistan. Its population is approximately 35000. Barren hills are near this village. This village is on the border of Kurram Agency. Other nearby villages are Peppal, Surwangi and Amandi Kala.
18 | #2 Wikipedia Title: Kurram Garhi Hydropower Plant
19 | Text: Kurram Garhi Hydropower Plant( KGHPP) is a small, low- head, run- of- the- river hydroelectric power generation station of 4.0 megawatt generation capacity( four units of 1.0 MW each), located at Kurram Garhi, a small town in Bannu KPK province of Pakistan on the flows of Kuchkot Canal from Kurram River. It is a small hydel power generating plant constructed and put in commercial operation on February 1958 with the Average Annual generating capacity of 17 million units( GWh) of least expensive electricity.
20 | #3 Wikipedia Title: Trojkrsti
21 | Text: Trojkrsti is a village in Municipality of Prilep, Republic of Macedonia.
22 | #4 Wikipedia Title: All Men Are the Same
23 | Text: All Men Are the Same is a 1994 Spanish comedy film directed by Manuel Gómez Pereira.
24 | #5 Wikipedia Title: The Both
25 | Text: The Both is an American musical duo consisting of Aimee Mann and Ted Leo, both of whom had longstanding musical careers before beginning a collaboration in 2013. Their first album, self- titled" The Both", was released in April 2014.
26 | Q: Are both Kurram Garhi and Trojkrsti located in the same country?
27 | A: Kurram Garhi is located in the country of Pakistan. Trojkrsti is located in the country of Republic of Macedonia. Thus, they are not in the same country. So the answer is: no.
28 |
29 | #1 Wikipedia Title: Krishna Shah (Nepalese royal)
30 | Text: Krishna Shah (?–1661) was the king of the Gorkha Kingdom in the Indian subcontinent, present-day Nepal. He was the father of Rudra Shah.
31 | #2 Wikipedia Title: Neer Shah
32 | Text: Neer Bikram Shah, also known as Nir Shah, is a Nepalese movie actor, a poet, lyricist, movie director, and businessman. He is related to the Royal family of Nepal.
33 | #3 Wikipedia Title: Gajraj Mishra
34 | Text: Rajguru Gajraj Mishra also spelled Gajaraj Mishra was a Nepalese politician, ambassador, diplomat and a royal priest of Shah dynasty. He was always inclined to his disciple Prince Regent Bahadur Shah of Nepal. Gajraj Mishra was disfavoured by his disciple King Pratap Singh Shah due to his support to Prince Bahadur Shah. He was also disfavoured by Pratap Singh's son Rana Bahadur Shah.
35 | #4 Wikipedia Title: Princess Helen Shah of Nepal
36 | Text: Princess Helen Shah of Nepal( September 21, 1926 – September 12, 2008) was a member of the former Nepalese royal family. She was the wife of Prince Basundhara of Nepal, a son of King Tribhuvan of Nepal and his second wife, Queen Ishwari.
37 | #5 Wikipedia Title: Rudra Shah
38 | Text: Rudra Shah (?–1673) was the king of the Gorkha Kingdom in the Indian subcontinent, present-day Nepal. He was the father of Prithvipati Shah.
39 | Q: Who is the grandchild of Krishna Shah (Nepalese Royal)?
40 | A: Krishna Shah has a child named Rudra Shah. Rudra Shah has a child named Prithvipati Shah. Thus, Krishna Shah has a grandchild named Prithvipati Shah. So the answer is: Prithvipati Shah.
--------------------------------------------------------------------------------
/src/2wiki/RoHT/ob/singlehop_prompt.txt:
--------------------------------------------------------------------------------
1 | Given a question and the relevant Wikipedia text, answer the question and explain why. If you are unsure, answer Unknown.
2 |
3 | #1 Wikipedia Title: Hypocrite (film)
4 | Text: Hypocrite (Spanish: Hipócrita..!) is a 1949 Mexican thriller film directed by Miguel Morayta and starring Antonio Badú, Leticia Palma, Carmen Molina and Luis Beristáin. The film included the song "Hipócrita". The film's sets were designed by Francisco Marco Chillet.
5 | #2 Wikipedia Title: Who? (film)
6 | Text: Who? is a 1974 film based on the 1958 novel of the same name by Algis Budrys. It was directed by Jack Gold and stars Elliott Gould, Trevor Howard, and Joseph Bova. Some video releases were retitled "The Man in the Steel Mask" or "Roboman".
7 | #3 Wikipedia Title: Who Is the Guilty?
8 | Text: Who is the Guilty? ( sometimes" Who is to Blame?") is a 1925 Georgian silent film directed by Alexandre Tsutsunava
9 | #4 Wikipedia Title: Who Is the Man?
10 | Text: Who Is The Man?( 1924) is a British silent film drama directed by Walter Summers. The film was based on the successful French play" Daniel" by Louis Verneuil and is notable as the first screen appearance of John Gielgud.
11 | #5 Wikipedia Title: Deceit (1923 film)
12 | Text: Deceit( sometimes referred to as The Deceit) is a 1923 American silent black- and- white film. It is a conventional melodrama directed by Oscar Micheaux. Like many of Micheaux\'s films," Deceit" casts clerics in a negative light. Although the film was shot in 1921, it was not released until 1923. It is not known whether the film currently survives, which suggests that it is a lost film. The 1922 film" The Hypocrite" was shown within" Deceit" as a film within a film.
13 | Q: Who is the director of film Hypocrite (Film)?
14 | A: The film Hypocrite is directed by Miguel Morayta. So the answer is: Miguel Morayta.
15 |
16 | #1 Wikipedia Title: Kurram Garhi
17 | Text: Kurram Garhi is a small village located near the city of Bannu, which is the part of Khyber Pakhtunkhwa province of Pakistan. Its population is approximately 35000. Barren hills are near this village. This village is on the border of Kurram Agency. Other nearby villages are Peppal, Surwangi and Amandi Kala.
18 | #2 Wikipedia Title: Kurram Garhi Hydropower Plant
19 | Text: Kurram Garhi Hydropower Plant( KGHPP) is a small, low- head, run- of- the- river hydroelectric power generation station of 4.0 megawatt generation capacity( four units of 1.0 MW each), located at Kurram Garhi, a small town in Bannu KPK province of Pakistan on the flows of Kuchkot Canal from Kurram River. It is a small hydel power generating plant constructed and put in commercial operation on February 1958 with the Average Annual generating capacity of 17 million units( GWh) of least expensive electricity.
20 | #3 Wikipedia Title: Country Is
21 | Text: " Country Is" is a song written and recorded by American country music artist Tom T. Hall. It was released in September 1974 as the second and final single from the album of the same name," Country Is". The song was Hall\'s fifth number one on the country chart. The single went to number one for a single week and spent a total of eleven weeks on the country chart.
22 | #4 Wikipedia Title: Which Way Is Up?
23 | Text: Which Way is Up? is a 1977 American comedy film starring Richard Pryor and directed by Michael Schultz. It is a remake of the 1972 Italian comedy film" The Seduction of Mimi". Richard Pryor plays three roles: an orange picker who has two women at the same time, the orange picker\'s father, and a reverend who gets the orange picker\'s wife pregnant.
24 | #5 Wikipedia Title: In Country
25 | Text: In Country is a 1989 American drama film produced and directed by Norman Jewison, starring Bruce Willis and Emily Lloyd. The screenplay by Frank Pierson and Cynthia Cidre was based on the novel by Bobbie Ann Mason. The original music score was composed by James Horner. Willis earned a best supporting actor Golden Globe nomination for his role.
26 | Q: Which country is Kurram Garhi located in?
27 | A: Kurram Garhi is located in the country of Pakistan. So the answer is: Pakistan.
28 |
29 | #1 Wikipedia Title: Neer Shah
30 | Text: Neer Bikram Shah, also known as Nir Shah, is a Nepalese movie actor, a poet, lyricist, movie director, and businessman. He is related to the Royal family of Nepal.
31 | #2 Wikipedia Title: Gajraj Mishra
32 | Text: TRajguru Gajraj Mishra also spelled Gajaraj Mishra was a Nepalese politician, ambassador, diplomat and a royal priest of Shah dynasty. He was always inclined to his disciple Prince Regent Bahadur Shah of Nepal. Gajraj Mishra was disfavoured by his disciple King Pratap Singh Shah due to his support to Prince Bahadur Shah. He was also disfavoured by Pratap Singh's son Rana Bahadur Shah.
33 | #3 Wikipedia Title: Princess Helen Shah of Nepal
34 | Text: Princess Helen Shah of Nepal( September 21, 1926 – September 12, 2008) was a member of the former Nepalese royal family. She was the wife of Prince Basundhara of Nepal, a son of King Tribhuvan of Nepal and his second wife, Queen Ishwari.
35 | #4 Wikipedia Title: Krishna Shah (Nepalese royal)
36 | Text: Krishna Shah (?–1661) was the king of the Gorkha Kingdom in the Indian subcontinent, present-day Nepal. He was the father of Rudra Shah.
37 | #5 Wikipedia Title: Ajaya Pratap Shah
38 | Text: Ajay Pratap Shah( died September 12,? in Lucknow, India) was a Nepalese politician, belonging to the Rastriya Prajatantra Party. In 1999 parliamentary election he was elected from the Kapilvastu- 4 constituency, with 14091 votes. After the royal coup d'état in February 2005, Shah went into exile in India. After his death, RPP nominated his son, Abhisek Pratap Shah, to take his parliamentary seat in January 2008.
39 | Q: Who is the child of Krishna Shah (Nepalese Royal)?
40 | A: Krishna Shah was the father of Rudra Shah. So the answer is: Rudra Shah.
--------------------------------------------------------------------------------
/src/2wiki/RoHT/openai_req.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import requests
3 | import time
4 | import os
5 | import json, jsonlines
6 |
7 | class OpenaiReq():
8 | def __init__(self):
9 | self.url = "http://127.0.0.1:10001/api/openai/completion"
10 | self.cache = {}
11 | self.cache_path = "./cache.jsonl"
12 | if os.path.exists(self.cache_path):
13 | with open(self.cache_path, "r") as f:
14 | for i, line in enumerate(f):
15 | #print(i+1)
16 | datum = json.loads(line.strip())
17 | self.cache[tuple(datum["input"])] = datum["response"]
18 | f.close()
19 |
20 | def req2openai(self, prompt, model="text-davinci-003", temperature=0, max_tokens=128, stop=None, logprobs=1, use_cache=True):
21 | assert isinstance(prompt, str)
22 | input = (prompt, model, max_tokens, stop, logprobs)
23 | if use_cache and temperature == 0 and input in self.cache:
24 | return self.cache[input], True
25 | for i in range(3):
26 | try:
27 | response = requests.post(self.url, json = {
28 | "model": model,
29 | "prompt": prompt,
30 | "temperature": temperature,
31 | "max_tokens": max_tokens,
32 | "stop": stop,
33 | "logprobs": logprobs,
34 | })
35 | if response.status_code != 200:
36 | raise Exception(response.text)
37 | break
38 | except Exception as e:
39 | err_msg = str(e)
40 | print(e)
41 | if "reduce your prompt" in err_msg: # this is because the input string too long
42 | return ['too long'], False
43 | try:
44 | response = response.json()['choices']
45 | except:
46 | return ['openai error'], False
47 | if temperature == 0:
48 | input = (prompt, model, max_tokens, stop, logprobs)
49 | res = response[0]
50 | if input not in self.cache:
51 | self.cache[input] = [res]
52 | with open(self.cache_path, "a") as f:
53 | f.write("%s\n"%json.dumps({"input": input, "response": [res]}))
54 | f.close()
55 | return response, True
56 |
57 | if __name__ == "__main__":
58 | caller = OpenaiReq()
59 | res = caller.req2openai("你好", use_cache=True)
60 | print(res)
61 |
62 |
--------------------------------------------------------------------------------
/src/2wiki/RoHT/parallel.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures, random, time
2 |
3 | def handle_item(data):
4 | waiting = random.random() * 3 + 1
5 | print("Thread %d, Waiting %.2f ..."%(data, waiting))
6 | time.sleep(waiting)
7 | if random.random() < 0.5:
8 | raise Exception()
9 | print("Thread %d, OK."%(data))
10 |
11 | def parallel_process_data(data, handle_item, workers=20, callback=None):
12 | with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
13 | futures = []
14 | for item in data:
15 | future = executor.submit(handle_item, item)
16 | futures.append(future)
17 | for future in concurrent.futures.as_completed(futures):
18 | result = future.result()
19 | if callback:
20 | callback(result)
21 |
22 | if __name__ == "__main__":
23 | parallel_process_data([i for i in range(20)], handle_item)
24 | print("end")
--------------------------------------------------------------------------------
/src/2wiki/RoHT/question_answering.py:
--------------------------------------------------------------------------------
1 | from openai_req import OpenaiReq
2 | import requests
3 | import os
4 |
5 | serp_api_key = ""#@param {type:"string"}
6 | os.environ["SERP_API_KEY"] = serp_api_key
7 |
8 | openai_caller = OpenaiReq()
9 |
10 | def bm25_search(question, k):
11 | web = "http://127.0.0.1:1440"
12 | data = {
13 | "query": question,
14 | "k": k
15 | }
16 | for i in range(3):
17 | try:
18 | r = requests.get(web, json=data)
19 | if r.status_code != 200:
20 | raise Exception(r.text)
21 | return r.json()
22 | except Exception as e:
23 | print(e)
24 |
25 | def postprocess(response):
26 | response = response[0]
27 | if response == 'too long' or response['finish_reason'] != 'stop':
28 | return 'ERROR: prompt too long', -100, ""
29 | tokens = response['logprobs']['tokens']
30 | token_logprobs = response['logprobs']['token_logprobs']
31 | cot = response['text'].strip()
32 | if len(token_logprobs) == 0:
33 | return 'ERROR: empty output', -100, cot
34 | pos = 0
35 | for idx, token in enumerate(tokens):
36 | if token.strip() == 'So' and idx + 1 <= len(tokens) and tokens[idx + 1].strip() == 'the' and idx + 2 <= len(tokens) and tokens[idx + 2].strip() == 'answer' and idx + 3 <= len(tokens) and tokens[idx + 3].strip() == 'is' and idx + 4 <= len(tokens) and tokens[idx + 4].strip() == ':':
37 | pos = idx
38 | break
39 | if tokens[-1] == '.':
40 | answer_logprobs = token_logprobs[pos+5:-1]
41 | answer = cot.split('So the answer is: ')[-1][:-1]
42 | else:
43 | answer_logprobs = token_logprobs[pos+5:]
44 | answer = cot.split('So the answer is: ')[-1]
45 | cot_process = cot.split('So the answer is: ')[0].strip()
46 | cot_process_logprobs = token_logprobs[:pos]
47 | if len(cot_process_logprobs) == 0:
48 | cot_process_logprob = -100
49 | else:
50 | cot_process_logprob = sum(cot_process_logprobs) / len(cot_process_logprobs)
51 | return answer, cot_process_logprob, cot
52 |
53 | def get_cb_answer(question):
54 | instruction = '\n'.join([_.strip() for _ in open('cb/prompt.txt').readlines()])
55 | prompt = instruction + '\nQ: ' + question + '\nA:'
56 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='Q:', use_cache=True)
57 | return postprocess(response)
58 |
59 | def get_singlehop_ob_answer(question, topic_entities):
60 | instruction = '\n'.join([_.strip() for _ in open('ob/singlehop_prompt.txt').readlines()])
61 | k = 5
62 | contexts = []
63 | hist = set()
64 | r = bm25_search(question, k)
65 | for datum in r:
66 | title, text = datum["title"], datum["paragraph_text"]
67 | stamp = title + text
68 | if not stamp in hist:
69 | hist.add(stamp)
70 | contexts.append([title, text])
71 |
72 | prompt = instruction + '\n'
73 | for idx, (title, text) in enumerate(contexts):
74 | prompt += '\n#' + str(idx + 1) + ' Wikipedia Title: ' + title + '\nText: ' + text
75 | prompt += '\nQ: ' + question + '\nA:'
76 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n', use_cache=True)
77 | return postprocess(response)
78 |
79 | def aggregate_singlehop_answer(cb_answer, ob_answer):
80 | cb_ans, cb_score, cb_cot = cb_answer
81 | ob_ans, ob_score, ob_cot = ob_answer
82 | if "ERROR" in cb_ans or 'Unknown' in cb_ans:
83 | cb_ans, cb_score = "", -100
84 | if "ERROR" in ob_ans or 'Unknown' in ob_ans:
85 | ob_ans, ob_score = "", -100
86 | return max([(cb_ans, cb_score, cb_cot), (ob_ans, ob_score, ob_cot)], key=lambda x:x[1])
87 |
88 | def get_multihop_ob_answer(node, tree):
89 |
90 | def is_descendant(a, b):
91 | while "fa" in tree[a]:
92 | a = tree[a]["fa"]
93 | if a == b:
94 | return True
95 | return False
96 |
97 | question = node["question"]
98 | instruction = '\n'.join([_.strip() for _ in open('ob/multihop_prompt.txt').readlines()])
99 | k = 5
100 | contexts = []
101 | hist = set()
102 | r = bm25_search(question, k)
103 | for datum in r:
104 | title, text = datum["title"], datum["paragraph_text"]
105 | stamp = title + text
106 | if stamp not in hist:
107 | hist.add(stamp)
108 | contexts.append([title, text])
109 |
110 | for idx in range(node["idx"]):
111 | if is_descendant(idx, node["idx"]):
112 | sub_question = tree[idx]["question"]
113 | r = bm25_search(sub_question, 3)
114 | for datum in r:
115 | title, text = datum["title"], datum["paragraph_text"]
116 | stamp = title + text
117 | if stamp not in hist:
118 | hist.add(stamp)
119 | contexts.append([title, text])
120 |
121 | prompt = instruction + '\n'
122 | for idx, (title, text) in enumerate(contexts):
123 | prompt += '\n#' + str(idx + 1) + ' Wikipedia Title: ' + title + '\nText: ' + text
124 | prompt += '\nQ: ' + question + '\nA:'
125 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n', use_cache=True)
126 | return postprocess(response)
127 |
128 | def calculate_score1(cot_process_logprob, qd_score, sub_answer_scores):
129 | return cot_process_logprob + qd_score + sum(sub_answer_scores)
130 |
131 | def calculate_score2(cot_process_logprob, qd_score, sub_answer_scores):
132 | return (cot_process_logprob + qd_score + sum(sub_answer_scores)) / (len(sub_answer_scores) + 2)
133 |
134 | def aggregate_multihop_answer(node, tree):
135 | instruction = '\n'.join([_.strip() for _ in open('aggregate/prompt.txt').readlines()])
136 | question = node["question"]
137 | qd_score = node["qd_logprob"]
138 | context = ''
139 | sub_answer_scores = []
140 | for son_idx in node["sons"]:
141 | sub_question = tree[son_idx]["question"]
142 | sub_answer = tree[son_idx]["answer"][0]
143 | sub_answer_scores.append(tree[son_idx]["answer"][1])
144 | context += '\n' + sub_question + ' ' + sub_answer
145 | prompt = instruction + '\nContext:\n{}\n\nQuestion:\n{}\n\nAnswer:'.format(context, question)
146 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n\n', use_cache=True)
147 | child_answer, cot_process_logprob, child_cot = postprocess(response)
148 |
149 | child_ans = child_answer
150 | child_score = calculate_score2(cot_process_logprob, qd_score, sub_answer_scores)
151 | res1 = (child_ans, child_score, child_cot)
152 | cb_ans, cb_score, cb_cot = node["cb_answer"]
153 | ob_ans, ob_score, ob_cot = node["ob_answer"]
154 | if "ERROR" in cb_ans or 'Unknown' in cb_ans:
155 | cb_ans, cb_score = "", -100
156 | if "ERROR" in ob_ans or 'Unknown' in ob_ans:
157 | ob_ans, ob_score = "", -100
158 | if "ERROR" in child_ans or "Unknow" in child_ans:
159 | child_ans, child_score = "", -100
160 | res2 = max([(cb_ans, cb_score, cb_cot), (ob_ans, ob_score, ob_cot), (child_ans, child_score, child_cot)], key=lambda x:x[1])
161 | return res1, res2
162 |
163 |
164 |
165 | if __name__ == "__main__":
166 | question = "Who is Joan Of Savoy's father?"
167 | r = bm25_search(question, k=5)
168 | for x in r:
169 | print(x["title"])
170 | print(x["paragraph_text"])
171 | print()
172 |
173 |
174 |
175 |
176 |
--------------------------------------------------------------------------------
/src/2wiki/Tree_Generation/0_get_prompt.py:
--------------------------------------------------------------------------------
1 | import json, jsonlines
2 |
3 | instruction = '\n'.join([_.strip() for _ in open('prompt.txt').readlines()])
4 |
5 | raw_data = jsonlines.open("../../released_data/2wikimultihopqa__v2_test_random_500.jsonl", "r")
6 |
7 | prompts = []
8 | for item in raw_data:
9 | question = item["question_text"]
10 | prompt = instruction + '\nQ: ' + question + '\nA: '
11 | prompts.append(prompt)
12 |
13 | json.dump(prompts, open('prompts.json', 'w'), indent = 2)
14 | print(len(prompts))
--------------------------------------------------------------------------------
/src/2wiki/Tree_Generation/1_query.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from openai_req import OpenaiReq
4 | import random
5 | from tqdm import tqdm
6 | import os
7 | from multiprocessing import Pool
8 | from termcolor import colored
9 | random.seed(42)
10 |
11 | MAX_SPLIT = 64
12 | STEP = 4
13 |
14 | def query(rank, prompts):
15 | print('Process rank {} PID {} begin...'.format(rank, os.getpid()))
16 | reqor = OpenaiReq()
17 | queries = prompts[int(len(prompts) * rank / MAX_SPLIT) : int(len(prompts) * (rank + 1) / MAX_SPLIT)]
18 | try:
19 | fout = open('outputs/rank_{}.json'.format(rank), 'w')
20 | if rank == 0:
21 | bar = tqdm(range(len(queries) // STEP + 1))
22 | else:
23 | bar = range(len(queries) // STEP + 1)
24 | for idx in bar:
25 | inputs = queries[idx * STEP : (idx + 1) * STEP]
26 | if len(inputs) == 0:
27 | break
28 | gpt_results = []
29 | for prompt in inputs:
30 | result, tag = reqor.req2openai(prompt, max_tokens = 512, stop = '\n\n')
31 | gpt_results.append(result[0])
32 | for prompt, res in zip(inputs, gpt_results):
33 | # print(res)
34 | fout.write(json.dumps({'prompt': prompt, 'response': res}) + '\n')
35 | fout.flush()
36 | fout.close()
37 | except Exception as err:
38 | print(Exception, err)
39 |
40 | if __name__=='__main__':
41 | prompts = json.load(open('prompts.json'))
42 | os.makedirs("outputs", exist_ok=False)
43 | print("number of prompts: {}".format(len(prompts)))
44 | print('Parent process %s.' % os.getpid())
45 | p = Pool(MAX_SPLIT)
46 | for i in range(MAX_SPLIT):
47 | p.apply_async(query, args=(i, prompts))
48 | print('Waiting for all subprocesses done...')
49 | p.close()
50 | p.join()
51 | print('All subprocesses done.')
--------------------------------------------------------------------------------
/src/2wiki/Tree_Generation/2_postprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from termcolor import colored
4 | import os
5 |
6 | raw_data = json.load(open('outputs/predictions.json'))
7 |
8 | data = {}
9 | for item in tqdm(raw_data):
10 | prompt = item['prompt']
11 | question = prompt.split('\n')[-2][len('Q: '):].strip()
12 | print(colored(question, 'red'))
13 | try:
14 | qds = item['response']['text'].strip()
15 | if qds.endswith('.'):
16 | qds = qds[:-1]
17 | hqdt = json.loads(qds)
18 | except:
19 | hqdt = None
20 |
21 |
22 |
23 |
24 | tokens = item['response']['logprobs']['tokens']
25 | token_logprobs = item['response']['logprobs']['token_logprobs']
26 | if len(token_logprobs) == 0:
27 | continue
28 |
29 | if tokens[-1] == '.':
30 | token_logprobs = token_logprobs[:-1]
31 |
32 | st, ed = 0, 0
33 | pos = 0
34 | qds = {}
35 | for sub_question, qd in hqdt.items():
36 | while pos < len(tokens):
37 | if "[" in tokens[pos] and ": [\"" in "".join(tokens[max(pos-1, 0): min(pos+2, len(tokens))]):
38 | st = pos
39 | break
40 | pos += 1
41 | while pos < len(tokens):
42 | if "]" in tokens[pos] and "\"]" in "".join(tokens[max(pos-1, 0): min(pos+2, len(tokens))]):
43 | ed = pos
44 | break
45 | pos += 1
46 | assert pos < len(tokens), sub_question
47 | qd_score = sum(token_logprobs[st:ed+1]) / len(token_logprobs[st:ed+1])
48 | qds[sub_question] = (qd, qd_score)
49 | print(colored(sub_question, 'blue'))
50 | print("".join(tokens[st:ed+1]))
51 |
52 |
53 | data[question] = qds
54 | json.dump(data, open('question_decompositions.json', 'w'), indent = 2)
--------------------------------------------------------------------------------
/src/2wiki/Tree_Generation/3_postprocess_tree.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | raw_data = json.load(open('question_decompositions.json'))
4 |
5 | def check(question):
6 | if '#1' in question or '#2' in question or '#3' in question or '#4' in question:
7 | return True
8 | tree = {}
9 | for father in raw_data:
10 | if check(father):
11 | continue
12 | qds = raw_data[father]
13 | if qds is None:
14 | continue
15 | tree[father] = {}
16 | for question in qds:
17 | if check(question):
18 | continue
19 | if any([x == question for x in qds[question][0]]):
20 | tree[father][question] = [[], None]
21 | else:
22 | tree[father][question] = qds[question]
23 |
24 | question_decompositions = {}
25 | for father in tree:
26 | qds = tree[father]
27 | for q in qds:
28 | if q not in question_decompositions:
29 | question_decompositions[q] = qds[q]
30 | else:
31 | if question_decompositions[q] != qds[q]:
32 | print(question_decompositions[q])
33 | print(qds[q])
34 | else:
35 | print('haha')
36 |
37 | json.dump(question_decompositions, open('tree.json', 'w'), indent = 2)
38 |
39 | print(len(tree))
--------------------------------------------------------------------------------
/src/2wiki/Tree_Generation/combine.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | def findAllFile(base):
5 | for root, ds, fs in os.walk(base):
6 | for f in fs:
7 | yield f
8 | base = './outputs'
9 | data = []
10 | for file_name in findAllFile(base):
11 | data += [json.loads(line.strip()) for line in open(os.path.join(base, file_name))]
12 | # data.update(json.load(open(os.path.join(base, file_name))))
13 | print(len(data))
14 | json.dump(data, open(os.path.join(base, 'predictions.json'), 'w'), indent = 2)
--------------------------------------------------------------------------------
/src/2wiki/Tree_Generation/openai_req.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import requests
3 | import time
4 | import os
5 | import json, jsonlines
6 |
7 | class OpenaiReq():
8 | def __init__(self):
9 | self.url = "http://127.0.0.1:10001/api/openai/completion"
10 | self.cache = {}
11 | self.cache_path = "./cache.jsonl"
12 | if os.path.exists(self.cache_path):
13 | with open(self.cache_path, "r") as f:
14 | for i, line in enumerate(f):
15 | #print(i+1)
16 | datum = json.loads(line.strip())
17 | self.cache[tuple(datum["input"])] = datum["response"]
18 | f.close()
19 |
20 | def req2openai(self, prompt, model="text-davinci-003", temperature=0, max_tokens=128, stop=None, logprobs=1, use_cache=True):
21 | assert isinstance(prompt, str)
22 | input = (prompt, model, max_tokens, stop, logprobs)
23 | if use_cache and temperature == 0 and input in self.cache:
24 | return self.cache[input], True
25 | for i in range(3):
26 | try:
27 | response = requests.post(self.url, json = {
28 | "model": model,
29 | "prompt": prompt,
30 | "temperature": temperature,
31 | "max_tokens": max_tokens,
32 | "stop": stop,
33 | "logprobs": logprobs,
34 | })
35 | if response.status_code != 200:
36 | raise Exception(response.text)
37 | break
38 | except Exception as e:
39 | err_msg = str(e)
40 | print(e)
41 | if "reduce your prompt" in err_msg: # this is because the input string too long
42 | return ['too long'], False
43 | try:
44 | response = response.json()['choices']
45 | except:
46 | return ['openai error'], False
47 | if temperature == 0:
48 | input = (prompt, model, max_tokens, stop, logprobs)
49 | res = response[0]
50 | if input not in self.cache:
51 | self.cache[input] = [res]
52 | with open(self.cache_path, "a") as f:
53 | f.write("%s\n"%json.dumps({"input": input, "response": [res]}))
54 | f.close()
55 | return response, True
56 |
57 | if __name__ == "__main__":
58 | caller = OpenaiReq()
59 | res = caller.req2openai("你好", use_cache=True)
60 | print(res)
61 |
62 |
--------------------------------------------------------------------------------
/src/2wiki/Tree_Generation/prompt.txt:
--------------------------------------------------------------------------------
1 | Please generate a hierarchical question decomposition tree (HQDT) with json format for a given question. In this tree, the root node is the original complex question, and each non-root node is a sub-question of its parent. The leaf nodes are atomic questions that cannot be further decomposed.
2 | Q: When did the director of film Hypocrite (Film) die?
3 | A: {"When did the director of film Hypocrite (Film) die?": ["Who is the director of film Hypocrite (Film)?", "When did #1 die?"]}.
4 | Q: Do director of film Coolie No. 1 (1995 Film) and director of film The Sensational Trial have the same nationality?
5 | A: {"Do director of film Coolie No. 1 (1995 Film) and director of film The Sensational Trial have the same nationality?": ["What is the nationality of the director of film Coolie No. 1 (1995 Film)?", "What is the nationality of the director of film The Sensational Trial?"], "What is the nationality of the director of film Coolie No. 1 (1995 Film)?": ["Who is the director of film Coolie No. 1 (1995 Film)?", "What is the nationality of #1?"], "What is the nationality of the director of film The Sensational Trial?": ["Who is the director of film The Sensational Trial?", "What is the nationality of #1?"]}.
6 | Q: Are both Kurram Garhi and Trojkrsti located in the same country?
7 | A: {"Are both Kurram Garhi and Trojkrsti located in the same country?": ["Which country is Kurram Garhi located in?", "Which country is Trojkrsti located in?"]}.
8 | Q: Who was born first out of Martin Hodge and Ivania Martinich?
9 | A: {"Who was born first out of Martin Hodge and Ivania Martinich?": ["When was Martin Hodge born?", "When was Ivania Martinich born?"]}.
10 | Q: Which film came out first, The Night Of Tricks or The Genealogy?
11 | A: {"Which film came out first, The Night Of Tricks or The Genealogy?": ["When was the film The Night Of Tricks published?", "When was the film The Genealogy published?"]}.
12 | Q: When did the director of film Laughter In Hell die?
13 | A: {"When did the director of film Laughter In Hell die?": ["Who is the director of film Laughter In Hell?", "When did #1 die?"]}.
14 | Q: Which film has the director died later, The Gal Who Took the West or Twenty Plus Two?
15 | A: {"Which film has the director died later, The Gal Who Took the West or Twenty Plus Two?": ["When did the director of film The Gal Who Took the West die?", "When did the director of film Twenty Plus Two die?"], "When did the director of film The Gal Who Took the West die?": ["Who is the director of film The Gal Who Took the West?", "When did #1 die?"], "When did the director of film Twenty Plus Two die?": ["Who is the director of film Twenty Plus Two?", "When did #1 die?"]}.
16 | Q: Who is Boraqchin (Wife Of ÃUgedei)'s father−in−law?
17 | A: {"Who is Boraqchin (Wife Of ÃUgedei)'s father−in−law?": ["Who is Boraqchin married to?", "Who is the father of #1?"]}.
18 | Q: What is the cause of death of Grand Duke Alexei Alexandrovich Of Russia's mother?
19 | A: {"What is the cause of death of Grand Duke Alexei Alexandrovich Of Russia's mother?": ["Who is the mother of Grand Duke Alexei Alexandrovich Of Russia?", "What is the cause of death of #1?"]}.
20 | Q: Which film has the director died earlier, When The Mad Aunts Arrive or The Miracle Worker (1962 Film)?
21 | A: {"Which film has the director died earlier, When The Mad Aunts Arrive or The Miracle Worker (1962 Film)?": ["When did the director of film When The Mad Aunts Arrive die?", "When did the director of film The Miracle Worker (1962 Film) die?"], "When did the director of film When The Mad Aunts Arrive die?": ["Who is the director of film When The Mad Aunts Arrive?", "When did #1 die?"], "When did the director of film The Miracle Worker (1962 Film) die?": ["Who is the director of film The Miracle Worker (1962 Film)?", "When did #1 die?"]}.
22 | Q: Which album was released earlier, What'S Inside or Cassandra'S Dream (Album)?
23 | A: {"Which album was released earlier, What'S Inside or Cassandra'S Dream (Album)?": ["When was the album What'S Inside released?", "When was the album Cassandra'S Dream (Album) released?"]}.
24 | Q: Are both mountains, Serre Mourene and Monte Galbiga, located in the same country?
25 | A: {"Are both mountains, Serre Mourene and Monte Galbiga, located in the same country?": ["Which country was the mountain Serre Mourene located in?", "Which country was the mountain Monte Galbiga located in?"]}.
26 | Q: What is the date of birth of the director of film Best Friends (1982 Film)?
27 | A: {"What is the date of birth of the director of film Best Friends (1982 Film)?": ["Who is the director of film Best Friends (1982 Film)?", "What is the date of birth of #1?"]}.
28 | Q: Which film has the director born first, Two Weeks With Pay or Chhailla Babu?
29 | A: {"Which film has the director born first, Two Weeks With Pay or Chhailla Babu?": ["When was the director of film Two Weeks With Pay born?", "When was the director of film Chhailla Babu born?"], "When was the director of film Two Weeks With Pay born?": ["Who is the director of film Two Weeks With Pay?", "When was #1 born?"], "When was the director of film Chhailla Babu born?": ["Who is the director of film Chhailla Babu?", "When was #1 born?"]}.
30 | Q: Who is the grandchild of Krishna Shah (Nepalese Royal)?
31 | A: {"Who is the grandchild of Krishna Shah (Nepalese Royal)?": ["Who is the child of Krishna Shah (Nepalese Royal)?", "Who is the child of #1?"]}.
32 | Q: When was the director of film P.S. Jerusalem born?
33 | A: {"When was the director of film P.S. Jerusalem born?": ["Who is the director of film P.S. Jerusalem?", "When was #1 born?"]}.
34 | Q: Which album was released more recently, If I Have to Stand Alone or Answering Machine Music?
35 | A: {"Which album was released more recently, If I Have to Stand Alone or Answering Machine Music?": ["When was the album If I Have to Stand Alone released?", "When was the album Answering Machine Music released?"]}.
36 | Q: Where did the director of film Maddalena (1954 Film) die?
37 | A: {"Where did the director of film Maddalena (1954 Film) die?": ["Who is the director of film Maddalena (1954 Film)?", "Where did #1 die?"]}.
38 | Q: When did the director of film The Boy And The Fog die?
39 | A: {"When did the director of film The Boy And The Fog die?": ["Who is the director of film The Boy And The Fog?", "When did #1 die?"]}.
40 | Q: Are the directors of films The Sun of the Sleepless and Nevada (1927 film) both from the same country?
41 | A: {"Are the directors of films The Sun of the Sleepless and Nevada (1927 film) both from the same country?": ["Which country is the director of film The Sun of the Sleepless from?", "Which country is the director of film Nevada (1927 film) from?"], "Which country is the director of film The Sun of the Sleepless from?": ["Who is the director of film The Sun of the Sleepless?", "Which country is #1 from?"], "Which country is the director of film Nevada (1927 film) from?": ["Who is the director of film Nevada (1927 film)?", "Which country is #1 from?"]}.
--------------------------------------------------------------------------------
/src/hotpotqa/0_generate_tree.sh:
--------------------------------------------------------------------------------
1 | cd ./Tree_Generation
2 | python 0_get_prompt.py
3 | python 1_query.py
4 | python combine.py
5 | python 2_postprocess.py
6 | python 3_posprocess_tree.py
--------------------------------------------------------------------------------
/src/hotpotqa/1_conduct_reasoning.sh:
--------------------------------------------------------------------------------
1 | cd ./RoHT
2 | python 1_build_tree.py
3 | python 2_run.py
4 | python 3_get_f1.py
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/1_build_tree.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 |
4 | raw_data = [json.loads(line.strip()) for line in open('../../../released_data/hotpotqa__v2_test_random_500.jsonl')]
5 | q2sub_q = json.load(open("../Tree_Generation/tree.json"))
6 | q2dq = json.load(open("../Tree_Generation/question_decompositions.json"))
7 | trees = []
8 |
9 | def dfs(q, tree):
10 | sons = []
11 | for sub_q in q2sub_q.get(q, [[]])[0]:
12 | son_idx = dfs(sub_q, tree)
13 | sons.append(son_idx)
14 | idx = len(tree)
15 | tree.append({
16 | "idx": idx,
17 | "question_text": q,
18 | "sons": sons,
19 | "qd_logprob": q2sub_q.get(q, [[], None])[1]
20 | })
21 | for son_idx in sons:
22 | tree[son_idx]["fa"] = idx
23 | return idx
24 |
25 | for item in raw_data:
26 | question = item['question_text'].strip()
27 | question = list(q2dq[question].keys())[0]
28 | assert question in q2sub_q, question
29 | tree = []
30 | dfs(question, tree)
31 | trees.append(tree)
32 |
33 | json.dump(trees, open("trees.json", "w"), indent=2)
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/2_run.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import os
4 | from question_answering import *
5 | from tqdm import tqdm
6 | from parallel import parallel_process_data
7 |
8 |
9 | PROC_NUM = 50
10 | cnt = 0
11 |
12 | def solve(tree):
13 | global cnt
14 | cnt += 1
15 | print(cnt)
16 | #print(tree[-1])
17 | try:
18 | for node in tree:
19 | question = node["question_text"].strip()
20 | ref_tokens = re.findall(r"<\d+>", question)
21 | topic_entities = []
22 | for ref_token in ref_tokens:
23 | if "fa" in node and int(ref_token[1:-1]) <= len(tree[node["fa"]]["sons"]):
24 | ref_idx = tree[node["fa"]]["sons"][int(ref_token[1:-1])-1]
25 | if "answer" in tree[ref_idx]:
26 | question = question.replace(ref_token, tree[ref_idx]["answer"][0])
27 | topic_entities.append(tree[ref_idx]["answer"][0])
28 | node["question"] = question
29 | node["cb_answer"] = get_cb_answer(question)
30 | #print(node["cb_answer"])
31 | if len(node["sons"]) == 0:
32 | node["ob_answer"] = get_singlehop_ob_answer(question, topic_entities)
33 | #print(node["ob_answer"])
34 | node["answer"] = aggregate_singlehop_answer(node["cb_answer"], node["ob_answer"])
35 | else:
36 | node["ob_answer"] = get_multihop_ob_answer(node, tree)
37 | #print(node["ob_answer"])
38 | node["child_answer"], node["answer"] = aggregate_multihop_answer(node, tree)
39 | except Exception as e:
40 | print("ERROR CASE")
41 | print(tree[-1])
42 | raise e
43 |
44 |
45 | trees = json.load(open("trees.json", "r"))
46 | print("Total: %d | Start Processing..."%len(trees))
47 | parallel_process_data(trees, solve, PROC_NUM)
48 |
49 |
50 | print("END")
51 | os.makedirs("results", exist_ok=True)
52 | json.dump(trees, open("results/test.json", "w"), indent=2)
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/3_get_f1.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import json
4 | from tqdm import tqdm
5 | from termcolor import colored
6 | from evaluate import update_answer
7 | import math
8 |
9 |
10 | q2a = {}
11 | raw_data = [json.loads(line.strip()) for line in open('../../../released_data/hotpotqa__v2_test_random_500.jsonl')]
12 | q2dq = json.load(open("../Tree_Generation/question_decompositions.json"))
13 | q2gold = {}
14 | for item in raw_data:
15 | question = item['question_text'].strip()
16 | question = list(q2dq[question].keys())[0]
17 | gold = item['answers_objects'][0]['spans'][0]
18 | q_type = item["type"]
19 | q2gold[question] = (gold, q_type)
20 |
21 | trees = json.load(open("results/test.json", "r"))
22 | metrics = {}
23 | for q_type in ["all", "bridge", "comparison"]:
24 | metrics[q_type] = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 'N': 0}
25 |
26 | print(len(trees))
27 | for i, tree in enumerate(trees):
28 | node = tree[-1]
29 | question, answer = node["question"], node["answer"][0]
30 | gold, q_type = q2gold[question]
31 | q2a[question] = (i, answer, gold)
32 | em, f1, prec, recall = update_answer(metrics["all"], answer, gold)
33 | update_answer(metrics[q_type], answer, gold)
34 |
35 | for q_type in ["all", "bridge", "comparison"]:
36 | print(q_type)
37 | print(metrics[q_type]['N'])
38 |
39 | for k in metrics[q_type].keys():
40 | metrics[q_type][k] /= metrics[q_type]['N']
41 | print(metrics[q_type])
42 |
43 |
44 | json.dump(q2a, open("q2a.json", "w"), indent=2)
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/aggregate/prompt.txt:
--------------------------------------------------------------------------------
1 | Given a qeustion and a context, answer the question and explain why.
2 |
3 | #
4 | Context:
5 | Which famous fashion show Stella Maxwell has been a model for? Victoria's Secret.
6 | Since when Victoria's Secret? 1977.
7 |
8 | Question:
9 | Stella Maxwell has been a model for a famous fashion shown since when?
10 |
11 | Answer:
12 | Stella Maxwell has been a model for a famous fashion shown, Victoria's Secret since 2015. So the answer is: since 2015.
13 | #
14 | Context:
15 | Who is the American retired professional basketball player who is current president of basketball operations for the Los Angeles Lakers? Devean George.
16 | William Novac co-wrote the memoir of Devean George? no.
17 |
18 | Question:
19 | William Novac co-wrote the memoir of what American retired professional basketball player who is current president of basketball operations for the Los Angeles Lakers?
20 |
21 | Answer:
22 | William Novac co-wrote the memoir of Magic Johnson, an American retired professional basketball player who is current president of basketball operations for the Los Angeles Lakers. So the answer is: Magic Johnson.
23 | #
24 | Context:
25 | Which athlete rode 400 miles across his country to bring attention to the plight of the disabled in the country? Emmanuel Ofosu Yeboah.
26 | What is the title of the documentary narrated by Oprah Winfrey about Emmanuel Ofosu Yeboah? Emmanuel's Gift.
27 |
28 | Question:
29 | Oprah Winfrey narrated a documentary about this athlete who rode 400 miles across his country to bring attention to the plight of the disabled in the country?
30 |
31 | Answer:
32 | Oprah Winfrey narrated a documentary about the athelete Emmanuel Ofosu Yeboah, who rode 400 miles across his country to bring attention to the plight of the disabled in the country. So the answer is: Emmanuel Ofosu Yeboah.
33 | #
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/cb/prompt.txt:
--------------------------------------------------------------------------------
1 | Please answer the question by thinking step-by-step.
2 | Q: Jeremy Theobald and Christopher Nolan share what profession?
3 | A: Jeremy Theobald is an actor and producer. Christopher Nolan is a director, producer, and screenwriter. Therefore, they both share the profession of being a producer. So the answer is: producer.
4 | Q: How many episodes were in the South Korean television series in which Ryu Hye−young played Bo−ra?
5 | A: The South Korean television series in which Ryu Hye−young played Bo−ra is Reply 1988. The number of episodes Reply 1988 has is 20. So the answer is: 20.
6 | Q: Vertical Limit stars which actor who also played astronaut Alan Shepard in "The Right Stuff"?
7 | A: The movie Vertical Limit starred actors including Chiris O'Donnell, Robin Tunney, Scott Glenn, etc. The actor who played astronaut Alan Shepard in "The Right Stuff" is Scott Glenn. So the actor who stars in Vertical Limit and played astronaut Alan Shepard in "The Right Stuff" is Scott Glenn. So the answer is: Scott Glenn.
8 | Q: What was the 2014 population of the city where Lake Wales Medical Center is located?
9 | A: Lake Wales Medical Center is located in the city of Polk County, Florida. The population of Polk County in 2014 was 15,140. So the answer is: 15,140.
10 | Q: Who was born first? Jan de Bont or Raoul Walsh?
11 | A: Jan de Bont was born on 22 October 1943. Raoul Walsh was born on March 11, 1887. Thus, Raoul Walsh was born the first. So the answer is: Raoul Walsh.
12 | Q: In what country was Lost Gravity manufactured?
13 | A: The Lost Gravity (roller coaster) was manufactured by Mack Rides. Mack Rides is a German company. So the answer is: Germany.
14 | Q: Which of the following had a debut album entitled "We Have an Emergency": Hot Hot Heat or The Operation M.D.?
15 | A: The debut album of the band "Hot Hot Heat" was "Make Up the Breakdown". The debut album of the band "The Operation M.D." was "We Have an Emergency". So the answer is: The Operation M.D..
16 | Q: Was Lonny (magazine) was founded in 2009?
17 | A: Lonny (magazine) was founded in 2009. So the answer is: yes.
18 | Q: In which country did this Australian who was detained in Guantanamo Bay detention camp and published "Guantanamo: My Journey" receive para−military training?
19 | A: The Australian who was detained in Guantanamo Bay detention camp and published "Guantanamo: My Journey" is David Hicks. David Hicks received his para−military training in Afghanistan. So the answer is: Afghanistan.
20 | Q: Does The Border Surrender or Unsane have more members?
21 | A: The Border Surrender band has following members: Keith Austin, Simon Shields, Johnny Manning and Mark Austin. That is, it has 4 members. Unsane has following members: Chris Spencer, Cooper, and Jon Syverson. That is, it has 3 members. Thus, The Border Surrender has more members. So the answer is: The Border Surrender.
22 | Q: James Paris Lee is best known for investing the Lee−Metford rifle and another rifle often referred to by what acronymn?
23 | A: James Paris Lee is best known for investing the Lee−Metford rifle and Lee–Enfield series of rifles. Lee–Enfield is often referred to by the acronym of SMLE. So the answer is: SMLE.
24 | Q: Was Lonny (magazine) was founded in 2008?
25 | A: Lonny (magazine) was founded in 2009. So the answer is: no.
26 | Q: What year did Edburga of Minster−in−Thanet's father die?
27 | A: The father of Edburga of Minster−in−Thanet is King Centwine. Centwine died after 685. So the answer is: after 685.
28 | Q: Were Lonny and Allure both founded in the 1990s?
29 | A: Lonny (magazine) was founded in 2009. Allure (magazine) was founded in 1991. Thus, of the two, only Allure was founded in 1990s. So the answer is: no.
30 | Q: The actor that stars as Joe Proctor on the series "Power" also played a character on "Entourage" that has what last name?
31 | A: The actor that stars as Joe Proctor on the series "Power" is Jerry Ferrara. Jerry Ferrara also played a character on Entourage named Turtle Assante. Turtle Assante's last name is Assante. So the answer is: Assante.
32 | Q: When was Jan de Bont born?
33 | A: Jan de Bont was born on 22 October 1943. So the answer is: 22 October 1943.
34 | Q: Nobody Loves You was written by John Lennon and released on what album that was issued by Apple Records, and was written, recorded, and released during his 18 month separation from Yoko Ono?
35 | A: Nobody Loves You was written by John Lennon and released on the album Walls and Bridges. The album issued by Apple Records, and written, recorded, and released during John Lennon's 18 month separation from Yoko Ono is Walls and Bridges. So the answer is: Walls and Bridges.
36 | Q: How many awards did the "A Girl Like Me" singer win at the American Music Awards of 2012?
37 | A: The singer of "A Girl Like Me" is Rihanna. In the American Music Awards of 2012, Rihana won one award. So the answer is: one.
38 | Q: Are both Bruce Chatwin and O. Henry writers?
39 | A: Bruce Chatwin was an English travel writer, novelist, and journalist. O. Henry was an American writer. So both Bruce Chatwin and O. Henry are writers. So the answer is: yes.
40 | Q: Which city is Lake Wales Medical Center located?
41 | A: Lake Wales Medical Center is located in the city of Polk County, Florida. So the answer is: Polk County, Florida.
42 | Q: Dadi Denis studied at a Maryland college whose name was changed in 1890 to honor what man?
43 | A: Dadi Denis studied at the Maryland college Morgan State University. In 1890, the university's name was changed to honor Reverend Lyttleton Morgan. So the answer is: Reverend Lyttleton Morgan.
44 | Q: William Orman Beerman was born in a city in northeastern Kansas that is the county seat of what county?
45 | A: William Orman Beerman was born in Manhattan, Kansas. Manhattan, Kansas is the county seat of Riley County. So the answer is: Riley County.
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/count.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | trees = json.load(open("./results/test_k=5_singlehop_serpapi_multiobprompt_oner_best.json", "r"))
4 | cnt = defaultdict(int)
5 | total = 0
6 | for tree in trees:
7 | for node in tree:
8 | if "child_answer" in node:
9 | if node["answer"][1] == node["cb_answer"][1]:
10 | cnt["non_leaf_cb"] += 1
11 | elif node["answer"][1] == node["ob_answer"][1]:
12 | cnt["non_leaf_ob"] += 1
13 | else:
14 | cnt["non_leaf_ca"] += 1
15 | else:
16 | if node["answer"][1] == node["cb_answer"][1]:
17 | cnt["leaf_cb"] += 1
18 | else:
19 | cnt["leaf_ob"] += 1
20 | total += 1
21 |
22 | print(cnt)
23 | keys = ["leaf_ob", "leaf_cb"]
24 | print("leaf_cb: ", cnt["leaf_cb"], cnt["leaf_cb"] / (cnt["leaf_ob"] + cnt["leaf_cb"]))
25 | print("leaf_ob: ", cnt["leaf_ob"], cnt["leaf_ob"] / (cnt["leaf_ob"] + cnt["leaf_cb"]))
26 |
27 | print("non_leaf_cb:", cnt["non_leaf_cb"], cnt["non_leaf_cb"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
28 | print("non_leaf_ob:", cnt["non_leaf_ob"], cnt["non_leaf_ob"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
29 | print("non_leaf_ca:", cnt["non_leaf_ca"], cnt["non_leaf_ca"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
30 |
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import ujson as json
3 | import re
4 | import string
5 | from collections import Counter
6 | import pickle
7 |
8 | def normalize_answer(s):
9 |
10 | def remove_articles(text):
11 | return re.sub(r'\b(a|an|the)\b', ' ', text)
12 |
13 | def white_space_fix(text):
14 | return ' '.join(text.split())
15 |
16 | def remove_punc(text):
17 | exclude = set(string.punctuation)
18 | return ''.join(ch for ch in text if ch not in exclude)
19 |
20 | def lower(text):
21 | return text.lower()
22 |
23 | return white_space_fix(remove_articles(remove_punc(lower(s))))
24 |
25 |
26 | def f1_score(prediction, ground_truth):
27 | normalized_prediction = normalize_answer(prediction)
28 | normalized_ground_truth = normalize_answer(ground_truth)
29 |
30 | ZERO_METRIC = (0, 0, 0)
31 |
32 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
33 | return ZERO_METRIC
34 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
35 | return ZERO_METRIC
36 |
37 | prediction_tokens = normalized_prediction.split()
38 | ground_truth_tokens = normalized_ground_truth.split()
39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
40 | num_same = sum(common.values())
41 | if num_same == 0:
42 | return ZERO_METRIC
43 | precision = 1.0 * num_same / len(prediction_tokens)
44 | recall = 1.0 * num_same / len(ground_truth_tokens)
45 | f1 = (2 * precision * recall) / (precision + recall)
46 | return f1, precision, recall
47 |
48 |
49 | def exact_match_score(prediction, ground_truth):
50 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
51 |
52 | def update_answer(metrics, prediction, gold):
53 | em = exact_match_score(prediction, gold)
54 | f1, prec, recall = f1_score(prediction, gold)
55 | metrics['em'] += float(em)
56 | metrics['f1'] += f1
57 | metrics['prec'] += prec
58 | metrics['recall'] += recall
59 | metrics['N'] += 1
60 | return em, f1, prec, recall
61 |
62 | def eval():
63 | gold = [json.loads(line.strip()) for line in open('/data/csl/exp/LLMReasoning/released_data/hotpotqa__v2_test_random_500.jsonl')]
64 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 'N': 0}
65 | for dp in gold:
66 | print(dp)
67 | print(dp['answers_objects'][0])
68 | answer = dp['answers_objects'][0]['spans'][0]
69 |
70 | # cur_id = dp['_id']
71 | # if cur_id not in prediction['answer']:
72 | # print('missing answer {}'.format(cur_id))
73 | # else:
74 | em, f1, prec, recall = update_answer(
75 | metrics, answer, answer)
76 |
77 | N = len(gold)
78 | for k in metrics.keys():
79 | metrics[k] /= N
80 |
81 | print(metrics)
82 |
83 | if __name__ == '__main__':
84 | eval()
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/ob/multihop_prompt.txt:
--------------------------------------------------------------------------------
1 | Please answer the question and explain why. Output no more than 5 words after "So the answer is".
2 |
3 | #1 Wikipedia Title: First (magazine)
4 | Text: FiRST is a Singaporean movie magazine formerly published monthly, now running as a weekly newspaper insert.
5 | #2 Wikipedia Title: Arthur's Magazine
6 | Text: Arthur's Magazine (1844–1846) was an American literary periodical published in Philadelphia in the 19th century. Edited by T.S. Arthur, it featured work by Edgar A. Poe, J.H. Ingraham, Sarah Josepha Hale, Thomas G. Spear, and others. In May 1846 it was merged into "Godey's Lady's Book".
7 | #3 Wikipedia Title: First for Women
8 | Text: First for Women is a woman's magazine published by Bauer Media Group in the USA. The magazine was started in 1989. It is based in Englewood Cliffs, New Jersey. In 2011 the circulation of the magazine was 1,310,696 copies.
9 | #4 Wikipedia Title: First Eleven (magazine)
10 | Text: First Eleven is a British specialist magazine for parents of children at independent schools.
11 | #5 Wikipedia Title: Earth First! (magazine)
12 | Text: Earth First!, the radical environmental journal, is the official publication of the Earth First! movement. First published as a newsletter in 1980, it has existed alongside the movement as a way to spread commonly held beliefs in "Earth First!" culture, such as biocentrism, deep ecology, and direct action. The magazine is also commonly known as the "Earth First! Journal".
13 | Q: Which magazine was started first Arthur's Magazine or First for Women?
14 | A: Arthur's Magazine was started in 1844. First for Women was started in 1989. So Arthur's Magazine was started first. So the answer is: Arthur's Magazine.
15 |
16 | #1 Wikipedia Title: The Oberoi Group
17 | Text: The Oberoi Group is a hotel company with its head office in Delhi. Founded in 1934, the company owns and/or operates 30+ luxury hotels and two river cruise ships in six countries, primarily under its Oberoi Hotels & Resorts and Trident Hotels brands.
18 | #2 Wikipedia Title: The Body Has a Head
19 | Text: The Body Has a Head is an album by King Missile frontman John S. Hall, released exclusively in Germany in 1996. Though billed as a Hall "solo album," the collection features considerable input from multi-instrumentalists Sasha Forte, Bradford Reed, and Jane Scarpantoni, all of whom would become members of the next incarnation of King Missile ("King Missile III") and contribute to that group's "debut" album, 1998's "Failure."
20 | #3 Wikipedia Title: Oberoi family
21 | Text: The Oberoi family is an Indian family that is famous for its involvement in hotels, namely through The Oberoi Group.
22 | #4 Wikipedia Title: Has-a
23 | Text: In database design, object-oriented programming and design (see object oriented program architecture), has-a (has_a or has a) is a composition relationship where one object (often called the constituted object, or part/constituent/member object) "belongs to" (is part or member of) another object (called the composite type), and behaves according to the rules of ownership. In simple words, has-a relationship in an object is called a member field of an object. Multiple has-a relationships will combine to form a possessive hierarchy.
24 | #5 Wikipedia Title: Oberoi Realty
25 | Text: Oberoi Realty is a real estate developer based in Mumbai, Maharashtra. It is led by Mr. Vikas Oberoi, CMD. The company has developed over 39 projects at locations across Mumbai. Its main interest is in Residential, Office Space, Retail, Hospitality and Social Infrastructure properties in Mumbai.
26 | Q: The Oberoi family is part of a hotel company that has a head office in what city?
27 | A: The Oberoi family is part of a hotel company The Oberoi Group. The Oberoi Group has a head office in Delhi. So the answer is: Delhi.
28 |
29 | #1 Wikipedia Title: 2014 Liqui Moly Bathurst 12 Hour
30 | Text: The 2014 Liqui Moly Bathurst 12 Hour was an endurance race for a variety of GT and touring car classes, including: GT3 cars, GT4 cars and Group 3E Series Production Cars. The event, which was staged at the Mount Panorama Circuit, near Bathurst, in New South Wales, Australia on 9 February 2014, was the twelfth running of the Bathurst 12 Hour.
31 | #2 Wikipedia Title: 2015 Liqui Moly Bathurst 12 Hour
32 | Text: The 2015 Liqui Moly Bathurst 12 Hour was an endurance race for a variety of GT and touring car classes, including: GT3 cars, GT4 cars and Group 3E Series Production Cars. The event, which was staged at the Mount Panorama Circuit, near Bathurst, in New South Wales, Australia on 8 February 2015, was the thirteenth running of the Bathurst 12 Hour.
33 | #3 Wikipedia Title: 2013 Liqui Moly Bathurst 12 Hour
34 | Text: The 2013 Liqui Moly Bathurst 12 Hour was an endurance race for a variety of GT and touring car classes, including: GT3 cars, GT4 cars, Group 3E Series Production Cars and Dubai 24 Hour cars. The event, which was staged at the Mount Panorama Circuit, near Bathurst, in New South Wales, Australia on 10 February 2013, was the eleventh running of the Bathurst 12 Hour. The race also incorporated the opening round of the 2013 Australian GT Championship. The Australian GT Championship was to compete as the first hour only and cars were permitted to enter for only that hour or to cross-enter for both the first hour and continue for the endurance race.
35 | #4 Wikipedia Title: Mount Panorama Circuit
36 | Text: Mount Panorama Circuit is a motor racing track located in Bathurst, New South Wales, Australia. It is situated on a hill with the dual official names of Mount Panorama and Wahluu and is best known as the home of the Bathurst 1000 motor race held each October, and the Bathurst 12 Hour event held each February. The 6.213 km long track is technically a street circuit, and is a public road, with normal speed restrictions, when no racing events are being run, and there are many residences which can only be accessed from the circuit.
37 | #5 Wikipedia Title: List of Mount Panorama races
38 | Text: This is a list of significant car races that have been held at the Mount Panorama Circuit near Bathurst, New South Wales, Australia. As Australia's most famous motor racing circuit, Mount Panorama has had a significant influence on the history and industry of Australian motor racing.
39 | Q: What is the length of the track where the 2013 Liqui Moly Bathurst 12 Hour was staged?
40 | A: The 2013 Liqui Moly Bathurst 12 Hour was staged at the Mount Panorama Circuit. Mount Panorama Circuit is 6.213 km long. So the answer is: 6.213 km long.
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/ob/singlehop_prompt.txt:
--------------------------------------------------------------------------------
1 | Given a question and the relevant Wikipedia text, answer the question and explain why. If you are unsure, answer Unknown.
2 |
3 | #1 Wikipedia Title: 2014 Liqui Moly Bathurst 12 Hour
4 | Text: The 2014 Liqui Moly Bathurst 12 Hour was an endurance race for a variety of GT and touring car classes, including: GT3 cars, GT4 cars and Group 3E Series Production Cars. The event, which was staged at the Mount Panorama Circuit, near Bathurst, in New South Wales, Australia on 9 February 2014, was the twelfth running of the Bathurst 12 Hour.
5 | #2 Wikipedia Title: 2015 Liqui Moly Bathurst 12 Hour
6 | Text: The 2015 Liqui Moly Bathurst 12 Hour was an endurance race for a variety of GT and touring car classes, including: GT3 cars, GT4 cars and Group 3E Series Production Cars. The event, which was staged at the Mount Panorama Circuit, near Bathurst, in New South Wales, Australia on 8 February 2015, was the thirteenth running of the Bathurst 12 Hour.
7 | #3 Wikipedia Title: 2013 Liqui Moly Bathurst 12 Hour
8 | Text: The 2013 Liqui Moly Bathurst 12 Hour was an endurance race for a variety of GT and touring car classes, including: GT3 cars, GT4 cars, Group 3E Series Production Cars and Dubai 24 Hour cars. The event, which was staged at the Mount Panorama Circuit, near Bathurst, in New South Wales, Australia on 10 February 2013, was the eleventh running of the Bathurst 12 Hour. The race also incorporated the opening round of the 2013 Australian GT Championship. The Australian GT Championship was to compete as the first hour only and cars were permitted to enter for only that hour or to cross-enter for both the first hour and continue for the endurance race.
9 | Q: Which track was the 2013 Liqui Moly Bathurst 12 Hour was staged?
10 | A: The 2013 Liqui Moly Bathurst 12 Hour was staged at the Mount Panorama Circuit. So the answer is: Mount Panorama Circuit.
11 |
12 | #1 Wikipedia Title: So Long, See You Tomorrow (album)
13 | Text: So Long, See You Tomorrow is the fourth album by the London indie rock band Bombay Bicycle Club, released on 3 February 2014. The album is named after the novel of the same name by William Maxwell.
14 | #2 Wikipedia Title: Hallelujah I Love Her So
15 | Text: ``Hallelujah I Love Her So ''Single by Ray Charles from the album Ray Charles (or, Hallelujah I Love Her So) B - side`` What Would I Do Without You'' Released 1956 Format 7 ''45rpm Recorded 1956 Genre soul rhythm and blues Length 2: 35 Label Atlantic Songwriter (s) Ray Charles Producer (s) Jerry Wexler Ray Charles singles chronology ``A Fool for You'' (1955)`` Hallelujah I Love Her So ''(1956) ``Mary Ann'' (1956)`` A Fool for You ''(1955) ``Hallelujah I Love Her So'' (1956)`` Mary Ann ''(1956)
16 | #3 Wikipedia Title: The First Time Ever I Saw Your Face
17 | Text: ``The First Time Ever I Saw Your Face ''Single by Roberta Flack from the album First Take Released March 7, 1972 (1972 - 03 - 07) Recorded 1969 Genre Soul vocal jazz Length 5: 22 4: 15 (1972 radio edit) Label Atlantic 2864 Songwriter (s) Ewan MacColl Producer (s) Joel Dorn Roberta Flack singles chronology`` Will You Still Love Me Tomorrow'' (1972) ``The First Time Ever I Saw Your Face ''(1972)`` Where Is the Love'' (1972) ``Will You Still Love Me Tomorrow ''(1972)`` The First Time Ever I Saw Your Face'' (1972) ``Where Is the Love ''(1972)
18 | Q: Is the performer of So Long, See You Tomorrow Bombay Bicycle Club?
19 | A: The performer of So Long, See You Tomorrow is Bombay Bicycle Club. So the answer is: yes.
20 |
21 | #1 Wikipedia Title: Oberoi family
22 | Text: The Oberoi family is an Indian family that is famous for its involvement in hotels, namely through The Oberoi Group.
23 | #2 Wikipedia Title: The Oberoi Group
24 | Text: The Oberoi Group is a hotel company with its head office in Delhi. Founded in 1934, the company owns and/or operates 30+ luxury hotels and two river cruise ships in six countries, primarily under its Oberoi Hotels & Resorts and Trident Hotels brands.
25 | #3 Wikipedia Title: Mohan Singh Oberoi
26 | Text: Rai Bahadur Mohan Singh Oberoi (15 August 1898 – 3 May 2002) was an Indian hotelier, the founder and chairman of Oberoi Hotels & Resorts, India's second-largest hotel company, with 35 hotels in India, Sri Lanka, Nepal, Egypt, Australia and Hungary.
27 | Q: The Oberoi family is part of which hotel company?
28 | A: The Oberoi family is part of the hotel company The Oberoi Group. So the answer is: The Oberoi Group.
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/openai_req.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import requests
3 | import time
4 | import os
5 | import json, jsonlines
6 |
7 | class OpenaiReq():
8 | def __init__(self):
9 | self.url = "http://127.0.0.1:10001/api/openai/completion"
10 | self.cache = {}
11 | self.cache_path = "./cache.jsonl"
12 | if os.path.exists(self.cache_path):
13 | with open(self.cache_path, "r") as f:
14 | for i, line in enumerate(f):
15 | #print(i+1)
16 | datum = json.loads(line.strip())
17 | self.cache[tuple(datum["input"])] = datum["response"]
18 | f.close()
19 |
20 | def req2openai(self, prompt, model="text-davinci-003", temperature=0, max_tokens=128, stop=None, logprobs=1, use_cache=True):
21 | assert isinstance(prompt, str)
22 | input = (prompt, model, max_tokens, stop, logprobs)
23 | if use_cache and temperature == 0 and input in self.cache:
24 | return self.cache[input], True
25 | for i in range(3):
26 | try:
27 | response = requests.post(self.url, json = {
28 | "model": model,
29 | "prompt": prompt,
30 | "temperature": temperature,
31 | "max_tokens": max_tokens,
32 | "stop": stop,
33 | "logprobs": logprobs,
34 | })
35 | if response.status_code != 200:
36 | raise Exception(response.text)
37 | break
38 | except Exception as e:
39 | err_msg = str(e)
40 | print(e)
41 | if "reduce your prompt" in err_msg: # this is because the input string too long
42 | return ['too long'], False
43 | try:
44 | response = response.json()['choices']
45 | except:
46 | return ['openai error'], False
47 | if temperature == 0:
48 | input = (prompt, model, max_tokens, stop, logprobs)
49 | res = response[0]
50 | if input not in self.cache:
51 | self.cache[input] = [res]
52 | with open(self.cache_path, "a") as f:
53 | f.write("%s\n"%json.dumps({"input": input, "response": [res]}))
54 | f.close()
55 | return response, True
56 |
57 | if __name__ == "__main__":
58 | caller = OpenaiReq()
59 | res = caller.req2openai("你好", use_cache=True)
60 | print(res)
61 |
62 |
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/parallel.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures, random, time
2 |
3 | def handle_item(data):
4 | waiting = random.random() * 3 + 1
5 | print("Thread %d, Waiting %.2f ..."%(data, waiting))
6 | time.sleep(waiting)
7 | if random.random() < 0.5:
8 | raise Exception()
9 | print("Thread %d, OK."%(data))
10 |
11 | def parallel_process_data(data, handle_item, workers=20, callback=None):
12 | with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
13 | futures = []
14 | for item in data:
15 | future = executor.submit(handle_item, item)
16 | futures.append(future)
17 | for future in concurrent.futures.as_completed(futures):
18 | result = future.result()
19 | if callback:
20 | callback(result)
21 |
22 | if __name__ == "__main__":
23 | parallel_process_data([i for i in range(20)], handle_item)
24 | print("end")
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/question_answering.py:
--------------------------------------------------------------------------------
1 | from openai_req import OpenaiReq
2 | import requests
3 | from search.serpapi import get_question_wiki_snippet
4 | import os
5 | from transformers import AutoTokenizer
6 | import random
7 |
8 | serp_api_key = "" # put you serp API key here
9 | os.environ["SERP_API_KEY"] = serp_api_key
10 |
11 | openai_caller = OpenaiReq()
12 |
13 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
14 | random.seed(666)
15 |
16 | def bm25_search(question, k, use_serpapi=False):
17 | web = "http://localhost:1439"
18 | data = {
19 | "query": question,
20 | "k": k
21 | }
22 | for i in range(3):
23 | try:
24 | r = requests.get(web, json=data)
25 | if r.status_code != 200:
26 | raise Exception(r.text)
27 | contexts = r.json()
28 | if use_serpapi:
29 | context = get_question_wiki_snippet(question, cache=True)
30 | title = context.split(': ')[0]
31 | text = ' '.join(context.split(': ')[1:])
32 | contexts.append({"title": title, "text": text})
33 | return contexts
34 | except Exception as e:
35 | print(e)
36 |
37 | def postprocess(response):
38 | response = response[0]
39 | if response == 'too long' or response['finish_reason'] != 'stop':
40 | return 'ERROR: prompt too long', -100, ""
41 | tokens = response['logprobs']['tokens']
42 | token_logprobs = response['logprobs']['token_logprobs']
43 | cot = response['text'].strip()
44 | if len(token_logprobs) == 0:
45 | return 'ERROR: empty output', -100, cot
46 | # if "Unknown" in cot:
47 | # return "Unknow", sum(token_logprobs) / len(token_logprobs), cot
48 | pos = 0
49 | for idx, token in enumerate(tokens):
50 | if token.strip() == 'So' and idx + 1 <= len(tokens) and tokens[idx + 1].strip() == 'the' and idx + 2 <= len(tokens) and tokens[idx + 2].strip() == 'answer' and idx + 3 <= len(tokens) and tokens[idx + 3].strip() == 'is' and idx + 4 <= len(tokens) and tokens[idx + 4].strip() == ':':
51 | pos = idx
52 | break
53 | if tokens[-1] == '.':
54 | answer_logprobs = token_logprobs[pos+5:-1]
55 | answer = cot.split('So the answer is: ')[-1][:-1]
56 | else:
57 | answer_logprobs = token_logprobs[pos+5:]
58 | answer = cot.split('So the answer is: ')[-1]
59 | cot_process = cot.split('So the answer is: ')[0].strip()
60 | cot_process_logprobs = token_logprobs[:pos]
61 | if len(cot_process_logprobs) == 0:
62 | cot_process_logprob = -100
63 | else:
64 | cot_process_logprob = sum(cot_process_logprobs) / len(cot_process_logprobs)
65 | return answer, cot_process_logprob, cot
66 |
67 | def get_cb_answer(question):
68 | #return "Unknow", -100
69 | instruction = '\n'.join([_.strip() for _ in open('cb/prompt.txt').readlines()])
70 | prompt = instruction + '\nQ: ' + question + '\nA:'
71 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n', use_cache=True)
72 | return postprocess(response)
73 |
74 | def get_singlehop_ob_answer(question, topic_entities):
75 | #return "Unknow", -100
76 | instruction = '\n'.join([_.strip() for _ in open('ob/singlehop_prompt.txt').readlines()])
77 | for k in range(5, 0, -1):
78 | contexts = []
79 | hist = set()
80 | r = bm25_search(question, k, use_serpapi=True)
81 | for datum in r:
82 | title, text = datum["title"], datum["text"]
83 | stamp = title + text
84 | if not stamp in hist:
85 | hist.add(stamp)
86 | contexts.append([title, text])
87 | for e in topic_entities:
88 | r = bm25_search(e, k, use_serpapi=False)
89 | for datum in r:
90 | title, text = datum["title"], datum["text"]
91 | stamp = title + text
92 | if stamp not in hist:
93 | contexts.append([title, text])
94 | hist.add(stamp)
95 |
96 |
97 | prompt = instruction + '\n'
98 | for idx, (title, text) in enumerate(contexts):
99 | prompt += '\n#' + str(idx + 1) + ' Wikipedia Title: ' + title + '\nText: ' + text
100 | prompt += '\nQ: ' + question + '\nA:'
101 | if len(tokenizer(prompt).input_ids) + 256 <= 4097:
102 | break
103 |
104 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n\n', use_cache=True)
105 | return postprocess(response)
106 |
107 | def aggregate_singlehop_answer(cb_answer, ob_answer):
108 | cb_ans, cb_score, cb_cot = cb_answer
109 | ob_ans, ob_score, ob_cot = ob_answer
110 | if "ERROR" in cb_ans or 'Unknown' in cb_ans:
111 | cb_ans, cb_score = "", -100
112 | if "ERROR" in ob_ans or 'Unknown' in ob_ans:
113 | ob_ans, ob_score = "", -100
114 | return max([(cb_ans, cb_score, cb_cot), (ob_ans, ob_score, ob_cot)], key=lambda x:x[1])
115 | #return random.choice([(cb_ans, cb_score), (ob_ans, ob_score)])
116 |
117 | def get_multihop_ob_answer(node, tree):
118 | #return "Unknow", -100
119 | question = node["question"]
120 | instruction = '\n'.join([_.strip() for _ in open('ob/multihop_prompt.txt').readlines()])
121 | k = 5
122 | for sub_k in range(3, 0, -1):
123 | contexts = []
124 | hist = set()
125 | r = bm25_search(question, k, use_serpapi=False)
126 | for datum in r:
127 | title, text = datum["title"], datum["text"]
128 | stamp = title + text
129 | if stamp not in hist:
130 | hist.add(stamp)
131 | contexts.append([title, text])
132 |
133 | # for son_idx in node["sons"]:
134 | # sub_question = tree[son_idx]["question"]
135 | # r = bm25_search(sub_question, sub_k, use_serpapi=True)
136 | # for datum in r:
137 | # title, text = datum["title"], datum["text"]
138 | # stamp = title + text
139 | # if stamp not in hist:
140 | # hist.add(stamp)
141 | # contexts.append([title, text])
142 |
143 | # for son_idx in node["sons"][:-1]:
144 | # sub_answer = tree[son_idx]["answer"][0]
145 | # r = bm25_search(sub_answer, sub_k, use_serpapi=False)
146 | # for datum in r:
147 | # title, text = datum["title"], datum["text"]
148 | # stamp = title + text
149 | # if stamp not in hist:
150 | # hist.add(stamp)
151 | # contexts.append([title, text])
152 |
153 | prompt = instruction + '\n'
154 | for idx, (title, text) in enumerate(contexts):
155 | prompt += '\n#' + str(idx + 1) + ' Wikipedia Title: ' + title + '\nText: ' + text
156 | prompt += '\nQ: ' + question + '\nA:'
157 | if len(tokenizer(prompt).input_ids) + 256 <= 4097:
158 | break
159 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n\n', use_cache=True)
160 | return postprocess(response)
161 |
162 | def calculate_score1(cot_process_logprob, qd_score, sub_answer_scores):
163 | return cot_process_logprob + qd_score + sum(sub_answer_scores)
164 |
165 | def calculate_score2(cot_process_logprob, qd_score, sub_answer_scores):
166 | return (cot_process_logprob + qd_score + sum(sub_answer_scores)) / (len(sub_answer_scores) + 2)
167 |
168 | def calculate_score3(cot_process_logprob, qd_score, sub_answer_scores):
169 | return (cot_process_logprob + sum(sub_answer_scores)) / (len(sub_answer_scores) + 1)
170 |
171 | def aggregate_multihop_answer(node, tree):
172 | instruction = '\n'.join([_.strip() for _ in open('aggregate/prompt.txt').readlines()])
173 | question = node["question"]
174 | qd_score = node["qd_logprob"]
175 | context = ''
176 | sub_answer_scores = []
177 | for son_idx in node["sons"]:
178 | sub_question = tree[son_idx]["question"]
179 | sub_answer = tree[son_idx]["answer"][0]
180 | sub_answer_scores.append(tree[son_idx]["answer"][1])
181 | context += '\n' + sub_question + ' ' + sub_answer
182 | prompt = instruction + '\nContext:\n{}\n\nQuestion:\n{}\n\nAnswer:'.format(context, question)
183 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n\n', use_cache=True)
184 | child_answer, cot_process_logprob, child_cot = postprocess(response)
185 |
186 | child_ans = child_answer
187 | child_score = calculate_score2(cot_process_logprob, qd_score, sub_answer_scores)
188 | res1 = (child_ans, child_score, child_cot)
189 | cb_ans, cb_score, cb_cot = node["cb_answer"]
190 | ob_ans, ob_score, ob_cot = node["ob_answer"]
191 | if "ERROR" in cb_ans or 'Unknown' in cb_ans:
192 | cb_ans, cb_score = "", -100
193 | if "ERROR" in ob_ans or 'Unknown' in ob_ans:
194 | ob_ans, ob_score = "", -100
195 | if "ERROR" in child_ans or "Unknow" in child_ans:
196 | child_ans, child_score = "", -100
197 | res2 = max([(cb_ans, cb_score, cb_cot), (ob_ans, ob_score, ob_cot), (child_ans, child_score, child_cot)], key=lambda x:x[1])
198 | #res2 = random.choice([(cb_ans, cb_score), (ob_ans, ob_score), (child_ans, child_score)])
199 | return res1, res2
200 |
201 |
202 |
203 | if __name__ == "__main__":
204 | question = "毛泽东"
205 | snippet = get_question_wiki_snippet(question, cache=True)
206 | print(snippet)
207 |
208 |
209 |
210 |
211 |
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/search/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-KEG/ProbTree/ce17f5f239c47389ab53920bfef817f8a18e3841/src/hotpotqa/RoHT/search/__init__.py
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/search/serpapi.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import json
4 | from typing import Dict
5 |
6 |
7 | from IPython.utils import io
8 | from serpapi import GoogleSearch
9 |
10 | from search.wikipedia import get_wikipedia_text
11 |
12 |
13 | READ_CACHE = True
14 | CACHE_DIR = './serpapi_cache'
15 |
16 | def google(question):
17 | #print(f"Asking google: {question}")
18 |
19 | params = {
20 | "api_key": os.getenv("SERP_API_KEY"),
21 | "engine": "google",
22 | "q": question,
23 | "google_domain": "google.com",
24 | "gl": "us",
25 | "hl": "en",
26 | }
27 |
28 | with io.capture_output() as captured: # disables prints from GoogleSearch
29 | print("hi man what's up?")
30 | search = GoogleSearch(params)
31 | res = search.get_dict()
32 |
33 | answer = None
34 | snippet = None
35 | title = None
36 |
37 | if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
38 | answer = res["answer_box"]["answer"]
39 | if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
40 | snippet = res["answer_box"]["snippet"]
41 | title = res["answer_box"]["title"]
42 | # elif 'answer_box' in res.keys() and 'snippet_highlighted_words' in res['answer_box'].keys():
43 | # toret = res['answer_box']["snippet_highlighted_words"][0]
44 | elif (
45 | "answer_box" in res.keys()
46 | and "contents" in res["answer_box"].keys()
47 | and "table" in res["answer_box"]["contents"].keys()
48 | ):
49 | snippet = res["answer_box"]["contents"]["table"]
50 | title = res["answer_box"]["title"]
51 | elif "answer_box" in res.keys() and "list" in res["answer_box"].keys():
52 | snippet = res["answer_box"]["list"]
53 | title = res["answer_box"]["title"]
54 | elif "organic_results" in res and "snippet" in res["organic_results"][0].keys():
55 | snippet = res["organic_results"][0]["snippet"]
56 | title = res["organic_results"][0]["title"]
57 | elif (
58 | "organic_results" in res
59 | and "rich_snippet_table" in res["organic_results"][0].keys()
60 | ):
61 | snippet = res["organic_results"][0]["rich_snippet_table"]
62 | title = res["organic_results"][0]["title"]
63 | else:
64 | snippet = None
65 | if snippet is not None:
66 | title = title.replace("- Wikipedia", "").strip()
67 | toret = f"{title}: {snippet}"
68 | toret = f"{toret} So the answer is {answer}." if answer is not None else toret
69 | else:
70 | toret = ""
71 | return [toret, res]
72 |
73 |
74 | def get_sentences(text, max_num_sentences, reverse=None):
75 | if text == "":
76 | return text
77 | sentences = text.split(". ")
78 | ret_sentences = ""
79 | actual_num_sentences = min(max_num_sentences, len(sentences))
80 | if reverse:
81 | for i in reversed(range(actual_num_sentences)):
82 | ret_sentences += f"{sentences[i]}. "
83 | else:
84 | for i in range(actual_num_sentences):
85 | ret_sentences += f"{sentences[i]}. "
86 | return ret_sentences.strip()
87 |
88 |
89 | def get_first_sentences(text, max_num_sentences):
90 | return get_sentences(text, max_num_sentences)
91 |
92 |
93 | def get_last_sentences(text, max_num_sentences):
94 | return get_sentences(text, max_num_sentences, reverse=True)
95 |
96 |
97 | def get_snippet_wiki_paragraph(wikipage_title, snippet):
98 | full_wikipage_text = get_wikipedia_text(wikipage_title)
99 | print(f"Wikipedia title: {wikipage_title}")
100 | print(f"Google snippet: {snippet}")
101 | try:
102 | assert snippet in full_wikipage_text
103 | text_before_snippet, text_after_snippet = full_wikipage_text.split(snippet)
104 | prev_sentences = get_last_sentences(text_before_snippet, 5).strip()
105 | next_sentences = get_first_sentences(text_after_snippet, 5).strip()
106 | return f"{prev_sentences} {snippet} {next_sentences}"
107 | except AssertionError:
108 | print("* Unable to find snippet in Wikipedia text, return original snippet.")
109 | return snippet
110 |
111 |
112 | def get_question_wiki_snippet(question, cache=None):
113 | # google_wikipedia_query = f"site:en.wikipedia.org '{question}'"
114 | try:
115 | cached_query_results = read_google_res_from_cache(query=question)
116 | snippet = cached_query_results["snippet"]
117 | if "error" in cached_query_results["full_results"] and "Your searches for the month are exhausted" in cached_query_results["full_results"]["error"]:
118 | raise IOError()
119 | # print(f"Read from cache for query: {question}, cached snippet: {snippet}")
120 | except IOError:
121 | #return ""
122 | print(question)
123 | google_wikipedia_query = (
124 | f"en.wikipedia.org {question}" # same as Ori's query format
125 | )
126 | snippet, full_results = google(google_wikipedia_query)
127 | # print(f"full_results: {full_results}")
128 | # print(f"snippet: {snippet}")
129 | if cache:
130 | print(f"Caching snippet: {snippet}")
131 | cache_google_res(
132 | question, {"snippet": snippet, "full_results": full_results}
133 | )
134 | clean_snippet = snippet.replace("...", "").strip()
135 | return clean_snippet
136 |
137 |
138 | def get_question_google_snippet(question, cache=None):
139 | # google_wikipedia_query = f"site:en.wikipedia.org '{question}'"
140 | try:
141 | cached_query_results = read_google_res_from_cache(query=question)
142 | snippet = cached_query_results["snippet"]
143 | #print(f"Read from cache for query: {question}, cached snippet: {snippet}")
144 | except IOError:
145 | google_wikipedia_query = f"{question}" # same as Ori's query format
146 | snippet, full_results = google(google_wikipedia_query)
147 | # print(f"full_results: {full_results}")
148 | # print(f"snippet: {snippet}")
149 | if cache:
150 | #print(f"Caching snippet: {snippet}")
151 | cache_google_res(
152 | question, {"snippet": snippet, "full_results": full_results}
153 | )
154 | clean_snippet = snippet.replace("...", "").strip()
155 | return clean_snippet
156 |
157 |
158 | def get_string_hash(query: str) -> str:
159 | return hashlib.md5(query.encode()).hexdigest()
160 |
161 |
162 | def cache_google_res(query: str, res: Dict) -> None:
163 | """"""
164 | filename = get_string_hash(query)
165 | retriever_cache_dir = CACHE_DIR
166 | with open(f"{retriever_cache_dir}/{filename}.json", "w") as json_file:
167 | json.dump(res, json_file)
168 | # with open(f"strategy_qa/google_results_2/{filename}.json", "w") as json_file:
169 | # json.dump(res, json_file)
170 |
171 |
172 | def read_google_res_from_cache(query: str) -> Dict:
173 | filename = get_string_hash(query)
174 | retriever_cache_dir = CACHE_DIR
175 | with open(f"{retriever_cache_dir}/{filename}.json", "r") as f:
176 | data = json.load(f)
177 | # with open(f"strategy_qa/google_results_2/{filename}.json", "r") as f:
178 | # data = json.load(f)
179 | return data
180 |
--------------------------------------------------------------------------------
/src/hotpotqa/RoHT/search/wikipedia.py:
--------------------------------------------------------------------------------
1 | import wikipedia
2 | import re
3 |
4 |
5 | def get_wikipedia_text(page_name):
6 | def clean(text):
7 | text = re.sub(r"==.*?==+", "", text)
8 | text = text.replace("\n", "")
9 | return text
10 |
11 | def remove_parentheses_from_first_sent(text):
12 | """remove first sentence parentheses as the info contains redundant spelling info"""
13 |
14 | # Specify the title of the Wikipedia page
15 | wiki = wikipedia.page(page_name)
16 | # Extract the plain text content of the page
17 | page_text = wiki.content
18 | return clean(page_text)
19 |
--------------------------------------------------------------------------------
/src/hotpotqa/Tree_Generation/0_get_prompt.py:
--------------------------------------------------------------------------------
1 | import json, jsonlines
2 |
3 | instruction = '\n'.join([_.strip() for _ in open('prompt.txt').readlines()])
4 |
5 | raw_data = jsonlines.open("/data/zjj/LLMReasoning/released_data/hotpotqa__v2_test_random_500.jsonl", "r")
6 |
7 | prompts = []
8 | for item in raw_data:
9 | question = item["question_text"].strip()
10 | prompt = instruction + '\nQ: ' + question + '\nA:'
11 | prompts.append(prompt)
12 | # print(prompt)
13 | # break
14 |
15 | json.dump(prompts, open('prompts.json', 'w'), indent = 2)
16 | print(len(prompts))
--------------------------------------------------------------------------------
/src/hotpotqa/Tree_Generation/1_query.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from openai_req import OpenaiReq
4 | import random
5 | from tqdm import tqdm
6 | import os
7 | from multiprocessing import Pool
8 | from termcolor import colored
9 | random.seed(42)
10 |
11 | MAX_SPLIT = 64
12 | STEP = 4
13 |
14 | def query(rank, prompts):
15 | print('Process rank {} PID {} begin...'.format(rank, os.getpid()))
16 | reqor = OpenaiReq()
17 | queries = prompts[int(len(prompts) * rank / MAX_SPLIT) : int(len(prompts) * (rank + 1) / MAX_SPLIT)]
18 | try:
19 | fout = open('outputs/rank_{}.json'.format(rank), 'w')
20 | if rank == 0:
21 | bar = tqdm(range(len(queries) // STEP + 1))
22 | else:
23 | bar = range(len(queries) // STEP + 1)
24 | for idx in bar:
25 | inputs = queries[idx * STEP : (idx + 1) * STEP]
26 | if len(inputs) == 0:
27 | break
28 | gpt_results = []
29 | for prompt in inputs:
30 | result, tag = reqor.req2openai(prompt, max_tokens = 512, stop = '\n\n')
31 | gpt_results.append(result[0])
32 | for prompt, res in zip(inputs, gpt_results):
33 | # print(res)
34 | fout.write(json.dumps({'prompt': prompt, 'response': res}) + '\n')
35 | fout.flush()
36 | fout.close()
37 | except Exception as err:
38 | print(Exception, err)
39 |
40 | if __name__=='__main__':
41 | prompts = json.load(open('prompts.json'))
42 | os.makedirs("outputs", exist_ok=False)
43 | print("number of prompts: {}".format(len(prompts)))
44 | print('Parent process %s.' % os.getpid())
45 | p = Pool(MAX_SPLIT)
46 | for i in range(MAX_SPLIT):
47 | p.apply_async(query, args=(i, prompts))
48 | print('Waiting for all subprocesses done...')
49 | p.close()
50 | p.join()
51 | print('All subprocesses done.')
--------------------------------------------------------------------------------
/src/hotpotqa/Tree_Generation/2_postprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from termcolor import colored
4 | import os
5 |
6 | def findAllFile(base):
7 | for root, ds, fs in os.walk(base):
8 | for f in fs:
9 | yield f
10 | base = './outputs'
11 | data = []
12 | for file_name in findAllFile(base):
13 | data += [json.loads(line.strip()) for line in open(os.path.join(base, file_name))]
14 | # data.update(json.load(open(os.path.join(base, file_name))))
15 | print(len(data))
16 | json.dump(data, open(os.path.join(base, 'predictions.json'), 'w'), indent = 2, ensure_ascii=False)
17 |
18 | raw_data = json.load(open('outputs/predictions.json'))
19 |
20 | data = {}
21 | for item in tqdm(raw_data):
22 | prompt = item['prompt']
23 | question = prompt.split('\n')[-2][len('Q: '):].strip()
24 | print(colored(question, 'red'))
25 | # print(item['response']['text'])
26 | try:
27 | qds = item['response']['text'].strip()
28 | if qds.endswith('.'):
29 | qds = qds[:-1]
30 | # print(qds)
31 | # if question.startswith('Who is the actress who plays the role of the Queen of Eng'):
32 | # continue
33 | hqdt = json.loads(qds)
34 | except:
35 | hqdt = None
36 |
37 |
38 |
39 |
40 | tokens = item['response']['logprobs']['tokens']
41 | token_logprobs = item['response']['logprobs']['token_logprobs']
42 | if len(token_logprobs) == 0:
43 | continue
44 |
45 | if tokens[-1] == '.':
46 | token_logprobs = token_logprobs[:-1]
47 | # print(answer_logprobs)
48 | # else:
49 | # answer_logprobs = token_logprobs[pos+6:]
50 |
51 | # print(tokens[pos+6:-1])
52 |
53 | st, ed = 0, 0
54 | pos = 0
55 | qds = {}
56 | for sub_question, qd in hqdt.items():
57 | while pos < len(tokens):
58 | #print("".join(tokens[max(pos-1, 0): min(pos+2, len(tokens))]))
59 | if "[" in tokens[pos] and ": [\"" in "".join(tokens[max(pos-1, 0): min(pos+2, len(tokens))]):
60 | st = pos
61 | break
62 | pos += 1
63 | while pos < len(tokens):
64 | if "]" in tokens[pos] and "\"]" in "".join(tokens[max(pos-1, 0): min(pos+2, len(tokens))]):
65 | ed = pos
66 | break
67 | pos += 1
68 | assert pos < len(tokens), question + ' | ' + str(st) + " | " + str(ed)
69 | qd_score = sum(token_logprobs[st:ed+1]) / len(token_logprobs[st:ed+1])
70 | if any([x == sub_question for x in qd]):
71 | qd, qd_score = [], None
72 | qds[sub_question] = (qd, qd_score)
73 | print(colored(sub_question, 'blue'))
74 | print("".join(tokens[st:ed+1]))
75 |
76 |
77 | # answer_logprob = sum(token_logprobs) / len(token_logprobs)
78 | # data[question] = [hqdt, answer_logprob]
79 | data[question] = qds
80 | json.dump(data, open('question_decompositions.json', 'w'), indent = 2)
--------------------------------------------------------------------------------
/src/hotpotqa/Tree_Generation/3_postprocess_tree.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | raw_data = json.load(open('question_decompositions.json'))
4 |
5 | def check(question):
6 | if '<1>' in question or '<2>' in question or '<3>' in question or '<4>' in question:
7 | return True
8 | tree = {}
9 | for father in raw_data:
10 | if check(father):
11 | print(father)
12 | continue
13 | qds = raw_data[father]
14 | if qds is None:
15 | continue
16 | tree[father] = {}
17 | for question in qds:
18 | if check(question):
19 | continue
20 | if any([x == question for x in qds[question][0]]):
21 | tree[father][question] = [[], None]
22 | else:
23 | tree[father][question] = qds[question]
24 | # if len(qds[question]) > 3:
25 | # print(father)
26 | # print(qds[question])
27 | # print('haha')
28 |
29 | # json.dump(tree, open('valid_tree.json', 'w'), indent = 2)
30 | print(len(tree))
31 | question_decompositions = {}
32 | for father in tree:
33 | qds = tree[father]
34 | for q in qds:
35 | if q not in question_decompositions:
36 | question_decompositions[q] = qds[q]
37 | else:
38 | if question_decompositions[q] != qds[q]:
39 | print(question_decompositions[q])
40 | print(qds[q])
41 | else:
42 | print('haha')
43 |
44 | json.dump(question_decompositions, open('tree.json', 'w'), indent = 2)
45 |
46 | print(len(tree))
--------------------------------------------------------------------------------
/src/hotpotqa/Tree_Generation/combine.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | def findAllFile(base):
5 | for root, ds, fs in os.walk(base):
6 | for f in fs:
7 | yield f
8 | base = './outputs'
9 | data = []
10 | for file_name in findAllFile(base):
11 | data += [json.loads(line.strip()) for line in open(os.path.join(base, file_name))]
12 | # data.update(json.load(open(os.path.join(base, file_name))))
13 | print(len(data))
14 | json.dump(data, open(os.path.join(base, 'predictions.json'), 'w'), indent = 2)
--------------------------------------------------------------------------------
/src/hotpotqa/Tree_Generation/openai_req.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import requests
3 | import time
4 | import os
5 | import json, jsonlines
6 |
7 | class OpenaiReq():
8 | def __init__(self):
9 | self.url = "http://127.0.0.1:10001/api/openai/completion"
10 | self.cache = {}
11 | self.cache_path = "./cache.jsonl"
12 | if os.path.exists(self.cache_path):
13 | with open(self.cache_path, "r") as f:
14 | for i, line in enumerate(f):
15 | #print(i+1)
16 | datum = json.loads(line.strip())
17 | self.cache[tuple(datum["input"])] = datum["response"]
18 | f.close()
19 |
20 | def req2openai(self, prompt, model="text-davinci-003", temperature=0, max_tokens=128, stop=None, logprobs=1, use_cache=True):
21 | assert isinstance(prompt, str)
22 | input = (prompt, model, max_tokens, stop, logprobs)
23 | if use_cache and temperature == 0 and input in self.cache:
24 | return self.cache[input], True
25 | for i in range(3):
26 | try:
27 | response = requests.post(self.url, json = {
28 | "model": model,
29 | "prompt": prompt,
30 | "temperature": temperature,
31 | "max_tokens": max_tokens,
32 | "stop": stop,
33 | "logprobs": logprobs,
34 | })
35 | if response.status_code != 200:
36 | raise Exception(response.text)
37 | break
38 | except Exception as e:
39 | err_msg = str(e)
40 | print(e)
41 | if "reduce your prompt" in err_msg: # this is because the input string too long
42 | return ['too long'], False
43 | try:
44 | response = response.json()['choices']
45 | except:
46 | return ['openai error'], False
47 | if temperature == 0:
48 | input = (prompt, model, max_tokens, stop, logprobs)
49 | res = response[0]
50 | if input not in self.cache:
51 | self.cache[input] = [res]
52 | with open(self.cache_path, "a") as f:
53 | f.write("%s\n"%json.dumps({"input": input, "response": [res]}))
54 | f.close()
55 | return response, True
56 |
57 | if __name__ == "__main__":
58 | caller = OpenaiReq()
59 | res = caller.req2openai("你好", use_cache=True)
60 | print(res)
61 |
62 |
--------------------------------------------------------------------------------
/src/hotpotqa/Tree_Generation/prompt.txt:
--------------------------------------------------------------------------------
1 | Please generate a hierarchical question decomposition tree (HQDT) with json format for a given question. In this tree, the root node is the original complex question, and each non-root node is a sub-question of its parent. The leaf nodes are atomic questions that cannot be further decomposed.
2 | Q: Jeremy Theobald and Christopher Nolan share what profession?
3 | A: {"Jeremy Theobald and Christopher Nolan share what profession?": ["What is Jeremy Theobald's profession?", "What is Christopher Nolan's profession?"]}.
4 | Q: How many episodes were in the South Korean television series in which Ryu Hye−young played Bo−ra?
5 | A: {"How many episodes were in the South Korean television series in which Ryu Hye−young played Bo−ra?": ["In which South Korean television series Ryu Hye−young played Bo−ra?", "How many episodes were <1>?"]}.
6 | Q: Vertical Limit stars which actor who also played astronaut Alan Shepard in "The Right Stuff"?
7 | A: {"Vertical Limit stars which actor who also played astronaut Alan Shepard in \"The Right Stuff\"?": ["Vertical Limit stars which actor?", "Which actor played astronaut Alan Shepard in \"The Right Stuff\"?"]}.
8 | Q: What was the 2014 population of the city where Lake Wales Medical Center is located?
9 | A: {"What was the 2014 population of the city where Lake Wales Medical Center is located?": ["Which city was Lake Wales Medical Center located in?", "What was the 2014 population of <1>?"]}.
10 | Q: Who was born first? Jan de Bont or Raoul Walsh?
11 | A: {"Who was born first? Jan de Bont or Raoul Walsh?": ["When was Jan de Bont born?", "When was Raoul Walsh born?"]}.
12 | Q: In what country was Lost Gravity manufactured?
13 | A: {"In what country was Lost Gravity manufactured?": ["Which company was Lost Gravity manufactured?", "Which country is <1> in?"]}.
14 | Q: Which of the following had a debut album entitled "We Have an Emergency": Hot Hot Heat or The Operation M.D.?
15 | A: {"Which of the following had a debut album entitled \"We Have an Emergency\": Hot Hot Heat or The Operation M.D.?": ["What is the debut album of the band Hot Hot Heat?", "What is the debut album of the band The Operation M.D.?"]}.
16 | Q: In which country did this Australian who was detained in Guantanamo Bay detention camp and published "Guantanamo: My Journey" receive para−military training?
17 | A: {"In which country did this Australian who was detained in Guantanamo Bay detention camp and published \"Guantanamo: My Journey\" receive para−military training?": ["Which Australian was detained in Guantanamo Bay detention camp and published \"Guantanamo: My Journey\"?", "In which country did <1> receive para−military training?"]}.
18 | Q: Does The Border Surrender or Unsane have more members?
19 | A: {"Does The Border Surrender or Unsane have more members?": ["How many members does The Border Surrender have?", "How many members does Unsane have?"]}.
20 | Q: James Paris Lee is best known for investing the Lee−Metford rifle and another rifle often referred to by what acronymn?
21 | A: {"James Paris Lee is best known for investing the Lee−Metford rifle and another rifle often referred to by what acronymn?": ["James Paris Lee is best known for investing the Lee−Metford rifle and which other rifle?", "<1> is often referred to by what acronymn?"]}.
22 | Q: What year did Edburga of Minster−in−Thanet's father die?
23 | A: {"What year did Edburga of Minster−in−Thanet's father die?": ["Who is Edburga of Minster−in−Thanet's father?", "What year did <1> die?"]}.
24 | Q: Were Lonny and Allure both founded in the 1990s?
25 | A: {"Were Lonny and Allure both founded in the 1990s?": ["When was Lonny (magazine) founded?", "When was Allure founded?"]}.
26 | Q: The actor that stars as Joe Proctor on the series "Power" also played a character on "Entourage" that has what last name?
27 | A: {"The actor that stars as Joe Proctor on the series \"Power\" also played a character on \"Entourage\" that has what last name?": ["Which actor stars as Joe Proctor on the series \"Power\"?", "<1> played a character on \"Entourage\" that has what last name?"]}.
28 | Q: How many awards did the "A Girl Like Me" singer win at the American Music Awards of 2012?
29 | A: {"How many awards did the \"A Girl Like Me\" singer win at the American Music Awards of 2012?": ["Who is the singer of \"A Girl Like Me\"?", "How many awards did <1> win at the American Music Awards of 2012?"]}.
30 | Q: Dadi Denis studied at a Maryland college whose name was changed in 1890 to honor what man?
31 | A: {"Dadi Denis studied at a Maryland college whose name was changed in 1890 to honor what man?": ["Dadi Denis studied at which Maryland college?", "<1>'s name was changed in 1890 to honor what man?"]}.
32 | Q: William Orman Beerman was born in a city in northeastern Kansas that is the county seat of what county?
33 | A: {"William Orman Beerman was born in a city in northeastern Kansas that is the county seat of what county?": ["In which city in northeastern Kansas William Orman Beerman was born?", "<1> is the county seat of what county?"]}.
--------------------------------------------------------------------------------
/src/musique/0_generate_tree.sh:
--------------------------------------------------------------------------------
1 | cd ./Tree_Generation
2 | python 0_get_prompt.py
3 | python 1_query.py
4 | python combine.py
5 | python 2_postprocess.py
6 | python 3_posprocess_tree.py
--------------------------------------------------------------------------------
/src/musique/1_conduct_reasoning.sh:
--------------------------------------------------------------------------------
1 | cd ./RoHT
2 | python 1_build_tree.py
3 | python 2_run.py
4 | python 3_get_f1.py
--------------------------------------------------------------------------------
/src/musique/RoHT/1_build_tree.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 |
4 | raw_data = [json.loads(line.strip()) for line in open('../../../released_data/musique_ans__v2_test_random_500.jsonl')]
5 | q2sub_q = json.load(open("../Tree_Generation/tree.json"))
6 | q2dq = json.load(open("../Tree_Generation/question_decompositions.json"))
7 |
8 | trees = []
9 |
10 | def dfs(q, tree):
11 | sons = []
12 | print(q)
13 | for sub_q in q2sub_q.get(q, [[]])[0]:
14 | son_idx = dfs(sub_q, tree)
15 | sons.append(son_idx)
16 | idx = len(tree)
17 | tree.append({
18 | "idx": idx,
19 | "question_text": q,
20 | "sons": sons,
21 | "qd_logprob": q2sub_q.get(q, [[], None])[1]
22 | })
23 | for son_idx in sons:
24 | tree[son_idx]["fa"] = idx
25 | return idx
26 |
27 | for item in raw_data:
28 | question = item['question_text'].strip()
29 | question = list(q2dq[question].keys())[0]
30 | assert question in q2sub_q, question
31 | tree = []
32 | dfs(question, tree)
33 | trees.append(tree)
34 |
35 | json.dump(trees, open("trees.json", "w"), indent=2)
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/src/musique/RoHT/2_run.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import os
4 | from question_answering import *
5 | from tqdm import tqdm
6 | from parallel import parallel_process_data
7 |
8 | PROC_NUM = 50
9 | cnt = 0
10 |
11 | def solve(tree):
12 | global cnt
13 | cnt += 1
14 | print(cnt)
15 | #print(tree[-1])
16 | try:
17 | for node in tree:
18 | #print(node)
19 | question = node["question_text"].strip()
20 | ref_tokens = re.findall(r"#\d+", question)
21 | topic_entities = []
22 | for ref_token in ref_tokens:
23 | if "fa" in node and int(ref_token[1:]) <= len(tree[node["fa"]]["sons"]):
24 | ref_idx = tree[node["fa"]]["sons"][int(ref_token[1:])-1]
25 | if "answer" in tree[ref_idx]:
26 | question = question.replace(ref_token, tree[ref_idx]["answer"][0])
27 | topic_entities.append(tree[ref_idx]["answer"][0])
28 | node["question"] = question
29 | node["cb_answer"] = get_cb_answer(question)
30 | #print(node["cb_answer"])
31 | if len(node["sons"]) == 0:
32 | node["ob_answer"] = get_singlehop_ob_answer(question, topic_entities)
33 | #print(node["ob_answer"])
34 | node["answer"] = aggregate_singlehop_answer(node["cb_answer"], node["ob_answer"])
35 | else:
36 | node["ob_answer"] = get_multihop_ob_answer(node, tree)
37 | #print(node["ob_answer"])
38 | node["child_answer"], node["answer"] = aggregate_multihop_answer(node, tree)
39 | #print(node)
40 | except Exception as e:
41 | print("ERROR CASE")
42 | print(tree[-1])
43 | raise e
44 |
45 |
46 | trees = json.load(open("trees.json", "r"))
47 | print("Total: %d | Start Processing..."%len(trees))
48 | parallel_process_data(trees, solve, PROC_NUM)
49 |
50 |
51 | print("END")
52 | os.makedirs("results", exist_ok=True)
53 | json.dump(trees, open("results/test.json", "w"), indent=2)
--------------------------------------------------------------------------------
/src/musique/RoHT/3_get_f1.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import json
4 | from tqdm import tqdm
5 | from termcolor import colored
6 | from evaluate import update_answer
7 | import math
8 |
9 |
10 | q2a = {}
11 | raw_data = [json.loads(line.strip()) for line in open('../../../released_data/musique_ans__v2_test_random_500.jsonl')]
12 | q2dq = json.load(open("../Tree_Generation/question_decompositions.json"))
13 | q2gold = {}
14 | for item in raw_data:
15 | question = item['question_text'].strip()
16 | question = list(q2dq[question].keys())[0]
17 | gold = item['answers_objects'][0]['spans'][0]
18 | q_type = item["question_id"].split("hop")[0]+"hop"
19 | q2gold[question] = (gold, q_type)
20 |
21 | trees = json.load(open("./results/test.json", "r"))
22 | metrics = {}
23 | for q_type in ["all", "2hop", "3hop", "4hop"]:
24 | metrics[q_type] = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 'N': 0}
25 |
26 | print(len(trees))
27 | for i, tree in enumerate(trees):
28 | node = tree[-1]
29 | question, answer = node["question"], node["answer"][0]
30 | q2a[question] = answer
31 | gold, q_type = q2gold[question]
32 | em, f1, prec, recall = update_answer(metrics["all"], answer, gold)
33 | update_answer(metrics[q_type], answer, gold)
34 | if f1 == 0:
35 | print(colored(question, 'red'))
36 | print(colored(gold, 'blue'))
37 | print(answer)
38 |
39 | for q_type in ["all", "2hop", "3hop", "4hop"]:
40 | print(q_type)
41 | print(metrics[q_type]['N'])
42 |
43 | for k in metrics[q_type].keys():
44 | metrics[q_type][k] /= metrics[q_type]['N']
45 | print(metrics[q_type])
46 |
47 |
48 | json.dump(q2a, open("q2a.json", "w"), indent=2)
--------------------------------------------------------------------------------
/src/musique/RoHT/aggregate/prompt.txt:
--------------------------------------------------------------------------------
1 | Given a qeustion and a context, answer the question and explain why.
2 |
3 | #
4 | Context:
5 | Which famous fashion show Stella Maxwell has been a model for? Victoria's Secret.
6 | Since when Victoria's Secret? 1977.
7 |
8 | Question:
9 | Stella Maxwell has been a model for a famous fashion shown since when?
10 |
11 | Answer:
12 | Stella Maxwell has been a model for a famous fashion shown, Victoria's Secret since 2015. So the answer is: since 2015.
13 | #
14 | Context:
15 | Which athlete rode 400 miles across his country to bring attention to the plight of the disabled in the country? Emmanuel Ofosu Yeboah.
16 | What is the title of the documentary narrated by Oprah Winfrey about Emmanuel Ofosu Yeboah? Emmanuel's Gift.
17 |
18 | Question:
19 | Oprah Winfrey narrated a documentary about this athlete who rode 400 miles across his country to bring attention to the plight of the disabled in the country?
20 |
21 | Answer:
22 | Oprah Winfrey narrated a documentary about the athelete Emmanuel Ofosu Yeboah, who rode 400 miles across his country to bring attention to the plight of the disabled in the country. So the answer is: Emmanuel Ofosu Yeboah.
23 | #
24 | Context:
25 | Where is Phu Luong located? Vietnam.
26 | Which country is Vietnam in? Southeast Asia.
27 |
28 | Question:
29 | Which country contains Phu Luong?
30 |
31 | Answer:
32 | Phu Luong is located in the country Vietnam. So the answer is: Vietnam.
33 | #
--------------------------------------------------------------------------------
/src/musique/RoHT/cb/prompt.txt:
--------------------------------------------------------------------------------
1 | Please answer the question by thinking step-by-step.
2 | Q: When did the first large winter carnival take place in the city where CIMI−FM is licensed to broadcast?
3 | A: CIMI−FM is licensed to broadcast in Quebec City. The first large winter carnival in Quebec City took place in 1894. So the answer is: 1894.
4 | Q: When was Neville A. Stanton's employer founded?
5 | A: The employer of Neville A. Stanton is University of Southampton. The University of Southampton was founded in 1862. So the answer is: 1862.
6 | Q: What religion did the black community found?
7 | A: The black community found African Methodist Episcopal Church. So the answer is: African Methodist Episcopal Church.
8 | Q: What county is Hebron located in, in the same province the Heritage Places Protection Act applies to?
9 | A: Heritage Places Protection Act applies to the jurisdiction of Prince Edward Island. Hebron, Prince Edward Island is located in the Prince County. So the answer is: Prince County.
10 | Q: What weekly publication in the Connecticut city with the most Zagat rated restaurants is issued by university of America−Lite: How Imperial Academia Dismantled Our Culture's author?
11 | A: The author of America−Lite: How Imperial Academia Dismantled Our Culture is David Gelernter. David Gelernter was educated at the Yale University. The city in Connecticut that has the highest number of Zagat−rated restaurants is New Haven. The weekly publication in New Haven that is issued by Yale University is Yale Herald. So the answer is: Yale Herald.
12 | Q: What is the headquarters for the organization who sets the standards for ISO 21500?
13 | A: The standards for ISO 21500 were set by International Organization for Standardization. The International Organization for Standardization has headquarters in Geneva. So the answer is: Geneva.
14 | Q: What did the publisher of Banjo−Tooie rely primarily on for its support?
15 | A: The publisher of Banjo−Tooie is Nintendo. Nintendo relied primarily for its support on first−party games. So the answer is: first−party games.
16 | Q: The Collegian was owned by?
17 | A: The Collegian was owned by Houston Baptist University. So the answer is: Houston Baptist University.
18 | Q: In which county was the birthplace of the Smoke in tha City performer?
19 | A: The performer of Smoke in tha City is MC Eiht. MC Eiht's birthplace is Compton. Compton is located in the county of Los Angeles County. So the answer is: Los Angeles County.
20 | Q: What region of the state where Guy Shepherdson was born, contains SMA Negeri 68?
21 | A: Guy Shepherdson was born in Jakarta. SMA Negeri 68 Jakarta is located in Central Jakarta. So the answer is: Central Jakarta.
22 | Q: When did Britain withdraw from the country containing Hoora?
23 | A: Hoora is in the country of Bahrain. Britain withdrew from Bahrain in 1971. So the answer is: 1971.
24 | Q: Where does the Snake River start, in the state where Lima Mountain is located?
25 | A: Lima Mountain is located in the state of Minnesota. The snake river in Minnesota starts in southern Aitkin County. So the answer is: southern Aitkin County.
26 | Q: What shares a border with Riviôlre−Verte in the province WRSU−FM broadcasts in?
27 | A: WRSU−FM was licensed to broadcast to New Brunswick. Riviôlre−Verte, New Brunswick shares border with Edmundston. So the answer is: Edmundston.
28 | Q: When was the state of emergency declared in the country where the Senate is located?
29 | A: The Senate is in the country of Kenya. The state of emergency was declared in Kenya on 20 October 1952. So the answer is: 20 October 1952.
30 | Q: How long is the US border with the country that borders the state where Finding Dory takes place?
31 | A: Finding Dory is supposed to take place in California. The country that shares a border with California is Mexico. The length of the us border with Mexico is 1,989 mi. So the answer is: 1,989 mi.
32 | Q: What genre is the record label of the performer of So Long, See You Tomorrow associated with?
33 | A: The performer of So Long, See You Tomorrow is Bombay Bicycle Club. The record label of Bombay Bicycle Club is Island Records. The genre of Island Records is jazz. So the answer is: jazz.
34 | Q: When did the first large winter carnival happen in Olivier Robitaille's place of birth?
35 | A: Olivier Robitaille was born in Quebec City. The first large winter carnival in Quebec City happened in the 1894. So the answer is: 1894.
36 | Q: What is the genre of the record label of the band that performed on the Crush Tour?
37 | A: The Crush Tour is performed by the band Bon Jovi. The record label of Bon Jovi is Island Records. The genre of Island Records is jazz. So the answer is: jazz.
38 | Q: When was the first railway line constructed between Kotri and the city where Marie Adelaide Leprosy Centre is located?
39 | A: Marie Adelaide Leprosy Centre is located in Karachi. The first railway line between Kotri and Karachi was constructed in April 1858. So the answer is: April 1858.
40 | Q: In which state is Hertfordshire located?
41 | A: Hertfordshire is located in the state East of England. So the answer is: East of England.
42 | Q: Where is the crying stone found in the country in which Raphael Tuju holds citizenship?
43 | A: Raphael Tuju is a citizen of Kenya. The crying stone in Kenya is found along the highway towards Kisumu. So the answer is: along the highway towards Kisumu.
44 | Q: When did Britain withdraw from the country where the village of Wadyan is found?
45 | A: Wadyan is in the country of Bahrain. Britain withdraw from Bahrain in 1971. So the answer is: 1971.
46 | Q: How many countries in Pacific National University's continent are recognized by the organization that mediated the truce ending the Iran−Iraq war?
47 | A: Pacific National University is located in Khabarovsk, Russia Khabarovsk. Russian is in the continent of Asia. The organization that mediated the truce which ended the Iran−Iraq War is the UN. The number of member states that UN recognises in Asia is 53. So the answer is: 53.
--------------------------------------------------------------------------------
/src/musique/RoHT/count.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | trees = json.load(open("./results/test.json", "r"))
4 | cnt = defaultdict(int)
5 | total = 0
6 | for tree in trees:
7 | for node in tree:
8 | if "child_answer" in node:
9 | if node["answer"][1] == node["cb_answer"][1]:
10 | cnt["non_leaf_cb"] += 1
11 | elif node["answer"][1] == node["ob_answer"][1]:
12 | cnt["non_leaf_ob"] += 1
13 | else:
14 | cnt["non_leaf_ca"] += 1
15 | else:
16 | if node["answer"][1] == node["cb_answer"][1]:
17 | cnt["leaf_cb"] += 1
18 | else:
19 | cnt["leaf_ob"] += 1
20 | total += 1
21 |
22 | print(cnt)
23 | keys = ["leaf_ob", "leaf_cb"]
24 | print("leaf_cb: ", cnt["leaf_cb"], cnt["leaf_cb"] / (cnt["leaf_ob"] + cnt["leaf_cb"]))
25 | print("leaf_ob: ", cnt["leaf_ob"], cnt["leaf_ob"] / (cnt["leaf_ob"] + cnt["leaf_cb"]))
26 |
27 | print("non_leaf_cb:", cnt["non_leaf_cb"], cnt["non_leaf_cb"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
28 | print("non_leaf_ob:", cnt["non_leaf_ob"], cnt["non_leaf_ob"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
29 | print("non_leaf_ca:", cnt["non_leaf_ca"], cnt["non_leaf_ca"] / (cnt["non_leaf_ob"] + cnt["non_leaf_cb"] + cnt["non_leaf_ca"]))
30 |
--------------------------------------------------------------------------------
/src/musique/RoHT/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import ujson as json
3 | import re
4 | import string
5 | from collections import Counter
6 | import pickle
7 |
8 | def normalize_answer(s):
9 |
10 | def remove_articles(text):
11 | return re.sub(r'\b(a|an|the)\b', ' ', text)
12 |
13 | def white_space_fix(text):
14 | return ' '.join(text.split())
15 |
16 | def remove_punc(text):
17 | exclude = set(string.punctuation)
18 | return ''.join(ch for ch in text if ch not in exclude)
19 |
20 | def lower(text):
21 | return text.lower()
22 |
23 | return white_space_fix(remove_articles(remove_punc(lower(s))))
24 |
25 |
26 | def f1_score(prediction, ground_truth):
27 | normalized_prediction = normalize_answer(prediction)
28 | normalized_ground_truth = normalize_answer(ground_truth)
29 |
30 | ZERO_METRIC = (0, 0, 0)
31 |
32 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
33 | return ZERO_METRIC
34 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
35 | return ZERO_METRIC
36 |
37 | prediction_tokens = normalized_prediction.split()
38 | ground_truth_tokens = normalized_ground_truth.split()
39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
40 | num_same = sum(common.values())
41 | if num_same == 0:
42 | return ZERO_METRIC
43 | precision = 1.0 * num_same / len(prediction_tokens)
44 | recall = 1.0 * num_same / len(ground_truth_tokens)
45 | f1 = (2 * precision * recall) / (precision + recall)
46 | return f1, precision, recall
47 |
48 |
49 | def exact_match_score(prediction, ground_truth):
50 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
51 |
52 | def update_answer(metrics, prediction, gold):
53 | em = exact_match_score(prediction, gold)
54 | f1, prec, recall = f1_score(prediction, gold)
55 | metrics['em'] += float(em)
56 | metrics['f1'] += f1
57 | metrics['prec'] += prec
58 | metrics['recall'] += recall
59 | metrics['N'] += 1
60 | return em, f1, prec, recall
61 |
62 | def eval():
63 | gold = [json.loads(line.strip()) for line in open('/data/csl/exp/LLMReasoning/released_data/hotpotqa__v2_test_random_500.jsonl')]
64 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 'N': 0}
65 | for dp in gold:
66 | print(dp)
67 | print(dp['answers_objects'][0])
68 | answer = dp['answers_objects'][0]['spans'][0]
69 |
70 | # cur_id = dp['_id']
71 | # if cur_id not in prediction['answer']:
72 | # print('missing answer {}'.format(cur_id))
73 | # else:
74 | em, f1, prec, recall = update_answer(
75 | metrics, answer, answer)
76 |
77 | N = len(gold)
78 | for k in metrics.keys():
79 | metrics[k] /= N
80 |
81 | print(metrics)
82 |
83 | if __name__ == '__main__':
84 | eval()
--------------------------------------------------------------------------------
/src/musique/RoHT/ob/get_para.py:
--------------------------------------------------------------------------------
1 | import json, jsonlines
2 | from itertools import chain
3 | from tqdm import tqdm
4 |
5 | train = jsonlines.open("/data/zjj/LLMReasoning/data/musique/musique_ans_v1.0_train.jsonl", "r")
6 | dev = jsonlines.open("/data/zjj/LLMReasoning/data/musique/musique_ans_v1.0_dev.jsonl", "r")
7 |
8 | question = "What genre is the record label of the performer of So Long, See You Tomorrow associated with?"
9 |
10 | for item in tqdm(chain(train, dev)):
11 | if item["question"] != question: continue
12 | pos_para, neg_para = [], []
13 | for para in item['paragraphs']:
14 | if para["is_supporting"]:
15 | pos_para.append([para["title"], para["paragraph_text"]])
16 | else:
17 | neg_para.append([para["title"], para["paragraph_text"]])
18 | break
19 |
20 | print("pos_para:")
21 | for title, para in pos_para:
22 | print(title)
23 | print(para)
24 | print('\n')
25 | exit()
26 | print("neg_para:")
27 | for title, para in neg_para:
28 | print(title)
29 | print(para)
30 | print('\n')
--------------------------------------------------------------------------------
/src/musique/RoHT/ob/multihop_prompt.txt:
--------------------------------------------------------------------------------
1 | Given a question and the relevant Wikipedia text, answer the question and explain why. If you are unsure, answer Unknown.
2 |
3 | #1 Wikipedia Title: Wadyan
4 | Text: Wadyan (Arabic: واديان) is a village in the island of Sitra, Bahrain. A branch of the National Bank of Bahrain and the Sitra police station are located in Wadyan.
5 | #2 Wikipedia Title: Bahrain
6 | Text: According to a January 2006 report by the United Nations Economic and Social Commission for Western Asia, Bahrain has the fastest-growing economy in the Arab world. Bahrain also has the freest economy in the Middle East and is twelfth-freest overall in the world based on the 2011 Index of Economic Freedom published by the Heritage Foundation/"Wall Street Journal".
7 | #3 Wikipedia Title: The Benefit Company
8 | Text: The Benefit Company (TBC) is the local switch in the Kingdom of Bahrain handling ATM and POS transactions among other services. Established in 1997 with a special license from the Central Bank of Bahrain as "Provider of Ancillary Services to the Financial Sector", it is the only financial network of its kind in the country.
9 | #4 Wikipedia Title: British Empire
10 | Text: While the Suez Crisis caused British power in the Middle East to weaken, it did not collapse. Britain again deployed its armed forces to the region, intervening in Oman (1957), Jordan (1958) and Kuwait (1961), though on these occasions with American approval, as the new Prime Minister Harold Macmillan's foreign policy was to remain firmly aligned with the United States. Britain maintained a military presence in the Middle East for another decade. In January 1968, a few weeks after the devaluation of the pound, Prime Minister Harold Wilson and his Defence Secretary Denis Healey announced that British troops would be withdrawn from major military bases East of Suez, which included the ones in the Middle East, and primarily from Malaysia and Singapore. The British withdrew from Aden in 1967, Bahrain in 1971, and Maldives in 1976.
11 | #5 Wikipedia Title: Gulf Air
12 | Text: Gulf Air ( "Ṭayarān al-Khalīj") is the flag carrier of Bahrain. Headquartered in Muharraq, adjacent to Bahrain International Airport, the airline operates scheduled services to 50 destinations in 28 countries across Africa, Asia and Europe. Its main base is Bahrain International Airport. It was formerly a multinational airline owned by Bahrain, UAE, Oman, and Qatar.
13 | Q: When did Britain withdraw from the country where the village of Wadyan is found?
14 | A: Wadyan is in the country of Bahrain. Britain withdraw from Bahrain in 1971. So the answer is: 1971.
15 |
16 | #1 Wikipedia Title: So Long, See You Tomorrow (album)
17 | Text: So Long, See You Tomorrow is the fourth album by the London indie rock band Bombay Bicycle Club, released on 3 February 2014. The album is named after the novel of the same name by William Maxwell.
18 | #2 Wikipedia Title: Hallelujah I Love Her So
19 | Text: ``Hallelujah I Love Her So ''Single by Ray Charles from the album Ray Charles (or, Hallelujah I Love Her So) B - side`` What Would I Do Without You'' Released 1956 Format 7 ''45rpm Recorded 1956 Genre soul rhythm and blues Length 2: 35 Label Atlantic Songwriter (s) Ray Charles Producer (s) Jerry Wexler Ray Charles singles chronology ``A Fool for You'' (1955)`` Hallelujah I Love Her So ''(1956) ``Mary Ann'' (1956)`` A Fool for You ''(1955) ``Hallelujah I Love Her So'' (1956)`` Mary Ann ''(1956)
20 | #3 Wikipedia Title: See You on the Other Side (Mercury Rev album)
21 | Text: See You on the Other Side is the third studio album by American neo-psychedelia band Mercury Rev, released in 1995 by record label Beggars Banquet.
22 | #4 Wikipedia Title: Flaws (album)
23 | Text: Flaws is the second studio album by the British indie rock band Bombay Bicycle Club, released on 9 July 2010 by Island Records. Unlike the band's previous releases, the album is entirely acoustic music, consisting of versions of their own tracks as well as cover versions of other artists. The album was produced in part by the guitarist Jamie MacColl's father, Neil MacColl, with recording taking place in February 2009 at The Church in Crouch End, London. The band started work on the album after completing their first album, "I Had the Blues But I Shook Them Loose".
24 | #5 Wikipedia Title: The Antidote (Ronny Jordan album)
25 | Text: The Antidote is the debut album by English jazz guitarist Ronny Jordan, that was released by Island Records in 1992.
26 | Q: What genre is the record label of the performer of So Long, See You Tomorrow associated with?
27 | A: The performer of So Long, See You Tomorrow is Bombay Bicycle Club. The record label of Bombay Bicycle Club is Island Records. Island Records released album The Antidote of English jazz guitarist Ronny Jordan. Thus Island Records is associated with jazz. So the answer is: jazz.
28 |
29 | #1 Wikipedia Title: Dance in the Country
30 | Text: Dance in the Country (French: "Danse à la campagne") is an 1883 oil painting by French artist Pierre-Auguste Renoir. It is currently kept at the Musée d'Orsay in Paris.
31 | #2 Wikipedia Title: Josip Broz Tito
32 | Text: In 1968, Tito offered Czechoslovak leader Alexander Dubček to fly to Prague on three hours notice if Dubček needed help in facing down the Soviets. In April 1969, Tito removed generals Ivan Gošnjak and Rade Hamović in the aftermath of the invasion of Czechoslovakia due to the unpreparedness of the Yugoslav army to respond to a similar invasion of Yugoslavia.
33 | #3 Wikipedia Title: 1939 German ultimatum to Lithuania
34 | Text: The 1939 German ultimatum to Lithuania was an oral ultimatum which Joachim von Ribbentrop, Foreign Minister of Nazi Germany, presented to Juozas Urbšys, Foreign Minister of Lithuania on 20 March 1939. The Germans demanded that Lithuania give up the Klaipėda Region (also known as the Memel Territory) which had been detached from Germany after World War I, or the Wehrmacht would invade Lithuania. The Lithuanians had been expecting the demand after years of rising tension between Lithuania and Germany, increasing pro-Nazi propaganda in the region, and continued German expansion. It was issued just five days after the Nazi occupation of Czechoslovakia.
35 | #4 Wikipedia Title: Slavs
36 | Text: The word "Slavs" was used in the national anthem of the Slovak Republic (1939–1945), Yugoslavia (1943–1992) and the Federal Republic of Yugoslavia (1992–2003), later Serbia and Montenegro (2003–2006).
37 | #5 Wikipedia Title: United States Army
38 | Text: Currently, the army is divided into the Regular Army, the Army Reserve, and the Army National Guard. The army is also divided into major branches such as Air Defense Artillery, Infantry, Aviation, Signal Corps, Corps of Engineers, and Armor. Before 1903 members of the National Guard were considered state soldiers unless federalized (i.e., activated) by the President. Since the Militia Act of 1903 all National Guard soldiers have held dual status: as National Guardsmen under the authority of the governor of their state or territory and, when activated, as a reserve of the U.S. Army under the authority of the President.
39 | Q: A country's military branch, the equivalent of which in the US contains the Air Defense Artillery, was unprepared for the invasion of the country occupied by the Nazi's. When was the word \"Slavs\" used in the national anthem of the unprepared country?
40 | A: The Air Defense Artillery is a branch of the amry in the US. Nazi occupied Czechoslovakia in 1939. The army of Yugoslavia was unprepared for the invasion of Czechoslovakia. The word "Slavs" was used in the national anthem of Yugoslavia from 1943 to 1992. So the answer is: 1943–1992.
--------------------------------------------------------------------------------
/src/musique/RoHT/ob/singlehop_prompt.txt:
--------------------------------------------------------------------------------
1 | Given a question and the relevant Wikipedia text, answer the question and explain why. If you are unsure, answer Unknown.
2 |
3 | #1 Wikipedia Title: Wadyan
4 | Text: Wadyan (Arabic: واديان) is a village in the island of Sitra, Bahrain. A branch of the National Bank of Bahrain and the Sitra police station are located in Wadyan.
5 | #2 Wikipedia Title: Child labour
6 | Text: From European settlement in 1888, child convicts were occasionally sent to Australia where they were made to work. Child labour was not as excessive in Australia as in Britain. With a low population, agricultural productivity was higher and families did not face starvation as in established industrialised countries. Australia also did not have significant industry until the later part of the 20th century when child labour laws, and compulsory schooling had developed under the influence of Britain. From the 1870s Child labour was restricted by compulsorry schooling.
7 | #3 Wikipedia Title: British Empire
8 | Text: While the Suez Crisis caused British power in the Middle East to weaken, it did not collapse. Britain again deployed its armed forces to the region, intervening in Oman (1957), Jordan (1958) and Kuwait (1961), though on these occasions with American approval, as the new Prime Minister Harold Macmillan's foreign policy was to remain firmly aligned with the United States. Britain maintained a military presence in the Middle East for another decade. In January 1968, a few weeks after the devaluation of the pound, Prime Minister Harold Wilson and his Defence Secretary Denis Healey announced that British troops would be withdrawn from major military bases East of Suez, which included the ones in the Middle East, and primarily from Malaysia and Singapore. The British withdrew from Aden in 1967, Bahrain in 1971, and Maldives in 1976.
9 | #4 Wikipedia Title: Mother Country: Britain, the Welfare State, and Nuclear Pollution
10 | Text: Mother Country: Britain, the Welfare State, and Nuclear Pollution (1989) is a work of nonfiction by Marilynne Robinson that tells the story of Sellafield, a government nuclear reprocessing plant located on the coast of the Irish Sea. The book shows how the closest village to Sellafield suffers from death and disease due to decades of waste and radiation from the plant. "Mother Country" was a National Book Award finalist for Nonfiction in 1989. While on sabbatical in England, Robinson's interest in the environmental ramifications of the plant began when she discovered a newspaper article detailing its hazards.
11 | #5 Wikipedia Title: Black Death
12 | Text: The study also found that there were two previously unknown but related clades (genetic branches) of the Y. pestis genome associated with medieval mass graves. These clades (which are thought to be extinct) were found to be ancestral to modern isolates of the modern Y. pestis strains Y. p. orientalis and Y. p. medievalis, suggesting the plague may have entered Europe in two waves. Surveys of plague pit remains in France and England indicate the first variant entered Europe through the port of Marseille around November 1347 and spread through France over the next two years, eventually reaching England in the spring of 1349, where it spread through the country in three epidemics. Surveys of plague pit remains from the Dutch town of Bergen op Zoom showed the Y. pestis genotype responsible for the pandemic that spread through the Low Countries from 1350 differed from that found in Britain and France, implying Bergen op Zoom (and possibly other parts of the southern Netherlands) was not directly infected from England or France in 1349 and suggesting a second wave of plague, different from those in Britain and France, may have been carried to the Low Countries from Norway, the Hanseatic cities or another site.
13 | Q: When did Britain withdraw from Bahrain?
14 | A: Britain withdraw from Bahrain in 1971. So the answer is: 1971.
15 |
16 | #1 Wikipedia Title: So Long, See You Tomorrow (album)
17 | Text: So Long, See You Tomorrow is the fourth album by the London indie rock band Bombay Bicycle Club, released on 3 February 2014. The album is named after the novel of the same name by William Maxwell.
18 | #2 Wikipedia Title: Hallelujah I Love Her So
19 | Text: ``Hallelujah I Love Her So ''Single by Ray Charles from the album Ray Charles (or, Hallelujah I Love Her So) B - side`` What Would I Do Without You'' Released 1956 Format 7 ''45rpm Recorded 1956 Genre soul rhythm and blues Length 2: 35 Label Atlantic Songwriter (s) Ray Charles Producer (s) Jerry Wexler Ray Charles singles chronology ``A Fool for You'' (1955)`` Hallelujah I Love Her So ''(1956) ``Mary Ann'' (1956)`` A Fool for You ''(1955) ``Hallelujah I Love Her So'' (1956)`` Mary Ann ''(1956)
20 | #3 Wikipedia Title: The First Time Ever I Saw Your Face
21 | Text: ``The First Time Ever I Saw Your Face ''Single by Roberta Flack from the album First Take Released March 7, 1972 (1972 - 03 - 07) Recorded 1969 Genre Soul vocal jazz Length 5: 22 4: 15 (1972 radio edit) Label Atlantic 2864 Songwriter (s) Ewan MacColl Producer (s) Joel Dorn Roberta Flack singles chronology`` Will You Still Love Me Tomorrow'' (1972) ``The First Time Ever I Saw Your Face ''(1972)`` Where Is the Love'' (1972) ``Will You Still Love Me Tomorrow ''(1972)`` The First Time Ever I Saw Your Face'' (1972) ``Where Is the Love ''(1972)
22 | #4 Wikipedia Title: See You on the Other Side (Mercury Rev album)
23 | Text: See You on the Other Side is the third studio album by American neo-psychedelia band Mercury Rev, released in 1995 by record label Beggars Banquet.
24 | #5 Wikipedia Title: The Dance (song)
25 | Text: ``The Dance ''Single by Garth Brooks from the album Garth Brooks B - side`` If Tomorrow Never Comes'' Released April 30, 1990 Format CD single, 7 ''45 RPM Recorded 1988 -- 1989 Genre Country Length 3: 40 Label Capitol Nashville 44629 Songwriter (s) Tony Arata Producer (s) Allen Reynolds Garth Brooks singles chronology ``Not Counting You'' (1990)`` The Dance ''(1990) ``Friends in Low Places'' (1990)`` Not Counting You ''(1990) ``The Dance'' (1990)`` Friends in Low Places ''(1990)
26 | Q: Who is the performer of So Long, See You Tomorrow?
27 | A: The record label of Bombay Bicycle Club is Island Records. So the answer is: Bombay Bicycle Club
28 |
29 | #1 Wikipedia Title: Wyddial
30 | Text: Wyddial is a village and civil parish in the East Hertfordshire district of Hertfordshire, England. It is located around a mile and a half north-east of Buntingford (OS grid reference ), and lies due north of Greenwich on the Prime Meridian.
31 | #2 Wikipedia Title: Hertfordshire Fire and Rescue Service
32 | Text: The Service Headquarters is located in Hertford whilst the Training and Development Centre and Fire Control Centre are located in Stevenage. It is administered by a Fire Authority which is an internal part of Hertfordshire County Council. The Chief Fire Officer is Darryl Keen, assisted by Deputy Chief Fire Officer Chris Bigland.
33 | #3 Wikipedia Title: Wareside
34 | Text: Wareside is a small village and civil parish in the East Hertfordshire District, in the county of Hertfordshire. The population of the civil parish as of the 2011 census is 735. It is approximately 3 miles away from the town of Ware (from where it probably took its name) and the larger town of Hertford, the county town of Hertfordshire. Nearby villages include Widford, Hunsdon, Babbs Green and Bakers End. Nearby hamlets include Cold Christmas and Helham Green. The B1004 linking Ware to Bishop's Stortford goes through the village and the main A10 road can be picked up at Thundridge. Fanhams Hall Road also links Wareside back to Ware. Ware railway station on the Hertford East Branch Line is located two and a half miles away.
35 | #4 Wikipedia Title: Hertfordshire
36 | Text: Hertfordshire is the county immediately north of London and is part of the East of England region, a mainly statistical unit. A significant minority of the population across all districts are City of London commuters. To the east is Essex, to the west is Buckinghamshire and to the north are Bedfordshire and Cambridgeshire.
37 | #5 Wikipedia Title: Bengeo Rural
38 | Text: Bengeo Rural is a civil parish in the East Hertfordshire district of Hertfordshire, England. According to the 2001 census it had a population of 601, increasing at the 2011 Census to 644. The parish includes the villages of Tonwell and Chapmore End.
39 | Q: In which state is Hertfordshire located?
40 | A: Hertfordshire is located in the state East of England. So the answer is: East of England.
--------------------------------------------------------------------------------
/src/musique/RoHT/openai_req.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import requests
3 | import time
4 | import os
5 | import json, jsonlines
6 |
7 | class OpenaiReq():
8 | def __init__(self):
9 | self.url = "http://127.0.0.1:10001/api/openai/completion"
10 | self.cache = {}
11 | self.cache_path = "./cache.jsonl"
12 | if os.path.exists(self.cache_path):
13 | with open(self.cache_path, "r") as f:
14 | for i, line in enumerate(f):
15 | #print(i+1)
16 | datum = json.loads(line.strip())
17 | self.cache[tuple(datum["input"])] = datum["response"]
18 | f.close()
19 |
20 | def req2openai(self, prompt, model="text-davinci-003", temperature=0, max_tokens=128, stop=None, logprobs=1, use_cache=True):
21 | assert isinstance(prompt, str)
22 | input = (prompt, model, max_tokens, stop, logprobs)
23 | if use_cache and temperature == 0 and input in self.cache:
24 | return self.cache[input], True
25 | for i in range(3):
26 | try:
27 | response = requests.post(self.url, json = {
28 | "model": model,
29 | "prompt": prompt,
30 | "temperature": temperature,
31 | "max_tokens": max_tokens,
32 | "stop": stop,
33 | "logprobs": logprobs,
34 | })
35 | if response.status_code != 200:
36 | raise Exception(response.text)
37 | break
38 | except Exception as e:
39 | err_msg = str(e)
40 | print(e)
41 | if "reduce your prompt" in err_msg: # this is because the input string too long
42 | return ['too long'], False
43 | try:
44 | response = response.json()['choices']
45 | except:
46 | return ['openai error'], False
47 | if temperature == 0:
48 | input = (prompt, model, max_tokens, stop, logprobs)
49 | res = response[0]
50 | if input not in self.cache:
51 | self.cache[input] = [res]
52 | with open(self.cache_path, "a") as f:
53 | f.write("%s\n"%json.dumps({"input": input, "response": [res]}))
54 | f.close()
55 | return response, True
56 |
57 | if __name__ == "__main__":
58 | caller = OpenaiReq()
59 | res = caller.req2openai("你好", use_cache=True)
60 | print(res)
61 |
62 |
--------------------------------------------------------------------------------
/src/musique/RoHT/parallel.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures, random, time
2 |
3 | def handle_item(data):
4 | waiting = random.random() * 3 + 1
5 | print("Thread %d, Waiting %.2f ..."%(data, waiting))
6 | time.sleep(waiting)
7 | if random.random() < 0.5:
8 | raise Exception()
9 | print("Thread %d, OK."%(data))
10 |
11 | def parallel_process_data(data, handle_item, workers=20, callback=None):
12 | with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
13 | futures = []
14 | for item in data:
15 | future = executor.submit(handle_item, item)
16 | futures.append(future)
17 | for future in concurrent.futures.as_completed(futures):
18 | result = future.result()
19 | if callback:
20 | callback(result)
21 |
22 | if __name__ == "__main__":
23 | parallel_process_data([i for i in range(20)], handle_item)
24 | print("end")
--------------------------------------------------------------------------------
/src/musique/RoHT/question_answering.py:
--------------------------------------------------------------------------------
1 | from openai_req import OpenaiReq
2 | import requests
3 | import os
4 | from transformers import AutoTokenizer
5 |
6 | openai_caller = OpenaiReq()
7 |
8 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
9 |
10 | def bm25_search(question, k):
11 | web = "http://127.0.0.1:1435"
12 | data = {
13 | "query": question,
14 | "k": k
15 | }
16 | for i in range(3):
17 | try:
18 | r = requests.get(web, json=data)
19 | if r.status_code != 200:
20 | raise Exception(r.text)
21 | return r.json()
22 | except Exception as e:
23 | print(e)
24 |
25 | def postprocess(response):
26 | response = response[0]
27 | if response == 'too long' or response['finish_reason'] != 'stop':
28 | return 'ERROR: prompt too long', -100, ""
29 | tokens = response['logprobs']['tokens']
30 | token_logprobs = response['logprobs']['token_logprobs']
31 | cot = response['text'].strip()
32 | if len(token_logprobs) == 0:
33 | return 'ERROR: empty output', -100, cot
34 | pos = 0
35 | for idx, token in enumerate(tokens):
36 | if token.strip() == 'So' and idx + 1 <= len(tokens) and tokens[idx + 1].strip() == 'the' and idx + 2 <= len(tokens) and tokens[idx + 2].strip() == 'answer' and idx + 3 <= len(tokens) and tokens[idx + 3].strip() == 'is' and idx + 4 <= len(tokens) and tokens[idx + 4].strip() == ':':
37 | pos = idx
38 | break
39 | if tokens[-1] == '.':
40 | answer_logprobs = token_logprobs[pos+5:-1]
41 | answer = cot.split('So the answer is: ')[-1][:-1]
42 | else:
43 | answer_logprobs = token_logprobs[pos+5:]
44 | answer = cot.split('So the answer is: ')[-1]
45 | cot_process = cot.split('So the answer is: ')[0].strip()
46 | cot_process_logprobs = token_logprobs[:pos]
47 | if len(cot_process_logprobs) == 0:
48 | cot_process_logprob = -100
49 | else:
50 | cot_process_logprob = sum(cot_process_logprobs) / len(cot_process_logprobs)
51 | return answer, cot_process_logprob, cot
52 |
53 | def get_cb_answer(question):
54 | instruction = '\n'.join([_.strip() for _ in open('cb/prompt.txt').readlines()])
55 | prompt = instruction + '\nQ: ' + question + '\nA:'
56 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='Q:', use_cache=True)
57 | return postprocess(response)
58 |
59 | def get_singlehop_ob_answer(question, topic_entities):
60 | instruction = '\n'.join([_.strip() for _ in open('ob/singlehop_prompt.txt').readlines()])
61 | for k in range(5, 0, -1):
62 | contexts = []
63 | hist = set()
64 | r = bm25_search(question, k)
65 | for datum in r:
66 | title, text = datum["title"], datum["paragraph_text"]
67 | stamp = title + text
68 | if not stamp in hist:
69 | hist.add(stamp)
70 | contexts.append([title, text])
71 |
72 | prompt = instruction + '\n'
73 | for idx, (title, text) in enumerate(contexts):
74 | prompt += '\n#' + str(idx + 1) + ' Wikipedia Title: ' + title + '\nText: ' + text
75 | prompt += '\nQ: ' + question + '\nA:'
76 | if len(tokenizer(prompt).input_ids) + 256 <= 4097:
77 | break
78 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n', use_cache=True)
79 | return postprocess(response)
80 |
81 | def aggregate_singlehop_answer(cb_answer, ob_answer):
82 | cb_ans, cb_score, cb_cot = cb_answer
83 | ob_ans, ob_score, ob_cot = ob_answer
84 | if "ERROR" in cb_ans or 'Unknown' in cb_ans:
85 | cb_ans, cb_score = "", -100
86 | if "ERROR" in ob_ans or 'Unknown' in ob_ans:
87 | ob_ans, ob_score = "", -100
88 | return max([(cb_ans, cb_score, cb_cot), (ob_ans, ob_score, ob_cot)], key=lambda x:x[1])
89 |
90 | def get_multihop_ob_answer(node, tree):
91 |
92 | def is_descendant(a, b):
93 | while "fa" in tree[a]:
94 | a = tree[a]["fa"]
95 | if a == b:
96 | return True
97 | return False
98 |
99 | question = node["question"]
100 | instruction = '\n'.join([_.strip() for _ in open('ob/multihop_prompt.txt').readlines()])
101 | k = 5
102 | for sub_k in range(3, 0, -1):
103 | contexts = []
104 | hist = set()
105 | r = bm25_search(question, k)
106 | for datum in r:
107 | title, text = datum["title"], datum["paragraph_text"]
108 | stamp = title + text
109 | if stamp not in hist:
110 | hist.add(stamp)
111 | contexts.append([title, text])
112 |
113 | for idx in range(node["idx"]):
114 | if is_descendant(idx, node["idx"]):
115 | sub_question = tree[idx]["question"]
116 | r = bm25_search(sub_question, sub_k)
117 | for datum in r:
118 | title, text = datum["title"], datum["paragraph_text"]
119 | stamp = title + text
120 | if stamp not in hist:
121 | hist.add(stamp)
122 | contexts.append([title, text])
123 |
124 | prompt = instruction + '\n'
125 | for idx, (title, text) in enumerate(contexts):
126 | prompt += '\n#' + str(idx + 1) + ' Wikipedia Title: ' + title + '\nText: ' + text
127 | prompt += '\nQ: ' + question + '\nA: '
128 | if len(tokenizer(prompt).input_ids) + 256 <= 4097:
129 | break
130 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n', use_cache=True)
131 | return postprocess(response)
132 |
133 | def calculate_score1(cot_process_logprob, qd_score, sub_answer_scores):
134 | return cot_process_logprob + qd_score + sum(sub_answer_scores)
135 |
136 | def calculate_score2(cot_process_logprob, qd_score, sub_answer_scores):
137 | return (cot_process_logprob + qd_score + sum(sub_answer_scores)) / (len(sub_answer_scores) + 2)
138 |
139 | def aggregate_multihop_answer(node, tree):
140 | instruction = '\n'.join([_.strip() for _ in open('aggregate/prompt.txt').readlines()])
141 | question = node["question"]
142 | qd_score = node["qd_logprob"]
143 | context = ''
144 | sub_answer_scores = []
145 | for son_idx in node["sons"]:
146 | sub_question = tree[son_idx]["question"]
147 | sub_answer = tree[son_idx]["answer"][0]
148 | sub_answer_scores.append(tree[son_idx]["answer"][1])
149 | context += '\n' + sub_question + ' ' + sub_answer
150 | prompt = instruction + '\nContext:\n{}\n\nQuestion:\n{}\n\nAnswer:'.format(context, question)
151 | response, tag = openai_caller.req2openai(prompt=prompt, max_tokens=256, stop='\n\n\n', use_cache=True)
152 | child_answer, cot_process_logprob, child_cot = postprocess(response)
153 |
154 | child_ans = child_answer
155 | child_score = calculate_score2(cot_process_logprob, qd_score, sub_answer_scores)
156 | res1 = (child_ans, child_score, child_cot)
157 | cb_ans, cb_score, cb_cot = node["cb_answer"]
158 | ob_ans, ob_score, ob_cot = node["ob_answer"]
159 | if "ERROR" in cb_ans or 'Unknown' in cb_ans:
160 | cb_ans, cb_score = "", -100
161 | if "ERROR" in ob_ans or 'Unknown' in ob_ans:
162 | ob_ans, ob_score = "", -100
163 | if "ERROR" in child_ans or "Unknow" in child_ans:
164 | child_ans, child_score = "", -100
165 | res2 = max([(cb_ans, cb_score, cb_cot), (ob_ans, ob_score, ob_cot), (child_ans, child_score, child_cot)], key=lambda x:x[1])
166 | return res1, res2
167 |
168 |
169 |
170 | if __name__ == "__main__":
171 | question = "Which religious order founded Harvard College?"
172 | r = bm25_search(question, k=5)
173 | for x in r:
174 | print(x["title"])
175 | print(x["paragraph_text"])
176 | print()
177 |
178 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/src/musique/Tree_Generation/0_get_prompt.py:
--------------------------------------------------------------------------------
1 | import json, jsonlines
2 |
3 | instruction = '\n'.join([_.strip() for _ in open('prompt.txt').readlines()])
4 |
5 | raw_data = jsonlines.open("../../../released_data/musique_ans__v2_test_random_500.jsonl", "r")
6 |
7 | prompts = []
8 | for item in raw_data:
9 | question = item['question_text'].strip()
10 | prompt = instruction + '\nQ: ' + question + '\nA:'
11 | prompts.append(prompt)
12 |
13 | json.dump(prompts, open('prompts.json', 'w'), indent = 2)
14 | print(len(prompts))
--------------------------------------------------------------------------------
/src/musique/Tree_Generation/1_query.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from openai_req import OpenaiReq
4 | import random
5 | from tqdm import tqdm
6 | import os
7 | from multiprocessing import Pool
8 | from termcolor import colored
9 | random.seed(42)
10 |
11 | MAX_SPLIT = 64
12 | STEP = 4
13 |
14 | def query(rank, prompts):
15 | print('Process rank {} PID {} begin...'.format(rank, os.getpid()))
16 | reqor = OpenaiReq()
17 | queries = prompts[int(len(prompts) * rank / MAX_SPLIT) : int(len(prompts) * (rank + 1) / MAX_SPLIT)]
18 | try:
19 | fout = open('outputs/rank_{}.json'.format(rank), 'w')
20 | if rank == 0:
21 | bar = tqdm(range(len(queries) // STEP + 1))
22 | else:
23 | bar = range(len(queries) // STEP + 1)
24 | for idx in bar:
25 | inputs = queries[idx * STEP : (idx + 1) * STEP]
26 | if len(inputs) == 0:
27 | break
28 | gpt_results = []
29 | for prompt in inputs:
30 | result, tag = reqor.req2openai(prompt, max_tokens = 512, stop = '\n\n')
31 | gpt_results.append(result[0])
32 | for prompt, res in zip(inputs, gpt_results):
33 | # print(res)
34 | fout.write(json.dumps({'prompt': prompt, 'response': res}) + '\n')
35 | fout.flush()
36 | fout.close()
37 | except Exception as err:
38 | print(Exception, err)
39 |
40 | if __name__=='__main__':
41 | prompts = json.load(open('prompts.json'))
42 | os.makedirs("outputs", exist_ok=False)
43 | print("number of prompts: {}".format(len(prompts)))
44 | print('Parent process %s.' % os.getpid())
45 | p = Pool(MAX_SPLIT)
46 | for i in range(MAX_SPLIT):
47 | p.apply_async(query, args=(i, prompts))
48 | print('Waiting for all subprocesses done...')
49 | p.close()
50 | p.join()
51 | print('All subprocesses done.')
--------------------------------------------------------------------------------
/src/musique/Tree_Generation/2_postprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from termcolor import colored
4 | import os
5 |
6 | # def findAllFile(base):
7 | # for root, ds, fs in os.walk(base):
8 | # for f in fs:
9 | # yield f
10 | # base = './outputs'
11 | # data = []
12 | # for file_name in findAllFile(base):
13 | # data += [json.loads(line.strip()) for line in open(os.path.join(base, file_name))]
14 | # # data.update(json.load(open(os.path.join(base, file_name))))
15 | # print(len(data))
16 | # json.dump(data, open(os.path.join(base, 'predictions.json'), 'w'), indent = 2, ensure_ascii=False)
17 |
18 | raw_data = json.load(open('outputs/predictions.json', "r"))
19 |
20 | data = {}
21 | for item in tqdm(raw_data):
22 | prompt = item['prompt']
23 | question = prompt.split('\n')[-2][len('Q: '):].strip()
24 | print(colored(question, 'red'))
25 | # print(item['response']['text'])
26 | try:
27 | qds = item['response']['text'].strip()
28 | if qds.endswith('.'):
29 | qds = qds[:-1]
30 | # print(qds)
31 | # if question.startswith('Who is the actress who plays the role of the Queen of Eng'):
32 | # continue
33 | hqdt = json.loads(qds)
34 | except:
35 | hqdt = None
36 | #print(question)
37 | #continue
38 |
39 |
40 |
41 |
42 | tokens = item['response']['logprobs']['tokens']
43 | token_logprobs = item['response']['logprobs']['token_logprobs']
44 | if len(token_logprobs) == 0:
45 | continue
46 |
47 | if tokens[-1] == '.':
48 | token_logprobs = token_logprobs[:-1]
49 | # print(answer_logprobs)
50 | # else:
51 | # answer_logprobs = token_logprobs[pos+6:]
52 |
53 | # print(tokens[pos+6:-1])
54 |
55 | st, ed = 0, 0
56 | pos = 0
57 | qds = {}
58 | for sub_question, qd in hqdt.items():
59 | while pos < len(tokens):
60 | if "[" in tokens[pos] and ": [\"" in "".join(tokens[max(pos-1, 0): min(pos+2, len(tokens))]):
61 | st = pos
62 | break
63 | pos += 1
64 | while pos < len(tokens):
65 | if "]" in tokens[pos] and "\"]" in "".join(tokens[max(pos-1, 0): min(pos+2, len(tokens))]):
66 | ed = pos
67 | break
68 | pos += 1
69 | assert pos < len(tokens), sub_question
70 | qd_score = sum(token_logprobs[st:ed+1]) / len(token_logprobs[st:ed+1])
71 | qds[sub_question] = (qd, qd_score)
72 | print(colored(sub_question, 'blue'))
73 | print("".join(tokens[st:ed+1]))
74 |
75 |
76 | # answer_logprob = sum(token_logprobs) / len(token_logprobs)
77 | # data[question] = [hqdt, answer_logprob]
78 | data[question] = qds
79 | json.dump(data, open('question_decompositions.json', 'w'), indent = 2)
--------------------------------------------------------------------------------
/src/musique/Tree_Generation/3_postprocess_tree.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | raw_data = json.load(open('question_decompositions.json'))
4 |
5 | def check(question):
6 | if '#1' in question or '#2' in question or '#3' in question or '#4' in question:
7 | return True
8 | tree = {}
9 | for father in raw_data:
10 | if check(father):
11 | continue
12 | qds = raw_data[father]
13 | if qds is None:
14 | continue
15 | tree[father] = {}
16 | for question in qds:
17 | if check(question):
18 | continue
19 | if any([x == question for x in qds[question][0]]):
20 | tree[father][question] = [[], None]
21 | else:
22 | tree[father][question] = qds[question]
23 |
24 | question_decompositions = {}
25 | for father in tree:
26 | qds = tree[father]
27 | for q in qds:
28 | if q not in question_decompositions:
29 | question_decompositions[q] = qds[q]
30 | else:
31 | if question_decompositions[q] != qds[q]:
32 | print(question_decompositions[q])
33 | print(qds[q])
34 | else:
35 | print('haha')
36 |
37 | json.dump(question_decompositions, open('tree.json', 'w'), indent = 2)
38 |
39 | print(len(tree))
--------------------------------------------------------------------------------
/src/musique/Tree_Generation/combine.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | def findAllFile(base):
5 | for root, ds, fs in os.walk(base):
6 | for f in fs:
7 | yield f
8 | base = './outputs'
9 | data = []
10 | for file_name in findAllFile(base):
11 | data += [json.loads(line.strip()) for line in open(os.path.join(base, file_name))]
12 | # data.update(json.load(open(os.path.join(base, file_name))))
13 | print(len(data))
14 | json.dump(data, open(os.path.join(base, 'predictions.json'), 'w'), indent = 2)
--------------------------------------------------------------------------------
/src/musique/Tree_Generation/openai_req.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import requests
3 | import time
4 | import os
5 | import json, jsonlines
6 |
7 | class OpenaiReq():
8 | def __init__(self):
9 | self.url = "http://127.0.0.1:10001/api/openai/completion"
10 | self.cache = {}
11 | self.cache_path = "./cache.jsonl"
12 | if os.path.exists(self.cache_path):
13 | with open(self.cache_path, "r") as f:
14 | for i, line in enumerate(f):
15 | #print(i+1)
16 | datum = json.loads(line.strip())
17 | self.cache[tuple(datum["input"])] = datum["response"]
18 | f.close()
19 |
20 | def req2openai(self, prompt, model="text-davinci-003", temperature=0, max_tokens=128, stop=None, logprobs=1, use_cache=True):
21 | assert isinstance(prompt, str)
22 | input = (prompt, model, max_tokens, stop, logprobs)
23 | if use_cache and temperature == 0 and input in self.cache:
24 | return self.cache[input], True
25 | for i in range(3):
26 | try:
27 | response = requests.post(self.url, json = {
28 | "model": model,
29 | "prompt": prompt,
30 | "temperature": temperature,
31 | "max_tokens": max_tokens,
32 | "stop": stop,
33 | "logprobs": logprobs,
34 | })
35 | if response.status_code != 200:
36 | raise Exception(response.text)
37 | break
38 | except Exception as e:
39 | err_msg = str(e)
40 | print(e)
41 | if "reduce your prompt" in err_msg: # this is because the input string too long
42 | return ['too long'], False
43 | try:
44 | response = response.json()['choices']
45 | except:
46 | return ['openai error'], False
47 | if temperature == 0:
48 | input = (prompt, model, max_tokens, stop, logprobs)
49 | res = response[0]
50 | if input not in self.cache:
51 | self.cache[input] = [res]
52 | with open(self.cache_path, "a") as f:
53 | f.write("%s\n"%json.dumps({"input": input, "response": [res]}))
54 | f.close()
55 | return response, True
56 |
57 | if __name__ == "__main__":
58 | caller = OpenaiReq()
59 | res = caller.req2openai("你好", use_cache=True)
60 | print(res)
61 |
62 |
--------------------------------------------------------------------------------
/src/musique/Tree_Generation/prompt.txt:
--------------------------------------------------------------------------------
1 | Please generate a hierarchical question decomposition tree (HQDT) with json format for a given question. In this tree, the root node is the original complex question, and each non-root node is a sub-question of its parent. The leaf nodes are atomic questions that cannot be further decomposed.
2 | Q: When did the first large winter carnival take place in the city where CIMI−FM is licensed to broadcast?
3 | A: {"When did the first large winter carnival take place in the city where CIMI−FM is licensed to broadcast?": ["Which city is CIMI−FM licensed to broadcast?", "When did the first large winter carnival take place in #1?"]}.
4 | Q: What county is Hebron located in, in the same province the Heritage Places Protection Act applies to?
5 | A: {"What county is Hebron located in, in the same province the Heritage Places Protection Act applies to?": ["Which did Heritage Places Protection Act apply to the jurisdiction of?", "which country is Hebron, #1 located in?"]}.
6 | Q: What weekly publication in the Connecticut city with the most Zagat rated restaurants is issued by university of America−Lite: How Imperial Academia Dismantled Our Culture's author?
7 | A: {"What weekly publication in the Connecticut city with the most Zagat rated restaurants is issued by university of America−Lite: How Imperial Academia Dismantled Our Culture's author?": ["Which university was the author of America−Lite: How Imperial Academia Dismantled Our Culture educated at?", "What city in Connecticut has the highest number of Zagat−rated restaurants?", "What is the weekly publication in #2 that is issued by #1?"], "Which university was the author of America−Lite: How Imperial Academia Dismantled Our Culture educated at?": ["Who is the author of America−Lite: How Imperial Academia Dismantled Our Culture?", "Which university was #1 educated at?"]}.
8 | Q: What did the publisher of Banjo−Tooie rely primarily on for its support?
9 | A: {"What did the publisher of Banjo−Tooie rely primarily on for its support?": ["What is the publisher of Banjo−Tooie?", "What did #1 rely primarily for its support on first−party games?"]}.
10 | Q: In which county was the birthplace of the Smoke in tha City performer?
11 | A: {"In which county was the birthplace of the Smoke in tha City performer?": ["What's the birthplace of the Smoke in tha City performer?", "Which country is #1 located in?"], "What's the birthplace of the Smoke in tha City performer?": ["Who is the performer of Smoke in tha City?", "Where was #1 born?"]}.
12 | Q: What region of the state where Guy Shepherdson was born, contains SMA Negeri 68?
13 | A: {"What region of the state where Guy Shepherdson was born, contains SMA Negeri 68?": ["Where was Guy Shepherdson born?", "what region of the state is SMA Negeri 68 #1 located in?"]}.
14 | Q: When did Britain withdraw from the country containing Hoora?
15 | A: {"When did Britain withdraw from the country containing Hoora?": ["Which country is Hoora in?", "When did Britain withdraw from #1?"]}.
16 | Q: How long is the US border with the country that borders the state where Finding Dory takes place?
17 | A: {"How long is the US border with the country that borders the state where Finding Dory takes place?": ["Which country shares a border with the state where Finding Dory is supposed to take place?", "how long is the us border with #1?"], "Which country shares a border with the state where Finding Dory is supposed to take place?": ["where is finding dory supposed to take place", "which country shares a border with #1"]}.
18 | Q: When did the first large winter carnival happen in Olivier Robitaille's place of birth?
19 | A: {"When did the first large winter carnival happen in Olivier Robitaille's place of birth?": ["Where was Olivier Robitaille born?", "when did the first large winter carnival take place in #1?"]}.
20 | Q: When did Britain withdraw from the country where the village of Wadyan is found?
21 | A: {"When did Britain withdraw from the country where the village of Wadyan is found?": ["Which country is Wadyan in ?", "When did Britain withdraw from #1?"]}.
22 | Q: How many countries in Pacific National University's continent are recognized by the organization that mediated the truce ending the Iran−Iraq war?
23 | A: {"How many countries in Pacific National University's continent are recognized by the organization that mediated the truce ending the Iran−Iraq war?": ["What continent is the country of Pacific National University located in?", "Who mediated the truce which ended the Iran-Iraq War?", "the #2 recognises how many regions in #1?"], "What continent is the country of Pacific National University located in?": ["which country is Pacific National University located in?", "What continent is #1 in?"]}.
24 | Q: When was Eritrea annexed by the Horn of Africa country where, along with Somalia and the country where Bissidiro is located, Somali people live?
25 | A: {"When was Eritrea annexed by the Horn of Africa country where, along with Somalia and the country where Bissidiro is located, Somali people live?": ["Along with Kenya, the country where Bissidiro is located and Somalia, in what Horn of Africa country do Somali people live?", "When was Eritrea annexed by #1?"], "Along with Kenya, the country where Bissidiro is located and Somalia, in what Horn of Africa country do Somali people live?": ["Which country is Bissidiro located in?", "Along with Kenya, #1 and Somalia, in what Horn of Africa country do Somali people live?"]}.
26 | Q: What was used to launch the probe of the country where Gao is located to the planet where Hephaestus Fossae is found?
27 | A: {"What was used to launch the probe of the country where Gao is located to the planet where Hephaestus Fossae is found?": ["Where was Goa?", "Where is Hephaestus Fossae found?", "#1 's mangalyaan was sent to the #2 by launching what?"]}.
28 | Q: Where is the lowest place in the country which, along with Eisenhower's VP's country, recognized Gaddafi's government early on?
29 | A: {"Where is the lowest place in the country which, along with Eisenhower's VP's country, recognized Gaddafi's government early on?": ["What country is along with Eisenhower's VP's country, recognized Gaddafi's government early on?", "Where is the lowest place in the #1"], "What country is along with Eisenhower's VP's country, recognized Gaddafi's government early on?": ["Eisenhower's vice president was a president of what country?", "Along with the #1 , what major power recognized Gaddafi's government at an early date?"], "Eisenhower's vice president was a president of what country?": ["Who served as Eisenhower's vice president?", "#1 was a president of what country?"]}.
30 | Q: When did the capital of Virginia moved from John Nicholas's birth city to Charles Oakley's alma mater's city?
31 | A: {"When did the capital of Virginia moved from John Nicholas's birth city to Charles Oakley's alma mater's city?": ["Which city was Charles Oakley's university located in?", "Where was John Nicholas born?", "When did the capital of virginia moved from #2 to #1?"], "Which city was Charles Oakley's university located in?": ["Which university was Charles Oakley educated at?", "Which city was #1 located in?"]}.
32 | Q: How many people whose name new students were once called by others live in the South American country discovered by the country Cristiano Ronaldo plays for?
33 | A: {"How many people whose name new students were once called by others live in the South American country discovered by the country Cristiano Ronaldo plays for?": ["What is the South American country discovered by the country Cristiano Ronaldo plays for?", "What were new students once called by others?", "How many #2 live in #1 ?"], "What is the South American country discovered by the country Cristiano Ronaldo plays for?": ["What country does cristiano ronaldo play for?", "What South American country did #1 discover?"]}.
34 | Q: When did the winners of the Battle of Borodino come to the place in which the island besides St. Barts granted COM status by France in 2007 is located?
35 | A: {"When did the winners of the Battle of Borodino come to the place in which the island besides St. Barts granted COM status by France in 2007 is located?": ["The island besides St. Barts granted COM status by France in 2007 is located on which terrain feature?", "Who won the Battle of Borodino?", "when did the #2 come to the #1?"], "The island besides St. Barts granted COM status by France in 2007 is located on which terrain feature?": ["What island besides St. Barts was granted COM status by France in 2007?", "#1 is located on which terrain feature?"]}.
36 | Q: How many square miles is the source of the most legal immigrants to the location of Gotham's filming from the region where Andy from The Office sailed to?
37 | A: {"How many square miles is the source of the most legal immigrants to the location of Gotham's filming from the region where Andy from The Office sailed to?": ["What is the source of the most legal immigrants to the location of Gotham's filming from the region where Andy from The Office sailed to?", "How many square miles is #1?"], "What is the source of the most legal immigrants to the location of Gotham's filming from the region where Andy from The Office sailed to?": ["where is the tv show gotham filmed at", "where did andy sail to in the office", "What nation provided the most legal immigrants to #1 in the #2 ?"]}.
38 | Q: When did the capitol of Virginia move from Robert Banks' birthplace to the city sharing a border with Laurel's county?
39 | A: {"When did the capitol of Virginia move from Robert Banks' birthplace to the city sharing a border with Laurel's county?": ["What is the city sharing a border with Laurel's county?", "Where is Robert Banks' birthplace?", "When did the capitol of Virginia move from #1 to #2?"], "What is the city sharing a border with Laurel's county?": ["What county is Laurel located in?", "What city shares a border with #1?"]}.
40 | Q: An actor in Nowhere to Run is a national of a European country. That country's King Albert I lived during a major war that Italy joined in what year?
41 | A: {"An actor in Nowhere to Run is a national of a European country. That country's King Albert I lived during a major war that Italy joined in what year?": ["Albert I of the country which has the actor in Nowhere to Run lived during which war?", "When did Italy join #1?"], "Albert I of the country which has the actor in Nowhere to Run lived during which war?": ["Tell me the country which has the actor in Nowhere to Run", "Albert I of #1 lived during which war?"], "Tell me the country which has the actor in Nowhere to Run": ["Nowhere to Run's cast member is whom?", "What is the country of #1?"]}.
--------------------------------------------------------------------------------
/src/service/es/index_2wiki_wiki.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from collections import Counter, defaultdict
3 | from elasticsearch import Elasticsearch
4 | import html
5 | import json
6 | from tqdm import tqdm
7 | from itertools import chain
8 |
9 | def chunks(l, n):
10 | """Yield successive n-sized chunks from l."""
11 | for i in range(0, len(l), n):
12 | yield l[i:i + n]
13 |
14 | INDEX_NAME = '2wiki_paragraph'
15 | def process_line(data):
16 | item = {'id': data['id'],
17 | 'url': 'empty',
18 | 'title': data['title'],
19 | 'title_unescape': html.unescape(data['title']),
20 | 'text': data['text'],
21 | 'title_bigram': html.unescape(data['title']),
22 | 'title_unescape_bigram': html.unescape(data['title']),
23 | 'text_bigram': data['text'],
24 | 'original_json': data
25 | }
26 | # tell elasticsearch we're indexing documents
27 | return "{}\n{}".format(json.dumps({ 'index': { '_id': 'wiki-{}'.format(data['id']) } }), json.dumps(item))
28 |
29 | es = Elasticsearch(hosts="http://localhost:9200")
30 | def index_chunk(chunk):
31 | res = es.bulk(index=INDEX_NAME, body='\n'.join(chunk), timeout='100s')
32 | assert not res['errors'], res
33 |
34 | def main(args):
35 |
36 | train = json.load(open('../../../data/2wiki/train.json', "r"))
37 | dev = json.load(open('../../../data/2wiki/dev.json', "r"))
38 | test = json.load(open('../../../data/2wiki/test.json', "r"))
39 |
40 |
41 | data = {}
42 | for item in tqdm(chain(train, dev, test)):
43 | for title, sentences in item['context']:
44 | para = " ".join(sentences)
45 | data[para] = title
46 | data = [{"id": i, "text": text, "title": title} for i, (text, title) in enumerate(data.items())]
47 | json.dump(data, open('2wiki_wikipedia.json', 'w'), indent = 2)
48 | print(len(data))
49 |
50 |
51 | # make index
52 | if not args.dry:
53 | es.indices.delete(index=INDEX_NAME, ignore=[400,403])
54 | es.indices.create(index=INDEX_NAME, ignore=400,
55 | mappings = {"doc":{"properties": {
56 | "id": { "type": "keyword" },
57 | "url": { "type": "keyword" },
58 | "title": { "type": "text", "analyzer": "simple", "copy_to": "title_all"},
59 | "title_unescape": { "type": "text", "analyzer": "simple", "copy_to": "title_all"},
60 | "text": { "type": "text", "analyzer": "my_english_analyzer"},
61 | "anchortext": { "type": "text", "analyzer": "my_english_analyzer"},
62 | "title_bigram": { "type": "text", "analyzer": "simple_bigram_analyzer", "copy_to": "title_all_bigram"},
63 | "title_unescape_bigram": { "type": "text", "analyzer": "simple_bigram_analyzer", "copy_to": "title_all_bigram"},
64 | "text_bigram": { "type": "text", "analyzer": "bigram_analyzer"},
65 | "anchortext_bigram": { "type": "text", "analyzer": "bigram_analyzer"},
66 | "original_json": { "type": "string" },
67 | }}
68 | },
69 | settings = {
70 | "analysis": {
71 | "my_english_analyzer": {
72 | "type": "standard",
73 | "stopwords": "_english_",
74 | },
75 | "simple_bigram_analyzer": {
76 | "tokenizer": "standard",
77 | "filter": [
78 | "lowercase", "shingle", "asciifolding"
79 | ]
80 | },
81 | "bigram_analyzer": {
82 | "tokenizer": "standard",
83 | "filter": [
84 | "lowercase", "stop", "shingle", "asciifolding"
85 | ]
86 | }
87 | },
88 | }
89 | )
90 |
91 |
92 | wikipedia_data = json.load(open('2wiki_wikipedia.json'))
93 |
94 | print('Making indexing queries...')
95 | all_queries = []
96 | for item in tqdm(wikipedia_data):
97 | all_queries.append(process_line(item))
98 |
99 | count = sum(len(queries.split('\n')) for queries in all_queries) // 2
100 |
101 | if not args.dry:
102 | print('Indexing...')
103 | chunksize = 100
104 | for chunk in tqdm(chunks(all_queries, chunksize), total=(len(all_queries) + chunksize - 1) // chunksize):
105 | res = es.bulk(index=INDEX_NAME, body='\n'.join(chunk), timeout='100s')
106 | assert not res['errors'], res
107 |
108 | print(f"{count} documents indexed in total")
109 |
110 | if __name__ == '__main__':
111 | parser = ArgumentParser()
112 |
113 | parser.add_argument('--reindex', action='store_true', help="Reindex everything")
114 | parser.add_argument('--dry', action='store_true', help="Dry run")
115 |
116 | args = parser.parse_args()
117 |
118 | main(args)
119 |
--------------------------------------------------------------------------------
/src/service/es/index_hotpotqa_wiki.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import bz2
3 | from elasticsearch import Elasticsearch
4 | from glob import glob
5 | import html
6 | import json
7 | from multiprocessing import Pool
8 | from tqdm import tqdm
9 |
10 | WIKIPEDIA_INDEX_NAME='wikipedia'
11 |
12 | def chunks(l, n):
13 | """Yield successive n-sized chunks from l."""
14 | for i in range(0, len(l), n):
15 | yield l[i:i + n]
16 |
17 | def process_line(line):
18 | data = json.loads(line)
19 | item = {'id': data['id'],
20 | 'url': data['url'],
21 | 'title': data['title'],
22 | 'title_unescape': html.unescape(data['title']),
23 | 'text': ''.join(data['text']),
24 | 'title_bigram': html.unescape(data['title']),
25 | 'title_unescape_bigram': html.unescape(data['title']),
26 | 'text_bigram': ''.join(data['text']),
27 | 'original_json': line
28 | }
29 | # tell elasticsearch we're indexing documents
30 | return "{}\n{}".format(json.dumps({ 'index': { '_id': 'wiki-{}'.format(data['id']) } }), json.dumps(item))
31 |
32 | def generate_indexing_queries_from_bz2(bz2file, dry=False):
33 | if dry:
34 | return
35 |
36 | with bz2.open(bz2file, 'rt') as f:
37 | body = [process_line(line) for line in f]
38 |
39 | return '\n'.join(body)
40 |
41 | # es = Elasticsearch(timeout=100)
42 | es = Elasticsearch(hosts="http://localhost:9200")
43 | def index_chunk(chunk):
44 | res = es.bulk(index=WIKIPEDIA_INDEX_NAME, body='\n'.join(chunk), timeout='100s')
45 | assert not res['errors'], res
46 |
47 | def main(args):
48 | # make index
49 | if not args.dry:
50 | if es.indices.exists(index=WIKIPEDIA_INDEX_NAME) and args.reindex:
51 | es.indices.delete(index=WIKIPEDIA_INDEX_NAME, ignore=[400,403])
52 | if not es.indices.exists(index=WIKIPEDIA_INDEX_NAME):
53 | es.indices.create(index=WIKIPEDIA_INDEX_NAME, ignore=400,
54 | body=json.dumps({
55 | "mappings":{"doc":{"properties": {
56 | "id": { "type": "keyword" },
57 | "url": { "type": "keyword" },
58 | "title": { "type": "text", "analyzer": "simple", "copy_to": "title_all"},
59 | "title_unescape": { "type": "text", "analyzer": "simple", "copy_to": "title_all"},
60 | "text": { "type": "text", "analyzer": "my_english_analyzer"},
61 | "anchortext": { "type": "text", "analyzer": "my_english_analyzer"},
62 | "title_bigram": { "type": "text", "analyzer": "simple_bigram_analyzer", "copy_to": "title_all_bigram"},
63 | "title_unescape_bigram": { "type": "text", "analyzer": "simple_bigram_analyzer", "copy_to": "title_all_bigram"},
64 | "text_bigram": { "type": "text", "analyzer": "bigram_analyzer"},
65 | "anchortext_bigram": { "type": "text", "analyzer": "bigram_analyzer"},
66 | "original_json": { "type": "string" },
67 | }}
68 | },
69 | "settings": {
70 | "analysis": {
71 | "my_english_analyzer": {
72 | "type": "standard",
73 | "stopwords": "_english_",
74 | },
75 | "simple_bigram_analyzer": {
76 | "tokenizer": "standard",
77 | "filter": [
78 | "lowercase", "shingle", "asciifolding"
79 | ]
80 | },
81 | "bigram_analyzer": {
82 | "tokenizer": "standard",
83 | "filter": [
84 | "lowercase", "stop", "shingle", "asciifolding"
85 | ]
86 | }
87 | },
88 | }
89 | }))
90 |
91 | filelist = glob('../../../data/enwiki-20171001-pages-meta-current-withlinks-abstracts/*/wiki_*.bz2')
92 |
93 | print('Making indexing queries...')
94 | pool = Pool()
95 | all_queries = list(tqdm(pool.imap(generate_indexing_queries_from_bz2, filelist), total=len(filelist)))
96 |
97 | count = sum(len(queries.split('\n')) for queries in all_queries) // 2
98 |
99 | if not args.dry:
100 | print('Indexing...')
101 | chunksize = 50
102 | for chunk in tqdm(chunks(all_queries, chunksize), total=(len(all_queries) + chunksize - 1) // chunksize):
103 | res = es.bulk(index=WIKIPEDIA_INDEX_NAME, body='\n'.join(chunk), timeout='100s')
104 | assert not res['errors'], res
105 |
106 | print(f"{count} documents indexed in total")
107 |
108 | if __name__ == '__main__':
109 | parser = ArgumentParser()
110 |
111 | parser.add_argument('--reindex', action='store_true', help="Reindex everything")
112 | parser.add_argument('--dry', action='store_true', help="Dry run")
113 |
114 | args = parser.parse_args()
115 |
116 | main(args)
117 |
--------------------------------------------------------------------------------
/src/service/es/index_musique_wiki.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import bz2
3 | from collections import Counter, defaultdict
4 | from elasticsearch import Elasticsearch
5 | import html
6 | import json
7 | from tqdm import tqdm
8 |
9 | from itertools import chain
10 |
11 | def chunks(l, n):
12 | """Yield successive n-sized chunks from l."""
13 | for i in range(0, len(l), n):
14 | yield l[i:i + n]
15 |
16 | INDEX_NAME = 'musique_wikipedia'
17 | def process_line(data):
18 | item = {'id': data['id'],
19 | 'url': 'empty',
20 | 'title': data['title'],
21 | 'title_unescape': html.unescape(data['title']),
22 | 'text': data['text'],
23 | 'title_bigram': html.unescape(data['title']),
24 | 'title_unescape_bigram': html.unescape(data['title']),
25 | 'text_bigram': data['text'],
26 | 'original_json': data
27 | }
28 | # tell elasticsearch we're indexing documents
29 | return "{}\n{}".format(json.dumps({ 'index': { '_id': 'wiki-{}'.format(data['id']) } }), json.dumps(item))
30 |
31 | es = Elasticsearch(hosts="http://localhost:9200")
32 | def index_chunk(chunk):
33 | res = es.bulk(index=INDEX_NAME, body='\n'.join(chunk), timeout='100s')
34 | assert not res['errors'], res
35 |
36 | def main(args):
37 | # make index
38 | if not args.dry:
39 | es.indices.delete(index=INDEX_NAME, ignore=[400,403])
40 | es.indices.create(index=INDEX_NAME, ignore=400,
41 | mappings = {"doc":{"properties": {
42 | "id": { "type": "keyword" },
43 | "url": { "type": "keyword" },
44 | "title": { "type": "text", "analyzer": "simple", "copy_to": "title_all"},
45 | "title_unescape": { "type": "text", "analyzer": "simple", "copy_to": "title_all"},
46 | "text": { "type": "text", "analyzer": "my_english_analyzer"},
47 | "anchortext": { "type": "text", "analyzer": "my_english_analyzer"},
48 | "title_bigram": { "type": "text", "analyzer": "simple_bigram_analyzer", "copy_to": "title_all_bigram"},
49 | "title_unescape_bigram": { "type": "text", "analyzer": "simple_bigram_analyzer", "copy_to": "title_all_bigram"},
50 | "text_bigram": { "type": "text", "analyzer": "bigram_analyzer"},
51 | "anchortext_bigram": { "type": "text", "analyzer": "bigram_analyzer"},
52 | "original_json": { "type": "string" },
53 | }}
54 | },
55 | settings = {
56 | "analysis": {
57 | "my_english_analyzer": {
58 | "type": "standard",
59 | "stopwords": "_english_",
60 | },
61 | "simple_bigram_analyzer": {
62 | "tokenizer": "standard",
63 | "filter": [
64 | "lowercase", "shingle", "asciifolding"
65 | ]
66 | },
67 | "bigram_analyzer": {
68 | "tokenizer": "standard",
69 | "filter": [
70 | "lowercase", "stop", "shingle", "asciifolding"
71 | ]
72 | }
73 | },
74 | }
75 | )
76 |
77 |
78 | train = [json.loads(line.strip()) for line in open('../../../data/musique/musique_ans_v0.1_train.jsonl')]
79 | dev = [json.loads(line.strip()) for line in open('../../../data/musique/musique_ans_v0.1_dev.jsonl')]
80 | test = [json.loads(line.strip()) for line in open('../../../data/musique/musique_ans_v0.1_test.jsonl')]
81 |
82 | tot = 0
83 | wikipedia_data = []
84 | hist = set()
85 | for item in tqdm(chain(train, dev, test)):
86 | for p in item['paragraphs']:
87 | stamp = p['title'] + ' ' + p['paragraph_text']
88 | if not stamp in hist:
89 | wikipedia_data.append({'id': tot, 'text': p['paragraph_text'], 'title': p['title']})
90 | hist.add(stamp)
91 | tot += 1
92 |
93 | # print(data[-1])
94 | # break
95 | json.dump(wikipedia_data, open('musique_wikipedia.json', 'w'), indent = 2)
96 |
97 | print('Making indexing queries...')
98 | all_queries = []
99 | for item in tqdm(wikipedia_data):
100 | all_queries.append(process_line(item))
101 |
102 | count = sum(len(queries.split('\n')) for queries in all_queries) // 2
103 |
104 | if not args.dry:
105 | print('Indexing...')
106 | chunksize = 100
107 | for chunk in tqdm(chunks(all_queries, chunksize), total=(len(all_queries) + chunksize - 1) // chunksize):
108 | res = es.bulk(index=INDEX_NAME, body='\n'.join(chunk), timeout='100s')
109 | assert not res['errors'], res
110 |
111 | print(f"{count} documents indexed in total")
112 |
113 | if __name__ == '__main__':
114 | parser = ArgumentParser()
115 |
116 | parser.add_argument('--reindex', action='store_true', help="Reindex everything")
117 | parser.add_argument('--dry', action='store_true', help="Dry run")
118 |
119 | args = parser.parse_args()
120 |
121 | main(args)
122 |
--------------------------------------------------------------------------------
/src/service/es/run_2wiki_index.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import json
3 | import os, time, random
4 | from elasticsearch import Elasticsearch
5 | from elasticsearch import helpers
6 | import dill
7 | import json
8 | from tqdm import tqdm
9 | from flask import Flask
10 | from flask import request
11 | from flask_cors import CORS, cross_origin
12 | from termcolor import colored
13 | import re
14 |
15 |
16 | app = Flask(__name__)
17 | cors = CORS(app)
18 | app.config['CORS_HEADERS'] = 'Content-Type'
19 |
20 |
21 | WIKIPEDIA_INDEX_NAME='2wiki_paragraph_3'
22 |
23 | core_title_matcher = re.compile('([^()]+[^\s()])(?:\s*\(.+\))?')
24 | core_title_filter = lambda x: core_title_matcher.match(x).group(1) if core_title_matcher.match(x) else x
25 |
26 | class ElasticSearch:
27 | def __init__(self):
28 | self.client = Elasticsearch("http://localhost:9200")
29 |
30 | def _extract_one(self, item, lazy=False):
31 | res = {k: item['_source'][k] for k in ['id', 'url', 'title', 'text', 'title_unescape']}
32 | res['_score'] = item['_score']
33 | #res['data_object'] = item['_source']['original_json'] if lazy else json.loads(item['_source']['original_json'])
34 | return res
35 |
36 |
37 | def rerank_with_query(self, query, results):
38 | def score_boost(item, query):
39 | score = item['_score']
40 | core_title = core_title_filter(item['title_unescape'])
41 | if query.startswith('The ') or query.startswith('the '):
42 | query1 = query[4:]
43 | else:
44 | query1 = query
45 | if query == item['title_unescape'] or query1 == item['title_unescape']:
46 | score *= 1.5
47 | elif query.lower() == item['title_unescape'].lower() or query1.lower() == item['title_unescape'].lower():
48 | score *= 1.2
49 | elif item['title'].lower() in query:
50 | score *= 1.1
51 | elif query == core_title or query1 == core_title:
52 | score *= 1.2
53 | elif query.lower() == core_title.lower() or query1.lower() == core_title.lower():
54 | score *= 1.1
55 | elif core_title.lower() in query.lower():
56 | score *= 1.05
57 |
58 | item['_score'] = score
59 | return item
60 |
61 | return list(sorted([score_boost(item, query) for item in results], key=lambda item: -item['_score']))
62 |
63 | def single_text_query(self, query, topn=10, lazy=False, rerank_topn=50):
64 | constructed_query = {
65 | "multi_match": {
66 | "query": query,
67 | "fields": ["title^1.25", "title_unescape^1.25", "text", "title_bigram^1.25", "title_unescape_bigram^1.25", "text_bigram"]
68 | }
69 | }
70 | res = self.client.search(index=WIKIPEDIA_INDEX_NAME, query = constructed_query, size = max(topn, rerank_topn))
71 |
72 | res = [self._extract_one(x, lazy=lazy) for x in res['hits']['hits']]
73 | res = self.rerank_with_query(query, res)[:topn]
74 | # print(res)
75 | res = [{'title': _['title'], 'paragraph_text': _['text']} for _ in res]
76 | return res
77 |
78 | def search(self, question, k=10):
79 | try:
80 | res = self.single_text_query(query = question, topn = k)
81 | return json.dumps(res, ensure_ascii=False)
82 | except Exception as err:
83 | raise
84 | print(Exception, err)
85 |
86 |
87 | @app.route('/', methods=['POST', 'GET'])
88 | @cross_origin()
89 | def main():
90 | global ES
91 | question = request.json['query']
92 | k = int(request.json['k'])
93 | return ES.search(question, k)
94 |
95 |
96 | if __name__ == '__main__':
97 | ES = ElasticSearch()
98 | app.run(host='0.0.0.0', port=1440, threaded = True)
99 |
100 |
101 |
--------------------------------------------------------------------------------
/src/service/es/run_hotpotqa_index.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import json
3 | import os, time, random
4 | from elasticsearch import Elasticsearch
5 | import json
6 | from flask import Flask
7 | from flask import request
8 | from flask_cors import CORS, cross_origin
9 | import re
10 |
11 |
12 | app = Flask(__name__)
13 | cors = CORS(app)
14 | app.config['CORS_HEADERS'] = 'Content-Type'
15 |
16 |
17 | WIKIPEDIA_INDEX_NAME='wikipedia'
18 |
19 | core_title_matcher = re.compile('([^()]+[^\s()])(?:\s*\(.+\))?')
20 | core_title_filter = lambda x: core_title_matcher.match(x).group(1) if core_title_matcher.match(x) else x
21 |
22 | class ElasticSearch:
23 | def __init__(self):
24 | self.client = Elasticsearch(timeout=300,hosts="http://127.0.0.1:9200")
25 |
26 | def _extract_one(self, item, lazy=False):
27 | res = {k: item['_source'][k] for k in ['id', 'url', 'title', 'text', 'title_unescape']}
28 | # res['_score'] = item['_score']
29 | # res['data_object'] = item['_source']['original_json'] if lazy else json.loads(item['_source']['original_json'])
30 | return res
31 |
32 | def _extract_one(self, item, lazy=False):
33 | res = {k: item['_source'][k] for k in ['id', 'url', 'title', 'text', 'title_unescape']}
34 | res['_score'] = item['_score']
35 | res['data_object'] = item['_source']['original_json'] if lazy else json.loads(item['_source']['original_json'])
36 |
37 | return res
38 | def rerank_with_query(self, query, results):
39 | def score_boost(item, query):
40 | score = item['_score']
41 | core_title = core_title_filter(item['title_unescape'])
42 | if query.startswith('The ') or query.startswith('the '):
43 | query1 = query[4:]
44 | else:
45 | query1 = query
46 | if query == item['title_unescape'] or query1 == item['title_unescape']:
47 | score *= 1.5
48 | elif query.lower() == item['title_unescape'].lower() or query1.lower() == item['title_unescape'].lower():
49 | score *= 1.2
50 | elif item['title'].lower() in query:
51 | score *= 1.1
52 | elif query == core_title or query1 == core_title:
53 | score *= 1.2
54 | elif query.lower() == core_title.lower() or query1.lower() == core_title.lower():
55 | score *= 1.1
56 | elif core_title.lower() in query.lower():
57 | score *= 1.05
58 |
59 | item['_score'] = score
60 | return item
61 |
62 | return list(sorted([score_boost(item, query) for item in results], key=lambda item: -item['_score']))
63 |
64 | def single_text_query(self, query, topn=10, lazy=False, rerank_topn=50):
65 | constructed_query = {
66 | "multi_match": {
67 | "query": query,
68 | "fields": ["title^1.25", "title_unescape^1.25", "text", "title_bigram^1.25", "title_unescape_bigram^1.25", "text_bigram"]
69 | }
70 | }
71 | res = self.client.search(index=WIKIPEDIA_INDEX_NAME, query = constructed_query, size = max(topn, rerank_topn))
72 |
73 | res = [self._extract_one(x, lazy=lazy) for x in res['hits']['hits']]
74 | res = self.rerank_with_query(query, res)[:topn]
75 | # print(res)
76 | res = [{'title': _['title'], 'text': _['text']} for _ in res]
77 | return res
78 |
79 | def search(self, question, k=10):
80 | try:
81 | res = self.single_text_query(query = question, topn = k)
82 | return json.dumps(res, ensure_ascii=False)
83 | except Exception as err:
84 | print(Exception, err)
85 |
86 |
87 | @app.route('/', methods=['POST', 'GET'])
88 | @cross_origin()
89 | def main():
90 | global ES
91 | question = request.json['query']
92 | k = int(request.json['k'])
93 | return ES.search(question, k)
94 |
95 |
96 | if __name__ == '__main__':
97 | ES = ElasticSearch()
98 | app.run(host='0.0.0.0', port=1439, threaded = True)
--------------------------------------------------------------------------------
/src/service/es/run_musique_indx.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import json
3 | import os, time, random
4 | from elasticsearch import Elasticsearch
5 | from elasticsearch import helpers
6 | import dill
7 | import json
8 | from tqdm import tqdm
9 | from flask import Flask
10 | from flask import request
11 | from flask_cors import CORS, cross_origin
12 | from termcolor import colored
13 | import re
14 |
15 |
16 | app = Flask(__name__)
17 | cors = CORS(app)
18 | app.config['CORS_HEADERS'] = 'Content-Type'
19 |
20 |
21 | WIKIPEDIA_INDEX_NAME='musique_wikipedia_2'
22 |
23 | core_title_matcher = re.compile('([^()]+[^\s()])(?:\s*\(.+\))?')
24 | core_title_filter = lambda x: core_title_matcher.match(x).group(1) if core_title_matcher.match(x) else x
25 |
26 | class ElasticSearch:
27 | def __init__(self):
28 | self.client = Elasticsearch("http://localhost:9200")
29 |
30 | def _extract_one(self, item, lazy=False):
31 | res = {k: item['_source'][k] for k in ['id', 'url', 'title', 'text', 'title_unescape']}
32 | res['_score'] = item['_score']
33 | #res['data_object'] = item['_source']['original_json'] if lazy else json.loads(item['_source']['original_json'])
34 | return res
35 |
36 |
37 | def rerank_with_query(self, query, results):
38 | def score_boost(item, query):
39 | score = item['_score']
40 | core_title = core_title_filter(item['title_unescape'])
41 | if query.startswith('The ') or query.startswith('the '):
42 | query1 = query[4:]
43 | else:
44 | query1 = query
45 | if query == item['title_unescape'] or query1 == item['title_unescape']:
46 | score *= 1.5
47 | elif query.lower() == item['title_unescape'].lower() or query1.lower() == item['title_unescape'].lower():
48 | score *= 1.2
49 | elif item['title'].lower() in query:
50 | score *= 1.1
51 | elif query == core_title or query1 == core_title:
52 | score *= 1.2
53 | elif query.lower() == core_title.lower() or query1.lower() == core_title.lower():
54 | score *= 1.1
55 | elif core_title.lower() in query.lower():
56 | score *= 1.05
57 |
58 | item['_score'] = score
59 | return item
60 |
61 | return list(sorted([score_boost(item, query) for item in results], key=lambda item: -item['_score']))
62 |
63 | def single_text_query(self, query, topn=10, lazy=False, rerank_topn=50):
64 | constructed_query = {
65 | "multi_match": {
66 | "query": query,
67 | "fields": ["title^1.25", "title_unescape^1.25", "text", "title_bigram^1.25", "title_unescape_bigram^1.25", "text_bigram"]
68 | }
69 | }
70 | res = self.client.search(index=WIKIPEDIA_INDEX_NAME, query = constructed_query, size = max(topn, rerank_topn))
71 |
72 | res = [self._extract_one(x, lazy=lazy) for x in res['hits']['hits']]
73 | res = self.rerank_with_query(query, res)[:topn]
74 | # print(res)
75 | res = [{'title': _['title'], 'paragraph_text': _['text']} for _ in res]
76 | return res
77 |
78 | def search(self, question, k=10):
79 | try:
80 | res = self.single_text_query(query = question, topn = k)
81 | return json.dumps(res, ensure_ascii=False)
82 | except Exception as err:
83 | raise
84 | print(Exception, err)
85 |
86 |
87 | @app.route('/', methods=['POST', 'GET'])
88 | @cross_origin()
89 | def main():
90 | global ES
91 | question = request.json['query']
92 | k = int(request.json['k'])
93 | return ES.search(question, k)
94 |
95 |
96 | if __name__ == '__main__':
97 | ES = ElasticSearch()
98 | app.run(host='0.0.0.0', port=1435, threaded = True)
99 |
100 | '''
101 | if __name__=='__main__':
102 | ES = ElasticSearch()
103 | question = "The Hobbit >> part of the series"
104 | contexts = ES.search(question)
105 | print(contexts)
106 | '''
107 |
--------------------------------------------------------------------------------
/src/service/openai/openai_service.py:
--------------------------------------------------------------------------------
1 | import openai, json, os, sys
2 | from flask import Flask, request, jsonify, abort
3 | from datetime import datetime, timedelta, timezone
4 | import openai.error
5 |
6 | app = Flask(__name__)
7 |
8 | key_pool = [
9 | #put your keys here
10 | ]
11 |
12 |
13 | print(*key_pool,sep="\n")
14 |
15 | class Log:
16 | @staticmethod
17 | def time_str():
18 | current = datetime.now(timezone(timedelta(hours=8)))
19 | return current.strftime("%Y-%m-%d %H:%M:%S")
20 |
21 | @staticmethod
22 | def log(file_name, log_type, content):
23 | content = "[%s] %s | %s"%(log_type, Log.time_str(), str(content).replace("\n", "\n "))
24 | with open(file_name, "a") as f:
25 | f.write("%s\n"%content)
26 |
27 | @staticmethod
28 | def message(file_name, content):
29 | return Log.log( file_name, "MSG",content)
30 |
31 | @staticmethod
32 | def error(file_name, content):
33 | return Log.log(file_name,"ERR", content)
34 |
35 | @staticmethod
36 | def warning(file_name, content):
37 | return Log.log(file_name, "WRN", content)
38 |
39 |
40 | # You exceeded your current quota
41 |
42 | current_key = 0
43 | def next_key():
44 | global current_key
45 | current_key += 1
46 | current_key %= len(key_pool)
47 | return key_pool[current_key]
48 |
49 | from collections import deque
50 | from datetime import datetime, timedelta
51 | import threading
52 |
53 | # Deque to store request timestamps
54 | timestamps = deque()
55 |
56 | # Lock for thread safety
57 | lock = threading.Lock()
58 | def update_speed():
59 | now = datetime.now()
60 | with lock:
61 | timestamps.append(now)
62 | # Remove timestamps older than 5 seconds
63 | while timestamps and now - timestamps[0] > timedelta(seconds=5):
64 | timestamps.popleft()
65 | # Calculate the average request rate
66 | rate = len(timestamps) / 5
67 | print("Current request rate in the latest 5 seconds:", rate, "req/s")
68 |
69 | @app.route('/api/openai/freq', methods = ["GET"])
70 | def frep():
71 | now = datetime.now()
72 | with lock:
73 | while timestamps and now - timestamps[0] > timedelta(seconds=5):
74 | timestamps.popleft()
75 | return jsonify({
76 | "message": "This is the request rate (requests/second) in the latest 5 seconds.",
77 | "request_rate": len(timestamps) / 5,
78 | "availabel_keys": len(key_pool)
79 | })
80 |
81 | @app.route('/api/openai/chat-completion', methods = ["POST"])
82 | def openai_chat_completion():
83 | key = next_key()
84 | tt = datetime.now(timezone(timedelta(hours=8)))
85 | day = tt.strftime("%Y-%m-%d")
86 | hour = tt.strftime("%H")
87 | log_dir = "log/%s/%s"%(day, hour)
88 | log_msg_path = os.path.join(log_dir, "chat-completion.log")
89 | log_data_path = os.path.join(log_dir, "chat-completion.jsonl")
90 | os.makedirs(log_dir, exist_ok=True)
91 | try:
92 | resp = openai.ChatCompletion.create(**request.json, api_key=key, timeout=20)
93 | except openai.error.OpenAIError as e:
94 | Log.error(log_msg_path, str(e))
95 | print("[Error] %s"%(str(e)))
96 | if str(e).find("You exceeded your current quota") != -1 or str(e).find("deactivate") != -1:
97 | key_pool.remove(key)
98 | Log.error("log/exceed.log", key)
99 | return abort(500, str(e))
100 | except Exception as e:
101 | Log.error(log_msg_path, str(e))
102 | print("[Error] %s"%(str(e)))
103 | return abort(500, str(e))
104 | Log.message(log_msg_path, "Successful")
105 | with open(log_data_path, "a+") as f:
106 | f.write("%s\n"%(json.dumps({
107 | "request": request.json,
108 | "response": resp
109 | })))
110 | update_speed()
111 | return jsonify(resp)
112 |
113 | @app.route('/api/openai/completion', methods = ["POST"])
114 | def openai_completion():
115 | key = next_key()
116 | tt = datetime.now(timezone(timedelta(hours=8)))
117 | day = tt.strftime("%Y-%m-%d")
118 | hour = tt.strftime("%H")
119 | log_dir = "log/%s/%s"%(day, hour)
120 | log_msg_path = os.path.join(log_dir, "completion.log")
121 | log_data_path = os.path.join(log_dir, "completion.jsonl")
122 | os.makedirs(log_dir, exist_ok=True)
123 | try:
124 | resp = openai.Completion.create(**request.json, api_key=key, timeout=20)
125 | except openai.error.OpenAIError as e:
126 | Log.error(log_msg_path, str(e))
127 | print("[Error] %s"%(str(e)))
128 | if str(e).find("You exceeded your current quota") != -1 or str(e).find("deactivate") != -1:
129 | key_pool.remove(key)
130 | Log.error("log/exceed.log", key)
131 | return abort(500, str(e))
132 | except Exception as e:
133 | Log.error(log_msg_path, str(e))
134 | print("[Error] %s"%(str(e)))
135 | return abort(500, str(e))
136 | Log.message(log_msg_path, "Succeesful")
137 | with open(log_data_path, "a+") as f:
138 | f.write("%s\n"%(json.dumps({
139 | "request": request.json,
140 | "response": resp
141 | })))
142 | update_speed()
143 | return jsonify(resp)
144 |
145 | if __name__ == '__main__':
146 | app.run("0.0.0.0", port=10001)
147 |
--------------------------------------------------------------------------------