├── Figures ├── intro_figure.pdf └── intro_figure.png ├── requirement.txt ├── tot ├── tasks │ ├── base.py │ ├── __init__.py │ ├── fever.py │ └── bamboogle.py ├── prompts │ ├── bamboogle.py │ └── fever.py ├── llama_models.py └── methods │ └── bfs_test.py ├── README.md ├── clean_dataset.py ├── dpo_training.py ├── load_data.py └── run_test.py /Figures/intro_figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/CPO/HEAD/Figures/intro_figure.pdf -------------------------------------------------------------------------------- /Figures/intro_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/CPO/HEAD/Figures/intro_figure.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.2+cu121 2 | transformers==4.38.2 3 | accelerate==0.27.1 4 | trl==0.7.11 5 | datasets==2.18.0 6 | -------------------------------------------------------------------------------- /tot/tasks/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data') 3 | 4 | class Task: 5 | def __init__(self): 6 | pass 7 | 8 | def __len__(self) -> int: 9 | pass 10 | 11 | def get_input(self, idx: int) -> str: 12 | pass 13 | 14 | def test_output(self, idx: int, output: str): 15 | pass 16 | -------------------------------------------------------------------------------- /tot/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | def get_task(name): 2 | if name == 'game24': 3 | from tot.tasks.game24 import Game24Task 4 | return Game24Task() 5 | elif name == 'text': 6 | from tot.tasks.text import TextTask 7 | return TextTask() 8 | elif name == 'crosswords': 9 | from tot.tasks.crosswords import MiniCrosswordsTask 10 | return MiniCrosswordsTask() 11 | elif name == 'bamboogle': 12 | from tot.tasks.bamboogle import FactualQA 13 | return FactualQA() 14 | elif name == '2wiki': 15 | from tot.tasks.wiki import FactualQA 16 | return FactualQA() 17 | elif name == 'qasc': 18 | from tot.tasks.qasc import FactualQA 19 | return FactualQA() 20 | elif name == 'fever': 21 | from tot.tasks.fever import FactualQA 22 | return FactualQA() 23 | else: 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /tot/prompts/bamboogle.py: -------------------------------------------------------------------------------- 1 | cot_prompt = ''' 2 | Task: Answer the given question step-by-step, and conclude with the phrase 'so the final answer is: '. 3 | Question: Who lived longer, Theodor Haecker or Harry Vaughan Watkins? Answer: Step 1, when did Theodor Haecker die? Theodor Haecker was 65 years old when he died. Step 2, when did Harry Vaughan Watkins die? Harry Vaughan Watkins was 69 years old when he died. Step 3, so the final answer is: Harry Vaughan Watkins. 4 | Question: Why did the founder of Versus die? Answer: Step 1, who is the funder of Versus? The founder of Versus was Gianni Versace. Step 2, why did Gianni Versace die? Gianni Versace was shot and killed on the steps of his Miami Beach mansion on July 15, 1997. Step 3, so the final answer is: Shot. 5 | Question: Who is the grandchild of Dambar Shah? Answer: Step 1, who is the son of Dambar Shah? Dambar Shah (? - 1645) was the father of Krishna Shah. Step 2, who is the son of Krishna Shah? Krishna Shah (? - 1661) was the father of Rudra Shah. Step 3, so the final answer is: Rudra Shah. 6 | Question: {input} 7 | ''' 8 | -------------------------------------------------------------------------------- /tot/prompts/fever.py: -------------------------------------------------------------------------------- 1 | cot_prompt = '''Task: Determine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFO. 2 | Claim: Reg Watson is a current television producer. Answer: Step 1, who is Reg Watson? Reginald James Watson AM was an Australian television producer and screenwriter. Step 2, when did Reginald James Watson AM die? Reginald James Watson AM died on 8 October 2019. Step 3, so the final answer is: REFUTES. 3 | Claim: The Gadsden flag was named by Christopher Gadsden. Answer: Step 1, what is the origin of the name of the Gadsden flag? The Gadsden flag is named after politician Christopher Gadsden. Step 2,who named the Gadsden flag? there is no information on who named the Gadsden flag. Step 3, so the final answer is: NOT ENOUGH INFO. 4 | Claim: Black Mirror is about society. Answer: Step 1, what is the son of Black Mirror? Black Mirror is a British anthology television series. Step 2, what issues does this series discuss? The series uses technology to comment on contemporary social issues. Step 3, so the final answer is: SUPPORTS. 5 | Claim: {input} 6 | ''' 7 | -------------------------------------------------------------------------------- /tot/llama_models.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import fire 4 | import gradio as gr 5 | import torch 6 | import transformers 7 | from typing import List 8 | from peft import PeftModel 9 | from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer 10 | from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \ 11 | STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings 12 | 13 | 14 | def gpt(prompt, tokenizer, GenerationConfig, model, device, temperature=0.5, max_tokens=1000, n=1, stop=None,top_p=0.8, 15 | ): 16 | stopping_criteria = None 17 | def evaluate( 18 | prompt, 19 | input=None, 20 | temperature=0.4, 21 | top_p=top_p, 22 | top_k=40, 23 | num_beams=n, 24 | max_new_tokens=32, 25 | **kwargs, 26 | ): 27 | if isinstance(prompt,list): 28 | try: 29 | inputs = tokenizer(prompt, padding=True, return_tensors="pt", truncation=True, max_length=4000) 30 | except: 31 | print('tokenizer') 32 | print(prompt) 33 | exit() 34 | else: 35 | inputs = tokenizer(prompt, return_tensors="pt") 36 | input_ids = inputs["input_ids"].to(device) 37 | attention_mask = inputs["attention_mask"].to(device) 38 | if isinstance(prompt,list): 39 | # generation_output = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, attention_mask=attention_mask, repetition_penalty=1.1, do_sample=True, num_beams=1, temperature=temperature, num_return_sequences=1) 40 | try: 41 | generation_output = model.generate(input_ids=input_ids, pad_token_id = 2, max_new_tokens=max_new_tokens, attention_mask=attention_mask, do_sample=False, num_return_sequences=1) 42 | s = tokenizer.batch_decode(generation_output, skip_special_tokens = True) 43 | except Exception as error: 44 | print(error) 45 | print(prompt) 46 | s = [] 47 | else: 48 | if n == 10: 49 | generation_output = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, attention_mask=attention_mask, do_sample=False, num_return_sequences=1) 50 | s = tokenizer.batch_decode(generation_output, skip_special_tokens = True) 51 | else: 52 | try: 53 | s = [] 54 | for i in range(n): 55 | generation_output = model.generate(input_ids=input_ids, pad_token_id = 2, max_new_tokens=max_new_tokens, attention_mask=attention_mask, repetition_penalty=1.1, do_sample=True, temperature=temperature, num_return_sequences=1) 56 | s.append(tokenizer.batch_decode(generation_output, skip_special_tokens = True)[0]) 57 | except Exception as error: 58 | print(error) 59 | # print('===oom===') 60 | print(prompt) 61 | s = [] 62 | 63 | return s, input_ids 64 | 65 | return evaluate(prompt, temperature=temperature, max_new_tokens=max_tokens,stopping_criteria=stopping_criteria) 66 | 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPO: Chain of Preference Optimization 2 | 3 | The official implementation of paper: [Chain of Preference Optimization: Improving Chain-of-Thought Reasoning in LLMs](https://arxiv.org/pdf/2406.09136). 4 | 5 | ## Overview 6 | 7 | The recent development of chain-of-thought (CoT) decoding has enabled large language models (LLMs) to generate explicit logical reasoning paths for complex problem-solving. However, research indicates that these paths are not always deliberate and optimal. The tree-of-thought (ToT) method employs tree-searching to extensively explore the reasoning space and find better reasoning paths that CoT decoding might overlook. This deliberation, however, comes at the cost of significantly increased inference complexity. In this work, we demonstrate that fine-tuning LLMs leveraging the search tree constructed by ToT allows CoT to achieve similar or better performance, thereby avoiding the substantial inference burden. This is achieved through \emph{Chain of Preference Optimization} (CPO), where LLMs are fine-tuned to align each step of the CoT reasoning paths with those of ToT using the inherent preference information in the tree-search process. Extensive experimental results show that CPO significantly improves LLM performance in solving a variety of complex problems, including question answering, fact verification, and arithmetic reasoning, demonstrating its effectiveness. 8 | 9 | ![](https://github.com/sail-sg/CPO/blob/main/Figures/intro_figure.png) 10 | 11 | ## Setup 12 | 13 | ``` 14 | pip install -r requirement.txt 15 | ``` 16 | 17 | ## Quick Start 18 | 19 | We show examples of one task. By simply changing the task's name, the approach can be applied to other tasks. 20 | 21 | ### Data Collection via ToT 22 | 23 | 1. Selecting reasoning path via ToT. 24 | 25 | ``` 26 | accelerate launch run_test.py --task bamboogle --method_generate sample --method_evaluate value --method_select greedy --n_evaluate_sample 5 --n_generate_sample 15 --n_select_sample 3 --base_model ./model/Llama-2-7b-hf --data_json_file bamboogle_7b.json --train True >>logs/bamboogle_7b_tot_test.out 27 | ``` 28 | 29 | 2. Collect paired preference thoughts for optimization. 30 | 31 | ``` 32 | python clean_dataset.py 33 | ``` 34 | 35 | ### Training via CPO 36 | 37 | ``` 38 | accelerate launch dpo_training.py --dataset bam_7b_data.json --wandb_name dpo_7b_bam --base_model ./model/Llama-2-7b-hf --output_dir ./results/results_bam_7b_dpo 39 | ``` 40 | 41 | ### Testing over CoT 42 | 43 | ``` 44 | accelerate launch run_test.py --task bamboogle --naive_run --method_generate greedy --base_model ./results/results_bam_7b_dpo >>logs/bam_7b_dpo.out 45 | 46 | ``` 47 | 48 | ## Reference Repositories 49 | 50 | - Tree-of-thought(ToT) [https://github.com/princeton-nlp/tree-of-thought-llm/](https://github.com/princeton-nlp/tree-of-thought-llm/) 51 | - Direct Preference Optimization (DPO) [https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) 52 | 53 | ## Citation 54 | 55 | If you find CPO helpful or intriguing and decide to use it, kindly acknowledge the paper by citing it and consider starring this repo, thanks! 56 | 57 | ```bibtex 58 | @article{zhang2024chain, 59 | title={Chain of Preference Optimization: Improving Chain-of-Thought Reasoning in LLMs}, 60 | author={Zhang, Xuan and Du, Chao and Pang, Tianyu and Liu, Qian and Gao, Wei and Lin, Min}, 61 | journal={arXiv preprint arXiv:2406.09136}, 62 | year={2024} 63 | } 64 | 65 | -------------------------------------------------------------------------------- /tot/tasks/fever.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | # import sympy 4 | import pandas as pd 5 | from tot.tasks.base import Task, DATA_PATH 6 | from tot.prompts.fever import * 7 | 8 | wiki_evaluate = '''Evaluate whether the language model can effectively decompose the claim into relevant sub-questions, and assess whether this decomposition helps in partially or directly verifying the original claim. The outcome will determine if this process of decomposition is "Likely" or "Impossible" to aid in verifing the claim. 9 | 10 | Evaluation Steps: Check if the language model can identify and decompose key sub-questions that are directly related to the original question. 11 | Evaluation Process: 1. Analyze whether each sub-question identified by the model is directly relevant to the verify the original claim. 2. Determine if the decomposition of these sub-questions forms a reasonable verification process to the original claim. 12 | Evaluation Result: 1. Likely: If the language model successfully decomposes the original claim into relevant sub-questions that help construct the final answer. 2. Impossible: If the language model fails to effectively decompose the claim, or if the decomposed sub-questions are not directly relevant to make the verification. 13 | 14 | Claim: Reg Watson is a current television producer. 15 | Thought Process: Step 1, who is Reg Watson? Reginald James Watson AM was an Australian television producer and screenwriter. 16 | Evaluation Process: 17 | Relevance of Sub-Questions: The sub-question of identifying who Reg Watson is, is directly relevant as it establishes the necessary context to further explore the original claim about his current status as a television producer. 18 | Effectiveness of Decomposition: By first identifying Reg Watson and then investigating his current professional activities, this approach forms a reasonable verification process. 19 | Evaluation Result: 20 | Likely 21 | 22 | Claim: The Gadsden flag was named by Christopher Gadsden. 23 | Thought: Step 1, why did Christopher Gadsden die? Gadsden died from head injuries suffered in a fall near his home. 24 | Evaluation Process: 25 | Relevance of Sub-Questions: The sub-question about the cause of Gadsden's death does not directly contribute to verifying the claim about the naming of the Gadsden flag. Instead, questions should focus on his involvement with the flag and any direct actions or contributions he made towards its naming. 26 | Verification Process: A more effective verification process would involve gathering evidence of Christopher Gadsden's direct involvement with the flag, including any documented instances where he is credited with its naming, as well as understanding the historical context of the flag's creation and use. 27 | Evaluation Results: 28 | Impossible 29 | 30 | Claim: Black Mirror is about society. 31 | Thought Process: Step 1, what is the son of Black Mirror? Black Mirror is a British anthology television series. Step 2, what issues does this series discuss? The series uses technology to comment on contemporary social issues. 32 | Evaluation Process: 33 | Relevance of Sub-Questions: Each sub-question is directly relevant and helps verify the original claim. The first establishes the series' nature and scope, and the second addresses the thematic content, specifically its societal focus. 34 | Verification Process: This process is reasonable for verifying the original claim. First, it establishes what "Black Mirror" is, laying the groundwork for further inquiry. Then, it dives into the series' thematic concerns, confirming its focus on societal issues through the lens of technology. This approach not only verifies the claim but also provides insight into how the series approaches its critique of society. 35 | Evaluation Results: 36 | Likely 37 | 38 | Claim: ''' 39 | 40 | Final_evaluate = '''Evaluate whether the conclusion can be drawn based on reasoning logic in the thought process." (Likely/Impossible). 41 | 42 | Claim: The Gadsden flag was named by Christopher Gadsden. 43 | Thought Process: Step 1, why did Christopher Gadsden die? Gadsden died from head injuries suffered in a fall near his home. Step 2, what is the origin of the name of the Gadsden flag? The Gadsden flag is named after politician Christopher Gadsden. 44 | Conclusion: so the final answer is: REFUTES. 45 | Evaluation Process: 46 | The thought process includes a correct statement about the origin of the Gadsden flag's name that aligns with the claim, but then concludes incorrectly that this information refutes the claim. 47 | Evaluation Results: 48 | Impossible 49 | 50 | Claim: Black Mirror is about society. 51 | Thought Process: Step 1, what is the son of Black Mirror? Black Mirror is a British anthology television series. Step 2, what issues does this series discuss? The series uses technology to comment on contemporary social issues. 52 | Conclusion: so the final answer is: SUPPORTS. 53 | Evaluation Process: 54 | The conclusion logically follows from the information provided, supporting the claim that "Black Mirror" is about society. 55 | Evaluation Results: 56 | Likely 57 | 58 | Claim: ''' 59 | 60 | class FactualQA(Task): 61 | """ 62 | Input (x) : a string of 4 numbers 63 | Output (y) : a trajectory of 3 steps to reach 24 64 | Reward (r) : 0 or 1, depending on whether the trajectory is correct 65 | Input Example: 66 | 1 2 3 4 67 | Output Example: 68 | 1 + 2 = 3 (left: 3 3 4) 69 | 3 + 3 = 6 (left: 4 6) 70 | 6 * 4 = 24 (left: 24) 71 | (1 + 2 + 3) * 4 = 24 72 | """ 73 | def __init__(self, file='Bamboogle Prerelease - Sheet1.csv'): 74 | """ 75 | file: a csv file (fixed) 76 | """ 77 | super().__init__() 78 | path = os.path.join(DATA_PATH, 'bamboogle', file) 79 | self.data = list(pd.read_csv(path)['Question']) 80 | self.ground_truth = list(pd.read_csv(path)['Answer']) 81 | self.value_cache = {} 82 | self.steps = 3 83 | self.stops = ['.', '.','Question'] 84 | 85 | def __len__(self) -> int: 86 | return len(self.data) 87 | 88 | def get_input(self, idx: int) -> str: 89 | return self.data[idx] 90 | 91 | def test_output(self, ground_truth: str, output: str, out): 92 | if 'answer is' not in output: 93 | print('====output====') 94 | print(output) 95 | return {'r':0}, out 96 | expression = output.strip().split('answer is')[1].lower().split('\n')[0] 97 | expression = expression.replace(': ', '') 98 | ground_truth = str(ground_truth) 99 | print('====GR===='+str(ground_truth) +'====Pre===='+str(expression)) 100 | # if re.search(ground_truth, expression, re.IGNORECASE): 101 | if ground_truth in expression: 102 | return {'r': 1}, out 103 | else: 104 | expression_ = re.sub(r'\W+', '', expression, flags=re.IGNORECASE) 105 | ground_truth = re.sub(r'\W+', '', ground_truth, flags=re.IGNORECASE) 106 | if re.search(ground_truth, expression_, re.IGNORECASE): 107 | return {'r': 1}, out 108 | else: 109 | ground_truth = ground_truth.split(' ') 110 | tmp = 1 111 | i = 0 112 | flag = 0 113 | while tmp: 114 | tmp = re.search(ground_truth[i], expression, re.IGNORECASE) 115 | i += 1 116 | if i == len(ground_truth): 117 | if tmp: 118 | flag = 1 119 | break 120 | if flag == 1: 121 | return {'r': 1}, out 122 | else: 123 | return {'r': 0}, out 124 | 125 | 126 | 127 | 128 | 129 | 130 | @staticmethod 131 | def cot_prompt_wrap(x: str, y:str='') -> str: 132 | return cot_prompt.format(input=x) + y 133 | 134 | 135 | @staticmethod 136 | def value_prompt_wrap(x: str, y: str) -> str: 137 | # return Final_evaluate + x + '\n' + y + '\nEvaluation Process: \n' 138 | if 'the final answer is' not in y.lower(): 139 | return wiki_evaluate + x +'\nThought Process: ' + y + '\nEvaluation Process:\n' 140 | else: 141 | try: 142 | return Final_evaluate + x + '\nThought Process:' + y.split('step 3, ')[0] + '\nConclusion: ' + y.split('step 3, ')[1] + '\nEvaluation Process: \n' 143 | except: 144 | print(y) 145 | exit() 146 | -------------------------------------------------------------------------------- /clean_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tot.tasks import get_task 3 | import random 4 | from load_data import * 5 | import re 6 | 7 | # task = get_task('fever') 8 | # path = 'feverous_13b_2.json' 9 | # file_name = 'feverous_13b_2_data.json' 10 | # final_sentence = 'step 3, so the final answer is: ' 11 | # thought_number = 3 12 | 13 | task = get_task('bamboogle') 14 | path = 'bamboogle_7b.json' 15 | file_name = 'bamboogle_7b_data.json' 16 | final_sentence = 'step 3, so the final answer is: ' 17 | thought_number = 3 18 | 19 | final_thought = str(thought_number-1) 20 | with open(path, 'r', 'utf-8') as f: 21 | instances = json.load(f) 22 | Corpus = {} 23 | for instance in instances: 24 | sample = list(instance.keys())[0] 25 | try: 26 | correct_predict = instance[sample]['correct'][0].split(final_sentence)[1] 27 | except: 28 | # correct_predict = instance[sample]['correct'][1].split('final_sentence')[1] 29 | # print(instance[sample]['correct'][0]) 30 | continue 31 | 32 | if ('fever' in path) or (('vitaminc' in path)): 33 | if ('suport' in correct_predict) or ('support' in correct_predict): 34 | correct_predict = 'supports' 35 | if ('refu' in correct_predict) or ('reject' in correct_predict): 36 | correct_predict = 'refutes' 37 | if ('not enough' in correct_predict) or ('no enough' in correct_predict): 38 | correct_predict = 'not enough information' 39 | if correct_predict.replace('.','').strip() not in ['supports', 'refutes', 'refuted', 'not enough info', 'not enough information']: 40 | # print(correct_predict) 41 | continue 42 | for j in range(len(instance[sample][final_thought]['candiate'])): 43 | try: 44 | if (final_sentence) not in instance[sample][final_thought]['candiate'][j]: 45 | continue 46 | pre = instance[sample][final_thought]['candiate'][j].split(final_sentence)[1] 47 | except: 48 | print(instance[sample][final_thought]['candiate'][j]) 49 | print(pre) 50 | exit() 51 | if 'fever' in path: 52 | if correct_predict.replace('.','').strip() in ['supports']: 53 | if 'support' in pre: 54 | instance[sample]['correct'].append(instance[sample][final_thought]['candiate'][j]) 55 | elif correct_predict.replace('.','').strip() in ['refutes', 'refuted']: 56 | if 'refute' in pre: 57 | instance[sample]['correct'].append(instance[sample][final_thought]['candiate'][j]) 58 | elif correct_predict.replace('.','').strip() in ['not enough info', 'not enough information']: 59 | if 'not enough' in pre: 60 | instance[sample]['correct'].append(instance[sample][final_thought]['candiate'][j]) 61 | else: 62 | if (correct_predict.replace('.','').strip() in pre.replace('.','').strip() )&(instance[sample][final_thought]['candiate'][j] not in instance[sample]['correct']): 63 | instance[sample]['correct'].append(instance[sample][final_thought]['candiate'][j]) 64 | if (pre.replace('.','').strip() in correct_predict.replace('.','').strip() )&(instance[sample][final_thought]['candiate'][j] not in instance[sample]['correct']): 65 | instance[sample]['correct'].append(instance[sample][final_thought]['candiate'][j]) 66 | if ('yes' in pre.lower()) & ('yes' in correct_predict.lower()) & (instance[sample][final_thought]['candiate'][j] not in instance[sample]['correct']): 67 | instance[sample]['correct'].append(instance[sample][final_thought]['candiate'][j]) 68 | if ('no' in pre.lower()) & ('no' in correct_predict.lower()) & (instance[sample][final_thought]['candiate'][j] not in instance[sample]['correct']): 69 | instance[sample]['correct'].append(instance[sample][final_thought]['candiate'][j]) 70 | 71 | if 'correct' in instance[sample]: 72 | if isinstance(instance[sample]['correct'], str): 73 | instance[sample]['correct'] = [instance[sample]['correct']] 74 | Corpus[sample] = {} 75 | for cor in instance[sample]['correct']: 76 | correct_list = cor.lower().replace(' * ','*').replace(' + ','+').replace(' +','+').replace('+ ','+').replace(' = ','=').replace('= ','=').replace(' - ','-').replace(' -','-').replace('- ','-').replace(' x ','x').replace(' / ','/').replace('/ ','/').replace('. so the final answer is', '. step 3, so the final answer is').split('step') 77 | correct_list = correct_list[1:] 78 | # print(1111) 79 | if len(correct_list) < thought_number: 80 | continue 81 | if final_sentence.split(', ')[1] not in correct_list[int(final_thought)]: 82 | continue 83 | for thought_idx in range(thought_number): 84 | if str(thought_idx) not in Corpus[sample]: 85 | Corpus[sample][str(thought_idx)] = {} 86 | if 'pos' not in Corpus[sample][str(thought_idx)]: 87 | Corpus[sample][str(thought_idx)]['pos'] = [] 88 | Corpus[sample][str(thought_idx)]['neg'] = [] 89 | Corpus[sample][str(thought_idx)]['prompt'] = [] 90 | choice_list = instance[sample][str(thought_idx)]['candiate'] 91 | if thought_idx == 0: 92 | pos = 'step' + correct_list[thought_idx] 93 | pos = pos.strip().lower() 94 | if pos in Corpus[sample][str(thought_idx)]['pos']: 95 | continue 96 | Corpus[sample][str(thought_idx)]['pos'].append(pos) 97 | neg_template = [] 98 | for choice in choice_list: 99 | choice = choice.strip().replace(' * ','*').replace(' + ','+').replace('+ ','+').replace(' +','+').replace(' ?','?').lower().replace(' = ','=').replace('= ','=').replace(' x ','x').replace(' - ','-').replace(' -','-').replace('- ','-').replace('answer: ','').replace('..','.').replace(' ',' ').replace('\"','\'').replace(' / ','/').replace('/ ','/').replace(' ?','?') 100 | # choice = choice.replace('\'','') 101 | if choice != pos: 102 | neg_template.append(choice) 103 | Corpus[sample][str(thought_idx)]['neg'].append(list(set(neg_template))) 104 | Corpus[sample][str(thought_idx)]['prompt'].append('') 105 | else: 106 | # print(correct_list) 107 | pos = 'step' + correct_list[thought_idx] 108 | pos = pos.lower().strip() 109 | neg_template = [] 110 | neg_correct_part = 'step'+'step'.join(correct_list[:thought_idx]) 111 | neg_correct_part = neg_correct_part.strip() 112 | if (pos in Corpus[sample][str(thought_idx)]['pos'])&(neg_correct_part in Corpus[sample][str(thought_idx)]['prompt']): 113 | continue 114 | Corpus[sample][str(thought_idx)]['pos'].append(pos) 115 | for choice in choice_list: 116 | choice = choice.strip().lower().replace(' * ','*').replace(' + ','+').replace('+ ','+').replace(' +','+').replace(' ?','?').replace(' = ','=').replace('= ','=').replace(' - ','-').replace(' -','-').replace('- ','-').replace('answer: ','').replace(' ',' ').replace('..','.').replace('\"','\'').replace(' / ','/').replace('/ ','/').replace(' ?','?').replace('therefore, the final answer is','so the final answer is').replace('. so the final answer is', '. step 3, so the final answer is') 117 | if (thought_idx == 1)&('step 3' in choice): 118 | choice = choice.split('step 3')[0] 119 | if len(choice.split('step 3'))>2: 120 | choice = 'step 3'.join(choice.split('step 3')[:-1]) 121 | if neg_correct_part in choice: 122 | choice = choice.replace(neg_correct_part,'').strip() 123 | if choice.replace('answer: ','').strip().startswith('step') == False: 124 | choice = 'step'+ 'step'.join(choice.split('step')[1:]) 125 | if choice != pos: 126 | neg_template.append(choice) 127 | Corpus[sample][str(thought_idx)]['neg'].append(list(set(neg_template))) 128 | Corpus[sample][str(thought_idx)]['prompt'].append(neg_correct_part) 129 | for thought_idx in range(thought_number): 130 | neg_samples = Corpus[sample][str(thought_idx)]['neg'] 131 | for j in range(len(neg_samples)): 132 | neg_sample = neg_samples[j] 133 | filtered_neg = [] 134 | for i in range(len(neg_sample)): 135 | neg_s = neg_sample[i].replace('answer: ','').replace(' + ','+').replace('+ ','+').replace(' +','+').replace(' / ','/').replace('/ ','/').replace(' ',' ').replace('..','.').replace('= ','=').replace(' - ','-').replace(' -','-').replace('- ','-').replace(' x ','x').replace('\"','\'').replace(' ?','?').replace('therefore, the final answer is','so the final answer is').replace('. so the final answer is', '. step 3, so the final answer is') 136 | # print(neg_s) 137 | # print(Corpus[sample][str(thought_idx)]['pos']) 138 | if neg_s not in Corpus[sample][str(thought_idx)]['pos']: 139 | filtered_neg.append(neg_s) 140 | neg_samples[j] = filtered_neg 141 | Corpus[sample][str(thought_idx)]['neg'] = neg_samples 142 | if len(Corpus[sample]) == 0: 143 | del Corpus[sample] 144 | 145 | 146 | paired_data = [] 147 | for instance in Corpus: 148 | for thought_idx in range(thought_number): 149 | 150 | for i, pos in enumerate(Corpus[instance][str(thought_idx)]['pos']): 151 | if task == 'math': 152 | propmt = create_demo_text() + "Q: " + instance + "\nA: " + Corpus[instance][str(thought_idx)]['prompt'][i] 153 | else: 154 | propmt = task.cot_prompt_wrap(instance, Corpus[instance][str(thought_idx)]['prompt'][i]) 155 | if len(Corpus[instance][str(thought_idx)]['neg'][i])==0: 156 | continue 157 | for neg in Corpus[instance][str(thought_idx)]['neg'][i]: 158 | if len(neg) <=3: 159 | continue 160 | if ('fever' in path) or (('vitaminc' in path)): 161 | if final_sentence in pos: 162 | if ('suport' in pos) or ('support' in pos): 163 | correct_predict = 'supports' 164 | if ('refu' in pos) or ('reject' in pos): 165 | correct_predict = 'refutes' 166 | if ('not enough' in pos) or ('no enough' in pos): 167 | correct_predict = 'not enough information' 168 | pos = final_sentence + correct_predict + '.' 169 | if ('math' in path) or ('svamp' in path): 170 | if 'step 4' in pos: 171 | if '(arabic numerals) is ' in pos: 172 | # print(111) 173 | pos_ = pos.split('(arabic numerals) is ')[1] 174 | pos_ = pos_.replace(',','') 175 | numbers = re.findall(r'\d+', pos_) 176 | if len(numbers) == 0: 177 | # print(instance) 178 | # print(pos_) 179 | # exit() 180 | continue 181 | else: 182 | continue 183 | 184 | # if random.randint(0, 1) == 0: 185 | # pos,neg = neg,pos 186 | pos = pos.replace(' ',' ').strip().replace('..','.').replace(' .','.').replace('..','.').replace('..','.') 187 | neg = neg.replace('answer: ','').replace('..','.').strip().replace(' .','.').replace('..','.').replace('..','.') 188 | if ('=' in pos): 189 | if ('-' not in pos) & ('+' not in pos) & ('/' not in pos) & ('*' not in pos): 190 | continue 191 | if ('=' in neg): 192 | if ('-' not in neg) & ('+' not in neg) & ('/' not in neg) & ('*' not in neg): 193 | continue 194 | if '841 34.' in pos: 195 | continue 196 | if pos[-1] != '.': 197 | pos_split = pos.split('.') 198 | if len(pos_split)<=1: 199 | continue 200 | else: 201 | pos = '.'.join(pos_split[:-1]) + '.' 202 | if neg[-1] != '.': 203 | neg_split = neg.split('.') 204 | if len(neg_split)<=1: 205 | continue 206 | else: 207 | neg = '.'.join(neg_split[:-1]) + '.' 208 | if '[tex]' in neg: 209 | continue 210 | if { 211 | "prompt": 212 | propmt 213 | , 214 | "chosen": pos, # rated better than k 215 | "rejected": neg, # rated worse than j 216 | } not in paired_data: 217 | paired_data.append( 218 | { 219 | "prompt": 220 | propmt 221 | , 222 | "chosen": pos, # rated better than k 223 | "rejected": neg, # rated worse than j 224 | } 225 | ) 226 | 227 | 228 | print(len(paired_data)) 229 | 230 | with open(file_name,'w','utf-8') as f: 231 | f.write(json.dumps(paired_data,ensure_ascii=False,indent=4)) 232 | 233 | -------------------------------------------------------------------------------- /tot/tasks/bamboogle.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | # import sympy 4 | import pandas as pd 5 | from tot.tasks.base import Task, DATA_PATH 6 | from tot.prompts.bamboogle import * 7 | 8 | choose_evaluate = '''Evaluate whether the provided answer matches any of the options listed in the question (Likely/Impossible). 9 | 10 | Question: You are presented with the question "What do veins carry?" and the following answer choices: - copper - glucose - Energy - Length - oxygen - warmth - Cells - voltage Now knowing that veins generally carry deoxygenated blood and blood carries blood cells, choose the best answer. 11 | Answer: Blood 12 | Evaluation Process: The listed choices include copper, glucose, energy, length, oxygen, warmth, cells, and voltage, but blood is not among these options. 13 | Evaluation Result: 14 | Impossible 15 | 16 | Question: Please answer the following question: You are presented with the question "What converts chemical energy into sound?" and the following answer choices: - a firework - sensory neurons - Gunshots - a bottle - a battery - animals - a flashlight - engines Now knowing that afirecracker converts chemical energy into sound and fireworks are illegal, including firecrackers, choose the best answer. 17 | Answer: a firework 18 | Evaluation Process: Among the provided options, which include sensory neurons, gunshots, a bottle, a battery, animals, a flashlight, and engines, a firework is indeed one of the choices. 19 | Evaluation Results: 20 | Likely 21 | 22 | Question: You are presented with the question "A good way for older adults to strengthen bones is to " and the following answer choices: - Through play - drive more slowly - exercise - use a hearing aid - quadriceps - donate bone marrow - sweating - movement Now knowing that exercise increases a body 's strength and strength training is also effective for increasing bone strength in older adults, choose the best answer. 23 | Answer: exercise 24 | Evaluation Process: Among the options listed, such as through play, driving more slowly, using a hearing aid, quadriceps, donating bone marrow, sweating, and movement, exercise is indeed included. 25 | Evaluation Results: 26 | Likely 27 | 28 | Question: You are presented with the question "what is a negative impact on an organism" and the following answer choices: - malnutrition - plants - hyperthyroidism - sweat - laughter - smallpox - leukemia Now knowing that disease has a negative impact on an organism and infectious diseases and diseases of malnutrition are prevalent, choose the best answer. 29 | Answer: cancer 30 | Evaluation Process: The options provided include malnutrition, plants, hyperthyroidism, sweat, laughter, smallpox, and leukemia, but cancer is not listed among these choices. 31 | Evaluation Results: 32 | Impossible 33 | 34 | Question: ''' 35 | 36 | wiki_evaluate = '''Evaluate whether the language model can effectively decompose the question into relevant sub-questions, and assess whether this decomposition helps in partially or directly answering the original question. The outcome will determine if this process of decomposition is "Likely" or "Impossible" to aid in finding the answer. 37 | 38 | Evaluation Steps: Check if the language model can identify and decompose key sub-questions that are directly related to the original question. 39 | Evaluation Process: 1. Analyze whether each sub-question identified by the model is directly relevant to the answer to the original question. 2. Determine if the decomposition of these sub-questions forms a reasonable response to the original question. 40 | Evaluation Result: 1. Likely: If the language model successfully decomposes the original question into relevant sub-questions that help construct the final answer. 2. Impossible: If the language model fails to effectively decompose the question, or if the decomposed sub-questions are not directly relevant to finding the answer. 41 | 42 | Question: Who lived longer, Theodor Haecker or Harry Vaughan Watkins? 43 | Thought Process: Step 1, when did Theodor Haecker die? Theodor Haecker was 65 years old when he died. Step 2, when did Harry Vaughan Watkins die? Harry Vaughan Watkins was 69 years old when he died. 44 | Evaluation Process: 45 | Relevance of Sub-Questions: The sub-question regarding Theodor Haecker's age at death is directly relevant to the main question, as it provides necessary information to determine his lifespan. Similarly, the sub-question about Harry Vaughan Watkins' age at death is also directly relevant for the same reason. 46 | Effectiveness of Decomposition: The decomposition into two key sub-questions (ages at death of both individuals) is an effective strategy. It breaks down the main question (comparison of lifespans) into specific, answerable elements. Each sub-question contributes a crucial piece of information required to compare the lifespans of the two individuals. 47 | Evaluation Result: 48 | Likely 49 | 50 | Question: When did the last king from Britain's House of Hanover die? 51 | Thought: Step 1, when did the last king from Britain's House of Hanover born? 52 | Evaluation Process: 53 | The thought process focuses on the birth date of the last king from Britain's House of Hanover. However, knowing the birth date does not directly help in determining the date of death, which is the actual question. The lifespan of an individual can vary widely and cannot be accurately inferred from their birth date alone. Therefore, this thought process is unlikely to lead to the correct answer without additional information. 54 | So the evaluation result is: this thought is impossible to help pariticially or directly answer the question. 55 | Evaluation Results: 56 | Impossible 57 | 58 | Question: What is the highest mountain in the world? 59 | Thought Process: Step 1, identify the tallest mountains known globally. Mount Everest is commonly known as the highest mountain peak in the world. 60 | Evaluation Process: 61 | The thought process begins with identifying the tallest mountains known globally, which is a logical first step. Since Mount Everest is commonly known and recognized as the highest mountain peak in the world, this thought directly leads to the answer to the question. Therefore, this approach is very likely to help in answering the question correctly. 62 | So, the evaluation result is: this thought is likely to help partially or directly answer the question. 63 | Evaluation Results: 64 | Likely 65 | 66 | Question: How many planets are in our solar system? 67 | Thought Process: Step 1, consider the composition of the Sun and its impact on the solar system. 68 | Evaluation Process: 69 | The thought process of considering the composition of the Sun and its impact on the solar system does not directly lead to an answer for the number of planets in our solar system. The Sun's composition and its effects are more relevant to solar physics and do not provide specific information about the count or existence of planets. The question requires knowledge about the classification and count of planets in the solar system, which is unrelated to the Sun's composition. 70 | So, the evaluation result is: this thought is impossible to help partially or directly answer the question. 71 | Evaluation Results: 72 | Impossible 73 | 74 | Question: ''' 75 | Final_evaluate = '''Evaluate if the given sentence is possible to answer the question (Likely/Impossible). 76 | 77 | Question: Who was the President of the United States in the year that Citibank was founded? 78 | So the final answer is: james madison. 79 | Evaluation Process: 80 | Yes, james madison is a person and is likely to answer a question start with 'who'. 81 | Evaluation Results: 82 | Likely 83 | 84 | Question: What rocket was the first spacecraft that ever approached Uranus launched on? 85 | So the final answer is: Voyager 2. 86 | Evaluation Process: 87 | Voyager 2 is not a rocket, so it can not answer a question start with 'what rokect'. 88 | Evaluation Results: 89 | Impossible 90 | 91 | Question: ''' 92 | class FactualQA(Task): 93 | """ 94 | Input (x) : a string of 4 numbers 95 | Output (y) : a trajectory of 3 steps to reach 24 96 | Reward (r) : 0 or 1, depending on whether the trajectory is correct 97 | Input Example: 98 | 1 2 3 4 99 | Output Example: 100 | 1 + 2 = 3 (left: 3 3 4) 101 | 3 + 3 = 6 (left: 4 6) 102 | 6 * 4 = 24 (left: 24) 103 | (1 + 2 + 3) * 4 = 24 104 | """ 105 | def __init__(self, file='Bamboogle Prerelease - Sheet1.csv'): 106 | """ 107 | file: a csv file (fixed) 108 | """ 109 | super().__init__() 110 | path = os.path.join(DATA_PATH, 'bamboogle', file) 111 | self.data = list(pd.read_csv(path)['Question']) 112 | self.ground_truth = list(pd.read_csv(path)['Answer']) 113 | self.value_cache = {} 114 | self.steps = 3 115 | self.stops = ['.', '.','Question'] 116 | 117 | def __len__(self) -> int: 118 | return len(self.data) 119 | 120 | def get_input(self, idx: int) -> str: 121 | return self.data[idx] 122 | 123 | def test_output(self, ground_truth: str, output: str, out): 124 | if 'answer is' not in output: 125 | print('====output====') 126 | print(output) 127 | return {'r':0}, out 128 | expression = output.strip().lower().split('so the final answer is')[1].lower().split('\n')[0] 129 | expression = expression.replace(': ', '') 130 | ground_truth = str(ground_truth) 131 | print('====GR===='+str(ground_truth) +'====Pre===='+str(expression)) 132 | # if re.search(ground_truth, expression, re.IGNORECASE): 133 | if ground_truth in expression: 134 | return {'r': 1}, out 135 | else: 136 | expression_ = re.sub(r'\W+', '', expression, flags=re.IGNORECASE) 137 | ground_truth = re.sub(r'\W+', '', ground_truth, flags=re.IGNORECASE) 138 | if re.search(ground_truth, expression_, re.IGNORECASE): 139 | return {'r': 1}, out 140 | else: 141 | ground_truth = ground_truth.split(' ') 142 | tmp = 1 143 | i = 0 144 | flag = 0 145 | while tmp: 146 | tmp = re.search(ground_truth[i], expression, re.IGNORECASE) 147 | i += 1 148 | if i == len(ground_truth): 149 | if tmp: 150 | flag = 1 151 | break 152 | if flag == 1: 153 | return {'r': 1}, out 154 | else: 155 | return {'r': 0}, out 156 | 157 | 158 | 159 | 160 | # @staticmethod 161 | # def standard_prompt_wrap(x: str, y:str='') -> str: 162 | # return standard_prompt.format(input=x) + y 163 | 164 | @staticmethod 165 | def cot_prompt_wrap(x: str, y:str='') -> str: 166 | return cot_prompt.format(input=x) + y 167 | 168 | # @staticmethod 169 | # def value_prompt_wrap(x: str, y: str) -> str: 170 | # return wiki_evaluate + x + '\nThought Process: ' + y + '\nEvaluation Process:' 171 | @staticmethod 172 | def value_prompt_wrap(x: str, y: str) -> str: 173 | # return Final_evaluate + x + '\n' + y + '\nEvaluation Process: \n' 174 | if 'the final answer is' not in y.lower(): 175 | return wiki_evaluate + x +'\nThought Process: ' + y + '\nEvaluation Process:\n' 176 | else: 177 | if 'choose the best answer' in x.lower(): 178 | return choose_evaluate + x +'\nAnswer: ' + y.lower().split('the final answer is')[1].replace(': ','') + '\nEvaluation Process:\n' 179 | else: 180 | return Final_evaluate + x + '\n' + y + '\nEvaluation Process: \n' 181 | -------------------------------------------------------------------------------- /dpo_training.py: -------------------------------------------------------------------------------- 1 | # 0. imports 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import Dict, Optional 5 | import torch 6 | from datasets import Dataset, load_dataset 7 | from peft import LoraConfig 8 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments 9 | 10 | from trl import DPOTrainer 11 | import argparse 12 | import json 13 | # Define and parse arguments. 14 | 15 | 16 | def parse_args(): 17 | args = argparse.ArgumentParser() 18 | args.add_argument('--percentage', type=float, default=1) 19 | args.add_argument('--output_dir', type=str, default="./results_hotpot_7b_base") 20 | args.add_argument('--base_model', type=str, default="") 21 | args.add_argument('--wandb_name', type=str, default='dpo_llama_2') 22 | args.add_argument('--dataset', type=str, default='hotpotqa_7b_data.json') 23 | args.add_argument('--bs', type=int, default=4) 24 | args.add_argument('--lora_r', type=int, default=8) 25 | args.add_argument('--mixed', type=bool, default=False) 26 | args.add_argument('--randomseed', type=int, default=False) 27 | args = args.parse_args() 28 | return args 29 | 30 | args = parse_args() 31 | pct = args.percentage 32 | bs = args.bs 33 | r = args.lora_r 34 | mixed = args.mixed 35 | @dataclass 36 | class ScriptArguments: 37 | """ 38 | The arguments for the DPO training script. 39 | """ 40 | 41 | # data parameters 42 | beta: Optional[float] = field(default=0.2, metadata={"help": "the beta parameter for DPO loss"}) 43 | 44 | # training parameters 45 | base_model: Optional[str] = field( 46 | default=args.base_model, 47 | metadata={"help": "the location of the SFT model name or path"}, 48 | ) 49 | learning_rate: Optional[float] = field(default=5e-6, metadata={"help": "optimizer learning rate"}) 50 | lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) 51 | warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) 52 | weight_decay: Optional[float] = field(default=0.00, metadata={"help": "the weight decay"}) 53 | optimizer_type: Optional[str] = field(default="adamw_torch", metadata={"help": "the optimizer type"}) 54 | mixed: Optional[bool] = field(default=mixed, metadata={"help": "whether training with mixed datasets"}) 55 | per_device_train_batch_size: Optional[int] = field(default=bs, metadata={"help": "train batch size per device"}) 56 | per_device_eval_batch_size: Optional[int] = field(default=bs, metadata={"help": "eval batch size per device"}) 57 | randomseed: Optional[int] = field(default=0, metadata={"help": "randomseed"}) 58 | gradient_accumulation_steps: Optional[int] = field( 59 | default=1, metadata={"help": "the number of gradient accumulation steps"} 60 | ) 61 | gradient_checkpointing: Optional[bool] = field( 62 | default=True, metadata={"help": "whether to use gradient checkpointing"} 63 | ) 64 | percentage: float = field(default=1.0, metadata={"help": "Description of the percentage parameter."}) 65 | bs: float = field(default=4, metadata={"help": "Description of the batch_size parameter."}) 66 | 67 | lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) 68 | lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) 69 | lora_r: Optional[int] = field(default=r, metadata={"help": "the lora r parameter"}) 70 | 71 | max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) 72 | max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) 73 | max_steps: Optional[int] = field(default=900, metadata={"help": "max number of training steps"}) 74 | logging_steps: Optional[int] = field(default=100, metadata={"help": "the logging frequency"}) 75 | save_steps: Optional[int] = field(default=300, metadata={"help": "the saving frequency"}) 76 | eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) 77 | wandb_name: Optional[str] = field(default="dpo_llama_2", metadata={"help": "the output directory"}) 78 | dataset: Optional[str] = field(default="hotpotqa_7b_data.json", metadata={"help": "the output directory"}) 79 | 80 | output_dir: Optional[str] = field(default="./results_hotpot_7b_base", metadata={"help": "the output directory"}) 81 | log_freq: Optional[int] = field(default=100, metadata={"help": "the logging frequency"}) 82 | 83 | # instrumentation 84 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) 85 | report_to: Optional[str] = field( 86 | default="wandb", 87 | metadata={ 88 | "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' 89 | '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' 90 | 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' 91 | }, 92 | ) 93 | # debug argument for distributed training 94 | ignore_bias_buffers: Optional[bool] = field( 95 | default=False, 96 | metadata={ 97 | "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" 98 | "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" 99 | }, 100 | ) 101 | 102 | 103 | def get_stack_exchange_paired( 104 | data_dir: str = "data/rl", 105 | sanity_check: bool = False, 106 | cache_dir: str = None, 107 | num_proc=24, 108 | ) -> Dataset: 109 | """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. 110 | 111 | The dataset is converted to a dictionary with the following structure: 112 | { 113 | 'prompt': List[str], 114 | 'chosen': List[str], 115 | 'rejected': List[str], 116 | } 117 | 118 | Prompts are structured as follows: 119 | "Question: " + + "\n\nAnswer: " 120 | """ 121 | dataset = load_dataset( 122 | "lvwerra/stack-exchange-paired", 123 | split="train", 124 | cache_dir=cache_dir, 125 | data_dir=data_dir, 126 | ) 127 | original_columns = dataset.column_names 128 | 129 | if sanity_check: 130 | dataset = dataset.select(range(min(len(dataset), 1000))) 131 | 132 | def return_prompt_and_responses(samples) -> Dict[str, str]: 133 | return { 134 | "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], 135 | "chosen": samples["response_j"], 136 | "rejected": samples["response_k"], 137 | } 138 | 139 | return dataset.map( 140 | return_prompt_and_responses, 141 | batched=True, 142 | num_proc=num_proc, 143 | remove_columns=original_columns, 144 | ) 145 | 146 | 147 | if __name__ == "__main__": 148 | parser = HfArgumentParser(ScriptArguments) 149 | script_args = parser.parse_args_into_dataclasses()[0] 150 | 151 | # 1. load a pretrained model 152 | print('=====load a pretrained model====') 153 | model = AutoModelForCausalLM.from_pretrained( 154 | args.base_model, 155 | low_cpu_mem_usage=True, 156 | torch_dtype=torch.bfloat16, 157 | # load_in_4bit=True, 158 | ) 159 | 160 | model.config.use_cache = False 161 | 162 | if script_args.ignore_bias_buffers: 163 | # torch distributed hack 164 | model._ddp_params_and_buffers_to_ignore = [ 165 | name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool 166 | ] 167 | 168 | tokenizer = AutoTokenizer.from_pretrained(args.base_model) 169 | tokenizer.pad_token = tokenizer.eos_token 170 | 171 | # 2. Load the Stack-exchange paired dataset 172 | print('====Load the Stack-exchange paired dataset====') 173 | ori_dataset = [] 174 | if args.mixed == False: 175 | with open(args.dataset, 'r') as f: 176 | ori_dataset.extend(json.load(f)) 177 | 178 | len_data = round(len(ori_dataset)*pct) 179 | if pct == 1: 180 | ori_dataset = ori_dataset[:len_data] 181 | else: 182 | import random 183 | random.seed(args.randomseed) 184 | random_numbers = random.sample(range(0, len(ori_dataset)), len_data) 185 | selected_dataset = [] 186 | for i, d in enumerate(ori_dataset): 187 | if i in random_numbers: 188 | selected_dataset.append(d) 189 | # else: 190 | # d['chosen'], d['rejected'] = d['rejected'], d['chosen'] 191 | # selected_dataset.append(d) 192 | ori_dataset = selected_dataset 193 | # if 'negative' in args.output_dir: 194 | # ori_dataset = ori_dataset[:3000] 195 | print('number of paired_data: ' + str(len(ori_dataset))) 196 | # 将数据转换为适合的字典格式 197 | data_dict = {key: [item[key] for item in ori_dataset] for key in ori_dataset[0]} 198 | # 创建datasets.Dataset对象 199 | dataset = Dataset.from_dict(data_dict) 200 | dataset = dataset.train_test_split(test_size=0.1) 201 | train_dataset = dataset['train'] 202 | warmup_steps = round(0.1*len(train_dataset)/(4*bs)) 203 | if warmup_steps < 10: 204 | warmup_steps = 10 205 | # 3. Load evaluation dataset 206 | print('====Load evaluation dataset====') 207 | eval_dataset =dataset['test'] 208 | 209 | 210 | # 4. initialize training arguments: 211 | print('====initialize training arguments:====') 212 | training_args = TrainingArguments( 213 | per_device_train_batch_size=script_args.per_device_train_batch_size, 214 | per_device_eval_batch_size=script_args.per_device_eval_batch_size, 215 | # max_steps=script_args.max_steps, 216 | max_steps=round(len(train_dataset)/(4*bs))*3, 217 | logging_steps=script_args.logging_steps, 218 | # save_steps=script_args.save_steps, 219 | save_steps=round(len(train_dataset)/(4*bs)*0.5), 220 | gradient_accumulation_steps=script_args.gradient_accumulation_steps, 221 | gradient_checkpointing=script_args.gradient_checkpointing, 222 | learning_rate=script_args.learning_rate, 223 | evaluation_strategy="steps", 224 | eval_steps=script_args.eval_steps, 225 | output_dir=args.output_dir, 226 | report_to=script_args.report_to, 227 | lr_scheduler_type=script_args.lr_scheduler_type, 228 | warmup_steps=warmup_steps, 229 | optim=script_args.optimizer_type, 230 | bf16=True, 231 | remove_unused_columns=False, 232 | run_name=args.wandb_name, 233 | ) 234 | 235 | peft_config = LoraConfig( 236 | r=script_args.lora_r, 237 | lora_alpha=script_args.lora_alpha, 238 | lora_dropout=script_args.lora_dropout, 239 | target_modules=[ 240 | "q_proj", 241 | "v_proj", 242 | "k_proj", 243 | "out_proj", 244 | "fc_in", 245 | "fc_out", 246 | "wte", 247 | ], 248 | bias="none", 249 | task_type="CAUSAL_LM", 250 | ) 251 | 252 | # 5. initialize the DPO trainer 253 | print('====initialize the DPO trainer====') 254 | dpo_trainer = DPOTrainer( 255 | model, 256 | None, 257 | args=training_args, 258 | beta=script_args.beta, 259 | train_dataset=train_dataset, 260 | eval_dataset=eval_dataset, 261 | tokenizer=tokenizer, 262 | peft_config=peft_config, 263 | max_prompt_length=script_args.max_prompt_length, 264 | max_length=script_args.max_length, 265 | ) 266 | 267 | # 6. train 268 | print('====train====') 269 | dpo_trainer.train() 270 | dpo_trainer.save_model(script_args.output_dir) 271 | 272 | # 7. save 273 | print('====save====') 274 | output_dir = os.path.join(script_args.output_dir, "final_checkpoint") 275 | dpo_trainer.model.save_pretrained(output_dir) 276 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import re 3 | import random 4 | import json 5 | 6 | def remove_substrings_with_double_angle_brackets(input_string): 7 | # Define the pattern to match substrings within double angled brackets 8 | pattern = r"<<[^>]+>>" 9 | # Use the sub() function from the re module to replace matching substrings with an empty string 10 | result = re.sub(pattern, "", input_string) 11 | return result 12 | 13 | def load_gsm8k_test(path: str = "gsm8k", subset: str = "main", split="test"): 14 | samples = [] 15 | # dataset = dataset[:200] 16 | i = 0 17 | for raw in load_dataset(path, subset, split=split): 18 | i +=1 19 | explanation, answer = raw["answer"].split("####") 20 | explanation = remove_substrings_with_double_angle_brackets(explanation) 21 | samples.append( 22 | { 23 | 'question':raw["question"].strip(), 24 | 'explanation':explanation.strip(), 25 | 'answer':answer.strip(), 26 | } 27 | ) 28 | if i == 200: 29 | break 30 | return samples 31 | 32 | 33 | def load_svamp_test(path: str = 'tot/data/SVAMP/train.json'): 34 | samples = [] 35 | with open(path,'r') as f: 36 | ins = json.load(f) 37 | for d in ins: 38 | samples.append( 39 | { 40 | 'question':d['Body'].strip()+d["Question"].strip(), 41 | 'answer':str(d['Answer']).strip(), 42 | }) 43 | if 'train' in path: 44 | if len(samples)>=300: 45 | break 46 | return samples 47 | 48 | def create_demo_text(cot_flag=True): 49 | x, z, y = [], [], [] 50 | # example sentences ... 51 | if 1: 52 | 53 | x.append("There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?") 54 | z.append("Step 1, There are 15 trees originally, Then there were 21 trees after some more were planted. Step 2, So there must have been 21 - 15 = 6.") 55 | y.append("6") 56 | 57 | x.append("If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?") 58 | z.append("Step 1, There are originally 3 cars, and 2 more cars arrive. Step 2, 3 + 2 = 5.") 59 | y.append("5") 60 | 61 | x.append("Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?") 62 | z.append("Step 1, Originally, Leah had 32 chocolates, and her sister had 42. Step 2, So in total they had 32 + 42 = 74. Step 3, After eating 35, they had 74 - 35 = 39.") 63 | y.append("39") 64 | 65 | x.append("Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?") 66 | z.append("Step 1, Jason started with 20 lollipops. Then he had 12 after giving some to Denny. Step 2, So he gave Denny 20 - 12 = 8.") 67 | y.append("8") 68 | 69 | 70 | else: 71 | raise ValueError("dataset is not properly defined ...") 72 | 73 | # randomize order of the examples ... 74 | index_list = list(range(len(x))) 75 | random.shuffle(index_list) 76 | 77 | # Concatenate demonstration examples ... 78 | direct_answer_trigger_for_fewshot = "The answer (arabic numerals) is " 79 | demo_text = "" 80 | for i in index_list: 81 | if cot_flag: 82 | if 'Step 3' in z[i]: 83 | demo_text += "Q: " + x[i] + "\nA: " + z[i] + " " + \ 84 | 'Step 4, ' + direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n" 85 | else: 86 | demo_text += "Q: " + x[i] + "\nA: " + z[i] + " " + \ 87 | 'Step 3, ' + direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n" 88 | else: 89 | demo_text += "Q: " + x[i] + "\nA: " + \ 90 | direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n" 91 | 92 | return demo_text 93 | 94 | math_evaluate = '''Evaluate whether the thought helps in partially or directly answering the original question (likely/impossible). 95 | 96 | Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? 97 | Thought: step 1, 16 eggs per day means 16 * 24 hours = 384 eggs per day. 98 | Evaluation Process: To answer the question ‘How much in dollars does she make every day at the farmers' market?’, we need know how much of eggs are laid per day and the price of each egg. The given thought multiplies 16 eggs per day by 24 hours, resulting in 384 eggs per day. The actual statement clearly says Janet's ducks lay 16 eggs per day in total, not per hour. Thus, the thought will lead to a wrong answer. So the final evaluation is impossible. 99 | Impossible 100 | 101 | Question: Two trains leave San Rafael at the same time. They begin traveling westward, both traveling for 80 miles. The next day, they travel northwards, covering 150 miles. What's the distance covered by each train in the two days 102 | Thought: step 1, 80 miles was covered in one day, and 150 miles was covered in another day. 103 | Evaluation Process: To answer the question ‘What's the distance covered by each train in the two days?’, the thought provide helpful information without any unnecessary details about the distance covered in both days respectively. Although it cannot directly answer the question, this is very helpful in promoting the next step towards the correct reasoning. So the final evaluation is likely. 104 | Likely 105 | 106 | Question: Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep? 107 | Thought: step 1, 2x as many means that toulouse has 2x as many sheep as charleston. step 2, 4x as many means that charleston has 4x as many sheep as seattle. step 3, 2(2x) = 4x means that toulouse has 4x as many sheep as charleston. 108 | Evaluation Process: To answer the question ‘How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep?’, the thought process makes a critical error in step 3 by stating "2(2x) = 4x means that Toulouse has 4x as many sheep as Charleston." This is incorrect. The initial information that Toulouse has twice as many sheep as Charleston is accurate and should remain the basis for calculation. The error seems to be in misunderstanding the multiplication of relationships. Thus, the thought will lead to a wrong answer. So the final evaluation is impossible. 109 | Impossible 110 | 111 | Question: There are 87 oranges and 290 bananas in Philip's collection. If the bananas are organized into 2 groups and oranges are organized into 93 groups. How big is each group of bananas? 112 | Thought: step 1, 87 oranges and 290 bananas are given. 113 | Evaluation Process: To answer the question ‘How big is each group of bananas?’, we nee to know the number of bananas and groups respectively. The thought provide one of the key information, which is the number of bananas. Although it cannot directly answer the question, this is very helpful in promoting the next step towards the correct reasoning. So the final evaluation is likely. 114 | Likely 115 | 116 | Question: Mary is baking a cake. The recipe calls for 6 cups of flour 8 cups of sugar and 7 cups of salt. She already put in 5 cups of flour. How many more cups of sugar than cups of salt does she need to add now? 117 | Thought: step 1, 6 cups of flour + 8 cups of sugar + 7 cups of salt = 21 cups. 118 | Evaluation Process: To answer the question ‘How many more cups of sugar than cups of salt does she need to add now?’, we need to know the number of cups of sugar and salts respectively, and then calculate the difference. The given thought focuses on summing up the total amount of ingredients needed for the cake, which does not directly address the specific question asked. Thus, the thought will lead to a wrong answer. So the final evaluation is impossible. 119 | Impossible 120 | 121 | Question: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take? 122 | Thought: Step 1, 2 bolts of blue fiber and half that much white fiber means 2/2 + 1/2 = 3. 123 | Evaluation Process: To answer the question ‘How many bolts in total does it take?’, it correctly identifies that if the robe requires 2 bolts of blue fiber and half that amount of white fiber, then the total amount of white fiber needed is half of 2 bolts, which is 1 bolt. By adding the 2 bolts of blue fiber to the 1 bolt of white fiber, it correctly calculates that a total of 3 bolts of fiber are needed to make the robe. This thought process directly addresses the question by accurately calculating the total number of bolts required for the robe. So the final evaluation is likely. 124 | Likely 125 | 126 | ''' 127 | 128 | 129 | math_final_evaluate = '''Evaluate whether the thought answers the question (likely/impossible). 130 | 131 | Question: Paul got a box of 110 crayons for his birthday. During the school year he gave 90 crayons to his friends while he lost 412 crayons.How many more crayons did he lose than those he gave to his friends? 132 | Thought: step 1, 110 crayons was given by paul's parents as a gift. step 2, 90 crayons were given away by him. step 3, 90-110= -20. step 4, the answer (arabic numerals) is -20. 133 | Evaluation Process: The question asks for the difference in the number of crayons he lost compared to those he gave to his friends. However, the thought process incorrectly focuses on the crayons received by Paul from his parents and those he gave away, which is a different comparison. 134 | Imposible 135 | 136 | Question: There are 87 oranges and 290 bananas in Philip's collection. If the bananas are organized into 2 groups and oranges are organized into 93 groupsHow big is each group of bananas? 137 | Thought: step 1, 87 oranges and 290 bananas are given. step 2, 290 divided by 2 is 145. step 3, 145 bananas are in one group. step 4, the answer (arabic numerals) is 145. 138 | Evaluation Process: The question specifically asked how many bananas would be in each of two equally divided groups, then the thought process is correct. 139 | Likely 140 | 141 | Question: Randy has 95 blocks. He uses 20 blocks to build a house and 50 blocks to build a tower.How many more blocks did he use to build the tower than he did to build the house? 142 | Thought: step 1, 95 blocks. step 2, 20 blocks to build a house and 50 blocks to build a tower. step 3, 20 + 50 = 70 blocks used to build the tower. step 4, the answer (arabic numerals) is 70. 143 | Evaluation Process: The question specifically asks for the difference in the number of blocks used to build the tower compared to the house. However, the thought process erroneously adds the number of blocks used for both the house and the tower, leading to a total number of blocks used for both constructions, not the difference between the two as asked. 144 | Impossible 145 | 146 | Question: You have 104 dollars.How many packs of dvds can you buy if each pack costs 26 dollars? 147 | Thought: step 1, 104 dollars will buy 104 / 26 = 4 packs of dvd's. step 2, 4 packs of dvd's cost 4 x 26 = 104 dollars. step 3, 104 dollars buys 4 packs of dvd's. step 4, the answer (arabic numerals) is 4. 148 | Evaluation Process: The question asks how many packs of DVDs can be purchased with 104 dollars if each pack costs 26 dollars. The calculation performed in the thought (104 / 26 = 4) correctly determines the number of DVD packs that can be bought, which is 4 packs. 149 | Likely 150 | 151 | ''' 152 | 153 | def math_test_output(d, y, out): 154 | # print('======testing========') 155 | expression = y.strip().replace(',','') 156 | if 'arabic numerals' in expression: 157 | expression = expression.split('(arabic numerals) is ')[-1] 158 | numbers = re.findall(r'\d+', expression) 159 | try: 160 | problem_numbers = re.findall(r'\d+', d['answer'][0]) 161 | except: 162 | print(d) 163 | return {'r': 0}, out 164 | # print('====GR===='+str(problem_numbers) +'====Pre===='+str(numbers)) 165 | if len(numbers)>0: 166 | numbers = numbers[-1] 167 | if len(problem_numbers)>0: 168 | problem_numbers = problem_numbers[0] 169 | print('====GR===='+str(problem_numbers) +'====Pre===='+str(numbers)) 170 | if numbers != problem_numbers: 171 | return {'r': 0}, out 172 | else: 173 | return {'r': 1}, out 174 | 175 | def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float: 176 | value_map = {'impossible': 0.001, 'unlikely': 0.001, 'likely': 1, 'sure': 1} # TODO: ad hoc 177 | value = 0.001 178 | for item in value_outputs: 179 | if item.lower() in value_map: 180 | value = value_map[item.lower()] 181 | return value 182 | for v in ['impossible','unlikely','likely', 'sure']: 183 | for item in value_outputs: 184 | if v in item.lower(): 185 | value = value_map[v] 186 | return value 187 | # value = sum(value * value_names.count(name) for name, value in value_map.items()) 188 | return value 189 | -------------------------------------------------------------------------------- /run_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import torch 5 | import sys 6 | import random 7 | from tot.tasks import get_task 8 | from tot.methods.bfs_test import solve, naive_solve 9 | # from tot.models import gpt_usage 10 | import warnings 11 | import csv 12 | import transformers 13 | from peft import PeftModel, LoraConfig 14 | from transformers import GenerationConfig, AutoModel, AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, LlamaTokenizer 15 | from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model 16 | from trl.core import respond_to_batch 17 | from accelerate import Accelerator 18 | from accelerate.state import AcceleratorState 19 | # import tensor_parallel as tp 20 | import pandas as pd 21 | from load_data import * 22 | from datasets import load_dataset 23 | # from vllm import LLM 24 | import torch.distributed as dist 25 | import time 26 | 27 | 28 | warnings.filterwarnings("ignore") 29 | 30 | def run(args, load_8bit: bool = False, 31 | base_model: str = "", 32 | 33 | instruct_dir: str = "", 34 | use_lora: bool = False, 35 | lora_weights: str = "", 36 | # The prompt template to use, will default to med_template. 37 | prompt_template: str = "med_template"): 38 | 39 | lora_config = LoraConfig( 40 | lora_alpha=16, 41 | lora_dropout=0.1, 42 | r=8, 43 | bias="none", 44 | task_type="CAUSAL_LM" 45 | ) 46 | import time 47 | 48 | print('===base_model===') 49 | base_model = args.base_model 50 | print(base_model) 51 | start_time = time.time() 52 | if args.fast_test: 53 | print('=====fast test using distilgpt2=====') 54 | checkpoint = "distilgpt2" 55 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 56 | model = AutoModelForCausalLM.from_pretrained(checkpoint) 57 | else: 58 | tokenizer = LlamaTokenizer.from_pretrained(base_model,padding_side = "left") 59 | end_time = time.time() 60 | model = LlamaForCausalLM.from_pretrained( 61 | base_model, 62 | torch_dtype=torch.float16, 63 | ) 64 | end_time = time.time() 65 | 66 | if use_lora: 67 | print(f"using lora {lora_weights}") 68 | model = PeftModel.from_pretrained( 69 | model, 70 | lora_weights, 71 | torch_dtype=torch.float32, 72 | ) 73 | # LoRA Config 74 | 75 | 76 | model.config.pad_token_id = tokenizer.pad_token_id = 2 # unk 77 | tokenizer.pad_token = tokenizer.eos_token 78 | model.config.bos_token_id = 1 79 | model.config.eos_token_id = 2 80 | model.config.pad_token_id = model.config.eos_token_id 81 | # model.resize_token_embeddings(len(tokenizer)) 82 | # if not load_8bit: 83 | # model.half() # seems to fix bugs for some users. 84 | end_time = time.time() 85 | 86 | 87 | accelerator = Accelerator() 88 | device = accelerator.device 89 | # model = model.to(device) 90 | model = accelerator.prepare_model(model) 91 | model.eval() 92 | 93 | end_time = time.time() 94 | 95 | logs, cnt_avg, cnt_any = [], 0, 0 96 | 97 | print("Let's use", torch.cuda.device_count(), "GPUs!") 98 | # model = torch.nn.DataParallel(model) 99 | 100 | end_time = time.time() 101 | 102 | if args.add_more != '': 103 | with open(args.add_more,'r') as f: 104 | lines = f.readlines() 105 | q_list = [] 106 | for i, line in enumerate(lines): 107 | if i == 0: 108 | continue 109 | if (' sum(accs)' in line) & (lines[i-1].startswith('====GR====')): 110 | q = line.split(' sum(accs)')[0] 111 | if q not in q_list: 112 | q_list.append(q) 113 | start_time = time.time() 114 | if args.task in ['game24', 'bamboogle', '2wiki', 'qasc', 'hotpotqa','fever','feverous','tabfacts','vitaminc']: 115 | if args.task == 'game24': 116 | data = list(pd.read_csv('tot/data/24/24.csv')['Puzzles'])[args.task_start_index:args.task_end_index] 117 | elif args.task == 'bamboogle': 118 | if args.train == False: 119 | data = list(pd.read_csv('tot/data/bamboogle/Bamboogle Prerelease - Sheet1.csv')['Question']) 120 | ground_truths = list(pd.read_csv('tot/data/bamboogle/Bamboogle Prerelease - Sheet1.csv')['Answer']) 121 | data_dic = {} 122 | for i in range(len(data)): 123 | data_dic[data[i]] = ground_truths[i] 124 | else: 125 | file = json.load(open('tot/data/bamboogle/train.json')) 126 | data = [] 127 | data_dic = {} 128 | # for i in range(len(file['data'])): 129 | for i in range(200): 130 | data.append(file['data'][i]['Question']) 131 | data_dic[file['data'][i]['Question']] = file['data'][i]['Answer'][0] 132 | 133 | elif args.task == '2wiki': 134 | if args.train == False: 135 | file = json.load(open('tot/data/2wiki/dev.json')) 136 | else: 137 | file = json.load(open('tot/data/2wiki/train.json')) 138 | data = [] 139 | data_dic = {} 140 | for i in range(len(file)): 141 | data.append(file[i]['question']) 142 | data_dic[file[i]['question']] = file[i]['answer'] 143 | args.task = 'bamboogle' 144 | elif args.task == 'hotpotqa': 145 | if args.train == True: 146 | dataset = load_dataset('hotpot_qa', 'fullwiki', split='train') 147 | else: 148 | dataset = load_dataset('hotpot_qa', 'fullwiki', split='validation') 149 | 150 | data = [] 151 | data_dic = {} 152 | for d in dataset: 153 | n += 1 154 | data.append(d['question']) 155 | data_dic[d['question']] = d['answer'] 156 | args.task = 'bamboogle' 157 | elif args.task in ['fever','feverous','tabfacts','vitaminc']: 158 | if args.train == False: 159 | if args.task == 'fever': 160 | path_fever = 'tot/data/fever/dev.jsonl' 161 | if args.task == 'tabfact': 162 | path_fever = 'tot/data/tabfact/test_data.jsonl' 163 | if args.task == 'vitaminc': 164 | path_fever = 'tot/data/vitaminc/test.jsonl' 165 | if args.task == 'feverous': 166 | path_fever = 'tot/data/feverous/feverous_dev_challenges.jsonl' 167 | with open(path_fever) as jsonl_file: 168 | lines = jsonl_file.readlines() 169 | file = {} 170 | for i, concept in enumerate(lines): 171 | concept_item = json.loads(concept) 172 | file[i] = {} 173 | file[i]['question'] = concept_item['claim'] 174 | file[i]['answer'] = concept_item['label'] 175 | else: 176 | if args.task == 'fever': 177 | path_fever = 'tot/data/fever/train.jsonl' 178 | if args.task == 'tabfact': 179 | path_fever = 'tot/data/tabfact/train_data.jsonl' 180 | if args.task == 'vitaminc': 181 | path_fever = 'tot/data/vitaminc/train.jsonl' 182 | if args.task == 'feverous': 183 | path_fever = 'tot/data/feverous/feverous_train_challenges.jsonl' 184 | with open(path_fever) as jsonl_file: 185 | lines = jsonl_file.readlines() 186 | file = {} 187 | for i, concept in enumerate(lines): 188 | concept_item = json.loads(concept) 189 | file[i] = {} 190 | 191 | file[i]['question'] = concept_item['claim'] 192 | file[i]['answer'] = concept_item['label'] 193 | dev_ids_ori = [] 194 | data = [] 195 | data_dic = {} 196 | for i in range(len(file)): 197 | data.append(file[i]['question']) 198 | data_dic[file[i]['question']] = file[i]['answer'] 199 | 200 | args.task = 'fever' 201 | task = get_task(args.task) 202 | d_l = data 203 | 204 | i = args.task_start_index 205 | if args.data_json_file == 'output.json': 206 | instances = [] 207 | with open('output.json','r') as f: 208 | ins = json.load(f) 209 | for in_ in ins: 210 | instances.append(list(in_.keys())[0]) 211 | file_out = open(args.data_json_file, 'a') 212 | 213 | dic = [] 214 | if len(args.add_more)>1: 215 | with open(args.add_more,'r') as f: 216 | lines = f.readlines() 217 | for line in lines: 218 | if 'sum(accs)' in line: 219 | if line.split(' sum(accs)')[0] not in dic: 220 | dic.append(line.split(' sum(accs)')[0]) 221 | # print(dic[0]) 222 | data_filtered = [] 223 | for d in data: 224 | if len(args.add_more)>1: 225 | ##filter out those have been tested 226 | if d in dic: 227 | continue 228 | 229 | if args.data_json_file == 'output.json': 230 | if d in instances: 231 | continue 232 | data_filtered.append(d) 233 | if len(data_filtered) == 0: 234 | exit() 235 | 236 | data = torch.utils.data.DataLoader(data_filtered,batch_size=1) 237 | data = accelerator.prepare_data_loader(data) 238 | 239 | for d in data: 240 | if args.naive_run: 241 | # d is batch_size*instances 242 | args.n_generate_sample = 10 243 | ys, info = naive_solve(args, task, d, model, tokenizer, device) 244 | out = {} 245 | else: 246 | ys, info, out = solve(args, task, d, model, tokenizer, device) 247 | # log 248 | for d_idx, d_i in enumerate(d): 249 | infos = [] 250 | y = ys[d_idx] 251 | if args.naive_run: 252 | out[d_i] = {} 253 | if args.task in ['2wiki', 'bamboogle','bbh', 'qasc','fever']: 254 | if args.task in ['2wiki', 'bamboogle','bbh','qasc','fever']: 255 | if 'the final answer is' not in y.lower(): 256 | continue 257 | info, out = task.test_output(data_dic[d_i], y, out) 258 | else: 259 | info, out = task.test_output(d_i, y, out) 260 | infos.append(info) 261 | # infos, out = [task.test_output(d, y, out) for y in ys] 262 | # info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far': gpt_usage(args.backend)}) 263 | info.update({'idx': d_i, 'ys': ys, 'infos': infos}) 264 | logs.append(info) 265 | # with open(file, 'w') as f: 266 | # json.dump(logs, f, indent=4) 267 | # log main metric 268 | accs = [info['r'] for info in infos] 269 | if len(accs) == 0: 270 | print('====wrong case===='+d_i) 271 | continue 272 | cnt_avg += sum(accs) / len(accs) 273 | cnt_any += any(accs) 274 | if len(args.output_file) == 0: 275 | print(d_i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n') 276 | else: 277 | process_id = dist.get_rank() # 获取当前进程的ID 278 | output_file = args.output_file + str(process_id) + '.out' 279 | with open(output_file, "a") as f: 280 | print(d_i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n', file=f) 281 | # ppo_trainer.save_model() 282 | file_out.write(json.dumps(out,ensure_ascii=False,indent=1)) 283 | file_out.write(',\n') 284 | file_out.flush() 285 | # n = args.task_end_index - args.task_start_index 286 | n = len(data) 287 | print(cnt_avg / n, cnt_any / n) 288 | # print('usage_so_far', gpt_usage(args.backend)) 289 | file_out.write('\n]') 290 | elif args.task in ['math','gsm8k','svamp','asdiv']: 291 | if args.train ==True: 292 | if args.task == 'svamp': 293 | data = load_svamp_test('tot/data/SVAMP/train.json') 294 | elif args.task == 'asdiv': 295 | data = load_svamp_test('tot/data/asdiv/train.json') 296 | else: 297 | data = load_gsm8k_test(split="train") 298 | else: 299 | if args.task == 'svamp': 300 | data = load_svamp_test('tot/data/SVAMP/test.json') 301 | elif args.task == 'asdiv': 302 | data = load_svamp_test('tot/data/asdiv/test.json') 303 | else: 304 | data = load_gsm8k_test(split="test") 305 | d_l = data 306 | args.task = 'math' 307 | data = torch.utils.data.DataLoader(data) 308 | data = accelerator.prepare_data_loader(data) 309 | i = args.task_start_index 310 | if args.data_json_file != 'test.json': 311 | instances = [] 312 | with open(args.data_json_file,'r') as f: 313 | ins = json.load(f) 314 | for in_ in ins: 315 | instances.append(list(in_.keys())[0]) 316 | file_out = open(args.data_json_file, 'a') 317 | # file_out.write('[\n') 318 | task_prompt = create_demo_text() 319 | for d in data: 320 | if args.data_json_file != 'test.json': 321 | if d['question'][0] in instances: 322 | continue 323 | if args.add_more != '': 324 | if str(d) in q_list: 325 | continue 326 | if args.naive_run: 327 | args.n_generate_sample = 10 328 | ys, info = naive_solve(args, task_prompt, d['question'][0], model, tokenizer, device) 329 | out = {} 330 | else: 331 | ys, info, out = solve(args, task_prompt, d['question'][0], model, tokenizer, device) 332 | infos = [] 333 | for y in ys: 334 | # print(y) 335 | info, out = math_test_output(d, y, out) 336 | infos.append(info) 337 | # infos, out = [task.test_output(d, y, out) for y in ys] 338 | # info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far': gpt_usage(args.backend)}) 339 | info.update({'idx': d['question'][0], 'ys': ys, 'infos': infos}) 340 | logs.append(info) 341 | # with open(file, 'w') as f: 342 | # json.dump(logs, f, indent=4) 343 | # log main metric 344 | accs = [info['r'] for info in infos] 345 | if len(accs) == 0: 346 | print(d, 'sum(accs)', sum(accs)) 347 | else: 348 | cnt_avg += sum(accs) / len(accs) 349 | cnt_any += any(accs) 350 | print(d, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n') 351 | # ppo_trainer.save_model() 352 | file_out.write(json.dumps(out,ensure_ascii=False,indent=1)) 353 | file_out.write(',\n') 354 | file_out.flush() 355 | # n = args.task_end_index - args.task_start_index 356 | print(cnt_avg / len(data), cnt_any / len(data)) 357 | # print('usage_so_far', gpt_usage(args.backend)) 358 | file_out.write('\n]') 359 | end_time = time.time() 360 | 361 | 362 | 363 | 364 | 365 | def parse_args(): 366 | args = argparse.ArgumentParser() 367 | args.add_argument('--backend', type=str, choices=['gpt-4', 'gpt-3.5-turbo', 'llama2-7b'], default='llama2-7b') 368 | args.add_argument('--temperature', type=float, default=0.9) 369 | 370 | args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords', 'math','gsm8k','svamp','asdiv', 'bamboogle', '2wiki', 'bbh', 'qasc', 'hotpotqa','fever','feverous','tabfacts','vitaminc']) 371 | args.add_argument('--task_start_index', type=int, default=900) 372 | args.add_argument('--task_end_index', type=int, default=1000) 373 | args.add_argument('--base_model', type=str, default='') 374 | args.add_argument('--naive_run', action='store_true') 375 | args.add_argument('--sr_run', action='store_true') 376 | args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot'], default='cot') # only used when method_generate = sample, or naive_run 377 | 378 | args.add_argument('--method_generate', type=str, choices=['sample', 'propose']) 379 | args.add_argument('--method_evaluate', type=str, choices=['value', 'vote']) 380 | args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy') 381 | args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run 382 | args.add_argument('--n_evaluate_sample', type=int, default=1) 383 | args.add_argument('--n_select_sample', type=int, default=1) 384 | args.add_argument('--data_json_file', type=str, default='test.json') 385 | args.add_argument('--fast_test', type=bool, default=False) 386 | args.add_argument('--train', type=bool, default=False) 387 | args.add_argument('--add_more', type=str, default='') 388 | args.add_argument('--percentage', type=float, default=1.0) 389 | args.add_argument('--epoch', type=float, default=-1) 390 | args.add_argument('--output_file', type=str, default='') 391 | args = args.parse_args() 392 | return args 393 | 394 | if __name__ == '__main__': 395 | args = parse_args() 396 | print(args) 397 | run(args) 398 | -------------------------------------------------------------------------------- /tot/methods/bfs_test.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from functools import partial 4 | import re 5 | import sys 6 | import json 7 | import fire 8 | import gradio as gr 9 | import torch 10 | import bisect 11 | import transformers 12 | from peft import PeftModel 13 | from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer 14 | from load_data import * 15 | 16 | import re 17 | import string 18 | 19 | def normalize_answer(s): 20 | 21 | def remove_articles(text): 22 | return re.sub(r'\b(a|an|the)\b', ' ', text) 23 | 24 | def white_space_fix(text): 25 | return ' '.join(text.split()) 26 | 27 | def remove_punc(text): 28 | exclude = set(string.punctuation) 29 | return ''.join(ch for ch in text if ch not in exclude) 30 | 31 | def lower(text): 32 | return text.lower() 33 | 34 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 35 | 36 | def math_test_output(y): 37 | expression = y.strip() 38 | if 'arabic numerals' in expression: 39 | expression = expression.split('(arabic numerals) is ')[1] 40 | expression = expression.replace(',','') 41 | numbers = re.findall(r'\d+', expression) 42 | if len(numbers)>0: 43 | numbers = numbers[-1] 44 | else: 45 | # print(y) 46 | numbers = '-1' 47 | return numbers 48 | 49 | def final_evaluate_fever(new_ys, values, final_sentence): 50 | choices_set = {} 51 | for i in range(len(new_ys)): 52 | new_ys[i] = new_ys[i].replace('\"','\'').replace('-',' ').replace(' ',' ').split('\n\n')[0].lower().replace('so the final answer is ', 'so the final answer is: ') 53 | choices_i_list = new_ys[i].split('\n') 54 | if final_sentence not in new_ys[i]: 55 | # print(new_ys[i]) 56 | continue 57 | if len(choices_i_list)>1: 58 | while final_sentence not in choices_i_list[-1].lower(): 59 | choices_i_list = choices_i_list[:-1] 60 | if len(choices_i_list)<=1: 61 | break 62 | if len(choices_i_list)<=1: 63 | continue 64 | new_ys[i] = ' '.join(choices_i_list) 65 | if new_ys[i] not in choices_set: 66 | choices_set[new_ys[i]] = [float(values[i])] 67 | else: 68 | choices_set[new_ys[i]].append(float(values[i])) 69 | for i in range(len(choices_set)): 70 | choices = list(choices_set.keys()) 71 | choices_set[choices[i]] = sum(choices_set[choices[i]])/len(choices_set[choices[i]]) 72 | sorted_dict = sorted(choices_set.items(), key=lambda item: item[1], reverse=True) 73 | sorted_dict = {item[0]: item[1] for item in sorted_dict} 74 | pre_set = {} 75 | for choice_item in list(sorted_dict.keys()): 76 | pre = choice_item.lower().split(final_sentence)[1].replace('.','') 77 | if '-' in pre: 78 | continue 79 | if 'refute' in pre: 80 | pre = 'refutes' 81 | if 'reje' in pre: 82 | pre = 'refutes' 83 | if ('support' in pre) or ('suport' in pre) or ('correct' in pre): 84 | pre = 'supports' 85 | if ('not enough info' in pre )or ('no enough info' in pre) or ('no evidence'in pre): 86 | pre = 'not enough info' 87 | if pre not in ['refutes','supports','not enough info']: 88 | print(pre) 89 | continue 90 | if pre not in pre_set: 91 | pre_set[pre] = {} 92 | pre_set[pre]['value'] = [choices_set[choice_item]] 93 | pre_set[pre]['item'] = [choice_item] 94 | else: 95 | pre_set[pre]['value'].append(choices_set[choice_item]) 96 | pre_set[pre]['item'].append(choice_item) 97 | 98 | if len(pre_set)>0: 99 | pre_ = '' 100 | len_ = 0 101 | max_val = 0 102 | # for pre_item in list(pre_set.keys()): 103 | # if sum(pre_set[pre_item]['value'])>max_val: 104 | # pre_ = pre_item 105 | # len_ = len(pre_set[pre_item]['value']) 106 | # max_val = sum(pre_set[pre_item]['value']) 107 | for pre_item in list(pre_set.keys()): 108 | if len(pre_set[pre_item]['value'])>len_: 109 | pre_ = pre_item 110 | len_ = len(pre_set[pre_item]['value']) 111 | max_val = sum(pre_set[pre_item]['value']) 112 | if pre_ == '': 113 | return [], [] 114 | else: 115 | return pre_set[pre_]['item'], pre_set[pre_]['value'] 116 | else: 117 | return [],[] 118 | 119 | def final_evaluate(new_ys, values, final_sentence): 120 | choices_set = {} 121 | for i in range(len(new_ys)): 122 | new_ys[i] = new_ys[i].replace('\"','\'').replace(' ',' ').split('\n\n')[0].lower().replace('so the final answer is ', 'so the final answer is: ') 123 | choices_i_list = new_ys[i].split('\n') 124 | 125 | if final_sentence not in new_ys[i]: 126 | print('line 57') 127 | print(new_ys[i]) 128 | continue 129 | 130 | if len(choices_i_list)>1: 131 | while final_sentence not in choices_i_list[-1].lower(): 132 | choices_i_list = choices_i_list[:-1] 133 | if len(choices_i_list)<=1: 134 | break 135 | if len(choices_i_list)==1: 136 | if final_sentence not in choices_i_list[0]: 137 | print('line 68') 138 | print(new_ys[i]) 139 | continue 140 | if len(choices_i_list)==0: 141 | print('line 72') 142 | print(new_ys[i]) 143 | continue 144 | 145 | new_ys[i] = ' '.join(choices_i_list) 146 | if new_ys[i] not in choices_set: 147 | choices_set[new_ys[i]] = [float(format(values[i],'.3f'))] 148 | else: 149 | choices_set[new_ys[i]].append(float(format(values[i],'.3f'))) 150 | for i in range(len(choices_set)): 151 | choices = list(choices_set.keys()) 152 | choices_set[choices[i]] = sum(choices_set[choices[i]])/len(choices_set[choices[i]]) 153 | sorted_dict = sorted(choices_set.items(), key=lambda item: item[1], reverse=True) 154 | sorted_dict = {item[0]: item[1] for item in sorted_dict} 155 | pre_set = {} 156 | for choice_item in list(sorted_dict.keys()): 157 | 158 | pre = choice_item.lower().split(final_sentence)[1].replace('.','') 159 | 160 | pre = normalize_answer(pre) 161 | if pre not in pre_set: 162 | pre_set[pre] = {} 163 | pre_set[pre]['value'] = [choices_set[choice_item]] 164 | pre_set[pre]['item'] = [choice_item] 165 | else: 166 | pre_set[pre]['value'].append(choices_set[choice_item]) 167 | pre_set[pre]['item'].append(choice_item) 168 | 169 | if len(pre_set)>0: 170 | pre_ = '' 171 | len_ = 0 172 | max_val = 0 173 | tem_list = [] 174 | for pre_item in list(pre_set.keys()): 175 | if sum(pre_set[pre_item]['value'])>max_val: 176 | if 'arabic numerals' in final_sentence: 177 | if math_test_output(pre_item) =='-1': 178 | continue 179 | pre_ = pre_item 180 | len_ = len(pre_set[pre_item]['value']) 181 | max_val = sum(pre_set[pre_item]['value']) 182 | tem_list = pre_set[pre_item]['value'] 183 | 184 | elif sum(pre_set[pre_item]['value'])==max_val: 185 | if sorted(pre_set[pre_item]['value'])[0]>sorted(tem_list)[0]: 186 | pre_ = pre_item 187 | len_ = len(pre_set[pre_item]['value']) 188 | max_val = sum(pre_set[pre_item]['value']) 189 | tem_list = pre_set[pre_item]['value'] 190 | # for pre_item in list(pre_set.keys()): 191 | # if len(pre_set[pre_item]['value'])>len_: 192 | # pre_ = pre_item 193 | # len_ = len(pre_set[pre_item]['value']) 194 | # max_val = sum(pre_set[pre_item]['value']) 195 | if pre_ == '': 196 | print(pre_set) 197 | return [], [] 198 | else: 199 | return pre_set[pre_]['item'], pre_set[pre_]['value'] 200 | else: 201 | return [],[] 202 | 203 | def get_value(task, x, y, n_evaluate_sample, tokenizer, GenerationConfig, model, device, step,cache_value=True): 204 | # if x.lower().strip() in y.lower().strip(): 205 | # return 0.001 206 | if isinstance(task,str) == False: 207 | if (task.steps == 3): 208 | if (step == 2): 209 | if 'the final answer is' not in y.lower(): 210 | return 0 211 | if task.stops == 'qasc': 212 | if (step == 3): 213 | if 'so the final answer is' not in y.lower(): 214 | return 0 215 | if y.lower().split('so the final answer is ')[1].replace(': ','').replace('.','') not in x.lower(): 216 | return 0 217 | if (y.lower().split('so the final answer is')[1].strip().startswith('(')) ==False: 218 | return 0 219 | if (task.steps == 4): 220 | value_prompts = task.value_prompt_wrap(x, y, step) 221 | if value_prompts == 0: 222 | return 0 223 | else: 224 | value_prompts = task.value_prompt_wrap(x, y) 225 | else: 226 | value_prompts = math_evaluate + 'Question: ' + x + '\nThought: ' + y +'\nEvaluation Process: ' 227 | # if 'answer (arabic numerals) is ' in y: 228 | # value_prompts = math_final_evaluate + 'Question: ' + x + '\nThought: ' + y +'\nEvaluation Process: ' 229 | # else: 230 | # value_prompts = math_evaluate + 'Question: ' + x + '\nThought: ' + y +'\nEvaluation Process: ' 231 | if isinstance(value_prompts,list): 232 | value_prompts_list = value_prompts 233 | value = 1 234 | for value_prompts in value_prompts_list: 235 | value_outputs, _ = gpt(value_prompts, tokenizer, GenerationConfig, model, device, n=n_evaluate_sample, temperature=0.4,stop=None) 236 | 237 | value_item = 0 238 | for i in range(len(value_outputs)): 239 | value_output = value_outputs[i] 240 | # if isinstance(value_prompts,list): 241 | # value_prompt = value_prompts[i] 242 | # else: 243 | value_prompt = value_prompts 244 | # print(f'========\n {value_output}') 245 | value_output = value_output.replace(value_prompt, "").replace('', '').replace('', '') 246 | # print(f'========\nQuestion: {x}\nThought: {y}\nvalue_output: {value_output}\n========') 247 | value_output = value_output.split('\n') 248 | value_item += value_outputs_unwrap(x, y, value_output) 249 | value = value*value_item 250 | else: 251 | value_outputs, _ = gpt(value_prompts, tokenizer, GenerationConfig, model, device, n=n_evaluate_sample, max_tokens=256, temperature=0.4,stop=None) 252 | value = 0 253 | 254 | for i in range(len(value_outputs)): 255 | value_output = value_outputs[i] 256 | if isinstance(value_prompts,list): 257 | value_prompt = value_prompts[i] 258 | else: 259 | value_prompt = value_prompts 260 | # print(f'========\n {value_output}') 261 | value_output = value_output.replace(value_prompt, "").replace('', '').replace('', '') 262 | # print(f'========\nQuestion: {x}\nThought: {y}\nvalue_output: {value_output}\n_generate_sample========') 263 | value_output = value_output.split('\n') 264 | value += value_outputs_unwrap(x, y, value_output) 265 | return value 266 | 267 | def get_values(task, x, ys, n_evaluate_sample, tokenizer, GenerationConfig, model, device, step,cache_value=True): 268 | values = [] 269 | for y in ys: # each partial output 270 | value = get_value(task, x, y, n_evaluate_sample, tokenizer, GenerationConfig, model, device,step, cache_value=cache_value) 271 | values.append(round(value, 3)) 272 | return values 273 | 274 | def get_votes(task, x, ys, n_evaluate_sample): 275 | vote_prompt = task.vote_prompt_wrap(x, ys) 276 | vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None) 277 | values = task.vote_outputs_unwrap(vote_outputs, len(ys)) 278 | return values 279 | 280 | def get_proposals(task, x, y, tokenizer, GenerationConfig, model): 281 | propose_prompt = task.propose_prompt_wrap(x, y) 282 | proposals = gpt(propose_prompt, tokenizer, GenerationConfig, model, n=1, stop='None').split('\n') 283 | indices = [index for index, value in enumerate(proposals) if 'Input' in value] 284 | if len(indices)>2: 285 | proposals = proposals[indices[1]+2:indices[2]] 286 | else: 287 | proposals = proposals[indices[1]+2:] 288 | return [y + _ + '\n' for _ in proposals] 289 | 290 | def get_samples_wiki(task, x, y, n_generate_sample, tokenizer, GenerationConfig, model, device, prompt_sample, step, bbh_flag): 291 | if bbh_flag != None: 292 | prompt = '' 293 | with open('/home/ducunxiao/Distill-ToT/tot/BIG-Bench-Hard/cot-prompts/' + bbh_flag.split('.')[0]+'.json') as f: 294 | lines = f.readlines() 295 | for line in lines: 296 | prompt += line 297 | prompt = prompt + '\nQ: '+ x + '\nA: ' 298 | else: 299 | if prompt_sample == 'standard': 300 | prompt = task.standard_prompt_wrap(x, y) 301 | elif prompt_sample == 'cot': 302 | prompt = task.cot_prompt_wrap(x, y) 303 | else: 304 | raise ValueError(f'prompt_sample {prompt_sample} not recognized') 305 | out_sample = [] 306 | # if 'so the final answer is' in y.lower(): 307 | # return [y], [] 308 | # else: 309 | if step == 0: 310 | prompt_ = prompt + 'Answer: Step 1, ' 311 | else: 312 | prompt_ = prompt + ' Step ' + str(step+1) + ', ' 313 | if bbh_flag != None: 314 | if step == 4: 315 | prompt_ = prompt_ + 'so the final answer is: ' 316 | else: 317 | if 'Options:' in prompt: 318 | if step == 3: 319 | prompt_ = prompt_ + 'So the final answer is (' 320 | else: 321 | if step == 2: 322 | prompt_ = prompt_ + 'So the final answer is: ' 323 | # while len(out_sample)', '').replace('', '').lower() 333 | if 'question:' in sample: 334 | sample = sample.split('question:')[0] 335 | if 'claim:' in sample: 336 | sample = sample.split('claim:')[0] 337 | if 'task:' in sample: 338 | sample = sample.split('task:')[0] 339 | if len(sample.strip()) < 2: 340 | continue 341 | if 'step' in sample: 342 | splited_sample = sample.split('step') 343 | if len(splited_sample) >2: 344 | sample = splited_sample[0]+'step' + splited_sample[1] 345 | else: 346 | print('====clean_generated_thought====') 347 | print(sample) 348 | print('====prompt====') 349 | print(prompt) 350 | exit() 351 | sample = sample.split('\n')[0] 352 | if sample.count('so the final answer is') > 1: 353 | first_idx = sample.lower().find('so the final answer is') 354 | sample = sample[:sample.lower().find('so the final answer is', first_idx+1)] 355 | if len(sample.strip()) < 2: 356 | continue 357 | sample = sample.strip() 358 | if (y + sample) not in out_sample: 359 | out_sample.append(y + sample) 360 | return out_sample 361 | 362 | def get_samples(task, x, y, n_generate_sample, tokenizer, GenerationConfig, model, device, prompt_sample, stop): 363 | if prompt_sample == 'standard': 364 | prompt = task.standard_prompt_wrap(x, y) 365 | elif prompt_sample == 'cot': 366 | prompt = task.cot_prompt_wrap(x, y) 367 | else: 368 | raise ValueError(f'prompt_sample {prompt_sample} not recognized') 369 | samples, input_ids = gpt(prompt, tokenizer, GenerationConfig, model, device, n=n_generate_sample, temperature=0.3, stop=stop) 370 | out_sample = [] 371 | for i in range(len(samples)): 372 | sample = samples[i] 373 | sample = sample.replace(prompt, "").replace('', '').replace('', '') 374 | splited_sample = sample.split('\n') 375 | if 'Steps' in splited_sample[0]: 376 | samples[i] = splited_sample[0] + '\n' + splited_sample[1] + '\n' 377 | elif 'Answer' in splited_sample[0]: 378 | samples[i] = splited_sample[0] 379 | else: 380 | samples[i] = splited_sample[0] + '\n' 381 | if len(samples[i].strip()) < 2: 382 | continue 383 | out_sample.append(y + samples[i]) 384 | return out_sample, input_ids 385 | 386 | def get_samples_cot(task, x, y, n_generate_sample, tokenizer, GenerationConfig, model, device, prompt_sample, stop=None,bbh_flag=None): 387 | if bbh_flag != None: 388 | prompt = '' 389 | with open('/home/ducunxiao/Distill-ToT/tot/BIG-Bench-Hard/cot-prompts/' + bbh_flag.split('.')[0]+'.txt') as f: 390 | lines = f.readlines() 391 | for line in lines: 392 | prompt += line 393 | prompt_ = prompt + '\nQ: '+ x + '\nA: ' 394 | else: 395 | if isinstance(x,str): 396 | if prompt_sample == 'standard': 397 | prompt = task.standard_prompt_wrap(x, y) 398 | elif prompt_sample == 'cot': 399 | prompt = task.cot_prompt_wrap(x, y) 400 | else: 401 | raise ValueError(f'prompt_sample {prompt_sample} not recognized') 402 | prompt_ = prompt+' Answer: Step 1, ' 403 | else: 404 | prompt_ = [] 405 | prompts = [] 406 | for x_i in x: 407 | if prompt_sample == 'standard': 408 | prompt = task.standard_prompt_wrap(x_i, y) 409 | elif prompt_sample == 'cot': 410 | prompt = task.cot_prompt_wrap(x_i, y) 411 | else: 412 | raise ValueError(f'prompt_sample {prompt_sample} not recognized') 413 | prompt_.append(prompt+' Answer: Step 1, ') 414 | prompts.append(prompt) 415 | 416 | while True: 417 | samples, input_ids = gpt(prompt_, tokenizer, GenerationConfig, model, device, n=n_generate_sample, stop=stop, temperature=0.4) 418 | for i in range(len(samples)): 419 | sample = samples[i] 420 | sample = sample.replace(prompts[i], "").replace('', '').replace('', '') 421 | splited_sample = sample.split('\n') 422 | tem_s = [] 423 | for splited_s in splited_sample: 424 | if 'nput' in splited_s: 425 | break 426 | if 'question:' in splited_s.lower(): 427 | break 428 | if 'nswer' in splited_s: 429 | tem_s.append(splited_s) 430 | break 431 | if 'claim: 'in splited_s.lower(): 432 | break 433 | tem_s.append(splited_s) 434 | samples[i] = '\n'.join(tem_s) 435 | if bbh_flag != None: 436 | break 437 | 438 | break_flag= True 439 | for i in range(len(samples)): 440 | if 'step 2' not in samples[i].lower(): 441 | prompt_[i] = prompts[i]+samples[i]+' Step 2, ' 442 | break_flag = False 443 | elif 'step 3, so the final answer is:' not in samples[i].lower(): 444 | prompt_[i] = prompts[i]+samples[i]+' Step 3, so the final answer is: ' 445 | break_flag = False 446 | else: 447 | prompt_[i] = prompts[i]+samples[i] 448 | if break_flag: 449 | break 450 | 451 | return [y + _ for _ in samples], input_ids 452 | 453 | 454 | def get_samples_math(task_prompt, x, ys, n_generate_sample, tokenizer, GenerationConfig, model, device, prompt_sample, stop=None): 455 | if isinstance(ys,list): 456 | prompt_ = [] 457 | task_prompt_ = [] 458 | for y in ys: 459 | task_prompt = task_prompt+'Q: '+x+'\nA: '+y + 'The answer (arabic numerals) is ' 460 | task_prompt_.append(task_prompt) 461 | prompt_.append(task_prompt_) 462 | else: 463 | y = ys 464 | task_prompt = task_prompt+'Q: '+x+'\n' 465 | task_prompt_ = (task_prompt + 'A: Step 1, '+y) 466 | prompt_ = [task_prompt_] 467 | task_prompt = [task_prompt] 468 | splited_samples = [] 469 | while True: 470 | samples, input_ids = gpt(prompt_, tokenizer, GenerationConfig, model, device, n=n_generate_sample, stop=stop) 471 | splited_samples = [] 472 | for i in range(len(samples)): 473 | sample = samples[i] 474 | sample = sample.replace(task_prompt[i], "").replace('', '').replace('', '').lower() 475 | if 'q:' in sample: 476 | sample = sample.split('q:')[0] 477 | if '\n\n' in sample: 478 | sample = sample.split('\n\n')[0] 479 | samples[i] = sample 480 | break_flag= True 481 | for i in range(len(samples)): 482 | if 'step 2' not in samples[i].lower(): 483 | prompt_[i] = prompt_[i]+samples[i]+' Step 2, ' 484 | break_flag = False 485 | elif 'step 3' not in samples[i].lower(): 486 | prompt_[i] = prompt_[i]+samples[i]+' Step 3, ' 487 | break_flag = False 488 | elif 'step 4' not in samples[i].lower(): 489 | prompt_[i] = prompt_[i]+samples[i]+' Step 4, The answer (arabic numerals) is ' 490 | break_flag = False 491 | elif ('step 4' in samples[i].lower()) & ('answer (arabic numerals' not in samples[i].lower()): 492 | prompt_[i] = prompt_[i]+samples[i].lower().split('step 4')[0]+' Step 4, The answer (arabic numerals) is ' 493 | else: 494 | prompt_[i] = prompt_[i]+samples[i] 495 | if break_flag: 496 | break 497 | # if '(arabic numerals) is ' in sample: 498 | # splited_samples.append(sample) 499 | return samples, samples, input_ids 500 | 501 | def get_samples_math_tot(task_prompt, x, y, n_generate_sample, tokenizer, GenerationConfig, model, device, prompt_sample, step): 502 | task_prompt = task_prompt+'Q: '+x+'\nA: '+y 503 | sample_last = y 504 | if step == 3: 505 | prompt_ = task_prompt + 'Step 4, The answer (arabic numerals) is ' 506 | else: 507 | prompt_ = task_prompt + 'Step ' + str(step+1) + ', ' 508 | samples, input_ids = gpt(prompt_, tokenizer, GenerationConfig, model, device, n=n_generate_sample, temperature=0.7, max_tokens=128, stop=None,top_p=0.9) 509 | for i in range(len(samples)): 510 | sample = samples[i] 511 | sample = sample.replace(task_prompt, "").replace('', '').replace('', '').replace('\n',' ').lower() 512 | if '\n\nq:' in sample: 513 | sample = sample.split('\n\nq:')[0] 514 | if 'q:' in sample: 515 | sample = sample.split('q:')[0] 516 | if 'step' in sample: 517 | splited_sample = sample.split('step') 518 | if len(splited_sample) >2: 519 | sample = splited_sample[0]+'step' + splited_sample[1] 520 | 521 | samples[i] = y + sample 522 | # print(samples) 523 | # print(len(samples)) 524 | return samples, input_ids 525 | 526 | def solve(args, task, x, model, tokenizer, device, to_print=True, 527 | # The prompt template to use, will default to med_template. 528 | prompt_template: str = "med_template", bbh_flag = None): 529 | if args.backend == 'llama2-7b': 530 | from tot.llama_models import gpt 531 | else: 532 | from tot.models import gpt 533 | global gpt 534 | gpt = partial(gpt) 535 | # print(gpt) 536 | out = {} 537 | ys = [''] # current output candidates 538 | infos = [] 539 | if 'math' not in args.task: 540 | x = x[0] 541 | out[x] = {} 542 | if isinstance(task, str): 543 | steps = 4 544 | stops = ['.']*steps 545 | else: 546 | steps = task.steps 547 | stops = task.stops 548 | if bbh_flag != None: 549 | steps = 5 550 | for step in range(steps): 551 | # generation 552 | select_new_ys = [] 553 | n_select_sample = args.n_select_sample-len(select_new_ys) 554 | if args.method_generate == 'sample': 555 | if 'math' in args.task: 556 | new_ys = [get_samples_math_tot(task, x, y, args.n_generate_sample, tokenizer, GenerationConfig, model, device, prompt_sample=args.prompt_sample, step=step)[0] for y in ys] 557 | elif '24' in args.task: 558 | new_ys = [get_samples(task, x, y, args.n_generate_sample, tokenizer, GenerationConfig, model, device, prompt_sample=args.prompt_sample, stop=stops[step])[0] for y in ys] 559 | else: 560 | new_ys = [get_samples_wiki(task, x, y, args.n_generate_sample, tokenizer, GenerationConfig, model, device, prompt_sample=args.prompt_sample, step=step, bbh_flag = bbh_flag)[0] for y in ys] 561 | elif args.method_generate == 'propose': 562 | new_ys = [get_proposals(task, x, y, tokenizer, GenerationConfig, model) for y in ys] 563 | new_ys = list(itertools.chain(*new_ys)) 564 | new_ys = list(set(new_ys)) 565 | out[x][str(step)] = {} 566 | out[x][str(step)]['candiate'] = new_ys 567 | ids = list(range(len(new_ys))) 568 | 569 | # evaluation 570 | if args.method_evaluate == 'vote': 571 | values = get_votes(task, x, new_ys, args.n_evaluate_sample) 572 | elif args.method_evaluate == 'value': 573 | values = get_values(task, x, new_ys, args.n_evaluate_sample, tokenizer, GenerationConfig, model, device = device, step=step) 574 | out[x][str(step)]['values'] = values 575 | 576 | # selection 577 | if args.method_select == 'sample': 578 | ps = np.array(values) / sum(values) 579 | select_ids = np.random.choice(ids, size=n_select_sample, p=ps).tolist() 580 | elif args.method_select == 'greedy': 581 | if ((step == 2 )&(args.task not in ['math','game24','qasc'])): 582 | if args.task == 'fever': 583 | new_select_new_ys, sorted_values = final_evaluate_fever(new_ys, values,'so the final answer is: ') 584 | else: 585 | new_select_new_ys, sorted_values = final_evaluate(new_ys, values,'so the final answer is: ') 586 | select_new_ys.extend(new_select_new_ys) 587 | print(f'-- new_ys --: {new_ys}\n-- sol values --: {values}\n-- choices --: {select_new_ys}\n') 588 | elif (step == 3 )&(args.task in ['math']): 589 | new_select_new_ys, sorted_values = final_evaluate(new_ys, values,'answer (arabic numerals) is ') 590 | select_new_ys.extend(new_select_new_ys) 591 | print(f'-- new_ys --: {new_ys}\n-- sol values --: {values}\n-- choices --: {select_new_ys}\n') 592 | elif ((step == 3 )&(args.task in ['qasc'])): 593 | new_select_new_ys, sorted_values = final_evaluate(new_ys, values,'so the final answer is: ') 594 | select_new_ys.extend(new_select_new_ys) 595 | print(f'-- new_ys --: {new_ys}\n-- sol values --: {values}\n-- choices --: {select_new_ys}\n') 596 | else: 597 | sorted_ids = sorted(ids, key=lambda x: values[x], reverse=True) 598 | select_ids = sorted_ids[:n_select_sample] 599 | 600 | for i in range(n_select_sample,len(sorted_ids)): 601 | id = sorted_ids[i] 602 | p_id = sorted_ids[i-1] 603 | if (values[id]==values[p_id]): 604 | select_ids.append(id) 605 | else: 606 | break 607 | select_new_ys.extend([new_ys[select_id] for select_id in select_ids]) 608 | # log 609 | if to_print: 610 | try: 611 | sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True)) 612 | except: 613 | print('====error====') 614 | print('new_ys') 615 | print(new_ys) 616 | print('values') 617 | print(values) 618 | exit() 619 | print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n') 620 | infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys}) 621 | ys = select_new_ys 622 | if to_print: 623 | print(ys) 624 | out[x]['correct'] = ys 625 | return ys, {'steps': infos}, out 626 | 627 | def naive_solve(args, task, x, model, tokenizer, device, to_print=True, bbh_flag = None): 628 | if args.backend == 'llama2-7b': 629 | from tot.llama_models import gpt 630 | else: 631 | from tot.models import gpt 632 | global gpt 633 | out = {} 634 | gpt = partial(gpt, temperature=args.temperature) 635 | # x = task.get_input(idx) # input 636 | if isinstance(x,list): 637 | for x_i in x: 638 | out[x_i] = {} 639 | else: 640 | out[x] = {} 641 | if 'math' in args.task: 642 | ys,ys_ori,_ = get_samples_math(task, x, '', args.n_generate_sample, tokenizer, GenerationConfig, model, device, args.prompt_sample, stop=None) 643 | if len(ys) == 0: 644 | ys,ys_ori,_ = get_samples_math(task, x, ys_ori, args.n_generate_sample, tokenizer, GenerationConfig, model, device, args.prompt_sample, stop=None) 645 | else: 646 | ys,_ = get_samples_cot(task, x, '', args.n_generate_sample, tokenizer, GenerationConfig, model, device, args.prompt_sample, stop=None, bbh_flag = bbh_flag) 647 | print(ys) 648 | return ys, out 649 | 650 | --------------------------------------------------------------------------------