├── README.md ├── corpus ├── cot.py ├── kg-to-text.py ├── pa_ablation.py ├── pa_construct_chatgpt.py ├── pa_construct_mistral.py ├── pa_filter.py └── summary.py ├── figs └── 1.png ├── inference ├── closed │ ├── answer │ │ └── answer.py │ └── rewrite │ │ ├── infer_chain.py │ │ ├── infer_pa.py │ │ ├── infer_summary.py │ │ └── infer_text.py └── open │ ├── answer │ ├── answer.py │ └── answer_no.py │ ├── linearize.py │ ├── process_freebase.py │ ├── query_interface.py │ ├── retrieve │ ├── 2hop │ │ ├── 2hop.py │ │ ├── format.py │ │ ├── format │ │ │ ├── GraphQuestions.json │ │ │ └── grailqa.json │ │ ├── query_interface.py │ │ └── sim_compute.py │ └── bm25 │ │ ├── build_index_sparse.sh │ │ ├── format.py │ │ ├── format │ │ ├── GraphQuestions.json │ │ └── grailqa.json │ │ ├── run_search_sparse.sh │ │ └── search.py │ └── rewrite │ ├── infer_chain.py │ ├── infer_pa.py │ ├── infer_summary.py │ └── infer_text.py ├── instruction-tuning ├── build_dataset.py ├── ds_zero2_no_offload.json ├── merge.py ├── run_clm_sft_with_peft-7b.py ├── run_clm_sft_with_peft-8b.py ├── run_dpo-step.sh ├── run_dpo.py ├── run_dpo.sh ├── run_llama-7b.sh └── run_llama-8b.sh ├── requirement1.txt ├── requirement2.txt └── subgraph ├── GraphQuestions ├── gold │ └── test.json ├── graph_query.py ├── query_interface.py └── sparql_utils │ ├── load_kb.py │ ├── misc.py │ ├── sparql_engine.py │ ├── sparql_executor.py │ └── value_class.py ├── gold_graph.py └── grailqa ├── gold └── test.json ├── graph_query.py ├── query_interface.py └── sparql_utils ├── load_kb.py ├── misc.py ├── sparql_engine.py ├── sparql_executor.py └── value_class.py /corpus/cot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import copy 4 | import time 5 | import random 6 | from openai import OpenAI 7 | from tqdm import tqdm 8 | import os 9 | 10 | client=OpenAI(api_key='YOUR KEY') 11 | 12 | interval=200 13 | DATA='GraphQuestions' 14 | # set EX_RATE 15 | if DATA in ['GraphQuestions','WebQSP']: 16 | EX_RATE=1 17 | if DATA in ['grailqa']: 18 | EX_RATE=0.5 19 | 20 | train=json.load(open('../subgraph/'+DATA+'/graph/train.json','r',encoding='utf-8')) 21 | 22 | os.makedirs(DATA+'/finetune/'+DATA+'/CoT/train/',exist_ok=True) 23 | os.makedirs(DATA+'/finetune/'+DATA+'/CoT/middle/',exist_ok=True) 24 | 25 | kr_prompt='''Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 26 | Triples: (Oxybutynin Oral, medicine.routed_drug.route_of_administration, Oral administration) (Oxybutynin Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin Chloride Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin chloride 5 extended release film coated tablet, medicine.drug_formulation.formulation_of, Oxybutynin) 27 | Question: oxybutynin chloride 5 extended release film coated tablet is the ingredients of what routed drug? 28 | Reason 1: I need to know which routed drug contains oxybutynin chloride 5 extended release film coated tablet. 29 | Knowledge 1: "Oxybutynin Chloride Oral" is a type of routed drug and "Oxybutynin chloride 5 extended release film coated tablet" is one of the marketed formulations of "Oxybutynin Chloride Oral". 30 | 31 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 32 | Triples: (Google, organization.organization.founders, Sergey Brin) (Sergey Brin, people.person.education, CVT1) (CVT1, education.education.institution, University of Maryland, College Park) (Google, organization.organization.founders, Larry Page) (Larry Page, people.person.education, CVT1) (CVT1, education.education.institution, University of Michigan) (CVT1, education.education.institution, Stanford University) 33 | Question: where did the founder of google go to college? 34 | Reason 1: I need to know who is the founder of google. 35 | Knowledge 1: Sergey Brin and Larry Page is the founder of google. 36 | Reason 2: I need to know Sergey Brin and Larry Page go to which the university. 37 | Knowledge 2: Sergey Brin studied at the University of Maryland, College Park and Stanford University. Larry Page studied at the University of Michigan and Stanford University. 38 | 39 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 40 | Triples: (Rock music, music.genre.artists, Outkast) (Rock music, music.genre.parent_genre, Folk music) (Rock music, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, Bright Eyes) (Electronica, music.genre.parent_genre, House music) (Electronica, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, t.A.T.u.) 41 | Question: the albums confessions tour is part of what parent genre of a musical genre? 42 | Reason 1: I need to know the musical genre of the albums confessions tour. 43 | Knowledge 1: The album confessions tour is associated with the genre Rock music and Electronica. 44 | Reason 2: I need to know the parent genre of Rock music and Electronica. 45 | Knowledge 2: The parent genre of Rock music is Folk music. The parent genre of Electronica is House music. 46 | 47 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 48 | Triples: {triple} 49 | Question: {ques} 50 | ''' 51 | 52 | kr_prompt1='''Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 53 | Triples: {triple} 54 | Question: {ques} 55 | ''' 56 | 57 | ans_prompt='''Below are the facts that might be relevant to answer the question: 58 | {knowledge} 59 | Question: {ques} 60 | Answer:''' 61 | 62 | num_dict = { 63 | '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', 64 | '5': 'five', '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine' 65 | } 66 | 67 | def getResponse(prompt,max_retries=10): 68 | # set retries 69 | retries=0 70 | while retries < max_retries: 71 | try: 72 | res = client.chat.completions.create( 73 | model='gpt-3.5-turbo', 74 | #model='gpt-4', 75 | messages=[ 76 | {'role': 'user', 'content': prompt} 77 | ], 78 | temperature=0, 79 | ) 80 | return res.choices[0].message.content 81 | except Exception as e: 82 | print(f"An error occurred: {e}") 83 | print("Retrying in 1 minutes...") 84 | retries += 1 85 | time.sleep(60) 86 | return '' 87 | 88 | data=[] 89 | resume=0 90 | #data=json.load(open('train-'+str(resume)+'.json','r',encoding='utf-8')) 91 | index=resume 92 | for sample in tqdm(train[resume:]): 93 | index+=1 94 | # gold graph 95 | gold_g=set() 96 | for i in sample['restrict_graph']: 97 | for j in i: 98 | temp='('+j[0]+', '+j[1]+', '+j[2]+')' 99 | gold_g.add(temp) 100 | # shuffle gold graph 101 | gold_g=list(gold_g) 102 | random.shuffle(gold_g) 103 | 104 | # extend graph 105 | extend=set() 106 | for i in sample["ex_graph"]: 107 | for j in i: 108 | temp='('+j[0]+', '+j[1]+', '+j[2]+')' 109 | if temp not in gold_g: 110 | extend.add(temp) 111 | extend=list(extend) 112 | random.shuffle(extend) 113 | 114 | # extend number filter 115 | ex_filter=set() 116 | NUM=math.ceil(len(gold_g)*EX_RATE) 117 | # first use no CVT triple 118 | for i in extend: 119 | if 'CVT' not in i: 120 | ex_filter.add(i) 121 | if len(ex_filter)==NUM: 122 | break 123 | # add CVT triple 124 | if len(ex_filter)num2: 45 | temp=dict() 46 | temp['prompt']=kr_prompt.format(triple=sample["noisy"],ques=sample["question"]) 47 | temp['chosen']=sample["output_list"][0] 48 | temp["rejected"]=sample["output_list"][1] 49 | ablation.append(temp) 50 | if num1 threshold] 22 | 23 | if len(repeated_words)>threshold: 24 | return True 25 | else: 26 | return False 27 | 28 | return repeated_words 29 | 30 | # filter too long sequence 31 | data1=[] 32 | num=0 33 | for sample in tqdm(data): 34 | 35 | p_l=len(tokenizer(sample['prompt'],return_tensors="pt")["input_ids"][0]) 36 | c_l=len(tokenizer(sample['chosen'],return_tensors="pt")["input_ids"][0]) 37 | r_l=len(tokenizer(sample['rejected'],return_tensors="pt")["input_ids"][0]) 38 | t_l=max(c_l,r_l) 39 | s_l=p_l+t_l 40 | 41 | if p_l>1024 or c_l>512 or r_l>1024 or s_l>2048: 42 | continue 43 | 44 | chosen=sample['chosen'].strip().split('\n') 45 | if len(chosen)>6: 46 | continue 47 | FLAG=True 48 | for index,line in enumerate(chosen): 49 | if index%2==0: 50 | if not line.startswith('Reason '): 51 | FLAG=False 52 | break 53 | else: 54 | if not line.startswith('Knowledge '): 55 | FLAG=False 56 | break 57 | if FLAG: 58 | data1.append(sample) 59 | 60 | #if detect_repeated_text(sample['chosen']): 61 | # print(sample) 62 | 63 | print(len(data1)) 64 | 65 | # divide into train and dev 66 | random.shuffle(data1) 67 | train_num=int(len(data1)*0.9) 68 | json.dump(data1[:train_num],open(DATA+'/PA-Mistral/CoT/'+LLM+'/train.json','w',encoding='utf-8'),ensure_ascii=False,indent=2) 69 | json.dump(data1[train_num:],open(DATA+'/PA-Mistral/CoT/'+LLM+'/dev.json','w',encoding='utf-8'),ensure_ascii=False,indent=2) 70 | -------------------------------------------------------------------------------- /corpus/summary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import copy 4 | import time 5 | import random 6 | from openai import OpenAI 7 | from tqdm import tqdm 8 | import os 9 | 10 | interval=500 11 | DATA='GraphQuestions' 12 | # set EX_RATE 13 | if DATA in ['GraphQuestions','WebQSP']: 14 | EX_RATE=1 15 | if DATA in ['grailqa']: 16 | EX_RATE=0.5 17 | 18 | # set client 19 | client=OpenAI(api_key='YOUR KEY') 20 | 21 | train=json.load(open('../subgraph/'+DATA+'/graph/train.json','r',encoding='utf-8')) 22 | 23 | os.makedirs(DATA+'/finetune/'+DATA+'/summary/train/',exist_ok=True) 24 | os.makedirs(DATA+'/finetune/'+DATA+'/summary/middle/',exist_ok=True) 25 | 26 | kr_prompt='''Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 27 | Triples: (Oxybutynin Oral, medicine.routed_drug.route_of_administration, Oral administration) (Oxybutynin Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin Chloride Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin chloride 5 extended release film coated tablet, medicine.drug_formulation.formulation_of, Oxybutynin) 28 | Question: oxybutynin chloride 5 extended release film coated tablet is the ingredients of what routed drug? 29 | Knowledge: "Oxybutynin Chloride Oral" is a type of routed drug and "Oxybutynin chloride 5 extended release film coated tablet" is one of the marketed formulations of "Oxybutynin Chloride Oral". 30 | 31 | Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 32 | Triples: (Google, organization.organization.founders, Sergey Brin) (Sergey Brin, people.person.education, CVT1) (CVT1, education.education.institution, University of Maryland, College Park) (Google, organization.organization.founders, Larry Page) (Larry Page, people.person.education, CVT2) (CVT2, education.education.institution, University of Michigan) (CVT2, education.education.institution, Stanford University) 33 | Question: where did the founder of google go to college? 34 | Knowledge: The founders of Google are Sergey Brin and Larry Page. Sergey Brin attended the University of Maryland, College Park for his education, while Larry Page attended both the University of Michigan and Stanford University. 35 | 36 | Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 37 | Triples: (Rock music, music.genre.artists, Outkast) (Rock music, music.genre.parent_genre, Folk music) (Rock music, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, Bright Eyes) (Electronica, music.genre.parent_genre, House music) (Electronica, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, t.A.T.u.) 38 | Question: the albums confessions tour is part of what parent genre of a musical genre? 39 | Knowledge: The album "The Confessions Tour" is associated with both the Rock music and Electronica genres. Rock music is the parent genre of Folk music, while Electronica is the parent genre of House music. 40 | 41 | Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 42 | Triples: {triple} 43 | Question: {ques} 44 | Knowledge: ''' 45 | 46 | kr_prompt1='''Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 47 | Triples: {triple} 48 | Question: {ques} 49 | Knowledge: ''' 50 | 51 | ans_prompt='''Below are the facts that might be relevant to answer the question: 52 | {knowledge} 53 | Question: {ques} 54 | Answer:''' 55 | 56 | num_dict = { 57 | '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', 58 | '5': 'five', '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine' 59 | } 60 | 61 | def getResponse(prompt,max_retries=10): 62 | # set retries 63 | retries=0 64 | while retries < max_retries: 65 | try: 66 | res = client.chat.completions.create( 67 | model='gpt-3.5-turbo', 68 | #model='gpt-4', 69 | messages=[ 70 | {'role': 'user', 'content': prompt} 71 | ], 72 | temperature=0, 73 | ) 74 | return res.choices[0].message.content 75 | except Exception as e: 76 | print(f"An error occurred: {e}") 77 | print("Retrying in 1 minutes...") 78 | retries += 1 79 | time.sleep(60) 80 | return '' 81 | 82 | data=[] 83 | resume=0 84 | #data=json.load(open('train-'+str(resume)+'.json','r',encoding='utf-8')) 85 | index=resume 86 | for sample in tqdm(train[resume:]): 87 | index+=1 88 | if index%interval==0: 89 | json.dump(data,open(DATA+'/finetune/'+DATA+'/summary/middle/all-'+str(index)+'.json','w',encoding='utf-8'),indent=2,ensure_ascii=False) 90 | # gold graph 91 | gold_g=set() 92 | for i in sample['restrict_graph']: 93 | for j in i: 94 | temp='('+j[0]+', '+j[1]+', '+j[2]+')' 95 | gold_g.add(temp) 96 | # shuffle gold graph 97 | gold_g=list(gold_g) 98 | random.shuffle(gold_g) 99 | 100 | # extend graph 101 | extend=set() 102 | for i in sample["ex_graph"]: 103 | for j in i: 104 | temp='('+j[0]+', '+j[1]+', '+j[2]+')' 105 | if temp not in gold_g: 106 | extend.add(temp) 107 | extend=list(extend) 108 | random.shuffle(extend) 109 | 110 | # extend number filter 111 | ex_filter=set() 112 | NUM=math.ceil(len(gold_g)*EX_RATE) 113 | # first use no CVT triple 114 | for i in extend: 115 | if 'CVT' not in i: 116 | ex_filter.add(i) 117 | if len(ex_filter)==NUM: 118 | break 119 | # add CVT triple 120 | if len(ex_filter)=K: 149 | break 150 | # avoid redundant triples 151 | triples1=[] 152 | for i in triples[:K]: 153 | if i not in triples1: 154 | triples1.append(i) 155 | contents=' '.join(triples1[:50]) 156 | # calculate retrieve metrics 157 | FLAG=False 158 | temp_r=0 159 | for a in answer: 160 | if a.lower() in contents.lower(): 161 | FLAG=True 162 | temp_r+=1 163 | if FLAG: 164 | accuracy+=1 165 | recall+=temp_r/len(answer) 166 | graphdict=dict() 167 | graphdict['question']=question 168 | graphdict['triples']=triples1 169 | graphdict['answers']=answer 170 | retrieve_subgraph.append(graphdict) 171 | print('*'*30,'Current Retrieve Results','*'*30) 172 | print('Accuracy:',accuracy/(index+1)) 173 | print('Recall:',recall/(index+1)) 174 | 175 | # save retrieve results 176 | os.makedirs('results', exist_ok=True) 177 | json.dump(retrieve_subgraph,open('results/'+DATA+'.json','w',encoding='utf-8'),indent=2,ensure_ascii=False) 178 | print('*'*30,'Retrieve Results','*'*30) 179 | print('Accuracy:',accuracy/len(dataset)) 180 | print('Recall:',recall/len(dataset)) -------------------------------------------------------------------------------- /inference/open/retrieve/2hop/format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | 5 | # grailqa, GraphQuestions, WebQSP 6 | DATA='grailqa' 7 | NUM=30 8 | 9 | def has_digit(input_string): 10 | for char in input_string: 11 | if char.isdigit(): 12 | return True 13 | return False 14 | 15 | with open('results/'+DATA+'.json', 'r') as rf: 16 | documents = json.load(rf) 17 | 18 | accuracy=0 19 | recall=0 20 | result=[] 21 | for doc in documents: 22 | question=doc["question"] 23 | answer=doc["answers"] 24 | # avoid redundant triples 25 | triplelist=doc["triples"] 26 | triplelist1=[] 27 | for i in triplelist: 28 | if len(i)>100: 29 | continue 30 | if len(i.split(', '))<3: 31 | continue 32 | rel=i.split(', ')[1] 33 | # skip relations 34 | if rel.startswith('common') or rel.startswith('type.object') or rel.startswith('freebase') or rel.endswith('type') or rel.endswith('label'): 35 | continue 36 | #print(i.split(', ')) 37 | # skip triples with too long object 38 | if i not in triplelist1: 39 | triplelist1.append(i) 40 | # construct triple string 41 | triples=' '.join(triplelist1[:NUM]) 42 | # convert mid to cvt 43 | candidate=re.findall(r'm\.[\da-zA-Z_]+', triples) 44 | candidate.extend(re.findall(r'g\.[\da-zA-Z_]+', triples)) 45 | candidate.extend(re.findall(r'n\.[\da-zA-Z_]+', triples)) 46 | cvtmid=[] 47 | for i in candidate: 48 | if has_digit(i): 49 | if i not in cvtmid: 50 | cvtmid.append(i) 51 | cvt_num=1 52 | for i in cvtmid: 53 | triples=triples.replace(i,'CVT'+str(cvt_num)) 54 | cvt_num+=1 55 | samdict=dict() 56 | samdict["question"]=question 57 | samdict["answer"]=answer 58 | samdict["triples"]=triples 59 | result.append(samdict) 60 | FLAG=False 61 | r=0 62 | for i in answer: 63 | if i in triples: 64 | FLAG=True 65 | r+=1 66 | if FLAG: 67 | accuracy+=1 68 | recall+=r/len(answer) 69 | 70 | json.dump(result,open('format/'+DATA+'.json','w',encoding='utf-8'),indent=2,ensure_ascii=False) 71 | print('Accuracy:',accuracy/len(result)) 72 | print('Recall:',recall/len(result)) 73 | -------------------------------------------------------------------------------- /inference/open/retrieve/2hop/sim_compute.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | from sentence_transformers import SentenceTransformer, util 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | class Similarity: 8 | def __init__(self): 9 | # Load model 10 | self.model = SentenceTransformer('../../../../../pretrain/all-MiniLM-L6-v2',device='cuda:0') 11 | 12 | def compute(self, query, relations): 13 | embedding1 = self.model.encode(query, show_progress_bar=False,device='cuda:0',convert_to_tensor=True) 14 | embedding2 = self.model.encode(relations,batch_size=1024,show_progress_bar=False, device='cuda:0',convert_to_tensor=True) 15 | cosine_scores = util.pytorch_cos_sim(embedding1, embedding2)[0] 16 | sim_relations = list(zip(cosine_scores.tolist(), relations)) 17 | sim_relations = sorted(sim_relations, key=lambda x: x[0], reverse=True) 18 | sorted_relations = [relation for _, relation in sim_relations] 19 | return sorted_relations -------------------------------------------------------------------------------- /inference/open/retrieve/bm25/build_index_sparse.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Build the index for the general knowledge base using pyserini. 4 | 5 | Freebase="../../Freebase/processed" 6 | 7 | python -m pyserini.index.lucene \ 8 | --collection JsonCollection \ 9 | --input ../../Freebase/processed/document \ 10 | --index index \ 11 | --generator DefaultLuceneDocumentGenerator \ 12 | --threads 10 \ 13 | --storePositions --storeDocvectors --storeRaw 14 | -------------------------------------------------------------------------------- /inference/open/retrieve/bm25/format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | 5 | DATA='grailqa' 6 | NUM=20 7 | 8 | def has_digit(input_string): 9 | for char in input_string: 10 | if char.isdigit(): 11 | return True 12 | return False 13 | 14 | with open('../../../../subgraph/'+DATA+'/data/test.json', 'r') as rf: 15 | data = json.load(rf) 16 | 17 | with open('results/'+DATA+'.json', 'r') as rf: 18 | documents = json.load(rf) 19 | 20 | result=[] 21 | for sample,doc in zip(data,documents): 22 | if DATA in ['WebQSP']: 23 | question=sample["question"] 24 | answer=sample["answername"].split('|') 25 | if DATA in ['GraphQuestions','grailqa']: 26 | question=sample["question"] 27 | answer=[] 28 | for i in sample["answer"]: 29 | if i.get("entity_name"): 30 | answer.append(i["entity_name"]) 31 | else: 32 | answer.append(i["answer_argument"]) 33 | doclist=doc["documents"] 34 | triple_str='' 35 | for d in doclist: 36 | triple_str=triple_str+d["triples"]+' ' 37 | triple_str=triple_str[:-1] 38 | # avoid redundant triples 39 | triplelist=triple_str.split(') (') 40 | triplelist[0]=triplelist[0][1:] 41 | triplelist[-1]=triplelist[-1][:-1] 42 | triplelist1=[] 43 | for i in triplelist: 44 | if len(i)>100: 45 | continue 46 | if len(i.split(', '))<3: 47 | continue 48 | rel=i.split(', ')[1] 49 | # skip relations 50 | if rel.startswith('common') or rel.startswith('type.object') or rel.startswith('freebase') or rel.endswith('type') or rel.endswith('label'): 51 | continue 52 | #print(i.split(', ')) 53 | # skip triples with too long object 54 | if i not in triplelist1: 55 | triplelist1.append(i) 56 | # construct triple string 57 | triples='' 58 | for i in triplelist1[:NUM]: 59 | triples=triples+'('+i+') ' 60 | triples=triples[:-1] 61 | # convert mid to cvt 62 | candidate=re.findall(r'm\.[\da-zA-Z_]+', triples) 63 | candidate.extend(re.findall(r'g\.[\da-zA-Z_]+', triples)) 64 | candidate.extend(re.findall(r'n\.[\da-zA-Z_]+', triples)) 65 | cvtmid=[] 66 | cvt_num=1 67 | for i in candidate: 68 | if has_digit(i): 69 | if i not in cvtmid: 70 | cvtmid.append(i) 71 | for i in cvtmid: 72 | triples=triples.replace(i,'CVT'+str(cvt_num)) 73 | cvt_num+=1 74 | samdict=dict() 75 | samdict["question"]=question 76 | samdict["answer"]=answer 77 | samdict["triples"]=triples 78 | result.append(samdict) 79 | 80 | json.dump(result,open('format/'+DATA+'.json','w',encoding='utf-8'),indent=2,ensure_ascii=False) -------------------------------------------------------------------------------- /inference/open/retrieve/bm25/run_search_sparse.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | index_name="index_bm25" 4 | dataset='grailqa' 5 | output="results/${dataset}.json" 6 | 7 | python search.py \ 8 | --dataset ${dataset} \ 9 | --query_data_path ../../../../subgraph/${dataset}/data/test.json \ 10 | --index_name ${index_name} \ 11 | --output ${output} \ 12 | --top_k 100 \ 13 | --k1 0.4 \ 14 | --b 0.4 \ 15 | --num_process 10 \ 16 | --eval \ 17 | --save 18 | -------------------------------------------------------------------------------- /inference/open/retrieve/bm25/search.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | from pyserini.search.lucene import LuceneSearcher 5 | import json 6 | from tqdm import tqdm 7 | import os 8 | import re 9 | import argparse 10 | import pickle 11 | import multiprocessing.pool 12 | from functools import partial 13 | from collections import defaultdict 14 | from pyserini.index import IndexReader 15 | 16 | def has_digit(input_string): 17 | for char in input_string: 18 | if char.isdigit(): 19 | return True 20 | return False 21 | 22 | class Bm25Searcher: 23 | def __init__(self, index_dir, args): 24 | self.index_dir = index_dir 25 | self.args = args 26 | self.searcher = LuceneSearcher(index_dir) 27 | self.searcher.set_bm25(args.k1, args.b) 28 | self.index_reader=IndexReader(index_dir) 29 | if len(args.ignore_string) > 0: 30 | self.ignore_list = args.ignore_string.split(',') 31 | print(f'ignore list: {self.ignore_list}') 32 | else: 33 | self.ignore_list = [] 34 | 35 | # load documents for post process 36 | ''' 37 | self.mid2doc=dict() 38 | for path in tqdm(os.listdir(self.args.documents)): 39 | with open(self.args.documents+'/'+path,'r',encoding='utf-8') as f: 40 | for line in f: 41 | try: 42 | data=json.loads(line) 43 | if self.mid2doc.get(data["mid"]) is None: 44 | self.mid2doc[data["mid"]]=[] 45 | self.mid2doc[data["mid"]].append(data["id"]) 46 | except: 47 | continue 48 | with open('mid2doc.pickle', 'wb') as f: 49 | pickle.dump(self.mid2doc, f) 50 | ''' 51 | with open('mid2doc.pickle', 'rb') as f: 52 | self.mid2doc = pickle.load(f) 53 | 54 | def perform_search(self, sample, top_k, ques_id): 55 | if self.args.dataset in ['WebQSP']: 56 | query=sample["question"] 57 | head=set() 58 | head.add(sample["headmid"]) 59 | if self.args.dataset in ['GraphQuestions','grailqa']: 60 | query=sample["question"] 61 | head=set() 62 | for n in sample["graph_query"]["nodes"]: 63 | if n["node_type"]=="entity": 64 | head.add(n["id"]) 65 | for string in self.ignore_list: 66 | query = query.replace(string, ' ') 67 | query = query.strip() 68 | 69 | # get relevant document for head entity 70 | docid=[] 71 | for i in head: 72 | if self.mid2doc.get(i): 73 | docid.extend(self.mid2doc[i]) 74 | # first search using relevant document 75 | id_score=[] 76 | for i in docid: 77 | score = self.index_reader.compute_query_document_score(i, query) 78 | id_score.append([score,i]) 79 | id_score=sorted(id_score,key=lambda x: x[0], reverse=True) 80 | documents=[] 81 | for i in id_score[:top_k]: 82 | raw_data=self.searcher.doc(i[1]) 83 | documents.append(json.loads(raw_data.raw())) 84 | 85 | # search 86 | if len(documents)=top_k: 94 | break 95 | except: 96 | continue 97 | context = dict() 98 | context['documents']=documents 99 | context['id']=ques_id 100 | return context 101 | 102 | def search_all(process_idx, num_process, searcher, args): 103 | # load dataset 104 | with open(args.query_data_path, 'r') as rf: 105 | data = json.load(rf) 106 | 107 | output_data = [] 108 | for i, data_i in tqdm(enumerate(data)): 109 | if i % num_process != process_idx: 110 | continue 111 | # search 112 | output_i = searcher.perform_search(data_i, args.top_k,i) 113 | output_data.append(output_i) 114 | return output_data 115 | 116 | def eval_top_k_one(documents, answer,top_k): 117 | recall = 0 118 | # merge into context 119 | context='' 120 | for doc in documents['documents'][:top_k]: 121 | context+=doc['triples'] 122 | for ans in answer: 123 | if ans.lower() in context.lower(): 124 | recall += 1 125 | return recall / (len(answer) + 1e-8) 126 | 127 | def eval_top_k(output_data, answers,top_k_list=[1,2,3,4,5,6,7,8,9,10]): 128 | print("*"*30,"Evaluate the Retrieval Result","*"*30) 129 | hits_dict = defaultdict(int) 130 | recall_dict = defaultdict(float) 131 | top_k_list = [k for k in top_k_list if k <= len(output_data[0]['documents'])] 132 | for documents,answer in tqdm(zip(output_data,answers)): 133 | for k in top_k_list: 134 | recall = eval_top_k_one(documents, answer,k) 135 | if recall > 0: 136 | hits_dict[k] += 1 137 | recall_dict[k] += recall 138 | for k in top_k_list: 139 | print("Top {}".format(k), 140 | "Hits: ", round(hits_dict[k] * 100 / len(output_data), 1), 141 | "Recall: ", round(recall_dict[k] * 100 / len(output_data), 1)) 142 | 143 | # argparse for root_dir, index_dir, query_data_path, output_dir 144 | parser = argparse.ArgumentParser(description='Search using pySerini') 145 | parser.add_argument("--dataset", type=str, default='WebQSP', 146 | help="KBQA dataset") 147 | parser.add_argument("--documents", type=str, default='../../Freebase/processed/document', 148 | help="documents directory") 149 | parser.add_argument("--index_name", type=str, default='Wikidata', 150 | help="directory to store the search index") 151 | parser.add_argument("--query_data_path", type=str, default='', 152 | help="directory to store the queries") 153 | parser.add_argument("--output", type=str, default='', 154 | help="directory to store the retrieved output") 155 | parser.add_argument("--num_process", type=int, default=10, 156 | help="number of processes to use for multi-threading") 157 | parser.add_argument("--top_k", type=int, default=150, 158 | help="number of passages to be retrieved for each query") 159 | parser.add_argument("--ignore_string", type=str, default="", 160 | help="string to ignore in the query, split by comma") 161 | parser.add_argument("--b", type=float, default=0.4, 162 | help="parameter of BM25") 163 | parser.add_argument("--k1", type=float, default=0.9, 164 | help="parameter of BM25") 165 | parser.add_argument("--save", action="store_true", 166 | help="whether to save the output") 167 | parser.add_argument("--eval", action="store_true", 168 | help="whether to evaluate the output") 169 | args = parser.parse_args() 170 | 171 | 172 | if __name__ == '__main__': 173 | index_dir = args.index_name 174 | searcher = Bm25Searcher(index_dir, args) 175 | 176 | num_process = args.num_process 177 | pool = multiprocessing.pool.ThreadPool(processes=num_process) 178 | sampleData = [x for x in range(num_process)] 179 | search_all_part = partial(search_all, 180 | searcher = searcher, 181 | num_process = num_process, 182 | args = args) 183 | results = pool.map(search_all_part, sampleData) 184 | pool.close() 185 | 186 | output_data = [] 187 | for result in results: 188 | output_data += result 189 | 190 | # sort the output data by question id 191 | output_data = sorted(output_data, key=lambda item: item['id']) 192 | if args.eval: 193 | # load answer from original data 194 | answers=[] 195 | with open(args.query_data_path, 'r') as rf: 196 | dataset = json.load(rf) 197 | if args.dataset in ['WebQSP']: 198 | for sample in dataset: 199 | answers.append(sample["answername"].split('|')) 200 | if args.dataset in ['GraphQuestions','grailqa']: 201 | for sample in dataset: 202 | answer=[] 203 | for i in sample["answer"]: 204 | if i.get("entity_name"): 205 | answer.append(i["entity_name"]) 206 | else: 207 | answer.append(i["answer_argument"]) 208 | answers.append(answer) 209 | # evaluate output 210 | eval_top_k(output_data, answers, top_k_list=[1,2,3,4,5,6,7,8,9,10,20,30,40,50,60,70,80,90,100]) 211 | 212 | # truncate documents into 10 documents 213 | for i in output_data: 214 | i['documents'] = i['documents'][:10] 215 | 216 | # save output data 217 | # create output dir recursively if not exist 218 | if args.save: 219 | os.makedirs('results', exist_ok=True) 220 | print("saving output data to {}".format(args.output)) 221 | with open(args.output, "w") as wf: 222 | json.dump(output_data, wf, indent=2, ensure_ascii=False) 223 | -------------------------------------------------------------------------------- /inference/open/rewrite/infer_chain.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import json 4 | import random 5 | from tqdm import tqdm 6 | from transformers import GenerationConfig,AutoModelForCausalLM,AutoTokenizer,AutoModel 7 | import torch 8 | from peft import PeftModel 9 | import sys 10 | import openai 11 | import time 12 | from openai import OpenAI 13 | 14 | # generation config 15 | generation_config = GenerationConfig( 16 | temperature=0.01, 17 | top_k=40, 18 | top_p=0.9, 19 | do_sample=True, 20 | num_beams=1, 21 | repetition_penalty=1.1, 22 | max_new_tokens=1024 23 | ) 24 | 25 | # dataset: grailqa, GraphQuestions 26 | DATA='grailqa' 27 | # llm: llama-2-7b-chat-hf, Meta-Llama-3-8B-Instruct, chatgpt 28 | LLM='chatgpt' 29 | # retrieve method: bm25, 2hop 30 | MODE='2hop' 31 | 32 | # set client 33 | client=OpenAI(api_key='YOUR KEY') 34 | 35 | test=json.load(open('../retrieve/'+MODE+'/format/'+DATA+'.json','r',encoding='utf-8')) 36 | 37 | kr_prompt_llm='''Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 38 | Triples: {triple} 39 | Question: {ques} 40 | ''' 41 | 42 | kr_prompt_gpt='''Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 43 | Triples: (Oxybutynin Oral, medicine.routed_drug.route_of_administration, Oral administration) (Oxybutynin Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin Chloride Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin chloride 5 extended release film coated tablet, medicine.drug_formulation.formulation_of, Oxybutynin) 44 | Question: oxybutynin chloride 5 extended release film coated tablet is the ingredients of what routed drug? 45 | Reason 1: I need to know which routed drug has the marketed formulation of oxybutynin chloride 5 extended release film coated tablet. 46 | Knowledge 1: The routed drugs Oxybutynin Oral and Oxybutynin Chloride Oral have the marketed formulation of oxybutynin chloride 5 extended release film coated tablet. 47 | 48 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 49 | Triples: (Google, organization.organization.founders, Sergey Brin) (Sergey Brin, people.person.education, CVT1) (CVT1, education.education.institution, University of Maryland, College Park) (Google, organization.organization.founders, Larry Page) (Larry Page, people.person.education, CVT2) (CVT2, education.education.institution, University of Michigan) (CVT2, education.education.institution, Stanford University) 50 | Question: where did the founder of google go to college? 51 | Reason 1: I need to know who the founders of Google are. 52 | Knowledge 1: The founders of Google are Sergey Brin and Larry Page. 53 | Reason 2: I need to know where Sergey Brin and Larry Page went to college. 54 | Knowledge 2: Sergey Brin went to the University of Maryland, College Park for college. Larry Page went to the University of Michigan and Stanford University for college. 55 | 56 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 57 | Triples: (Rock music, music.genre.artists, Outkast) (Rock music, music.genre.parent_genre, Folk music) (Rock music, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, Bright Eyes) (Electronica, music.genre.parent_genre, House music) (Electronica, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, t.A.T.u.) 58 | Question: the albums confessions tour is part of what parent genre of a musical genre? 59 | Reason 1: I need to know the musical genre of the albums confessions tour. 60 | Knowledge 1: The album confessions tour is associated with the genre Rock music and Electronica. 61 | Reason 2: I need to know the parent genre of Rock music and Electronica. 62 | Knowledge 2: The parent genre of Rock music is Folk music. The parent genre of Electronica is House music. 63 | 64 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 65 | Triples: {triple} 66 | Question: {ques} 67 | ''' 68 | 69 | num_dict = { 70 | '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', 71 | '5': 'five', '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', 72 | '10': 'ten', '11': 'eleven', '12': 'twelve', '13': 'thirteen', 73 | '14': 'fourteen', '15': 'fifteen', '16': 'sixteen', '17': 'seventeen', 74 | '18': 'eighteen', '19': 'nineteen', '20': 'twenty' 75 | } 76 | 77 | if LLM!='chatgpt': 78 | # path for LLM 79 | LLM_PATH='../../../../pretrain/'+LLM 80 | # path for tokenizer 81 | TOKENIZER_PATH='../../../../pretrain/'+LLM 82 | # path for lora 83 | PEFT_PATH='../../../instruction-tuning/output-'+DATA+'/CoT/'+LLM+'/best_model' 84 | # load tokenizer and llm 85 | tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_PATH) 86 | llm=AutoModelForCausalLM.from_pretrained(LLM_PATH,torch_dtype=torch.float16,low_cpu_mem_usage=True,device_map='cuda:0') 87 | # merge peft into base LLM 88 | if PEFT_PATH: 89 | llm=PeftModel.from_pretrained(llm, PEFT_PATH,torch_dtype=torch.float16,device_map='cuda:0') 90 | 91 | # result 92 | result='result/'+DATA+'/'+MODE+'/'+LLM+'/chain.json' 93 | os.makedirs('result/'+DATA+'/'+MODE+'/'+LLM,exist_ok = True) 94 | log_file='log/'+DATA+'/'+MODE+'/'+LLM+'/chain.log' 95 | os.makedirs('log/'+DATA+'/'+MODE+'/'+LLM,exist_ok = True) 96 | 97 | # redirect output to log 98 | sys.stdout = open(log_file, 'w') 99 | 100 | def getResponse(prompt,max_retries=10): 101 | # set retries 102 | retries=0 103 | while retries < max_retries: 104 | try: 105 | res = client.chat.completions.create( 106 | model='gpt-3.5-turbo', 107 | #model='gpt-4', 108 | messages=[ 109 | {'role': 'user', 'content': prompt} 110 | ], 111 | temperature=0, 112 | ) 113 | return res.choices[0].message.content 114 | except Exception as e: 115 | print(f"An error occurred: {e}") 116 | print("Retrying in 1 minutes...") 117 | retries += 1 118 | time.sleep(60) 119 | return '' 120 | 121 | def LLMResponse(prompt,llm,tokenizer,cuda): 122 | inputs = tokenizer(prompt,return_tensors="pt") 123 | generation_output = llm.generate( 124 | input_ids=inputs["input_ids"].to(cuda), 125 | attention_mask=inputs['attention_mask'].to(cuda), 126 | eos_token_id=tokenizer.eos_token_id, 127 | pad_token_id=tokenizer.eos_token_id, 128 | generation_config=generation_config 129 | ) 130 | output = tokenizer.decode(generation_output[0],skip_special_tokens=True) 131 | response = output.split(prompt)[-1].strip() 132 | return response 133 | 134 | data=[] 135 | for sample in tqdm(test): 136 | 137 | # knowledge rewriter 138 | if len(sample["triples"])!=0: 139 | if LLM!='chatgpt': 140 | knowledge=LLMResponse(kr_prompt_llm.format(triple=sample["triples"],ques=sample["question"]),llm,tokenizer,'cuda:0') 141 | print(kr_prompt_llm.format(triple=sample["triples"],ques=sample["question"])) 142 | print(knowledge) 143 | else: 144 | knowledge=getResponse(kr_prompt_gpt.format(triple=sample["triples"],ques=sample["question"])) 145 | print(kr_prompt_gpt.format(triple=sample["triples"],ques=sample["question"])) 146 | print(knowledge) 147 | else: 148 | knowledge='' 149 | 150 | # record 151 | temp=dict() 152 | temp['question']=sample['question'] 153 | temp['answer']=sample["answer"] 154 | temp['graph']=sample["triples"] 155 | temp['knowledge']=knowledge 156 | data.append(temp) 157 | 158 | json.dump(data,open(result,'w',encoding='utf-8'),indent=2,ensure_ascii=False) -------------------------------------------------------------------------------- /inference/open/rewrite/infer_pa.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import json 4 | import random 5 | from tqdm import tqdm 6 | from transformers import GenerationConfig,AutoModelForCausalLM,AutoTokenizer,AutoModel 7 | import torch 8 | from peft import PeftModel 9 | import sys 10 | import openai 11 | import time 12 | from openai import OpenAI 13 | 14 | # generation config 15 | generation_config = GenerationConfig( 16 | temperature=0.01, 17 | top_k=40, 18 | top_p=0.9, 19 | do_sample=True, 20 | num_beams=1, 21 | repetition_penalty=1.1, 22 | max_new_tokens=1024 23 | ) 24 | 25 | # dataset: grailqa, GraphQuestions 26 | DATA='grailqa' 27 | # llm: llama-2-7b-chat-hf, Meta-Llama-3-8B-Instruct 28 | LLM='Meta-Llama-3-8B-Instruct' 29 | # retrieve method: bm25, 2hop 30 | MODE='2hop' 31 | 32 | # set client 33 | client=OpenAI(api_key='YOUR KEY') 34 | 35 | test=json.load(open('../retrieve/'+MODE+'/format/'+DATA+'.json','r',encoding='utf-8')) 36 | 37 | kr_prompt_llm='''Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 38 | Triples: {triple} 39 | Question: {ques} 40 | ''' 41 | 42 | kr_prompt_gpt='''Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 43 | Triples: (Oxybutynin Oral, medicine.routed_drug.route_of_administration, Oral administration) (Oxybutynin Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin Chloride Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin chloride 5 extended release film coated tablet, medicine.drug_formulation.formulation_of, Oxybutynin) 44 | Question: oxybutynin chloride 5 extended release film coated tablet is the ingredients of what routed drug? 45 | Reason 1: I need to know which routed drug has the marketed formulation of oxybutynin chloride 5 extended release film coated tablet. 46 | Knowledge 1: The routed drugs Oxybutynin Oral and Oxybutynin Chloride Oral have the marketed formulation of oxybutynin chloride 5 extended release film coated tablet. 47 | 48 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 49 | Triples: (Google, organization.organization.founders, Sergey Brin) (Sergey Brin, people.person.education, CVT1) (CVT1, education.education.institution, University of Maryland, College Park) (Google, organization.organization.founders, Larry Page) (Larry Page, people.person.education, CVT2) (CVT2, education.education.institution, University of Michigan) (CVT2, education.education.institution, Stanford University) 50 | Question: where did the founder of google go to college? 51 | Reason 1: I need to know who the founders of Google are. 52 | Knowledge 1: The founders of Google are Sergey Brin and Larry Page. 53 | Reason 2: I need to know where Sergey Brin and Larry Page went to college. 54 | Knowledge 2: Sergey Brin went to the University of Maryland, College Park for college. Larry Page went to the University of Michigan and Stanford University for college. 55 | 56 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 57 | Triples: (Rock music, music.genre.artists, Outkast) (Rock music, music.genre.parent_genre, Folk music) (Rock music, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, Bright Eyes) (Electronica, music.genre.parent_genre, House music) (Electronica, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, t.A.T.u.) 58 | Question: the albums confessions tour is part of what parent genre of a musical genre? 59 | Reason 1: I need to know the musical genre of the albums confessions tour. 60 | Knowledge 1: The album confessions tour is associated with the genre Rock music and Electronica. 61 | Reason 2: I need to know the parent genre of Rock music and Electronica. 62 | Knowledge 2: The parent genre of Rock music is Folk music. The parent genre of Electronica is House music. 63 | 64 | Your task is to summarize the relevant information that is helpful to answer the question from the following triples. Please think step by step and iteratively generate the reasoning chain and the corresponding knowledge. 65 | Triples: {triple} 66 | Question: {ques} 67 | ''' 68 | 69 | num_dict = { 70 | '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', 71 | '5': 'five', '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', 72 | '10': 'ten', '11': 'eleven', '12': 'twelve', '13': 'thirteen', 73 | '14': 'fourteen', '15': 'fifteen', '16': 'sixteen', '17': 'seventeen', 74 | '18': 'eighteen', '19': 'nineteen', '20': 'twenty' 75 | } 76 | 77 | if LLM!='chatgpt': 78 | # path for LLM 79 | LLM_PATH='../../../instruction-tuning/output-'+DATA+'/sft/CoT/'+LLM 80 | # path for tokenizer 81 | TOKENIZER_PATH='../../../../pretrain/'+LLM 82 | # path for lora 83 | PEFT_PATH='../../../instruction-tuning/output-'+DATA+'/PA-chatgpt/CoT/'+LLM+'/best_model' 84 | print(PEFT_PATH) 85 | # load tokenizer and llm 86 | tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_PATH) 87 | llm=AutoModelForCausalLM.from_pretrained(LLM_PATH,torch_dtype=torch.float16,low_cpu_mem_usage=True,device_map='cuda:0') 88 | # merge peft into base LLM 89 | if PEFT_PATH: 90 | llm=PeftModel.from_pretrained(llm, PEFT_PATH,torch_dtype=torch.float16,device_map='cuda:0') 91 | 92 | # result 93 | result='result/'+DATA+'/'+MODE+'/'+LLM+'/pa-chatgpt.json' 94 | os.makedirs('result/'+DATA+'/'+MODE+'/'+LLM,exist_ok = True) 95 | log_file='log/'+DATA+'/'+MODE+'/'+LLM+'/pa-chatgpt.log' 96 | os.makedirs('log/'+DATA+'/'+MODE+'/'+LLM,exist_ok = True) 97 | 98 | # redirect output to log 99 | sys.stdout = open(log_file, 'w') 100 | 101 | def getResponse(prompt,max_retries=10): 102 | # set retries 103 | retries=0 104 | while retries < max_retries: 105 | try: 106 | res = client.chat.completions.create( 107 | model='gpt-3.5-turbo', 108 | #model='gpt-4', 109 | messages=[ 110 | {'role': 'user', 'content': prompt} 111 | ], 112 | temperature=0, 113 | ) 114 | return res.choices[0].message.content 115 | except Exception as e: 116 | print(f"An error occurred: {e}") 117 | print("Retrying in 1 minutes...") 118 | retries += 1 119 | time.sleep(60) 120 | return '' 121 | 122 | def LLMResponse(prompt,llm,tokenizer,cuda): 123 | inputs = tokenizer(prompt,return_tensors="pt") 124 | generation_output = llm.generate( 125 | input_ids=inputs["input_ids"].to(cuda), 126 | attention_mask=inputs['attention_mask'].to(cuda), 127 | eos_token_id=tokenizer.eos_token_id, 128 | pad_token_id=tokenizer.eos_token_id, 129 | generation_config=generation_config 130 | ) 131 | output = tokenizer.decode(generation_output[0],skip_special_tokens=True) 132 | response = output.split(prompt)[-1].strip() 133 | return response 134 | 135 | data=[] 136 | for sample in tqdm(test): 137 | 138 | # knowledge rewriter 139 | if len(sample["triples"])!=0: 140 | if LLM!='chatgpt': 141 | knowledge=LLMResponse(kr_prompt_llm.format(triple=sample["triples"],ques=sample["question"]),llm,tokenizer,'cuda:0') 142 | print(kr_prompt_llm.format(triple=sample["triples"],ques=sample["question"])) 143 | print(knowledge) 144 | else: 145 | knowledge=getResponse(kr_prompt_gpt.format(triple=sample["triples"],ques=sample["question"])) 146 | print(kr_prompt_gpt.format(triple=sample["triples"],ques=sample["question"])) 147 | print(knowledge) 148 | else: 149 | knowledge='' 150 | 151 | # record 152 | temp=dict() 153 | temp['question']=sample['question'] 154 | temp['answer']=sample["answer"] 155 | temp['graph']=sample["triples"] 156 | temp['knowledge']=knowledge 157 | data.append(temp) 158 | 159 | json.dump(data,open(result,'w',encoding='utf-8'),indent=2,ensure_ascii=False) 160 | -------------------------------------------------------------------------------- /inference/open/rewrite/infer_summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import json 4 | import random 5 | from tqdm import tqdm 6 | from transformers import GenerationConfig,AutoModelForCausalLM,AutoTokenizer,AutoModel 7 | import torch 8 | from peft import PeftModel 9 | import sys 10 | import openai 11 | import time 12 | from openai import OpenAI 13 | 14 | # generation config 15 | generation_config = GenerationConfig( 16 | temperature=0.01, 17 | top_k=40, 18 | top_p=0.9, 19 | do_sample=True, 20 | num_beams=1, 21 | repetition_penalty=1.1, 22 | max_new_tokens=1024 23 | ) 24 | 25 | # dataset: GraphQuestions, grailqa, WebQSP 26 | DATA='grailqa' 27 | # llm: llama-2-7b-chat-hf, Meta-Llama-3-8B-Instruct, chatgpt 28 | LLM='Meta-Llama-3-8B-Instruct' 29 | # retrieve method: bm25, 2hop 30 | MODE='2hop' 31 | 32 | # set client 33 | client=OpenAI(api_key='YOUR KEY') 34 | 35 | test=json.load(open('../retrieve/'+MODE+'/format/'+DATA+'.json','r',encoding='utf-8')) 36 | 37 | kr_prompt_llm='''Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 38 | Triples: {triple} 39 | Question: {ques} 40 | Knowledge: ''' 41 | 42 | kr_prompt_gpt='''Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 43 | Triples: (Oxybutynin Oral, medicine.routed_drug.route_of_administration, Oral administration) (Oxybutynin Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin Chloride Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin chloride 5 extended release film coated tablet, medicine.drug_formulation.formulation_of, Oxybutynin) 44 | Question: oxybutynin chloride 5 extended release film coated tablet is the ingredients of what routed drug? 45 | Knowledge: The Oxybutynin chloride 5 extended release film coated tablet is a marketed formulation of the routed drugs Oxybutynin Oral and Oxybutynin Chloride Oral. 46 | 47 | Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 48 | Triples: (Google, organization.organization.founders, Sergey Brin) (Sergey Brin, people.person.education, CVT1) (CVT1, education.education.institution, University of Maryland, College Park) (Google, organization.organization.founders, Larry Page) (Larry Page, people.person.education, CVT2) (CVT2, education.education.institution, University of Michigan) (CVT2, education.education.institution, Stanford University) 49 | Question: where did the founder of google go to college? 50 | Knowledge: The founders of Google are Sergey Brin and Larry Page. Sergey Brin attended the University of Maryland, College Park for his education. Larry Page attended the University of Michigan and Stanford University for his education. 51 | 52 | Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 53 | Triples: (Rock music, music.genre.artists, Outkast) (Rock music, music.genre.parent_genre, Folk music) (Rock music, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, Bright Eyes) (Electronica, music.genre.parent_genre, House music) (Electronica, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, t.A.T.u.) 54 | Question: the albums confessions tour is part of what parent genre of a musical genre? 55 | Knowledge: The album confessions tour is associated with the genre Rock music and Electronica. The parent genre of Rock music is Folk music. The parent genre of Electronica is House music. 56 | 57 | Your task is to summarize the relevant knowledge that is helpful to answer the question from the following triples. 58 | Triples: {triple} 59 | Question: {ques} 60 | Knowledge: ''' 61 | 62 | 63 | num_dict = { 64 | '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', 65 | '5': 'five', '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', 66 | '10': 'ten', '11': 'eleven', '12': 'twelve', '13': 'thirteen', 67 | '14': 'fourteen', '15': 'fifteen', '16': 'sixteen', '17': 'seventeen', 68 | '18': 'eighteen', '19': 'nineteen', '20': 'twenty' 69 | } 70 | 71 | if LLM!='chatgpt': 72 | # path for LLM 73 | LLM_PATH='../../../../pretrain/'+LLM 74 | # path for tokenizer 75 | TOKENIZER_PATH='../../../../pretrain/'+LLM 76 | # path for lora 77 | PEFT_PATH='../../../instruction-tuning/output-'+DATA+'/summary/'+LLM+'/best_model' 78 | # load tokenizer and llm 79 | tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_PATH) 80 | llm=AutoModelForCausalLM.from_pretrained(LLM_PATH,torch_dtype=torch.float16,low_cpu_mem_usage=True,device_map='cuda:0') 81 | # merge peft into base LLM 82 | if PEFT_PATH: 83 | llm=PeftModel.from_pretrained(llm, PEFT_PATH,torch_dtype=torch.float16,device_map='cuda:0') 84 | 85 | # result 86 | result='result/'+DATA+'/'+MODE+'/'+LLM+'/summary.json' 87 | os.makedirs('result/'+DATA+'/'+MODE+'/'+LLM,exist_ok = True) 88 | log_file='log/'+DATA+'/'+MODE+'/'+LLM+'/summary.log' 89 | os.makedirs('log/'+DATA+'/'+MODE+'/'+LLM,exist_ok = True) 90 | 91 | # redirect output to log 92 | sys.stdout = open(log_file, 'w') 93 | 94 | def getResponse(prompt,max_retries=10): 95 | # set retries 96 | retries=0 97 | while retries < max_retries: 98 | try: 99 | res = client.chat.completions.create( 100 | model='gpt-3.5-turbo', 101 | #model='gpt-4', 102 | messages=[ 103 | {'role': 'user', 'content': prompt} 104 | ], 105 | temperature=0, 106 | ) 107 | return res.choices[0].message.content 108 | except Exception as e: 109 | print(f"An error occurred: {e}") 110 | print("Retrying in 1 minutes...") 111 | retries += 1 112 | time.sleep(60) 113 | return '' 114 | 115 | def LLMResponse(prompt,llm,tokenizer,cuda): 116 | inputs = tokenizer(prompt,return_tensors="pt") 117 | generation_output = llm.generate( 118 | input_ids=inputs["input_ids"].to(cuda), 119 | attention_mask=inputs['attention_mask'].to(cuda), 120 | eos_token_id=tokenizer.eos_token_id, 121 | pad_token_id=tokenizer.eos_token_id, 122 | generation_config=generation_config 123 | ) 124 | output = tokenizer.decode(generation_output[0],skip_special_tokens=True) 125 | response = output.split(prompt)[-1].strip() 126 | return response 127 | 128 | data=[] 129 | for sample in tqdm(test): 130 | 131 | # knowledge rewriter 132 | if len(sample["triples"])!=0: 133 | if LLM!='chatgpt': 134 | knowledge=LLMResponse(kr_prompt_llm.format(triple=sample["triples"],ques=sample["question"]),llm,tokenizer,'cuda:0') 135 | print(kr_prompt_llm.format(triple=sample["triples"],ques=sample["question"])) 136 | print(knowledge) 137 | else: 138 | knowledge=getResponse(kr_prompt_gpt.format(triple=sample["triples"],ques=sample["question"])) 139 | print(kr_prompt_gpt.format(triple=sample["triples"],ques=sample["question"])) 140 | print(knowledge) 141 | else: 142 | knowledge='' 143 | 144 | # record 145 | temp=dict() 146 | temp['question']=sample['question'] 147 | temp['answer']=sample["answer"] 148 | temp['graph']=sample["triples"] 149 | temp['knowledge']=knowledge 150 | data.append(temp) 151 | 152 | json.dump(data,open(result,'w',encoding='utf-8'),indent=2,ensure_ascii=False) 153 | -------------------------------------------------------------------------------- /inference/open/rewrite/infer_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import json 4 | import random 5 | from tqdm import tqdm 6 | from transformers import GenerationConfig,AutoModelForCausalLM,AutoTokenizer,AutoModel 7 | import torch 8 | from peft import PeftModel 9 | import sys 10 | import openai 11 | import time 12 | from openai import OpenAI 13 | 14 | # generation config 15 | generation_config = GenerationConfig( 16 | temperature=0.01, 17 | top_k=40, 18 | top_p=0.9, 19 | do_sample=True, 20 | num_beams=1, 21 | repetition_penalty=1.1, 22 | max_new_tokens=1024 23 | ) 24 | 25 | # dataset: grailqa, GraphQuestions 26 | DATA='grailqa' 27 | # llm: llama-2-7b-chat-hf, Meta-Llama-3-8B-Instruct, chatgpt 28 | LLM='Meta-Llama-3-8B-Instruct' 29 | # retrieve method: bm25, 2hop 30 | MODE='2hop' 31 | 32 | # set client 33 | client=OpenAI(api_key='YOUR KEY') 34 | 35 | test=json.load(open('../retrieve/'+MODE+'/format/'+DATA+'.json','r',encoding='utf-8')) 36 | 37 | kr_prompt_llm='''Your task is to transform a knowledge graph to a sentence or multiple sentences. The knowledge graph is: {triple}. The sentence is: ''' 38 | 39 | kr_prompt_gpt='''Your task is to transform a knowledge graph to a sentence or multiple sentences. The knowledge graph is: (Oxybutynin Oral, medicine.routed_drug.route_of_administration, Oral administration) (Oxybutynin Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin Chloride Oral, medicine.routed_drug.marketed_formulations, Oxybutynin chloride 5 extended release film coated tablet) (Oxybutynin chloride 5 extended release film coated tablet, medicine.drug_formulation.formulation_of, Oxybutynin). The sentence is: Oxybutynin Oral is a medication that is administered orally. It is marketed in the form of Oxybutynin chloride 5 extended release film coated tablets. Another marketed formulation is Oxybutynin Chloride Oral. Furthermore, Oxybutynin chloride 5 extended release film coated tablet is a formulation of Oxybutynin. 40 | 41 | Your task is to transform a knowledge graph to a sentence or multiple sentences. The knowledge graph is: (Google, organization.organization.founders, Sergey Brin) (Sergey Brin, people.person.education, CVT1) (CVT1, education.education.institution, University of Maryland, College Park) (Google, organization.organization.founders, Larry Page) (Larry Page, people.person.education, CVT2) (CVT2, education.education.institution, University of Michigan) (CVT2, education.education.institution, Stanford University). The sentence is: Google was founded by Sergey Brin and Larry Page. Sergey Brin was educated at the University of Maryland, College Park, while Larry Page was educated at the University of Michigan and Stanford University. 42 | 43 | Your task is to transform a knowledge graph to a sentence or multiple sentences. The knowledge graph is: (Rock music, music.genre.artists, Outkast) (Rock music, music.genre.parent_genre, Folk music) (Rock music, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, Bright Eyes) (Electronica, music.genre.parent_genre, House music) (Electronica, music.genre.albums, The Confessions Tour) (Electronica, music.genre.artists, t.A.T.u.). The sentence is: Rock music, which is a subgenre of Folk music, includes artists like Outkast and albums such as "The Confessions Tour". Conversely, Electronica is a daughter genre of House music with artists like Bright Eyes and t.A.T.u., and also features albums like "The Confessions Tour". 44 | 45 | Your task is to transform a knowledge graph to a sentence or multiple sentences. The knowledge graph is: {triple}. The sentence is: ''' 46 | 47 | num_dict = { 48 | '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', 49 | '5': 'five', '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', 50 | '10': 'ten', '11': 'eleven', '12': 'twelve', '13': 'thirteen', 51 | '14': 'fourteen', '15': 'fifteen', '16': 'sixteen', '17': 'seventeen', 52 | '18': 'eighteen', '19': 'nineteen', '20': 'twenty' 53 | } 54 | 55 | if LLM!='chatgpt': 56 | # path for LLM 57 | LLM_PATH='../../../../pretrain/'+LLM 58 | # path for tokenizer 59 | TOKENIZER_PATH='../../../../pretrain/'+LLM 60 | # path for lora 61 | PEFT_PATH='../../../instruction-tuning/output-'+DATA+'/kg-to-text/'+LLM+'/best_model' 62 | # load tokenizer and llm 63 | tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_PATH) 64 | llm=AutoModelForCausalLM.from_pretrained(LLM_PATH,torch_dtype=torch.float16,low_cpu_mem_usage=True,device_map='cuda:0') 65 | # merge peft into base LLM 66 | if PEFT_PATH: 67 | llm=PeftModel.from_pretrained(llm, PEFT_PATH,torch_dtype=torch.float16,device_map='cuda:0') 68 | 69 | # result 70 | result='result/'+DATA+'/'+MODE+'/'+LLM+'/text.json' 71 | os.makedirs('result/'+DATA+'/'+MODE+'/'+LLM,exist_ok = True) 72 | log_file='log/'+DATA+'/'+MODE+'/'+LLM+'/text.log' 73 | os.makedirs('log/'+DATA+'/'+MODE+'/'+LLM,exist_ok = True) 74 | 75 | # redirect output to log 76 | sys.stdout = open(log_file, 'w') 77 | 78 | def getResponse(prompt,max_retries=10): 79 | # set retries 80 | retries=0 81 | while retries < max_retries: 82 | try: 83 | res = client.chat.completions.create( 84 | model='gpt-3.5-turbo', 85 | #model='gpt-4', 86 | messages=[ 87 | {'role': 'user', 'content': prompt} 88 | ], 89 | temperature=0, 90 | ) 91 | return res.choices[0].message.content 92 | except Exception as e: 93 | print(f"An error occurred: {e}") 94 | print("Retrying in 1 minutes...") 95 | retries += 1 96 | time.sleep(60) 97 | return '' 98 | 99 | def LLMResponse(prompt,llm,tokenizer,cuda): 100 | inputs = tokenizer(prompt,return_tensors="pt") 101 | generation_output = llm.generate( 102 | input_ids=inputs["input_ids"].to(cuda), 103 | attention_mask=inputs['attention_mask'].to(cuda), 104 | eos_token_id=tokenizer.eos_token_id, 105 | pad_token_id=tokenizer.eos_token_id, 106 | generation_config=generation_config 107 | ) 108 | output = tokenizer.decode(generation_output[0],skip_special_tokens=True) 109 | response = output.split(prompt)[-1].strip() 110 | return response 111 | 112 | data=[] 113 | for sample in tqdm(test): 114 | 115 | # knowledge rewriter 116 | if len(sample["triples"])!=0: 117 | if LLM!='chatgpt': 118 | knowledge=LLMResponse(kr_prompt_llm.format(triple=sample["triples"]),llm,tokenizer,'cuda:0') 119 | print(kr_prompt_llm.format(triple=sample["triples"])) 120 | print(knowledge) 121 | else: 122 | knowledge=getResponse(kr_prompt_gpt.format(triple=sample["triples"])) 123 | print(kr_prompt_gpt.format(triple=sample["triples"])) 124 | print(knowledge) 125 | else: 126 | knowledge='' 127 | 128 | # record 129 | temp=dict() 130 | temp['question']=sample['question'] 131 | temp['answer']=sample["answer"] 132 | temp['graph']=sample["triples"] 133 | temp['knowledge']=knowledge 134 | data.append(temp) 135 | 136 | json.dump(data,open(result,'w',encoding='utf-8'),indent=2,ensure_ascii=False) 137 | -------------------------------------------------------------------------------- /instruction-tuning/ds_zero2_no_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 100, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1e-10 9 | }, 10 | 11 | "zero_optimization": { 12 | "stage": 2, 13 | "allgather_partitions": true, 14 | "allgather_bucket_size": 1e8, 15 | "overlap_comm": true, 16 | "reduce_scatter": true, 17 | "reduce_bucket_size": 1e8, 18 | "contiguous_gradients": true 19 | }, 20 | 21 | "gradient_accumulation_steps": "auto", 22 | "gradient_clipping": "auto", 23 | "steps_per_print": 2000, 24 | "train_batch_size": "auto", 25 | "train_micro_batch_size_per_gpu": "auto", 26 | "wall_clock_breakdown": false 27 | } 28 | -------------------------------------------------------------------------------- /instruction-tuning/merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 4 | from transformers import GenerationConfig,AutoModelForCausalLM,AutoTokenizer,AutoModel 5 | from peft import PeftModel 6 | 7 | # dataset: GraphQuestions, grailqa 8 | DATA='grailqa' 9 | # llm: llama-2-7b-chat-hf, Meta-Llama-3-8B-Instruct 10 | LLM='llama-2-7b-chat-hf' 11 | # mode 12 | MODE='CoT' 13 | # path for LLM 14 | LLM_PATH='../../pretrain/'+LLM 15 | # path for tokenizer 16 | TOKENIZER_PATH='../../pretrain/'+LLM 17 | # path for lora 18 | PEFT_PATH='output-'+DATA+'/'+MODE+'/'+LLM+'/best_model' 19 | # result 20 | result='output-'+DATA+'/sft/'+MODE+'/'+LLM 21 | 22 | tokenizer=AutoTokenizer.from_pretrained(LLM_PATH) 23 | llm=AutoModelForCausalLM.from_pretrained(LLM_PATH,torch_dtype=torch.float16,low_cpu_mem_usage=True,device_map='cuda:0') 24 | llm=PeftModel.from_pretrained(llm, PEFT_PATH,torch_dtype=torch.float16,device_map='cuda:0') 25 | llm=llm.merge_and_unload() 26 | llm.save_pretrained(result) 27 | -------------------------------------------------------------------------------- /instruction-tuning/run_dpo-step.sh: -------------------------------------------------------------------------------- 1 | llm=Meta-Llama-3-8B-Instruct 2 | data=GraphQuestions 3 | MODE=CoT 4 | dataset=${data}/PA-chatgpt/${MODE}/${llm} 5 | load_in_kbits=16 6 | train_file=$dataset/train.json 7 | validation_file=$dataset/dev.json 8 | gpu_id='1' 9 | train_batch_size=1 10 | eval_batch_size=1 11 | accumulation_steps=128 12 | epoch=1 13 | node=1 14 | max_prompt_length=2048 15 | max_target_length=2048 16 | max_seq_length=4096 17 | 18 | lr=1e-4 19 | lora_rank=64 20 | lora_alpha=128 21 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" 22 | lora_dropout=0.05 23 | pretrained_model=output-${data}/sft/${MODE}/${llm} 24 | chinese_tokenizer_path=../../pretrain/${llm} 25 | per_device_train_batch_size=${train_batch_size} 26 | per_device_eval_batch_size=${eval_batch_size} 27 | gradient_accumulation_steps=${accumulation_steps} 28 | output_dir=output-$dataset 29 | modules_to_save="embed_tokens,lm_head" 30 | deepspeed_config_file=ds_zero2_no_offload.json 31 | 32 | CUDA_VISIBLE_DEVICES=${gpu_id} torchrun --master_port 28610 --nnodes 1 --nproc_per_node ${node} run_dpo.py \ 33 | --model_name_or_path ${pretrained_model} \ 34 | --tokenizer_name_or_path ${chinese_tokenizer_path} \ 35 | --train_file ${train_file} \ 36 | --validation_file ${validation_file} \ 37 | --per_device_train_batch_size ${per_device_train_batch_size} \ 38 | --per_device_eval_batch_size ${per_device_eval_batch_size} \ 39 | --do_train \ 40 | --do_eval \ 41 | --seed $RANDOM \ 42 | --fp16 \ 43 | --num_train_epochs ${epoch} \ 44 | --lr_scheduler_type cosine \ 45 | --learning_rate ${lr} \ 46 | --warmup_ratio 0.03 \ 47 | --weight_decay 0 \ 48 | --logging_strategy steps \ 49 | --logging_steps 5 \ 50 | --save_strategy steps \ 51 | --save_steps 5 \ 52 | --evaluation_strategy no \ 53 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 54 | --preprocessing_num_workers 8 \ 55 | --max_prompt_length ${max_prompt_length} \ 56 | --max_target_length ${max_target_length} \ 57 | --max_seq_length ${max_seq_length} \ 58 | --output_dir ${output_dir} \ 59 | --save_safetensors False \ 60 | --overwrite_output_dir \ 61 | --ddp_timeout 30000 \ 62 | --logging_first_step True \ 63 | --lora_rank ${lora_rank} \ 64 | --lora_alpha ${lora_alpha} \ 65 | --trainable ${lora_trainable} \ 66 | --lora_dropout ${lora_dropout} \ 67 | --torch_dtype float16 \ 68 | --load_in_kbits ${load_in_kbits} \ 69 | --gradient_checkpointing \ 70 | --ddp_find_unused_parameters False \ 71 | --report_to none 72 | 73 | -------------------------------------------------------------------------------- /instruction-tuning/run_dpo.sh: -------------------------------------------------------------------------------- 1 | llm=Meta-Llama-3-8B-Instruct 2 | data=GraphQuestions 3 | MODE=CoT 4 | dataset=${data}/PA-chatgpt/${MODE}/${llm} 5 | load_in_kbits=16 6 | train_file=$dataset/train.json 7 | validation_file=$dataset/dev.json 8 | gpu_id='0' 9 | train_batch_size=1 10 | eval_batch_size=1 11 | accumulation_steps=128 12 | epoch=10 13 | node=1 14 | max_prompt_length=2048 15 | max_target_length=2048 16 | max_seq_length=4096 17 | 18 | lr=1e-4 19 | lora_rank=64 20 | lora_alpha=128 21 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" 22 | lora_dropout=0.05 23 | pretrained_model=output-${data}/sft/${MODE}/${llm} 24 | chinese_tokenizer_path=../../pretrain/${llm} 25 | per_device_train_batch_size=${train_batch_size} 26 | per_device_eval_batch_size=${eval_batch_size} 27 | gradient_accumulation_steps=${accumulation_steps} 28 | output_dir=output-$dataset 29 | modules_to_save="embed_tokens,lm_head" 30 | deepspeed_config_file=ds_zero2_no_offload.json 31 | 32 | CUDA_VISIBLE_DEVICES=${gpu_id} torchrun --master_port 29920 --nnodes 1 --nproc_per_node ${node} run_dpo.py \ 33 | --model_name_or_path ${pretrained_model} \ 34 | --tokenizer_name_or_path ${chinese_tokenizer_path} \ 35 | --train_file ${train_file} \ 36 | --validation_file ${validation_file} \ 37 | --per_device_train_batch_size ${per_device_train_batch_size} \ 38 | --per_device_eval_batch_size ${per_device_eval_batch_size} \ 39 | --do_train \ 40 | --do_eval \ 41 | --seed $RANDOM \ 42 | --fp16 \ 43 | --num_train_epochs ${epoch} \ 44 | --lr_scheduler_type cosine \ 45 | --learning_rate ${lr} \ 46 | --warmup_ratio 0.03 \ 47 | --weight_decay 0 \ 48 | --logging_strategy steps \ 49 | --logging_steps 10 \ 50 | --save_strategy epoch \ 51 | --save_total_limit 10 \ 52 | --evaluation_strategy epoch \ 53 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 54 | --preprocessing_num_workers 8 \ 55 | --max_prompt_length ${max_prompt_length} \ 56 | --max_target_length ${max_target_length} \ 57 | --max_seq_length ${max_seq_length} \ 58 | --output_dir ${output_dir} \ 59 | --save_safetensors False \ 60 | --overwrite_output_dir \ 61 | --ddp_timeout 30000 \ 62 | --logging_first_step True \ 63 | --lora_rank ${lora_rank} \ 64 | --lora_alpha ${lora_alpha} \ 65 | --trainable ${lora_trainable} \ 66 | --lora_dropout ${lora_dropout} \ 67 | --torch_dtype float16 \ 68 | --load_in_kbits ${load_in_kbits} \ 69 | --gradient_checkpointing \ 70 | --ddp_find_unused_parameters False \ 71 | --load_best_model_at_end True \ 72 | --report_to none 73 | -------------------------------------------------------------------------------- /instruction-tuning/run_llama-7b.sh: -------------------------------------------------------------------------------- 1 | llm='llama-2-7b-chat-hf' 2 | dataset='GraphQuestions/CoT' 3 | train_batch_size=1 4 | eval_batch_size=1 5 | accumulation_steps=64 6 | node=2 7 | max_length=4096 8 | 9 | lr=1e-4 10 | lora_rank=64 11 | lora_alpha=128 12 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" 13 | lora_dropout=0.05 14 | pretrained_model=../../pretrain/${llm} 15 | chinese_tokenizer_path=../../pretrain/${llm} 16 | per_device_train_batch_size=${train_batch_size} 17 | per_device_eval_batch_size=${eval_batch_size} 18 | gradient_accumulation_steps=${accumulation_steps} 19 | dataset_dir=${dataset}/train/ 20 | output_dir=output-${dataset}/${llm} 21 | validation_file=${dataset}/dev.json 22 | modules_to_save="embed_tokens,lm_head" 23 | 24 | deepspeed_config_file=ds_zero2_no_offload.json 25 | torchrun --master_port 27140 --nnodes 1 --nproc_per_node ${node} run_clm_sft_with_peft-7b.py \ 26 | --deepspeed ${deepspeed_config_file} \ 27 | --model_name_or_path ${pretrained_model} \ 28 | --tokenizer_name_or_path ${chinese_tokenizer_path} \ 29 | --dataset_dir ${dataset_dir} \ 30 | --validation_split_percentage 0.001 \ 31 | --per_device_train_batch_size ${per_device_train_batch_size} \ 32 | --per_device_eval_batch_size ${per_device_eval_batch_size} \ 33 | --do_train \ 34 | --do_eval \ 35 | --seed $RANDOM \ 36 | --fp16 \ 37 | --num_train_epochs 10 \ 38 | --lr_scheduler_type cosine \ 39 | --learning_rate ${lr} \ 40 | --warmup_ratio 0.03 \ 41 | --weight_decay 0 \ 42 | --logging_strategy steps \ 43 | --logging_steps 10 \ 44 | --save_strategy epoch \ 45 | --save_total_limit 2 \ 46 | --evaluation_strategy epoch \ 47 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 48 | --preprocessing_num_workers 8 \ 49 | --max_seq_length ${max_length} \ 50 | --output_dir ${output_dir} \ 51 | --overwrite_output_dir \ 52 | --ddp_timeout 30000 \ 53 | --logging_first_step True \ 54 | --lora_rank ${lora_rank} \ 55 | --lora_alpha ${lora_alpha} \ 56 | --trainable ${lora_trainable} \ 57 | --modules_to_save ${modules_to_save} \ 58 | --lora_dropout ${lora_dropout} \ 59 | --torch_dtype float16 \ 60 | --validation_file ${validation_file} \ 61 | --gradient_checkpointing \ 62 | --ddp_find_unused_parameters False \ 63 | --load_best_model_at_end True \ 64 | --report_to none 65 | -------------------------------------------------------------------------------- /instruction-tuning/run_llama-8b.sh: -------------------------------------------------------------------------------- 1 | llm='Meta-Llama-3-8B-Instruct' 2 | dataset='grailqa/kg-to-text' 3 | train_batch_size=1 4 | eval_batch_size=1 5 | accumulation_steps=64 6 | node=2 7 | max_length=4096 8 | 9 | lr=1e-4 10 | lora_rank=64 11 | lora_alpha=128 12 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" 13 | lora_dropout=0.05 14 | pretrained_model=../../pretrain/${llm} 15 | chinese_tokenizer_path=../../pretrain/${llm} 16 | per_device_train_batch_size=${train_batch_size} 17 | per_device_eval_batch_size=${eval_batch_size} 18 | gradient_accumulation_steps=${accumulation_steps} 19 | dataset_dir=${dataset}/train/ 20 | output_dir=output-${dataset}/${llm} 21 | validation_file=${dataset}/dev.json 22 | modules_to_save="embed_tokens,lm_head" 23 | 24 | deepspeed_config_file=ds_zero2_no_offload.json 25 | 26 | torchrun --master_port 29548 --nnodes 1 --nproc_per_node ${node} run_clm_sft_with_peft-8b.py \ 27 | --deepspeed ${deepspeed_config_file} \ 28 | --model_name_or_path ${pretrained_model} \ 29 | --tokenizer_name_or_path ${chinese_tokenizer_path} \ 30 | --dataset_dir ${dataset_dir} \ 31 | --validation_split_percentage 0.001 \ 32 | --per_device_train_batch_size ${per_device_train_batch_size} \ 33 | --per_device_eval_batch_size ${per_device_eval_batch_size} \ 34 | --do_train \ 35 | --do_eval \ 36 | --seed $RANDOM \ 37 | --fp16 \ 38 | --num_train_epochs 10 \ 39 | --lr_scheduler_type cosine \ 40 | --learning_rate ${lr} \ 41 | --warmup_ratio 0.03 \ 42 | --weight_decay 0 \ 43 | --logging_strategy steps \ 44 | --logging_steps 10 \ 45 | --save_strategy epoch \ 46 | --save_total_limit 2 \ 47 | --evaluation_strategy epoch \ 48 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 49 | --preprocessing_num_workers 8 \ 50 | --max_seq_length ${max_length} \ 51 | --output_dir ${output_dir} \ 52 | --overwrite_output_dir \ 53 | --ddp_timeout 30000 \ 54 | --logging_first_step True \ 55 | --lora_rank ${lora_rank} \ 56 | --lora_alpha ${lora_alpha} \ 57 | --trainable ${lora_trainable} \ 58 | --lora_dropout ${lora_dropout} \ 59 | --torch_dtype float16 \ 60 | --validation_file ${validation_file} \ 61 | --gradient_checkpointing \ 62 | --ddp_find_unused_parameters False \ 63 | --load_best_model_at_end True \ 64 | --report_to none 65 | -------------------------------------------------------------------------------- /requirement1.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.0 2 | datasets==2.19.1 3 | deepspeed==0.10.0 4 | numpy==1.26.4 5 | openai==1.27.0 6 | pandas==2.2.2 7 | peft==0.10.0 8 | safetensors==0.4.3 9 | sentence-transformers==2.2.2 10 | sentencepiece==0.2.0 11 | tensorboard==2.15.1 12 | torch==2.3.0 13 | transformers==4.40.2 -------------------------------------------------------------------------------- /requirement2.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.0 2 | datasets==2.19.1 3 | deepspeed==0.14.2 4 | numpy==1.26.4 5 | openai==1.27.0 6 | pandas==2.2.2 7 | peft==0.6.2 8 | safetensors==0.4.3 9 | sentence-transformers==2.2.2 10 | sentencepiece==0.2.0 11 | tensorboard==2.15.1 12 | torch==2.3.0 13 | transformers==4.40.2 14 | trl==0.8.6 -------------------------------------------------------------------------------- /subgraph/GraphQuestions/graph_query.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from sparql_utils.sparql_executor import execute_query, execute_query_allvar, get_friendly_name 4 | from query_interface import get_1hop 5 | from tqdm import tqdm 6 | import random 7 | import copy 8 | 9 | EXNUM=10 10 | GRAPHNUM=10 11 | 12 | # change sparql to query all variables 13 | def update_sparql_query(query_string): 14 | # find all ver 15 | var_pattern = r'\?[xy]\d+' 16 | variables = set(re.findall(var_pattern, query_string)) 17 | # sort based on num 18 | variables = sorted(variables, key=lambda x: int(x[2:])) 19 | # extract triples with y for it may not be used in main query 20 | query_lines = query_string.split('\n') 21 | query_y=[] 22 | for line in query_lines: 23 | if line.startswith("?") and len(line.split(' ')) == 5 and '?y' in line: 24 | query_y.append(line) 25 | query_lines=query_lines[:3]+query_y+query_lines[3:] 26 | # modify select distinct 27 | if query_lines[1].startswith('SELECT '): 28 | # remove SELECT (?x0 AS ?value) WHERE { and last } 29 | query_lines = query_lines[:1] + query_lines[2:-1] 30 | if query_lines[1].startswith('SELECT DISTINCT'): 31 | select_parts = query_lines[1].split(' ') 32 | select_parts[2] = ' '.join(variables) 33 | query_lines[1] = ' '.join(select_parts) 34 | return '\n'.join(query_lines), list(variables) 35 | 36 | # parse sparql to subgraph 37 | def sparql_to_graph(query): 38 | # input: sparql query 39 | # return: str(triples), 40 | lines = query.split('\n') 41 | graph_lines = [] 42 | values = {} 43 | # extract all intermediate entity mid 44 | for line in lines: 45 | if line.startswith("VALUE"): 46 | k = None 47 | v = None 48 | for item in line.split(' '): 49 | if item.startswith("?"): 50 | k = item 51 | if item.startswith(":") or "1 and n[j[0][1:]]['mid'][0:2] in ['m.','n.','g.']: 145 | midset.add(n[j[0][1:]]['mid']) 146 | # j[2] 147 | # make sure j[2] is an entity 148 | if j[2].startswith('?') and len(n[j[2][1:]]['mid'])>1 and n[j[2][1:]]['mid'][0:2] in ['m.','n.','g.']: 149 | midset.add(n[j[2][1:]]['mid']) 150 | # skip type relation 151 | if triple[1]!='type.object.type': 152 | one_graph.append(triple) 153 | midlist.append(midset) 154 | random.shuffle(one_graph) 155 | graph.append(one_graph) 156 | 157 | # graph extend 158 | ex_graph=[] 159 | for index,g in enumerate(graph[:GRAPHNUM]): 160 | # copy g to g1 161 | g1=copy.deepcopy(g) 162 | # iteratively extend triple 163 | ex_triple=[] 164 | # collect mid triple 165 | mid_triple=[] 166 | for j in midlist[index]: 167 | for k in get_1hop(j)[:EXNUM]: 168 | if k not in mid_triple: 169 | mid_triple.append(k) 170 | # avoid redundant triple 171 | unique_triples = set(tuple(triple) for triple in mid_triple) 172 | mid_triple = [list(triple) for triple in unique_triples] 173 | random.shuffle(mid_triple) 174 | # mid to name 175 | for k in mid_triple: 176 | extend=[] 177 | # k[0] 178 | temp='' 179 | # k[0] is in mid_dict 180 | if mid_dict.get(k[0]): 181 | temp=mid_dict[k[0]] 182 | # k[0] is not entity 183 | if len(temp)==0 and (len(k[0])==1 or k[0][0:2] not in ['m.','n.','g.']): 184 | temp=k[0].replace('-08:00','') 185 | # k[0] is entity 186 | if len(temp)==0: 187 | temp=get_friendly_name(k[0]) 188 | if temp=='null': 189 | if cvt_dict.get(k[0]): 190 | temp=cvt_dict[k[0]] 191 | else: 192 | temp='CVT'+str(cvt) 193 | cvt_dict[k[0]]=temp 194 | cvt+=1 195 | else: 196 | temp=temp.replace('-08:00','') 197 | extend.append(temp) 198 | # k[1] 199 | extend.append(k[1]) 200 | # k[2] 201 | temp='' 202 | # k[2] is in mid_dict 203 | if mid_dict.get(k[2]): 204 | temp=mid_dict[k[2]] 205 | # k[2] is not entity 206 | if len(temp)==0 and (len(k[2])==1 or k[2][0:2] not in ['m.','n.','g.']): 207 | temp=k[2].replace('-08:00','') 208 | # k[2] is entity 209 | if len(temp)==0: 210 | temp=get_friendly_name(k[2]) 211 | if temp=='null': 212 | if cvt_dict.get(k[2]): 213 | temp=cvt_dict[k[2]] 214 | else: 215 | temp='CVT'+str(cvt) 216 | cvt_dict[k[2]]=temp 217 | cvt+=1 218 | else: 219 | temp=temp.replace('-08:00','') 220 | extend.append(temp) 221 | #if extend not in g1: 222 | # g1.append(extend) 223 | if extend not in ex_triple: 224 | ex_triple.append(extend) 225 | # add ex_triple to g1 226 | random.shuffle(ex_triple) 227 | g1.extend(ex_triple) 228 | random.shuffle(g1) 229 | ex_graph.append(g1) 230 | 231 | sample['qid']=one_example['qid'] 232 | sample['question']=one_example['question'] 233 | sample['answer']=one_example['answer'] 234 | sample['sparql_query']=one_example['sparql_query'] 235 | sample['s_expression']=one_example['s_expression'] 236 | sample['graph']=graph 237 | sample['restrict_graph']=graph[:GRAPHNUM] 238 | sample['ex_graph']=ex_graph 239 | graphdata.append(sample) 240 | 241 | json.dump(graphdata,open('graph/'+file_type+'.json','w',encoding='utf-8'),indent=2,ensure_ascii=False) 242 | 243 | query('train') 244 | query('test') -------------------------------------------------------------------------------- /subgraph/GraphQuestions/sparql_utils/misc.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter, deque 2 | import torch 3 | import json 4 | import pickle 5 | import numpy as np 6 | import torch.nn as nn 7 | import random 8 | import os 9 | import time 10 | ###################################################### 11 | ##################### used in SRN #################### 12 | START_RELATION = 'START_RELATION' 13 | NO_OP_RELATION = 'NO_OP_RELATION' 14 | NO_OP_ENTITY = 'NO_OP_ENTITY' 15 | DUMMY_RELATION = 'DUMMY_RELATION' 16 | DUMMY_ENTITY = 'DUMMY_ENTITY' 17 | 18 | DUMMY_RELATION_ID = 0 19 | START_RELATION_ID = 1 20 | NO_OP_RELATION_ID = 2 21 | DUMMY_ENTITY_ID = 0 22 | NO_OP_ENTITY_ID = 1 23 | 24 | EPSILON = float(np.finfo(float).eps) 25 | HUGE_INT = 1e31 26 | 27 | def format_path(path_trace, id2entity, id2relation): 28 | def get_most_recent_relation(j): 29 | relation_id = int(path_trace[j][0]) 30 | if relation_id == NO_OP_RELATION_ID: 31 | return '' 32 | else: 33 | return id2relation[relation_id] 34 | 35 | def get_most_recent_entity(j): 36 | return id2entity[int(path_trace[j][1])] 37 | 38 | path_str = get_most_recent_entity(0) 39 | for j in range(1, len(path_trace)): 40 | rel = get_most_recent_relation(j) 41 | if not rel.endswith('_inv'): 42 | path_str += ' -{}-> '.format(rel) 43 | else: 44 | path_str += ' <-{}- '.format(rel[:-4]) 45 | path_str += get_most_recent_entity(j) 46 | return path_str 47 | 48 | def pad_and_cat(a, padding_value, padding_dim=1): 49 | max_dim_size = max([x.size()[padding_dim] for x in a]) 50 | padded_a = [] 51 | for x in a: 52 | if x.size()[padding_dim] < max_dim_size: 53 | res_len = max_dim_size - x.size()[1] 54 | pad = nn.ConstantPad1d((0, res_len), padding_value) 55 | padded_a.append(pad(x)) 56 | else: 57 | padded_a.append(x) 58 | return torch.cat(padded_a, dim=0) 59 | 60 | def safe_log(x): 61 | return torch.log(x + EPSILON) 62 | 63 | def entropy(p): 64 | return torch.sum(- p * safe_log(p), 1) 65 | 66 | def init_word2id(): 67 | return { 68 | '': 0, 69 | '': 1, 70 | 'E_S': 2, 71 | } 72 | def init_entity2id(): 73 | return { 74 | DUMMY_ENTITY: DUMMY_ENTITY_ID, 75 | NO_OP_ENTITY: NO_OP_ENTITY_ID 76 | } 77 | def init_relation2id(): 78 | return { 79 | DUMMY_RELATION: DUMMY_RELATION_ID, 80 | START_RELATION: START_RELATION_ID, 81 | NO_OP_RELATION: NO_OP_RELATION_ID 82 | } 83 | 84 | def add_item_to_x2id(item, x2id): 85 | if not item in x2id: 86 | x2id[item] = len(x2id) 87 | 88 | def tile_along_beam(v, beam_size, dim=0): 89 | """ 90 | Tile a tensor along a specified dimension for the specified beam size. 91 | :param v: Input tensor. 92 | :param beam_size: Beam size. 93 | """ 94 | if dim == -1: 95 | dim = len(v.size()) - 1 96 | v = v.unsqueeze(dim + 1) 97 | v = torch.cat([v] * beam_size, dim=dim+1) 98 | new_size = [] 99 | for i, d in enumerate(v.size()): 100 | if i == dim + 1: 101 | new_size[-1] *= d 102 | else: 103 | new_size.append(d) 104 | return v.view(new_size) 105 | ##################### used in SRN #################### 106 | ###################################################### 107 | 108 | 109 | 110 | def init_vocab(): 111 | return { 112 | '': 0, 113 | '': 1, 114 | '': 2, 115 | '': 3 116 | } 117 | 118 | def invert_dict(d): 119 | return {v: k for k, v in d.items()} 120 | 121 | def load_glove(glove_pt, idx_to_token): 122 | glove = pickle.load(open(glove_pt, 'rb')) 123 | dim = len(glove['the']) 124 | matrix = [] 125 | for i in range(len(idx_to_token)): 126 | token = idx_to_token[i] 127 | tokens = token.split() 128 | if len(tokens) > 1: 129 | v = np.zeros((dim,)) 130 | for token in tokens: 131 | v = v + glove.get(token, glove['the']) 132 | v = v / len(tokens) 133 | else: 134 | v = glove.get(token, glove['the']) 135 | matrix.append(v) 136 | matrix = np.asarray(matrix) 137 | return matrix 138 | 139 | 140 | class SmoothedValue(object): 141 | """Track a series of values and provide access to smoothed values over a 142 | window or the global series average. 143 | """ 144 | 145 | def __init__(self, window_size=20): 146 | self.deque = deque(maxlen=window_size) 147 | self.series = [] 148 | self.total = 0.0 149 | self.count = 0 150 | 151 | def update(self, value): 152 | self.deque.append(value) 153 | self.series.append(value) 154 | self.count += 1 155 | self.total += value 156 | 157 | @property 158 | def median(self): 159 | d = torch.tensor(list(self.deque)) 160 | return d.median().item() 161 | 162 | @property 163 | def avg(self): 164 | d = torch.tensor(list(self.deque)) 165 | return d.mean().item() 166 | 167 | @property 168 | def global_avg(self): 169 | return self.total / self.count 170 | 171 | 172 | class MetricLogger(object): 173 | def __init__(self, delimiter="\t"): 174 | self.meters = defaultdict(SmoothedValue) 175 | self.delimiter = delimiter 176 | 177 | def update(self, **kwargs): 178 | for k, v in kwargs.items(): 179 | if isinstance(v, torch.Tensor): 180 | v = v.item() 181 | assert isinstance(v, (float, int)) 182 | self.meters[k].update(v) 183 | 184 | def __getattr__(self, attr): 185 | if attr in self.meters: 186 | return self.meters[attr] 187 | if attr in self.__dict__: 188 | return self.__dict__[attr] 189 | raise AttributeError("'{}' object has no attribute '{}'".format( 190 | type(self).__name__, attr)) 191 | 192 | def __str__(self): 193 | loss_str = [] 194 | for name, meter in self.meters.items(): 195 | loss_str.append( 196 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 197 | ) 198 | return self.delimiter.join(loss_str) 199 | 200 | 201 | def seed_everything(seed=1029): 202 | ''' 203 | 设置整个开发环境的seed 204 | :param seed: 205 | :param device: 206 | :return: 207 | ''' 208 | random.seed(seed) 209 | os.environ['PYTHONHASHSEED'] = str(seed) 210 | np.random.seed(seed) 211 | torch.manual_seed(seed) 212 | torch.cuda.manual_seed(seed) 213 | torch.cuda.manual_seed_all(seed) 214 | # some cudnn methods can be random even after fixing the seed 215 | # unless you tell it to be deterministic 216 | torch.backends.cudnn.deterministic = True 217 | 218 | 219 | class ProgressBar(object): 220 | ''' 221 | custom progress bar 222 | Example: 223 | >>> pbar = ProgressBar(n_total=30,desc='training') 224 | >>> step = 2 225 | >>> pbar(step=step) 226 | ''' 227 | def __init__(self, n_total,width=30,desc = 'Training'): 228 | self.width = width 229 | self.n_total = n_total 230 | self.start_time = time.time() 231 | self.desc = desc 232 | 233 | def __call__(self, step, info={}): 234 | now = time.time() 235 | current = step + 1 236 | recv_per = current / self.n_total 237 | bar = f'[{self.desc}] {current}/{self.n_total} [' 238 | if recv_per >= 1: 239 | recv_per = 1 240 | prog_width = int(self.width * recv_per) 241 | if prog_width > 0: 242 | bar += '=' * (prog_width - 1) 243 | if current< self.n_total: 244 | bar += ">" 245 | else: 246 | bar += '=' 247 | bar += '.' * (self.width - prog_width) 248 | bar += ']' 249 | show_bar = f"\r{bar}" 250 | time_per_unit = (now - self.start_time) / current 251 | if current < self.n_total: 252 | eta = time_per_unit * (self.n_total - current) 253 | if eta > 3600: 254 | eta_format = ('%d:%02d:%02d' % 255 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 256 | elif eta > 60: 257 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 258 | else: 259 | eta_format = '%ds' % eta 260 | time_info = f' - ETA: {eta_format}' 261 | else: 262 | if time_per_unit >= 1: 263 | time_info = f' {time_per_unit:.1f}s/step' 264 | elif time_per_unit >= 1e-3: 265 | time_info = f' {time_per_unit * 1e3:.1f}ms/step' 266 | else: 267 | time_info = f' {time_per_unit * 1e6:.1f}us/step' 268 | 269 | show_bar += time_info 270 | if len(info) != 0: 271 | show_info = f'{show_bar} ' + \ 272 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()]) 273 | print(show_info, end='') 274 | else: 275 | print(show_bar, end='') -------------------------------------------------------------------------------- /subgraph/GraphQuestions/sparql_utils/value_class.py: -------------------------------------------------------------------------------- 1 | def comp(a, b, op): 2 | """ 3 | Args: 4 | - a (ValueClass): attribute value of a certain entity 5 | - b (ValueClass): comparison target 6 | - op: =/>/': 21 | return a > b 22 | elif op == '!=': 23 | return a != b 24 | 25 | class ValueClass(): 26 | def __init__(self, type, value, unit=None): 27 | """ 28 | When type is 29 | - string, value is a str 30 | - quantity, value is a number and unit is required 31 | - year, value is a int 32 | - date, value is a date object 33 | """ 34 | self.type = type 35 | self.value = value 36 | self.unit = unit 37 | 38 | def isTime(self): 39 | return self.type in {'year', 'date'} 40 | 41 | def can_compare(self, other): 42 | if self.type == 'string': 43 | return other.type == 'string' 44 | elif self.type == 'quantity': 45 | # NOTE: for two quantity, they can compare only when they have the same unit 46 | return other.type == 'quantity' and other.unit == self.unit 47 | else: 48 | # year can compare with date 49 | return other.type == 'year' or other.type == 'date' 50 | 51 | def contains(self, other): 52 | """ 53 | check whether self contains other, which is different from __eq__ and the result is asymmetric 54 | used for conditions like whether 2001-01-01 in 2001, or whether 2001 in 2001-01-01 55 | """ 56 | if self.type == 'year': # year can contain year and date 57 | other_value = other.value if other.type == 'year' else other.value.year 58 | return self.value == other_value 59 | elif self.type == 'date': # date can only contain date 60 | return other.type == 'date' and self.value == other.value 61 | else: 62 | raise Exception('not supported type: %s' % self.type) 63 | 64 | 65 | def __eq__(self, other): 66 | """ 67 | 2001 and 2001-01-01 is not equal 68 | """ 69 | assert self.can_compare(other) 70 | return self.type == other.type and self.value == other.value 71 | 72 | def __lt__(self, other): 73 | """ 74 | Comparison between a year and a date will convert them both to year 75 | """ 76 | assert self.can_compare(other) 77 | if self.type == 'string': 78 | raise Exception('try to compare two string') 79 | elif self.type == 'quantity': 80 | return self.value < other.value 81 | elif self.type == 'year': 82 | other_value = other.value if other.type == 'year' else other.value.year 83 | return self.value < other_value 84 | elif self.type == 'date': 85 | if other.type == 'year': 86 | return self.value.year < other.value 87 | else: 88 | return self.value < other.value 89 | 90 | def __gt__(self, other): 91 | assert self.can_compare(other) 92 | if self.type == 'string': 93 | raise Exception('try to compare two string') 94 | elif self.type == 'quantity': 95 | return self.value > other.value 96 | elif self.type == 'year': 97 | other_value = other.value if other.type == 'year' else other.value.year 98 | return self.value > other_value 99 | elif self.type == 'date': 100 | if other.type == 'year': 101 | return self.value.year > other.value 102 | else: 103 | return self.value > other.value 104 | 105 | def __str__(self): 106 | if self.type == 'string': 107 | return self.value 108 | elif self.type == 'quantity': 109 | if self.value - int(self.value) < 1e-5: 110 | v = int(self.value) 111 | else: 112 | v = self.value 113 | return '{} {}'.format(v, self.unit) if self.unit != '1' else str(v) 114 | elif self.type == 'year': 115 | return str(self.value) 116 | elif self.type == 'date': 117 | return self.value.isoformat() 118 | -------------------------------------------------------------------------------- /subgraph/gold_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import random 5 | from tqdm import tqdm 6 | 7 | # dataset: grailqa, GraphQuestions 8 | DATA='grailqa' 9 | 10 | # result for subgraph 11 | result=DATA+'/gold/test.json' 12 | 13 | # load data 14 | data=json.load(open(DATA+'/graph/test.json','r',encoding='utf-8')) 15 | 16 | num_dict = { 17 | '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', 18 | '5': 'five', '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', 19 | '10': 'ten', '11': 'eleven', '12': 'twelve', '13': 'thirteen', 20 | '14': 'fourteen', '15': 'fifteen', '16': 'sixteen', '17': 'seventeen', 21 | '18': 'eighteen', '19': 'nineteen', '20': 'twenty' 22 | } 23 | 24 | MAX_NUM=10 25 | 26 | samplelist=[] 27 | for sample in tqdm(data): 28 | # graph sample 29 | graphset=set() 30 | for i in sample['graph'][:MAX_NUM]: 31 | for j in i: 32 | graphset.add('('+j[0]+', '+j[1]+', '+j[2]+')') 33 | # avoid too many triples 34 | graphlist=list(graphset) 35 | 36 | 37 | # gold answer extraction 38 | if DATA=='WebQSP': 39 | gold=sample["answer"] 40 | else: 41 | gold=[] 42 | for i in sample["answer"]: 43 | if i.get("entity_name"): 44 | gold.append(i["entity_name"]) 45 | else: 46 | gold.append(i["answer_argument"]) 47 | 48 | 49 | # save 50 | temp=dict() 51 | temp['question']=sample['question'] 52 | temp["triples"]=' '.join(graphlist) 53 | temp['answer']=gold 54 | samplelist.append(temp) 55 | 56 | json.dump(samplelist,open(result,'w',encoding='utf-8'),indent=2,ensure_ascii=False) 57 | -------------------------------------------------------------------------------- /subgraph/grailqa/graph_query.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from sparql_utils.sparql_executor import execute_query, execute_query_allvar, get_friendly_name 4 | from query_interface import get_1hop 5 | from tqdm import tqdm 6 | import random 7 | import copy 8 | 9 | # max extend triple number for each entity in gold graph 10 | EXNUM=10 11 | # max graph number for extend 12 | GRAPHNUM=10 13 | 14 | # change sparql to query all variables 15 | def update_sparql_query(query_string): 16 | # find all ver 17 | var_pattern = r'\?[xy]\d+' 18 | variables = set(re.findall(var_pattern, query_string)) 19 | # sort based on num 20 | variables = sorted(variables, key=lambda x: int(x[2:])) 21 | # extract triples with y for it may not be used in main query 22 | query_lines = query_string.split('\n') 23 | query_y=[] 24 | for line in query_lines: 25 | if line.startswith("?") and len(line.split(' ')) == 5 and '?y' in line: 26 | query_y.append(line) 27 | query_lines=query_lines[:3]+query_y+query_lines[3:] 28 | # modify select distinct 29 | if query_lines[1].startswith('SELECT '): 30 | # remove SELECT (?x0 AS ?value) WHERE { and last } 31 | query_lines = query_lines[:1] + query_lines[2:-1] 32 | if query_lines[1].startswith('SELECT DISTINCT'): 33 | select_parts = query_lines[1].split(' ') 34 | select_parts[2] = ' '.join(variables) 35 | query_lines[1] = ' '.join(select_parts) 36 | return '\n'.join(query_lines), list(variables) 37 | 38 | # parse sparql to subgraph 39 | def sparql_to_graph(query): 40 | # input: sparql query 41 | # return: str(triples), 42 | lines = query.split('\n') 43 | graph_lines = [] 44 | values = {} 45 | # extract all intermediate entity mid 46 | for line in lines: 47 | if line.startswith("VALUE"): 48 | k = None 49 | v = None 50 | for item in line.split(' '): 51 | if item.startswith("?"): 52 | k = item 53 | if item.startswith(":") or "1 and n[j[0][1:]]['mid'][0:2] in ['m.','n.','g.']: 147 | midset.add(n[j[0][1:]]['mid']) 148 | # j[2] 149 | # make sure j[2] is an entity 150 | if j[2].startswith('?') and len(n[j[2][1:]]['mid'])>1 and n[j[2][1:]]['mid'][0:2] in ['m.','n.','g.']: 151 | midset.add(n[j[2][1:]]['mid']) 152 | # skip type relation 153 | if triple[1]!='type.object.type': 154 | one_graph.append(triple) 155 | midlist.append(midset) 156 | random.shuffle(one_graph) 157 | graph.append(one_graph) 158 | 159 | # graph extend 160 | ex_graph=[] 161 | for index,g in enumerate(graph[:GRAPHNUM]): 162 | # copy g to g1 163 | g1=copy.deepcopy(g) 164 | # iteratively extend triple 165 | ex_triple=[] 166 | # collect mid triple 167 | mid_triple=[] 168 | for j in midlist[index]: 169 | for k in get_1hop(j)[:EXNUM]: 170 | if k not in mid_triple: 171 | mid_triple.append(k) 172 | # avoid redundant triple 173 | unique_triples = set(tuple(triple) for triple in mid_triple) 174 | mid_triple = [list(triple) for triple in unique_triples] 175 | random.shuffle(mid_triple) 176 | # mid to name 177 | for k in mid_triple: 178 | extend=[] 179 | # k[0] 180 | temp='' 181 | # k[0] is in mid_dict 182 | if mid_dict.get(k[0]): 183 | temp=mid_dict[k[0]] 184 | # k[0] is not entity 185 | if len(temp)==0 and (len(k[0])==1 or k[0][0:2] not in ['m.','n.','g.']): 186 | temp=k[0].replace('-08:00','') 187 | # k[0] is entity 188 | if len(temp)==0: 189 | temp=get_friendly_name(k[0]) 190 | if temp=='null': 191 | if cvt_dict.get(k[0]): 192 | temp=cvt_dict[k[0]] 193 | else: 194 | temp='CVT'+str(cvt) 195 | cvt_dict[k[0]]=temp 196 | cvt+=1 197 | else: 198 | temp=temp.replace('-08:00','') 199 | extend.append(temp) 200 | # k[1] 201 | extend.append(k[1]) 202 | # k[2] 203 | temp='' 204 | # k[2] is in mid_dict 205 | if mid_dict.get(k[2]): 206 | temp=mid_dict[k[2]] 207 | # k[2] is not entity 208 | if len(temp)==0 and (len(k[2])==1 or k[2][0:2] not in ['m.','n.','g.']): 209 | temp=k[2].replace('-08:00','') 210 | # k[2] is entity 211 | if len(temp)==0: 212 | temp=get_friendly_name(k[2]) 213 | if temp=='null': 214 | if cvt_dict.get(k[2]): 215 | temp=cvt_dict[k[2]] 216 | else: 217 | temp='CVT'+str(cvt) 218 | cvt_dict[k[2]]=temp 219 | cvt+=1 220 | else: 221 | temp=temp.replace('-08:00','') 222 | extend.append(temp) 223 | #if extend not in g1: 224 | # g1.append(extend) 225 | if extend not in ex_triple: 226 | ex_triple.append(extend) 227 | # add ex_triple to g1 228 | random.shuffle(ex_triple) 229 | g1.extend(ex_triple) 230 | random.shuffle(g1) 231 | ex_graph.append(g1) 232 | 233 | sample['qid']=one_example['qid'] 234 | sample['question']=one_example['question'] 235 | sample['answer']=one_example['answer'] 236 | sample['sparql_query']=one_example['sparql_query'] 237 | sample['s_expression']=one_example['s_expression'] 238 | sample['graph']=graph 239 | sample['restrict_graph']=graph[:GRAPHNUM] 240 | sample['ex_graph']=ex_graph 241 | graphdata.append(sample) 242 | 243 | json.dump(graphdata,open('graph/'+file_type+'.json','w',encoding='utf-8'),indent=2,ensure_ascii=False) 244 | 245 | query('train') 246 | query('dev') -------------------------------------------------------------------------------- /subgraph/grailqa/sparql_utils/misc.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter, deque 2 | import torch 3 | import json 4 | import pickle 5 | import numpy as np 6 | import torch.nn as nn 7 | import random 8 | import os 9 | import time 10 | ###################################################### 11 | ##################### used in SRN #################### 12 | START_RELATION = 'START_RELATION' 13 | NO_OP_RELATION = 'NO_OP_RELATION' 14 | NO_OP_ENTITY = 'NO_OP_ENTITY' 15 | DUMMY_RELATION = 'DUMMY_RELATION' 16 | DUMMY_ENTITY = 'DUMMY_ENTITY' 17 | 18 | DUMMY_RELATION_ID = 0 19 | START_RELATION_ID = 1 20 | NO_OP_RELATION_ID = 2 21 | DUMMY_ENTITY_ID = 0 22 | NO_OP_ENTITY_ID = 1 23 | 24 | EPSILON = float(np.finfo(float).eps) 25 | HUGE_INT = 1e31 26 | 27 | def format_path(path_trace, id2entity, id2relation): 28 | def get_most_recent_relation(j): 29 | relation_id = int(path_trace[j][0]) 30 | if relation_id == NO_OP_RELATION_ID: 31 | return '' 32 | else: 33 | return id2relation[relation_id] 34 | 35 | def get_most_recent_entity(j): 36 | return id2entity[int(path_trace[j][1])] 37 | 38 | path_str = get_most_recent_entity(0) 39 | for j in range(1, len(path_trace)): 40 | rel = get_most_recent_relation(j) 41 | if not rel.endswith('_inv'): 42 | path_str += ' -{}-> '.format(rel) 43 | else: 44 | path_str += ' <-{}- '.format(rel[:-4]) 45 | path_str += get_most_recent_entity(j) 46 | return path_str 47 | 48 | def pad_and_cat(a, padding_value, padding_dim=1): 49 | max_dim_size = max([x.size()[padding_dim] for x in a]) 50 | padded_a = [] 51 | for x in a: 52 | if x.size()[padding_dim] < max_dim_size: 53 | res_len = max_dim_size - x.size()[1] 54 | pad = nn.ConstantPad1d((0, res_len), padding_value) 55 | padded_a.append(pad(x)) 56 | else: 57 | padded_a.append(x) 58 | return torch.cat(padded_a, dim=0) 59 | 60 | def safe_log(x): 61 | return torch.log(x + EPSILON) 62 | 63 | def entropy(p): 64 | return torch.sum(- p * safe_log(p), 1) 65 | 66 | def init_word2id(): 67 | return { 68 | '': 0, 69 | '': 1, 70 | 'E_S': 2, 71 | } 72 | def init_entity2id(): 73 | return { 74 | DUMMY_ENTITY: DUMMY_ENTITY_ID, 75 | NO_OP_ENTITY: NO_OP_ENTITY_ID 76 | } 77 | def init_relation2id(): 78 | return { 79 | DUMMY_RELATION: DUMMY_RELATION_ID, 80 | START_RELATION: START_RELATION_ID, 81 | NO_OP_RELATION: NO_OP_RELATION_ID 82 | } 83 | 84 | def add_item_to_x2id(item, x2id): 85 | if not item in x2id: 86 | x2id[item] = len(x2id) 87 | 88 | def tile_along_beam(v, beam_size, dim=0): 89 | """ 90 | Tile a tensor along a specified dimension for the specified beam size. 91 | :param v: Input tensor. 92 | :param beam_size: Beam size. 93 | """ 94 | if dim == -1: 95 | dim = len(v.size()) - 1 96 | v = v.unsqueeze(dim + 1) 97 | v = torch.cat([v] * beam_size, dim=dim+1) 98 | new_size = [] 99 | for i, d in enumerate(v.size()): 100 | if i == dim + 1: 101 | new_size[-1] *= d 102 | else: 103 | new_size.append(d) 104 | return v.view(new_size) 105 | ##################### used in SRN #################### 106 | ###################################################### 107 | 108 | 109 | 110 | def init_vocab(): 111 | return { 112 | '': 0, 113 | '': 1, 114 | '': 2, 115 | '': 3 116 | } 117 | 118 | def invert_dict(d): 119 | return {v: k for k, v in d.items()} 120 | 121 | def load_glove(glove_pt, idx_to_token): 122 | glove = pickle.load(open(glove_pt, 'rb')) 123 | dim = len(glove['the']) 124 | matrix = [] 125 | for i in range(len(idx_to_token)): 126 | token = idx_to_token[i] 127 | tokens = token.split() 128 | if len(tokens) > 1: 129 | v = np.zeros((dim,)) 130 | for token in tokens: 131 | v = v + glove.get(token, glove['the']) 132 | v = v / len(tokens) 133 | else: 134 | v = glove.get(token, glove['the']) 135 | matrix.append(v) 136 | matrix = np.asarray(matrix) 137 | return matrix 138 | 139 | 140 | class SmoothedValue(object): 141 | """Track a series of values and provide access to smoothed values over a 142 | window or the global series average. 143 | """ 144 | 145 | def __init__(self, window_size=20): 146 | self.deque = deque(maxlen=window_size) 147 | self.series = [] 148 | self.total = 0.0 149 | self.count = 0 150 | 151 | def update(self, value): 152 | self.deque.append(value) 153 | self.series.append(value) 154 | self.count += 1 155 | self.total += value 156 | 157 | @property 158 | def median(self): 159 | d = torch.tensor(list(self.deque)) 160 | return d.median().item() 161 | 162 | @property 163 | def avg(self): 164 | d = torch.tensor(list(self.deque)) 165 | return d.mean().item() 166 | 167 | @property 168 | def global_avg(self): 169 | return self.total / self.count 170 | 171 | 172 | class MetricLogger(object): 173 | def __init__(self, delimiter="\t"): 174 | self.meters = defaultdict(SmoothedValue) 175 | self.delimiter = delimiter 176 | 177 | def update(self, **kwargs): 178 | for k, v in kwargs.items(): 179 | if isinstance(v, torch.Tensor): 180 | v = v.item() 181 | assert isinstance(v, (float, int)) 182 | self.meters[k].update(v) 183 | 184 | def __getattr__(self, attr): 185 | if attr in self.meters: 186 | return self.meters[attr] 187 | if attr in self.__dict__: 188 | return self.__dict__[attr] 189 | raise AttributeError("'{}' object has no attribute '{}'".format( 190 | type(self).__name__, attr)) 191 | 192 | def __str__(self): 193 | loss_str = [] 194 | for name, meter in self.meters.items(): 195 | loss_str.append( 196 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 197 | ) 198 | return self.delimiter.join(loss_str) 199 | 200 | 201 | def seed_everything(seed=1029): 202 | ''' 203 | 设置整个开发环境的seed 204 | :param seed: 205 | :param device: 206 | :return: 207 | ''' 208 | random.seed(seed) 209 | os.environ['PYTHONHASHSEED'] = str(seed) 210 | np.random.seed(seed) 211 | torch.manual_seed(seed) 212 | torch.cuda.manual_seed(seed) 213 | torch.cuda.manual_seed_all(seed) 214 | # some cudnn methods can be random even after fixing the seed 215 | # unless you tell it to be deterministic 216 | torch.backends.cudnn.deterministic = True 217 | 218 | 219 | class ProgressBar(object): 220 | ''' 221 | custom progress bar 222 | Example: 223 | >>> pbar = ProgressBar(n_total=30,desc='training') 224 | >>> step = 2 225 | >>> pbar(step=step) 226 | ''' 227 | def __init__(self, n_total,width=30,desc = 'Training'): 228 | self.width = width 229 | self.n_total = n_total 230 | self.start_time = time.time() 231 | self.desc = desc 232 | 233 | def __call__(self, step, info={}): 234 | now = time.time() 235 | current = step + 1 236 | recv_per = current / self.n_total 237 | bar = f'[{self.desc}] {current}/{self.n_total} [' 238 | if recv_per >= 1: 239 | recv_per = 1 240 | prog_width = int(self.width * recv_per) 241 | if prog_width > 0: 242 | bar += '=' * (prog_width - 1) 243 | if current< self.n_total: 244 | bar += ">" 245 | else: 246 | bar += '=' 247 | bar += '.' * (self.width - prog_width) 248 | bar += ']' 249 | show_bar = f"\r{bar}" 250 | time_per_unit = (now - self.start_time) / current 251 | if current < self.n_total: 252 | eta = time_per_unit * (self.n_total - current) 253 | if eta > 3600: 254 | eta_format = ('%d:%02d:%02d' % 255 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 256 | elif eta > 60: 257 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 258 | else: 259 | eta_format = '%ds' % eta 260 | time_info = f' - ETA: {eta_format}' 261 | else: 262 | if time_per_unit >= 1: 263 | time_info = f' {time_per_unit:.1f}s/step' 264 | elif time_per_unit >= 1e-3: 265 | time_info = f' {time_per_unit * 1e3:.1f}ms/step' 266 | else: 267 | time_info = f' {time_per_unit * 1e6:.1f}us/step' 268 | 269 | show_bar += time_info 270 | if len(info) != 0: 271 | show_info = f'{show_bar} ' + \ 272 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()]) 273 | print(show_info, end='') 274 | else: 275 | print(show_bar, end='') -------------------------------------------------------------------------------- /subgraph/grailqa/sparql_utils/sparql_engine.py: -------------------------------------------------------------------------------- 1 | import rdflib 2 | from rdflib import URIRef, BNode, Literal, XSD 3 | from rdflib.plugins.stores import sparqlstore 4 | from itertools import chain 5 | from tqdm import tqdm 6 | import argparse 7 | 8 | import sys 9 | from sparql_utils.load_kb import DataForSPARQL 10 | from sparql_utils.value_class import ValueClass 11 | 12 | 13 | virtuoso_address = "http://10.201.190.172:8890/sparql" 14 | # virtuoso_graph_uri = 'KQApro' 15 | virtuoso_graph_uri = 'freebase' 16 | 17 | 18 | 19 | def legal(s): 20 | # convert predicate and attribute keys to legal format 21 | return s.replace(' ', '_') 22 | 23 | def esc_escape(s): 24 | ''' 25 | Why we need this: 26 | If there is an escape in Literal, such as '\EUR', the query string will be something like '?pv "\\EUR"'. 27 | However, in virtuoso engine, \\ is connected with E, and \\E forms a bad escape sequence. 28 | So we must repeat \\, and virtuoso will consider "\\\\EUR" as "\EUR". 29 | 30 | Note this must be applied before esc_quot, as esc_quot will introduce extra escapes. 31 | ''' 32 | return s.replace('\\', '\\\\') 33 | 34 | def esc_quot(s): 35 | ''' 36 | Why we need this: 37 | We use "" to represent a literal value in the sparql query. 38 | If the has a double quotation mark itself, we must escape it to make sure the query is valid for the virtuoso engine. 39 | ''' 40 | return s.replace('"', '\\"') 41 | 42 | class SparqlEngine(): 43 | gs1 = None 44 | PRED_INSTANCE = 'pred:instance_of' 45 | PRED_NAME = 'pred:name' 46 | 47 | PRED_VALUE = 'pred:value' # link packed value node to its literal value 48 | PRED_UNIT = 'pred:unit' # link packed value node to its unit 49 | 50 | PRED_YEAR = 'pred:year' # link packed value node to its year value, which is an integer 51 | PRED_DATE = 'pred:date' # link packed value node to its date value, which is a date 52 | 53 | PRED_FACT_H = 'pred:fact_h' # link qualifier node to its head 54 | PRED_FACT_R = 'pred:fact_r' 55 | PRED_FACT_T = 'pred:fact_t' 56 | 57 | SPECIAL_PREDICATES = (PRED_INSTANCE, PRED_NAME, PRED_VALUE, PRED_UNIT, PRED_YEAR, PRED_DATE, PRED_FACT_H, PRED_FACT_R, PRED_FACT_T) 58 | def __init__(self, data, ttl_file=''): 59 | self.nodes = nodes = {} 60 | for i in chain(data.concepts, data.entities): 61 | nodes[i] = URIRef(i) 62 | for p in chain(data.predicates, data.attribute_keys, SparqlEngine.SPECIAL_PREDICATES): 63 | nodes[p] = URIRef(legal(p)) 64 | 65 | self.graph = graph = rdflib.Graph() 66 | 67 | for i in chain(data.concepts, data.entities): 68 | name = data.get_name(i) 69 | graph.add((nodes[i], nodes[SparqlEngine.PRED_NAME], Literal(name))) 70 | 71 | for ent_id in tqdm(data.entities, desc='Establishing rdf graph'): 72 | for con_id in data.get_all_concepts(ent_id): 73 | graph.add((nodes[ent_id], nodes[SparqlEngine.PRED_INSTANCE], nodes[con_id])) 74 | for (k, v, qualifiers) in data.get_attribute_facts(ent_id): 75 | h, r = nodes[ent_id], nodes[k] 76 | t = self._get_value_node(v) 77 | graph.add((h, r, t)) 78 | fact_node = self._new_fact_node(h, r, t) 79 | 80 | for qk, qvs in qualifiers.items(): 81 | for qv in qvs: 82 | h, r = fact_node, nodes[qk] 83 | t = self._get_value_node(qv) 84 | if len(list(graph[t])) == 0: 85 | print(t) 86 | graph.add((h, r, t)) 87 | 88 | for (pred, obj_id, direction, qualifiers) in data.get_relation_facts(ent_id): 89 | if direction == 'backward': 90 | if data.is_concept(obj_id): 91 | h, r, t = nodes[obj_id], nodes[pred], nodes[ent_id] 92 | else: 93 | continue 94 | else: 95 | h, r, t = nodes[ent_id], nodes[pred], nodes[obj_id] 96 | graph.add((h, r, t)) 97 | fact_node = self._new_fact_node(h, r, t) 98 | for qk, qvs in qualifiers.items(): 99 | for qv in qvs: 100 | h, r = fact_node, nodes[qk] 101 | t = self._get_value_node(qv) 102 | graph.add((h, r, t)) 103 | 104 | if ttl_file: 105 | print('Save graph to {}'.format(ttl_file)) 106 | graph.serialize(ttl_file, format='turtle') 107 | 108 | 109 | def _get_value_node(self, v): 110 | # we use a URIRef node, because we need its reference in query results, which is not supported by BNode 111 | if v.type == 'string': 112 | node = BNode() 113 | self.graph.add((node, self.nodes[SparqlEngine.PRED_VALUE], Literal(v.value))) 114 | return node 115 | elif v.type == 'quantity': 116 | # we use a node to pack value and unit 117 | node = BNode() 118 | self.graph.add((node, self.nodes[SparqlEngine.PRED_VALUE], Literal(v.value, datatype=XSD.double))) 119 | self.graph.add((node, self.nodes[SparqlEngine.PRED_UNIT], Literal(v.unit))) 120 | return node 121 | elif v.type == 'year': 122 | node = BNode() 123 | self.graph.add((node, self.nodes[SparqlEngine.PRED_YEAR], Literal(v.value))) 124 | return node 125 | elif v.type == 'date': 126 | # use a node to pack year and date 127 | node = BNode() 128 | self.graph.add((node, self.nodes[SparqlEngine.PRED_YEAR], Literal(v.value.year))) 129 | self.graph.add((node, self.nodes[SparqlEngine.PRED_DATE], Literal(v.value, datatype=XSD.date))) 130 | return node 131 | 132 | def _new_fact_node(self, h, r, t): 133 | node = BNode() 134 | self.graph.add((node, self.nodes[SparqlEngine.PRED_FACT_H], h)) 135 | self.graph.add((node, self.nodes[SparqlEngine.PRED_FACT_R], r)) 136 | self.graph.add((node, self.nodes[SparqlEngine.PRED_FACT_T], t)) 137 | return node 138 | 139 | 140 | def query_virtuoso(q): 141 | endpoint = virtuoso_address 142 | store=sparqlstore.SPARQLUpdateStore(endpoint) 143 | gs = rdflib.ConjunctiveGraph(store) 144 | gs.open((endpoint, endpoint)) 145 | gs1 = gs.get_context(rdflib.URIRef(virtuoso_graph_uri)) 146 | res = gs1.query(q) 147 | return res 148 | 149 | 150 | 151 | def get_sparql_answer(sparql, data): 152 | """ 153 | data: DataForSPARQL object, we need the key_type 154 | """ 155 | try: 156 | # infer the parse_type based on sparql 157 | if sparql.startswith('SELECT DISTINCT ?e') or sparql.startswith('SELECT ?e'): 158 | parse_type = 'name' 159 | elif sparql.startswith('SELECT (COUNT(DISTINCT ?e)'): 160 | parse_type = 'count' 161 | elif sparql.startswith('SELECT DISTINCT ?p '): 162 | parse_type = 'pred' 163 | elif sparql.startswith('ASK'): 164 | parse_type = 'bool' 165 | else: 166 | tokens = sparql.split() 167 | tgt = tokens[2] 168 | for i in range(len(tokens)-1, 1, -1): 169 | if tokens[i]=='.' and tokens[i-1]==tgt: 170 | key = tokens[i-2] 171 | break 172 | key = key[1:-1].replace('_', ' ') 173 | t = data.key_type[key] 174 | parse_type = 'attr_{}'.format(t) 175 | 176 | parsed_answer = None 177 | res = query_virtuoso(sparql) 178 | if res.vars: 179 | res = [[binding[v] for v in res.vars] for binding in res.bindings] 180 | if len(res) != 1: 181 | return None 182 | else: 183 | res = res.askAnswer 184 | assert parse_type == 'bool' 185 | 186 | if parse_type == 'name': 187 | node = res[0][0] 188 | sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v . }}'.format(node, SparqlEngine.PRED_NAME) 189 | res = query_virtuoso(sp) 190 | res = [[binding[v] for v in res.vars] for binding in res.bindings] 191 | name = res[0][0].value 192 | parsed_answer = name 193 | elif parse_type == 'count': 194 | count = res[0][0].value 195 | parsed_answer = str(count) 196 | elif parse_type.startswith('attr_'): 197 | node = res[0][0] 198 | v_type = parse_type.split('_')[1] 199 | unit = None 200 | if v_type == 'string': 201 | sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v . }}'.format(node, SparqlEngine.PRED_VALUE) 202 | elif v_type == 'quantity': 203 | # Note: For those large number, ?v is truncated by virtuoso (e.g., 14756087 to 1.47561e+07) 204 | # To obtain the accurate ?v, we need to cast it to str 205 | sp = 'SELECT DISTINCT ?v,?u,(str(?v) as ?sv) WHERE {{ <{}> <{}> ?v ; <{}> ?u . }}'.format(node, SparqlEngine.PRED_VALUE, SparqlEngine.PRED_UNIT) 206 | elif v_type == 'year': 207 | sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v . }}'.format(node, SparqlEngine.PRED_YEAR) 208 | elif v_type == 'date': 209 | sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v . }}'.format(node, SparqlEngine.PRED_DATE) 210 | else: 211 | raise Exception('unsupported parse type') 212 | res = query_virtuoso(sp) 213 | res = [[binding[v] for v in res.vars] for binding in res.bindings] 214 | # if there is no specific date, then convert the type to year 215 | if len(res)==0 and v_type == 'date': 216 | v_type = 'year' 217 | sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v . }}'.format(node, SparqlEngine.PRED_YEAR) 218 | res = query_virtuoso(sp) 219 | res = [[binding[v] for v in res.vars] for binding in res.bindings] 220 | if v_type == 'quantity': 221 | value = float(res[0][2].value) 222 | unit = res[0][1].value 223 | else: 224 | value = res[0][0].value 225 | value = ValueClass(v_type, value, unit) 226 | parsed_answer = str(value) 227 | elif parse_type == 'bool': 228 | parsed_answer = 'yes' if res else 'no' 229 | elif parse_type == 'pred': 230 | parsed_answer = str(res[0][0]) 231 | parsed_answer = parsed_answer.replace('_', ' ') 232 | return parsed_answer 233 | except Exception: 234 | return None 235 | 236 | 237 | if __name__ == '__main__': 238 | parser = argparse.ArgumentParser() 239 | # input and output 240 | parser.add_argument('--kb_path', required=True) 241 | parser.add_argument('--ttl_path', required=True) 242 | args = parser.parse_args() 243 | 244 | data = DataForSPARQL(args.kb_path) 245 | engine = SparqlEngine(data, args.ttl_path) 246 | -------------------------------------------------------------------------------- /subgraph/grailqa/sparql_utils/value_class.py: -------------------------------------------------------------------------------- 1 | def comp(a, b, op): 2 | """ 3 | Args: 4 | - a (ValueClass): attribute value of a certain entity 5 | - b (ValueClass): comparison target 6 | - op: =/>/': 21 | return a > b 22 | elif op == '!=': 23 | return a != b 24 | 25 | class ValueClass(): 26 | def __init__(self, type, value, unit=None): 27 | """ 28 | When type is 29 | - string, value is a str 30 | - quantity, value is a number and unit is required 31 | - year, value is a int 32 | - date, value is a date object 33 | """ 34 | self.type = type 35 | self.value = value 36 | self.unit = unit 37 | 38 | def isTime(self): 39 | return self.type in {'year', 'date'} 40 | 41 | def can_compare(self, other): 42 | if self.type == 'string': 43 | return other.type == 'string' 44 | elif self.type == 'quantity': 45 | # NOTE: for two quantity, they can compare only when they have the same unit 46 | return other.type == 'quantity' and other.unit == self.unit 47 | else: 48 | # year can compare with date 49 | return other.type == 'year' or other.type == 'date' 50 | 51 | def contains(self, other): 52 | """ 53 | check whether self contains other, which is different from __eq__ and the result is asymmetric 54 | used for conditions like whether 2001-01-01 in 2001, or whether 2001 in 2001-01-01 55 | """ 56 | if self.type == 'year': # year can contain year and date 57 | other_value = other.value if other.type == 'year' else other.value.year 58 | return self.value == other_value 59 | elif self.type == 'date': # date can only contain date 60 | return other.type == 'date' and self.value == other.value 61 | else: 62 | raise Exception('not supported type: %s' % self.type) 63 | 64 | 65 | def __eq__(self, other): 66 | """ 67 | 2001 and 2001-01-01 is not equal 68 | """ 69 | assert self.can_compare(other) 70 | return self.type == other.type and self.value == other.value 71 | 72 | def __lt__(self, other): 73 | """ 74 | Comparison between a year and a date will convert them both to year 75 | """ 76 | assert self.can_compare(other) 77 | if self.type == 'string': 78 | raise Exception('try to compare two string') 79 | elif self.type == 'quantity': 80 | return self.value < other.value 81 | elif self.type == 'year': 82 | other_value = other.value if other.type == 'year' else other.value.year 83 | return self.value < other_value 84 | elif self.type == 'date': 85 | if other.type == 'year': 86 | return self.value.year < other.value 87 | else: 88 | return self.value < other.value 89 | 90 | def __gt__(self, other): 91 | assert self.can_compare(other) 92 | if self.type == 'string': 93 | raise Exception('try to compare two string') 94 | elif self.type == 'quantity': 95 | return self.value > other.value 96 | elif self.type == 'year': 97 | other_value = other.value if other.type == 'year' else other.value.year 98 | return self.value > other_value 99 | elif self.type == 'date': 100 | if other.type == 'year': 101 | return self.value.year > other.value 102 | else: 103 | return self.value > other.value 104 | 105 | def __str__(self): 106 | if self.type == 'string': 107 | return self.value 108 | elif self.type == 'quantity': 109 | if self.value - int(self.value) < 1e-5: 110 | v = int(self.value) 111 | else: 112 | v = self.value 113 | return '{} {}'.format(v, self.unit) if self.unit != '1' else str(v) 114 | elif self.type == 'year': 115 | return str(self.value) 116 | elif self.type == 'date': 117 | return self.value.isoformat() 118 | --------------------------------------------------------------------------------