├── .gitattributes ├── .gitignore ├── README.md ├── data ├── stage1 │ └── train.json └── stage2 │ ├── arithmetic │ ├── train.split.1.json │ └── train.split.2.json │ └── commonsense │ └── train.json ├── prompts ├── __init__.py ├── coin.py ├── csqa.py ├── gsm8k.py └── strategyqa.py ├── requirements.txt └── src ├── data_utils.py ├── decoder.py ├── generate_data.py ├── model.py ├── openai_api_mp.py ├── parser_utils.py ├── test.py ├── train.py └── trainer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.json filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Deductive Beam Search
Decoding Deducible Rationale for Chain-of-Thought Reasoning

2 | 3 | ![Static Badge](https://img.shields.io/badge/task-reasoning-purple) 4 | 5 | 6 | Source Code of paper [`Deductive Beam Search: Decoding Deducible Rationale for Chain-of-Thought Reasoning`](https://arxiv.org/abs/2401.17686). 7 | 8 | # Quick Start 9 | Clone repo: 10 | ```bash 11 | git clone https://github.com/OSU-NLP-Group/Deductive-Beam-Search.git 12 | cd Deductive-Beam-Search 13 | mkdir outputs 14 | ``` 15 | 16 | Install requirements: 17 | ```bash 18 | conda create --name dbs python=3.10 19 | conda activate dbs 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | Training: 24 | 25 | - Stage 1: 26 | 27 | ```bash 28 | python src/train.py \ 29 | --train_datapath your_data_path_stage1 \ 30 | --experiment_name your_experiment_name_stage1 \ 31 | --batch_size 8 --gradient_accumulation_steps 16 \ 32 | --learning_rate 1e-5 33 | ``` 34 | 35 | - Stage 2: 36 | 37 | ```bash 38 | python src/train.py \ 39 | --train_datapath your_data_path_stage2 \ 40 | --experiment_name your_experiment_name_stage2 \ 41 | --batch_size 8 --gradient_accumulation_steps 16 \ 42 | --learning_rate 1e-7 43 | ``` 44 | 45 | Inference: 46 | ```bash 47 | DECODER_PATH=your_model_path 48 | DATASETS=("gsm8k" "svamp") # add datasets you want to test 49 | 50 | for TEST_DATASET in ${DATASETS[@]}; 51 | do 52 | python src/test.py \ 53 | --test_dataset $TEST_DATASET \ 54 | --output_dir outputs/$TEST_DATASET \ 55 | --decoder_path $DECODER_PATH \ 56 | --decode_strategy beam_search \ 57 | --num_beams 5 \ 58 | --num_sampling 10 \ 59 | --num_gpus_decode 4 # how many gpus for inference 60 | done 61 | ``` 62 | 63 | # Data 64 | 65 | ## Training Data 66 | 67 | All training data synthesized is in `data/` folder. The training file `data/stage1/train.json` is used for training a general deductive verifier. For stage 2, the arithmetic and symbolic verifier is trained on `data/stage2/arithmetic/train.split.*.json`, and the commonsense verifier is trained on `data/stage2/commonsense/train.json`. 68 | 69 | ## Checkpoints 70 | 71 | We provide the checkpoint of a general deductive verifier, please download from this [link](https://drive.google.com/drive/folders/1GbnAiX160Cz63zAbr2FAgB0QFySfM2Vn?usp=sharing). 72 | You can use it to continue-train on our data or train on your own data. 73 | 74 | ## Data Generation 75 | The complete process of data construction is in `src/generate_data.py`. 76 | If you want to generate data on your own, please modify the data loading part. 77 | After modifying, you can run `python src/generate_data.py` to generate data for your own domains. 78 | 79 | **\[TODO\]** We will improve the code for easier modification and usage. 80 | 81 | # Contact 82 | 83 | If you have any problems, please contact 84 | [Tinghui Zhu](mailto:darthzhu@gmail.com) and 85 | [Kai Zhang](mailto:zhang.13253@osu.edu). 86 | 87 | # Citation Information 88 | 89 | If you find our codes and data useful, please consider citing our paper: 90 | 91 | ``` 92 | @article{zhu2024deductive, 93 | title={Deductive Beam Search: Decoding Deducible Rationale for Chain-of-Thought Reasoning}, 94 | author={Zhu, Tinghui and Zhang, Kai and Xie, Jian and Su, Yu}, 95 | journal={arXiv preprint arXiv:2401.17686}, 96 | year={2024} 97 | } 98 | ``` -------------------------------------------------------------------------------- /data/stage1/train.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:48c570fc3f18ee95499fa664d97fe7fbbca5ace9cf9c0383aa4e7c49eeb854f8 3 | size 16897210 4 | -------------------------------------------------------------------------------- /data/stage2/arithmetic/train.split.1.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:458035c7327280eaa66536d8195507f4dac600735b41cf809f0c9bb93e1491db 3 | size 69144494 4 | -------------------------------------------------------------------------------- /data/stage2/arithmetic/train.split.2.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:458035c7327280eaa66536d8195507f4dac600735b41cf809f0c9bb93e1491db 3 | size 69144494 4 | -------------------------------------------------------------------------------- /data/stage2/commonsense/train.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b42f1ca5739bf7517310776aeadc8335e2813631be5a676c7db3375cfa449888 3 | size 3443251 4 | -------------------------------------------------------------------------------- /prompts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSU-NLP-Group/Deductive-Beam-Search/67ed155699cf3ae966d44793f92fbf10379bceca/prompts/__init__.py -------------------------------------------------------------------------------- /prompts/coin.py: -------------------------------------------------------------------------------- 1 | prompt = """Given the question, output the rationale step by step and give the final answer. 2 | 3 | Example 1 4 | Question: 5 | A coin is heads up. sager does not flip the coin. zyheir flips the coin. Is the coin still heads up? 6 | Answer: 7 | sager does not flip the coin, so the coin is heads up. 8 | zyheir flips the coins, so the coin is tails up. 9 | Final Answer: no 10 | 11 | Example 2 12 | Question: 13 | A coin is heads up. mailey does not flip the coin. maurisa does not flip the coin. Is the coin still heads up? 14 | Answer: 15 | mailye does not flip the coin, so the coin is heads up. 16 | maurisa does not flip the coin, so the coin is heads up. 17 | Final Answer: yes 18 | """ -------------------------------------------------------------------------------- /prompts/csqa.py: -------------------------------------------------------------------------------- 1 | prompt = """Given the question, output the rationale step by step and give the final answer. You should choose the best answer. 2 | 3 | Example 1 4 | Question: 5 | Sammy wanted to go to where the people were. Where might he go? 6 | A. race track 7 | B. populated area 8 | C. the desert 9 | D. apartment 10 | E. roadblock 11 | Answer: 12 | Sammy wanted to go to places with many people. 13 | Race track and apartment do not have many people. 14 | The desert and roadblock have few people. 15 | And, the populated area means that it is the place with many people. 16 | Thus, Sammy should go to populated area. 17 | Final Answer: B 18 | 19 | Example 2 20 | Question: 21 | The fox walked from the city into the forest, what was it looking for? 22 | A. pretty flowers 23 | B. hen house 24 | C. natural habitat 25 | D. storybook 26 | E. dense forest 27 | Answer: 28 | The forest does not have hen house or storybook. 29 | The fox is a carnivore that does not look for flowers and forest. 30 | The forest is a natural habitat for foxes. 31 | Thus, it was looking for a natural habitat. 32 | Final Answer: C 33 | """ 34 | 35 | recall_prompt = """Given the question, output the rationale step by step and give the final answer. You should choose the best answer. 36 | 37 | Example 1 38 | Question: 39 | Sammy wanted to go to where the people were. Where might he go? 40 | A. race track 41 | B. populated area 42 | C. the desert 43 | D. apartment 44 | E. roadblock 45 | Answer: 46 | Fact: 47 | Sammy wanted to go to places with many people. 48 | Race track and apartment do not have many people. 49 | The desert and roadblock have few people. 50 | And, the populated area means that it is the place with many people. 51 | Reasoning: 52 | Thus, Sammy should go to populated area. 53 | Final Answer: B 54 | 55 | Example 2 56 | Question: 57 | The fox walked from the city into the forest, what was it looking for? 58 | A. pretty flowers 59 | B. hen house 60 | C. natural habitat 61 | D. storybook 62 | E. dense forest 63 | Answer: 64 | Fact: 65 | The forest does not have hen house or storybook. 66 | The fox is a carnivore that does not look for flowers and forest. 67 | The forest is a natural habitat for foxes. 68 | Reasoning: 69 | Thus, it was looking for a natural habitat. 70 | Final Answer: C 71 | """ -------------------------------------------------------------------------------- /prompts/gsm8k.py: -------------------------------------------------------------------------------- 1 | prompt = """You are a good math solver. Based on the question, please give the rationales step by step and give a final answer. 2 | 3 | Example 1: 4 | Question: 5 | Kate's hair is half as long as Emily's hair. Emily's hair is 6 inches longer than Logan's hair. If Logan hair is 20 inches, how many inches is Kate's hair? 6 | Answer: 7 | Emily's hair is 20-6 = 14 inches long. 8 | Kate's hair 14/2= 7 inches long. 9 | Final Answer:7 10 | 11 | Example 2: 12 | Question: 13 | John puts $25 in his piggy bank every month for 2 years to save up for a vacation. He had to spend $400 from his piggy bank savings last week to repair his car. How many dollars are left in his piggy bank? 14 | Answer: 15 | He saved money for 2 years, which is equal to 12 x 2 = 24 months. 16 | The amount of money he saved is $25*24 = $600. 17 | But he spent some money so there is $600 - $400 = 200 left. 18 | Final Answer:200 19 | """ 20 | 21 | """ 22 | Example 3: 23 | Question: 24 | After complaints from the residents of Tatoosh about the number of cats on the island, the wildlife service carried out a relocation mission that saw the number of cats on the island drastically reduced. On the first relocation mission, 600 cats were relocated from the island to a neighboring island. On the second mission, half of the remaining cats were relocated to a rescue center inland. If the number of cats originally on the island was 1800, how many cats remained on the island after the rescue mission? 25 | Answer: 26 | After the first mission, the number of cats remaining on the island was 1800-600 = <<1800-600=1200>>1200. 27 | If half of the remaining cats on the island were relocated to a rescue center inland, the number of cats taken by the wildlife service on the second mission is 1200/2 = <<1200/2=600>>600 cats. 28 | The number of cats remaining on the island is 1200-600 = <<1200-600=600>>600 29 | Final Answer:600 30 | 31 | Example 4: 32 | Question: 33 | Paul, Amoura, and Ingrid were to go to a friend's party planned to start at 8:00 a.m. Paul arrived at 8:25. Amoura arrived 30 minutes later than Paul, and Ingrid was three times later than Amoura. How late, in minutes, was Ingrid to the party? 34 | Answer: 35 | If the party was planned to start at 8:00 am, Paul was 8:25-8:00 = <<825-800=25>>25 minutes late. 36 | If Paul was 25 minutes late, Amoura was 25+30 = <<25+30=55>>55 minutes late. 37 | If Ingrid was three times late than Amoura was, she was 3*55 = <<3*55=165>>165 minutes late 38 | Final Answer:165 39 | 40 | Example 5: 41 | Question: 42 | Alicia has to buy some books for the new school year. She buys 2 math books, 3 art books, and 6 science books, for a total of $30. If both the math and science books cost $3 each, what was the cost of each art book? 43 | Answer: 44 | The total cost of maths books is 2*3 = <<2*3=6>>6 dollars 45 | The total cost of science books is 6*3 = <<6*3=18>>18 dollars 46 | The total cost for maths and science books is 6+18 = <<6+18=24>>24 dollars 47 | The cost for art books is 30-24 = <<30-24=6>>6 dollars. 48 | Since he bought 3 art books, the cost for each art book will be 6/3 = 2 dollars 49 | Final Answer:2 50 | """ -------------------------------------------------------------------------------- /prompts/strategyqa.py: -------------------------------------------------------------------------------- 1 | prompt = """Given the question, output the rationale step by step and give the final answer (yes or no). 2 | 3 | Example 1 4 | Question: 5 | Do hamsters provide food for any animals? 6 | Answer: 7 | Hamsters are prey animals. 8 | Prey are food for predators. 9 | Final answer: yes 10 | 11 | Example 2 12 | Question: 13 | Could a llama birth twice during War in Vietnam (1945-46)? 14 | Answer: 15 | The War in Vietnam was 6 months. 16 | The gestation period for a llama is 11 months, which is more than 6 months. 17 | Final answer: no 18 | """ 19 | 20 | recall_prompt = """Given the question, output the rationale step by step and give the final answer (yes or no). 21 | 22 | Example 1 23 | Question: 24 | Do hamsters provide food for any animals? 25 | Answer: 26 | Fact: 27 | Hamsters are prey animals. 28 | Prey are food for predators. 29 | Reasoning: 30 | Hamsters are food for some predators. 31 | Final answer: yes 32 | 33 | Example 2 34 | Question: 35 | Could a llama birth twice during War in Vietnam (1945-46)? 36 | Answer: 37 | Fact: 38 | The War in Vietnam was 6 months. 39 | The gestation period for a llama is 11 months, which is more than 6 months. 40 | Reasoning: 41 | A llama could not birth twice during War in Vietnam. 42 | Final answer: no 43 | """ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | datasets==2.13.1 3 | multiprocess==0.70.14 4 | openai==0.28.0 5 | tiktoken==0.4.0 6 | tokenizers==0.13.3 7 | torch==2.0.1 8 | torchaudio==2.0.2 9 | torchvision==0.15.2 10 | transformers==4.33.2 11 | vllm==0.1.4 -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torch.utils.data import Dataset 3 | 4 | def collate_fn(batch): 5 | return tuple(zip(*batch)) 6 | 7 | class GSM8KDataset(Dataset): 8 | def __init__(self, datapath, mode="train") -> None: 9 | if mode == "train": 10 | with open(datapath, "r") as fin: 11 | self.datas = json.load(fin) 12 | 13 | def __len__(self): 14 | return len(self.datas) 15 | 16 | def __getitem__(self, index): 17 | """ 18 | { 19 | "context": , 20 | "answer": , 21 | "label": 1 for entail; 0 for neutral 22 | } 23 | """ 24 | data = self.datas[index] 25 | return data["context"], data["answer"], data["label"] 26 | 27 | class GSM8KRankingDataset(Dataset): 28 | def __init__(self, datapath, mode="train") -> None: 29 | if mode == "train": 30 | with open(datapath, "r") as fin: 31 | self.datas = json.load(fin) 32 | 33 | def __len__(self): 34 | return len(self.datas) 35 | 36 | def __getitem__(self, index): 37 | """ 38 | { 39 | "context": , 40 | "answer": , 41 | "false_answer": 1 for entail; 0 for neutral 42 | } 43 | """ 44 | data = self.datas[index] 45 | return data["context"], data["answer"], data["false_answer"] 46 | 47 | class GSM8KRankingMultiNegativeDataset(Dataset): 48 | def __init__(self, datapath, mode="train") -> None: 49 | if mode == "train": 50 | with open(datapath, "r") as fin: 51 | self.datas = json.load(fin) 52 | 53 | def __len__(self): 54 | return len(self.datas) 55 | 56 | def __getitem__(self, index): 57 | """ 58 | { 59 | "context": , 60 | "answers": , "0", "1", "2", "3" 61 | } 62 | """ 63 | data = self.datas[index] 64 | return data["context"], data["answers"] -------------------------------------------------------------------------------- /src/decoder.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 5 | from vllm import LLM, SamplingParams 6 | from collections import Counter 7 | 8 | from src.model import Ranker 9 | from src.openai_api_mp import ChatGPTPool 10 | 11 | class Decoder(): 12 | def __init__(self, config) -> None: 13 | self.device = config.device 14 | self.prompt = config.prompt 15 | self.num_sampling = config.num_sampling 16 | self.num_beams = config.num_beams 17 | 18 | if config.decode_strategy == "rank": 19 | self.ranker = Ranker(config) 20 | # self.ranker.from_pretrained(config.resume_path) 21 | checkpoint = torch.load(config.resume_path, map_location="cpu") 22 | checkpoint = checkpoint["model"] 23 | self.ranker.load_state_dict(checkpoint) 24 | self.ranker.to(self.device) 25 | self.decode_fn = self.decode_with_ranking 26 | self.stop = "\n" 27 | elif config.decode_strategy == "beam_search": 28 | self.ranker = Ranker(config) 29 | # self.ranker.from_pretrained(config.resume_path) 30 | checkpoint = torch.load(config.resume_path, map_location="cpu") 31 | checkpoint = checkpoint["model"] 32 | self.ranker.load_state_dict(checkpoint) 33 | self.ranker.to(self.device) 34 | self.decode_fn = self.beam_search 35 | self.stop = "\n" 36 | elif config.decode_strategy == "beam_search_with_fact": 37 | self.ranker = Ranker(config) 38 | # self.ranker.from_pretrained(config.resume_path) 39 | checkpoint = torch.load(config.resume_path, map_location="cpu") 40 | checkpoint = checkpoint["model"] 41 | self.ranker.load_state_dict(checkpoint) 42 | self.ranker.to(self.device) 43 | self.decode_fn = self.beam_search_with_fact 44 | self.stop = "\n" 45 | else: 46 | self.decode_fn = self.decode_direct 47 | self.stop = None 48 | 49 | if "gpt" in config.decoder_path: 50 | self.model = ChatGPTPool( 51 | num_sampling = config.num_sampling, 52 | temperature = 1.0, 53 | max_tokens = 256, 54 | stop = self.stop, 55 | ) 56 | else: 57 | if config.is_greedy: 58 | self.sampling_params = SamplingParams(n=1, temperature=0, top_p=1, max_tokens=512) 59 | else: 60 | self.sampling_params = SamplingParams(n=self.num_sampling, temperature=1, top_p=0.95, max_tokens=512, stop=self.stop) 61 | self.model = LLM(model=config.decoder_path, seed=42, max_num_batched_tokens=4096, tensor_parallel_size=config.num_gpus_decode, gpu_memory_utilization=0.85) # , dtype="float32" , tensor_parallel_size=2 62 | 63 | # TODO: batch inference 64 | def decode(self, question): 65 | return self.decode_fn(question) 66 | 67 | @torch.inference_mode() 68 | def decode_with_ranking(self, question): 69 | raise NotImplementedError 70 | 71 | @torch.inference_mode() 72 | def decode_direct(self, question): 73 | raise NotImplementedError 74 | 75 | @torch.inference_mode() 76 | def beam_search(self, question): 77 | raise NotImplementedError 78 | 79 | @torch.inference_mode() 80 | def beam_search_with_fact(self, question): 81 | raise NotImplementedError 82 | 83 | class FastDecoder(Decoder): 84 | def __init__(self, config) -> None: 85 | super().__init__(config) 86 | 87 | @torch.inference_mode() 88 | def decode_with_ranking(self, question): 89 | answers = [] 90 | context = "\nQuestion:\n" + question + "\nAnswer:\n" 91 | terminated = False 92 | step = 0 93 | rationales = [] 94 | while not terminated and step < 20: 95 | outputs = self.model.generate(self.prompt + context, self.sampling_params, use_tqdm=False)[0].outputs 96 | answers = [] 97 | for output in outputs: 98 | answers.append(output.text.strip().split("\n")[0].strip()) 99 | contexts = [context] * len(answers) 100 | scores = self.ranker(contexts, answers).cpu().squeeze() 101 | sorted, indices = torch.sort(scores, descending=True) 102 | rationale = answers[indices[0]] 103 | rationales.append(rationale) 104 | context += f"{rationale}\n" 105 | if "Final Answer" in rationale: 106 | terminated = True 107 | step += 1 108 | return rationales 109 | 110 | @torch.inference_mode() 111 | def decode_direct(self, question): 112 | context = "\nQuestion:\n" + question + "\nAnswer:\n" 113 | outputs = self.model.generate(self.prompt + context, self.sampling_params, use_tqdm=False)[0].outputs 114 | answers = [] 115 | for output in outputs: 116 | answers.append(output.text.strip()) 117 | return answers 118 | 119 | @torch.inference_mode() 120 | def beam_search(self, question): 121 | self.sampling_params = SamplingParams(n=self.num_beams, temperature=1, top_p=0.95, max_tokens=512) 122 | answers = [] 123 | context = "\nQuestion:\n" + question + "\nAnswer:\n" 124 | global_terminated = False 125 | step = 0 126 | beams = [(context, 1, False)] * self.num_beams # (current rationale, score, terminated) 127 | completed_rationales = [] 128 | while not global_terminated and step < 20: 129 | current_beams = [] 130 | for beam, score, terminated in beams: 131 | # if terminated, leave it alone 132 | if terminated: 133 | current_beams.append((beam, score, terminated)) 134 | continue 135 | 136 | # otherwise, generate next rationale 137 | outputs = self.model.generate(self.prompt + beam, self.sampling_params, use_tqdm=False)[0].outputs 138 | answers = [] 139 | for output in outputs: 140 | answers.append(output.text.strip().split("\n")[0].strip()) 141 | contexts_for_ranker = [beam] * len(answers) 142 | scores = self.ranker(contexts_for_ranker, answers).cpu().squeeze() 143 | sorted_scores, indices = torch.sort(scores, descending=True) 144 | 145 | # calculate current score 146 | for _ in range(self.num_beams): 147 | current_beam = beam + answers[indices[_]] + "\n" 148 | current_score = score * scores[indices[_]] 149 | # if termintated, add to completed rationales 150 | if "Final Answer" in answers[indices[_]] or "Final answer" in answers[indices[_]]: 151 | terminated = True 152 | completed_rationales.append((current_beam, current_score.item())) 153 | current_beams.append((current_beam, current_score.item(), terminated)) 154 | sorted_beams = sorted(current_beams, key=lambda x: x[1], reverse=True) 155 | beams = sorted_beams[:self.num_beams] 156 | flag = True 157 | for _ , _, terminated in beams: 158 | if not terminated: 159 | flag = False 160 | break 161 | global_terminated = flag 162 | step += 1 163 | 164 | return beams, completed_rationales 165 | 166 | @torch.inference_mode() 167 | def beam_search_with_fact(self, question): 168 | answers = [] 169 | context = "\nQuestion:\n" + question + "\nAnswer:\n" 170 | if_verify = False 171 | step = 0 172 | while not if_verify and step < 10: 173 | fact_sampling_params = SamplingParams(n=1, temperature=0, top_p=1, max_tokens=512, stop="\n") 174 | outputs = self.model.generate(self.prompt + context, fact_sampling_params, use_tqdm=False)[0].outputs 175 | context += outputs[0].text 176 | context += "\n" 177 | if "Reasoning" in outputs[0].text: 178 | if_verify = True 179 | step += 1 180 | 181 | global_terminated = False 182 | beams = [(context, 1, False)] * self.num_beams # (current rationale, score, terminated) 183 | completed_rationales = [] 184 | 185 | while not global_terminated and step < 10: 186 | current_beams = [] 187 | for beam, score, terminated in beams: 188 | # if terminated, leave it alone 189 | if terminated: 190 | current_beams.append((beam, score, terminated)) 191 | continue 192 | 193 | # otherwise, generate next rationale 194 | outputs = self.model.generate(self.prompt + beam, self.sampling_params, use_tqdm=False)[0].outputs 195 | answers = [] 196 | for output in outputs: 197 | answers.append(output.text.strip().split("\n")[0].strip()) 198 | contexts_for_ranker = [beam] * len(answers) 199 | scores = self.ranker(contexts_for_ranker, answers).cpu().squeeze() 200 | sorted_scores, indices = torch.sort(scores, descending=True) 201 | 202 | # calculate current score 203 | for _ in range(self.num_beams): 204 | current_beam = beam + answers[indices[_]] + "\n" 205 | current_score = score * scores[indices[_]] 206 | # if termintated, add to completed rationales 207 | if "Final Answer" in answers[indices[_]] or "Final answer" in answers[indices[_]]: 208 | terminated = True 209 | completed_rationales.append((current_beam, current_score.item())) 210 | current_beams.append((current_beam, current_score.item(), terminated)) 211 | 212 | sorted_beams = sorted(current_beams, key=lambda x: x[1], reverse=True) 213 | beams = sorted_beams[:self.num_beams] 214 | flag = True 215 | for _ , _, terminated in beams: 216 | if not terminated: 217 | flag = False 218 | break 219 | global_terminated = flag 220 | step += 1 221 | 222 | return beams, completed_rationales 223 | 224 | class ChatGPTDecoder(Decoder): 225 | def __init__(self, config) -> None: 226 | super().__init__(config) 227 | 228 | @torch.inference_mode() 229 | def decode_direct(self, question): 230 | context = "\nQuestion:\n" + question + "\nAnswer:\n" 231 | outputs = self.model.chat_single_round(self.prompt + context) 232 | answers = [] 233 | for output in outputs: 234 | answers.append(output.strip()) 235 | return answers 236 | 237 | @torch.inference_mode() 238 | def beam_search(self, question): 239 | answers = [] 240 | context = "\nQuestion:\n" + question + "\nAnswer:\n" 241 | global_terminated = False 242 | step = 0 243 | beams = [(context, 1, False)] * self.num_beams # (current rationale, score, terminated) 244 | completed_rationales = [] 245 | while not global_terminated and step < 50: 246 | current_beams = [] 247 | for beam, score, terminated in beams: 248 | # if terminated, leave it alone 249 | if terminated: 250 | current_beams.append((beam, score, terminated)) 251 | continue 252 | 253 | # otherwise, generate next rationale 254 | outputs = self.model.chat_single_round(self.prompt + beam) 255 | answers = [] 256 | for output in outputs: 257 | answer = output.strip().split("\n")[0].strip() 258 | if len(answer) < 1: 259 | continue 260 | answers.append(answer) 261 | contexts_for_ranker = [beam] * len(answers) 262 | scores = self.ranker(contexts_for_ranker, answers).cpu().squeeze() 263 | sorted_scores, indices = torch.sort(scores, descending=True) 264 | 265 | # calculate current score 266 | for _ in range(min(len(answers), self.num_beams)): 267 | current_beam = beam + answers[indices[_]] + "\n" 268 | current_score = score * scores[indices[_]] 269 | # if termintated, add to completed rationales 270 | if "Final Answer" in answers[indices[_]] or "Final answer" in answers[indices[_]]: 271 | terminated = True 272 | completed_rationales.append((current_beam, current_score.item())) 273 | current_beams.append((current_beam, current_score.item(), terminated)) 274 | sorted_beams = sorted(current_beams, key=lambda x: x[1], reverse=True) 275 | beams = sorted_beams[:self.num_beams] 276 | flag = True 277 | for _ , _, terminated in beams: 278 | if not terminated: 279 | flag = False 280 | break 281 | global_terminated = flag 282 | step += 1 283 | return (beams, completed_rationales) 284 | 285 | @torch.inference_mode() 286 | def beam_search_with_fact(self, question): 287 | answers = [] 288 | context = "\nQuestion:\n" + question + "\nAnswer:\n" 289 | if_verify = False 290 | step = 0 291 | self.model.num_sampling = 1 292 | while not if_verify and step < 10: 293 | self.model.num_sampling = 1 294 | outputs = self.model.chat_single_round(self.prompt + context) 295 | context += outputs[0] 296 | context += "\n" 297 | if "Reasoning" in outputs[0]: 298 | if_verify = True 299 | 300 | self.model.num_sampling = self.num_sampling 301 | global_terminated = False 302 | beams = [(context, 1, False)] * self.num_beams # (current rationale, score, terminated) 303 | completed_rationales = [] 304 | while not global_terminated and step < 20: 305 | current_beams = [] 306 | for beam, score, terminated in beams: 307 | # if terminated, leave it alone 308 | if terminated: 309 | current_beams.append((beam, score, terminated)) 310 | continue 311 | 312 | # otherwise, generate next rationale 313 | outputs = self.model.chat_single_round(self.prompt + beam) 314 | answers = [] 315 | for output in outputs: 316 | answers.append(output.strip().split("\n")[0].strip()) 317 | contexts_for_ranker = [beam] * len(answers) 318 | scores = self.ranker(contexts_for_ranker, answers).cpu().squeeze() 319 | sorted_scores, indices = torch.sort(torch.tensor(scores), descending=True) 320 | 321 | # calculate current score 322 | for _ in range(self.num_beams): 323 | current_beam = beam + answers[indices[_]] + "\n" 324 | current_score = score * scores[indices[_]] 325 | # if termintated, add to completed rationales 326 | if "Final Answer" in answers[indices[_]] or "Final answer" in answers[indices[_]]: 327 | terminated = True 328 | completed_rationales.append((current_beam, current_score.item())) 329 | current_beams.append((current_beam, current_score.item(), terminated)) 330 | sorted_beams = sorted(current_beams, key=lambda x: x[1], reverse=True) 331 | beams = sorted_beams[:self.num_beams] 332 | flag = True 333 | for _ , _, terminated in beams: 334 | if not terminated: 335 | flag = False 336 | break 337 | global_terminated = flag 338 | step += 1 339 | return (beams, completed_rationales) -------------------------------------------------------------------------------- /src/generate_data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import torch 4 | import random 5 | import argparse 6 | from tqdm import tqdm 7 | from datasets import load_dataset 8 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, GenerationConfig 9 | from vllm import LLM, SamplingParams 10 | 11 | from src.decoder import Decoder 12 | from src.parser_utils import get_parser 13 | from prompts.gsm8k import prompt 14 | from src.model import Ranker 15 | 16 | def generate_correct_rationales(): 17 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") 18 | model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to(torch.device("cuda:0")) 19 | datasets = load_dataset("gsm8k", "main") 20 | dataset = datasets["train"] 21 | pattern = r"<<(.*?)>>" 22 | datas = [] 23 | prompt = "Please rephrase the following sentence and keep the equation between <<>> unchanged. For exaxmple:\n{example}" 24 | # model = ChatGPTAPI(prompt) 25 | count_done = 0 26 | with open("data/gsm8k/gsm8k.txt", "r", encoding="utf-8") as fin: 27 | for _ in fin.readlines(): 28 | count_done += 1 29 | sid = 0 30 | for sample in tqdm(dataset): 31 | if sid < count_done: 32 | sid += 1 33 | continue 34 | # model.reset() 35 | question = sample["question"] 36 | gold_answer = sample["answer"] 37 | final_answer = gold_answer.split("####")[-1].strip() 38 | # equations = re.findall(pattern, gold_answer) 39 | rationales = gold_answer.split("####")[0].split("\n")[:-1] 40 | diverse_rationales = [] 41 | for r in rationales: 42 | model.reset() 43 | # print(r) 44 | rs = [r] 45 | state = f"{prompt}\n{r}" 46 | for _ in range(4): 47 | answer = model.chat_multi_round(state) 48 | # print(answer) 49 | rs.append(answer) 50 | state = None 51 | diverse_rationales.append(rs) 52 | write_data = { 53 | "question": question, 54 | "rationales": diverse_rationales, 55 | "chain_length": len(rationales), 56 | "final_answer": final_answer, 57 | } 58 | with open("data/gsm8k/gsm8k.txt", "a+", encoding="utf-8") as fout: 59 | fout.write(f"{json.dumps(write_data)}\n") 60 | 61 | def confuse(rationales): 62 | """ 63 | Generate false samples for given rationales 64 | """ 65 | 66 | pattern = r"([\$A-Za-z0-9\%\.]+\s*[\+\-\*\/x])+\s*[\$A-Za-z0-9\%\.]+\s*(=\s*[\$A-Za-z0-9\%\.]+)" 67 | number_set = set() 68 | for rs in rationales: 69 | rationale = rs[0] 70 | if rationale.endswith("."): 71 | rationale = rationale[:-1] 72 | equation = re.search(pattern, rationale) 73 | if equation is None: 74 | continue 75 | equation = equation.group() 76 | eq = equation.split("=")[0].strip() 77 | elements = re.split(r"[\+\-\/\*x]", eq) 78 | for e in elements: 79 | number_set.add(e.strip()) 80 | number_set.add(equation.split("=")[-1].strip()) 81 | number_set = list(number_set) 82 | print(number_set) 83 | false_rationales = {} 84 | for rid, rs in enumerate(rationales): 85 | f_rs = [] 86 | for r in rs: 87 | print(r) 88 | if r.endswith("."): 89 | r = r[:-1] 90 | equation = re.search(pattern, r) 91 | # get some following rationales 92 | following_rationales = [] 93 | for fid, tmp_rs in enumerate(rationales): 94 | if fid <= rid: 95 | continue 96 | following_rationales.extend(tmp_rs) 97 | print(len(following_rationales)) 98 | if equation is None: # if there is no eqaution, continue 99 | continue 100 | # if there is an equation, replace some elements in it 101 | e_start = equation.start() 102 | e_end = equation.end() 103 | equation = equation.group() 104 | eq = equation.split("=")[0].strip() 105 | symbols = re.findall(r"[\+\-\/\*x]", eq) 106 | elements = re.split(r"[\+\-\/\*x]", eq) 107 | elements = [ele.strip() for ele in elements] 108 | count = 0 109 | while count < 10: 110 | try: 111 | replaced_index = random.choice(range(len(elements))) 112 | print(elements[replaced_index]) 113 | tmp_number_set = [n for n in number_set] 114 | tmp_number_set.remove(elements[replaced_index]) 115 | elements[replaced_index] = random.choice(tmp_number_set) 116 | false_eq = f"{elements[0]}" 117 | for sid, symbol in enumerate(symbols): 118 | if symbol == "x": 119 | symbol = "*" 120 | false_eq += f"{symbol}{elements[sid + 1]}" 121 | try: 122 | print(false_eq) 123 | value = eval(false_eq) 124 | f_r = r[:e_start - 1] + f" {false_eq} = {value} " + r[e_end:] 125 | f_rs.append(f_r) 126 | except: 127 | print(false_eq) 128 | value = equation.split("=")[-1].strip() 129 | f_r = r[:e_start - 1] + f" {false_eq} = {value} " + r[e_end:] 130 | f_rs.append(f_r) 131 | break 132 | except: 133 | count += 1 134 | continue 135 | false_rationales[rid] = f_rs 136 | return false_rationales 137 | 138 | def generate_ranking_negative_samples(): 139 | """ 140 | 1. gold answer 141 | 2. gold answer but change one number 142 | 2. similar answer that change one number 143 | 3. other rationales from different question 144 | """ 145 | def get_some_random_samples(excluded_id): 146 | id_list = list(range(len(id2false_rationales))) 147 | id_list.remove(excluded_id) 148 | id_list = random.sample(id_list, 1) 149 | selected_false_rationales = [] 150 | for sid in id_list: 151 | false_rationales = None 152 | while false_rationales is None or len(false_rationales) == 0: 153 | false_rationales = id2false_rationales.get(sid) 154 | sid += 1 155 | s_false_rationales = [] 156 | for v in false_rationales.values(): 157 | s_false_rationales.extend(v) 158 | selected_false_rationales.append(random.choice(s_false_rationales)) 159 | return selected_false_rationales 160 | 161 | 162 | id2false_rationales = {} 163 | with open("data/gsm8k/false_rationales.txt", "r") as fin: 164 | for line in fin.readlines(): 165 | data = json.loads(line.strip()) 166 | id2false_rationales.update({data["id"]: data["false_rationales"]}) 167 | samples = {} 168 | with open("data/gsm8k/gsm8k_clean.txt", "r") as fin: 169 | for line in fin.readlines(): 170 | data = json.loads(line.strip()) 171 | metadata = {**data} 172 | samples.update({data["id"]: metadata}) 173 | with open("data/gsm8k/train_short.json", "r") as fin: 174 | shortened_contexts = json.load(fin) 175 | answer2context = {} 176 | for sample in shortened_contexts: 177 | answer2context.update({sample["answer"]: sample["context"]}) 178 | datasets = load_dataset("gsm8k", "main") 179 | dataset = datasets["train"] 180 | write_datas = [] 181 | for sid, sample in enumerate(dataset): 182 | false_rationales = id2false_rationales[sid] 183 | correct_metadata = samples[sid] 184 | correct_rationales = correct_metadata["rationales"] 185 | chain_length = correct_metadata["chain_length"] 186 | context = correct_metadata["question"] 187 | for i in range(chain_length): 188 | i_correct_rationales = correct_rationales[i] 189 | i_false_rationales = false_rationales.get(str(i)) 190 | if i_false_rationales is None or len(i_false_rationales) <= 1: 191 | continue 192 | cr = i_correct_rationales[0] 193 | answers = { 194 | 0: cr, 195 | 1: i_false_rationales[0], 196 | 2: random.choice(i_false_rationales[1:]), 197 | 3: get_some_random_samples(sid)[0] 198 | } 199 | write_datas.append({ 200 | "context": context, 201 | "answers": answers, 202 | }) 203 | context += f" {cr}" 204 | print(len(write_datas)) # 111280 205 | with open("data/gsm8k/train_ranking_small_full_multi_negative.json", "w") as fout: 206 | json.dump(write_datas, fout, indent=1, ensure_ascii=False) 207 | 208 | 209 | @torch.inference_mode() 210 | def generate_hard_negative(): 211 | torch.manual_seed(42) 212 | random.seed(42) 213 | 214 | parser = get_parser() 215 | config = parser.parse_args() 216 | 217 | config.device = torch.device("cuda") 218 | config.generation_config = GenerationConfig( 219 | do_sample=True, 220 | temperature=1, 221 | top_p=0.95, 222 | max_new_tokens=512, 223 | ) 224 | config.prompt = prompt 225 | 226 | sampling_params = SamplingParams(n=10, temperature=1, top_p=0.95, max_tokens=512) 227 | 228 | # decoder = Decoder(config) 229 | 230 | dataset = load_dataset("data/cache/meta-math___json/meta-math--MetaMathQA-b6af0a8ce3115a0e")["train"] 231 | datas = [] 232 | all_rationales = [] 233 | for sample in dataset: 234 | if sample["type"].startswith("GSM"): 235 | response = sample["response"].split("\n")[:-1] 236 | final_rationale = response[-1].replace("####", "").strip() 237 | response[-1] = f"Final Answer: {final_rationale}" 238 | sample["response"] = response 239 | datas.append(sample) 240 | all_rationales.extend(response) 241 | 242 | llm = LLM(model="~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/08751db2aca9bf2f7f80d2e516117a53d7450235/", seed=42) # , dtype="float32" 243 | ranker = Ranker(config) 244 | checkpoint = torch.load("ckpts/gsm8k/gsm8k_ranking_multi_neg_margin.pt", map_location="cpu") 245 | checkpoint = checkpoint["model"] 246 | del(checkpoint["model.embeddings.position_ids"]) 247 | ranker.load_state_dict(checkpoint) 248 | ranker.cuda() 249 | ranker.eval() 250 | 251 | progressbar = tqdm(range(len(datas))) 252 | for did, data in enumerate(datas): 253 | question = data["query"] 254 | answer = data["response"] 255 | context = "\nQuestion:\n" + question + "\nAnswer:\n" 256 | for _ in range(len(answer)): 257 | outputs = llm.generate(prompt + context, sampling_params, use_tqdm=False)[0].outputs 258 | generated_answers = [] 259 | for output in outputs: 260 | generated_answers.append(output.text.strip().split("\n")[0].strip()) 261 | contexts = [context] * len(generated_answers) 262 | scores = ranker(contexts, generated_answers).cpu().squeeze() 263 | # scores = torch.tensor(scores).squeeze() 264 | sorted, indices = torch.sort(scores, descending=True) 265 | # print(answers) 266 | # input() 267 | write_data = { 268 | "sid": did, 269 | "context": context.strip(), 270 | "answers": { 271 | "0": answer[_], 272 | "1": generated_answers[indices[-1]], 273 | "2": random.choice(all_rationales), 274 | } 275 | } 276 | with open("data/gsm8k/hard_negative_metamathqa.txt", "a+") as fout: 277 | fout.write(f"{json.dumps(write_data)}\n") 278 | context += answer[_] + "\n" 279 | # print(context) 280 | # input() 281 | progressbar.update(1) 282 | 283 | if __name__ == "__main__": 284 | generate_hard_negative() -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoModel, AutoTokenizer 4 | 5 | class Verifier(nn.Module): 6 | def __init__(self, config) -> None: 7 | super().__init__() 8 | self.config = config 9 | self.tokenizer = AutoTokenizer.from_pretrained(config.verifier_model_path) 10 | self.model = AutoModel.from_pretrained(config.verifier_model_path) 11 | 12 | self.decision_layers = nn.Sequential( 13 | nn.Linear(self.model.config.hidden_size, config.hidden_size), 14 | nn.ReLU(), 15 | nn.Linear(config.hidden_size, 1), # 0 for neutral, 1 for entail 16 | # nn.Softmax(), 17 | nn.Sigmoid(), 18 | ) 19 | 20 | def forward(self, contexts, rationales): 21 | inputs = [f"[CLS]{c}[SEP]{r}[SEP]" for c, r in zip(contexts, rationales)] 22 | tokenized = self.tokenizer( 23 | inputs, 24 | add_special_tokens=False, 25 | max_length=self.config.max_length, 26 | padding="max_length", 27 | truncation=True, 28 | return_tensors="pt", 29 | ) 30 | input_ids = tokenized.input_ids.to(self.config.device) 31 | attention_mask = tokenized.attention_mask.to(self.config.device) 32 | encoded = self.model(input_ids, attention_mask).last_hidden_state[:, 0, :] 33 | # print(encoded) 34 | logits = self.decision_layers(encoded) 35 | return logits 36 | 37 | def train(self, contexts, rationales, labels): 38 | logits = self.forward(contexts, rationales).squeeze() 39 | # print(logits) 40 | # print(labels) 41 | # input() 42 | return logits, torch.tensor(labels).float().to(logits.device) 43 | 44 | def from_pretrained(self, model_path): 45 | checkpoint = torch.load(model_path, map_location="cpu") 46 | pretrained_dict = checkpoint["model"] 47 | model_dict = self.model.state_dict() 48 | pretrained_dict = {k.replace("model.", ""): v for k, v in pretrained_dict.items() if k.replace("model.", "") in model_dict} 49 | self.model.load_state_dict(pretrained_dict) 50 | 51 | pretrained_dict = checkpoint["model"] 52 | model_dict = self.decision_layers.state_dict() 53 | pretrained_dict = {k.replace("decision_layers.", ""): v for k, v in pretrained_dict.items() if k.replace("decision_layers.", "") in model_dict} 54 | self.decision_layers.load_state_dict(pretrained_dict) 55 | return self 56 | 57 | class Ranker(nn.Module): 58 | def __init__(self, config) -> None: 59 | super().__init__() 60 | self.config = config 61 | self.tokenizer = AutoTokenizer.from_pretrained(config.verifier_model_path) 62 | self.model = AutoModel.from_pretrained(config.verifier_model_path) 63 | 64 | self.decision_layers = nn.Sequential( 65 | nn.Linear(self.model.config.hidden_size, config.hidden_size), 66 | nn.ReLU(), 67 | nn.Linear(config.hidden_size, 1), # 0 for neutral, 1 for entail 68 | nn.Sigmoid(), 69 | ) 70 | 71 | def forward(self, contexts, rationales): 72 | inputs = [f"[CLS]{c}[SEP]{r}[SEP]" for c, r in zip(contexts, rationales)] 73 | tokenized = self.tokenizer( 74 | inputs, 75 | add_special_tokens=False, 76 | max_length=self.config.max_length, 77 | padding="max_length", 78 | truncation=True, 79 | return_tensors="pt", 80 | ) 81 | input_ids = tokenized.input_ids.to(self.config.device) 82 | attention_mask = tokenized.attention_mask.to(self.config.device) 83 | encoded = self.model(input_ids, attention_mask).last_hidden_state[:, 0, :] 84 | # print(encoded) 85 | logits = self.decision_layers(encoded) 86 | return logits 87 | 88 | def forward_train(self, contexts, rationales, labels): 89 | logits = self.forward(contexts, rationales).squeeze() 90 | # print(logits) 91 | # print(labels) 92 | # input() 93 | return logits, torch.tensor(labels).to(logits.device) 94 | 95 | def from_pretrained(self, model_path): 96 | checkpoint = torch.load(model_path, map_location="cpu") 97 | pretrained_dict = checkpoint["model"] 98 | model_dict = self.model.state_dict() 99 | pretrained_dict = {k.replace("model.", ""): v for k, v in pretrained_dict.items() if k.replace("model.", "") in model_dict} 100 | self.model.load_state_dict(pretrained_dict) 101 | 102 | pretrained_dict = checkpoint["model"] 103 | model_dict = self.decision_layers.state_dict() 104 | pretrained_dict = {k.replace("decision_layers.", ""): v for k, v in pretrained_dict.items() if k.replace("decision_layers.", "") in model_dict} 105 | self.decision_layers.load_state_dict(pretrained_dict) 106 | return self -------------------------------------------------------------------------------- /src/openai_api_mp.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 3 | import openai 4 | import tiktoken 5 | from openai.error import InvalidRequestError 6 | from typing import List, Dict, AnyStr 7 | from tenacity import stop_after_attempt, retry, wait_random_exponential, retry_if_not_exception_type 8 | from multiprocessing import Pool, Value 9 | 10 | def openai_setup(api_key = None): 11 | if api_key: 12 | openai.api_key = api_key 13 | else: 14 | openai.api_key = os.environ['openai_api_key'] 15 | 16 | class ChatGPTPool: 17 | def __init__( 18 | self, 19 | api_key = "sk-your_api_key", 20 | prompt: AnyStr = None, 21 | history_messages: List[Dict[AnyStr, AnyStr]] = None, 22 | num_sampling = 1, 23 | temperature = 1.0, 24 | max_tokens = 256, 25 | stop = "\n", 26 | ) -> None: 27 | openai_setup(api_key=api_key) 28 | if history_messages: 29 | self.messages = history_messages 30 | else: 31 | self.messages = [] 32 | if prompt: 33 | msg = {"role": "system","content": prompt} 34 | self.messages.append(msg) 35 | 36 | self.money = 0 37 | self.num_sampling = num_sampling 38 | self.temperature = temperature 39 | self.max_tokens = max_tokens 40 | self.stop = stop 41 | self.prompt = prompt 42 | self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-instruct") 43 | 44 | def check_cost(self): 45 | return self.money 46 | 47 | def chat_single_round( 48 | self, 49 | message, 50 | ): 51 | num_tokens = 0 52 | messages = [] 53 | if self.prompt: 54 | messages.append({"role": "system", "content": self.prompt}) 55 | num_tokens += len(self.encoding.encode(self.prompt)) 56 | e = {"role": "user", "content": message} 57 | messages.append(e) 58 | num_tokens += len(self.encoding.encode(message)) 59 | self.money += 0.0015 / 1000 * num_tokens * self.num_sampling 60 | 61 | with Pool(self.num_sampling) as p: 62 | contents = p.map(chat_single_round, [(messages, self.temperature, self.max_tokens, self.stop)] * self.num_sampling) 63 | 64 | num_tokens = sum([len(self.encoding.encode(content)) for content in contents]) 65 | self.money += 0.002 / 1000 * num_tokens 66 | return contents 67 | 68 | 69 | @retry(stop=stop_after_attempt(5), wait=wait_random_exponential(min=1, max=60), retry=retry_if_not_exception_type(InvalidRequestError)) 70 | def chat_single_round( 71 | messages, 72 | ): 73 | temperature = messages[1] 74 | max_tokens = messages[2] 75 | stop = messages[3] 76 | messages = messages[0] 77 | 78 | prompt = [] 79 | for message in messages: 80 | prompt.append(message["content"]) 81 | prompt = "\n".join(prompt) 82 | completion = openai.Completion.create( 83 | model = 'gpt-3.5-turbo-instruct', 84 | prompt = prompt, 85 | temperature = temperature, 86 | max_tokens = max_tokens, 87 | stop = stop, 88 | ) 89 | content = completion.choices[0].text 90 | return content 91 | -------------------------------------------------------------------------------- /src/parser_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | def get_parser(): 4 | parser = ArgumentParser() 5 | parser.add_argument("--experiment_name", default="gsm8k") 6 | 7 | parser.add_argument("--train_datapath", default="data/gsm8k/train_ranking_small_full.json") 8 | parser.add_argument("--valid_datapath", default=None) 9 | parser.add_argument("--test_datapath", default=None) 10 | 11 | parser.add_argument("--save_dir", default="ckpts/gsm8k") 12 | parser.add_argument("--resume_path", default=None) 13 | 14 | parser.add_argument("--verifier_model_path", default="data/cache/deberta/64a8c8eab3e352a784c658aef62be1662607476f") 15 | parser.add_argument("--hidden_size", type=int, default=1024) 16 | 17 | parser.add_argument("--num_workers", type=int, default=8) 18 | parser.add_argument("--max_length", type=int, default=512) 19 | parser.add_argument("--dense_dim", type=int, default=768) 20 | parser.add_argument("--batch_size", type=int, default=16) 21 | parser.add_argument("--num_epochs", type=int, default=10) 22 | parser.add_argument("--num_warmup_steps", type=int, default=200) 23 | parser.add_argument("--learning_rate", type=float, default=5e-5) 24 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8) 25 | parser.add_argument("--warmup_epochs", type=int, default=20) 26 | parser.add_argument("--lr_step_size", type=int, default=5) 27 | parser.add_argument("--lr_gamma", type=float, default=0.1) 28 | parser.add_argument("--weight_decay", type=int, default=0.0005) 29 | 30 | parser.add_argument("--min_margin", type=float, default=0.1) 31 | parser.add_argument("--max_margin", type=float, default=0.3) 32 | parser.add_argument("--margin_increase_step", type=int, default=2000) 33 | 34 | parser.add_argument("--use_gpu", action="store_true", default=True) 35 | parser.add_argument("--device") 36 | 37 | parser.add_argument("--test_dataset", type=str, choices=["gsm8k", "svamp", "aqua", "addsub", "singleeq", "multiarith", "strategyqa", "csqa", "llc", "coin"]) 38 | parser.add_argument("--prompt", type=str, default="") 39 | parser.add_argument("--decoder_path", type=str, default="meta-llama/Llama-2-7b-chat-hf") 40 | parser.add_argument("--num_sampling", type=int, default=1) 41 | parser.add_argument("--num_beams", type=int, default=1) 42 | parser.add_argument("--decode_strategy", type=str, default="direct") 43 | parser.add_argument("--shorten_context", action="store_true") 44 | parser.add_argument("--is_greedy", action="store_true", default=False) 45 | parser.add_argument("--num_gpus_decode", type=int, default=1) 46 | parser.add_argument("--output_dir", default="outputs/") 47 | 48 | return parser 49 | 50 | # python train_gsm8k.py --batch_size 3 --save_dir ckpts/debug --train_datapath data/gsm8k/train_ranking_small_full.json --experiment_name gsm8k_ranking_small_full_low_lr -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | from datasets import load_dataset 6 | 7 | from src.decoder import FastDecoder, ChatGPTDecoder 8 | from src.parser_utils import get_parser 9 | from prompts.gsm8k import prompt as math_prompt 10 | from prompts.strategyqa import prompt as cs_prompt 11 | from prompts.csqa import prompt as csqa_prompt 12 | from prompts.coin import prompt as coin_prompt 13 | from prompts.strategyqa import recall_prompt as cs_recall_prompt 14 | from prompts.csqa import recall_prompt as csqa_recall_prompt 15 | from tqdm import tqdm 16 | 17 | MATH_DATASET = [ 18 | "gsm8k", 19 | "svamp", 20 | "aqua", 21 | "singleeq", 22 | "multiarith", 23 | ] 24 | COMMONSENSE_DATASET = [ 25 | "strategyqa", 26 | "csqa", 27 | ] 28 | SYMBOLIC_DATASET = [ 29 | "coin", 30 | ] 31 | 32 | MODEL_NAME = ["7b", "13b", "70b", "gpt"] 33 | 34 | 35 | def load_dataset_from_config(config): 36 | """ 37 | The returned datas should be a list and each element should be a dict containing keyword "quesiton". 38 | [ 39 | { 40 | "id": id, 41 | "question": question, 42 | } 43 | ] 44 | """ 45 | datas = [] 46 | if config.test_dataset == "gsm8k": 47 | dataset = load_dataset("data/cache/gsm8k", "main")["test"] 48 | for did, data in enumerate(dataset): 49 | datas.append({ 50 | "id": did, 51 | "question": data["question"], 52 | }) 53 | elif config.test_dataset == "svamp": 54 | dataset = load_dataset("data/cache/ChilleD___json/ChilleD--SVAMP-4bd8179a65d5f05b")["test"] 55 | for data in dataset: 56 | datas.append({ 57 | "id": data["ID"], 58 | "question": data["Body"] + " " + data["Question"], 59 | }) 60 | elif config.test_dataset == "aqua": 61 | dataset = [] 62 | with open("data/cache/AQuA/test.json", "r") as fin: 63 | for line in fin.readlines(): 64 | dataset.append(json.loads(line.strip())) 65 | # dataset = json.load(fin) 66 | for did, data in enumerate(dataset): 67 | datas.append({ 68 | "id": did, 69 | "question": data["question"], 70 | }) 71 | elif config.test_dataset == "addsub": 72 | dataset = load_dataset("data/cache/allenai___lila", "addsub")["test"] 73 | for did, data in enumerate(dataset): 74 | datas.append({ 75 | "id": did, 76 | "question": data["input"], 77 | }) 78 | elif config.test_dataset == "singleeq": 79 | data_dir = "data/cache/SingleEq" 80 | for i in range(5): 81 | with open(os.path.join(data_dir, f"test{i}")) as fin: 82 | for lid, line in enumerate(fin.readlines()): 83 | if lid % 3 == 0: 84 | datas.append({ 85 | "id": int(lid / 3), 86 | "question": line.strip(), 87 | }) 88 | elif config.test_dataset == "multiarith": 89 | dataset = load_dataset("data/cache/ChilleD___json/ChilleD--MultiArith-2e3d95e2a4ce9083")["test"] 90 | for did, data in enumerate(dataset): 91 | datas.append({ 92 | "id": did, 93 | "question": data["question"], 94 | }) 95 | elif config.test_dataset == "strategyqa": 96 | with open("data/cache/strategyqa/test_set.json", "r") as fin: 97 | dataset = json.load(fin) 98 | for data in dataset: 99 | datas.append({ 100 | "id": data["qid"], 101 | "question": data["question"], 102 | }) 103 | elif config.test_dataset == "csqa": 104 | dataset = load_dataset("data/cache/commonsense_qa")["validation"] 105 | for data in dataset: 106 | labels = data["choices"]["label"] 107 | texts = data["choices"]["text"] 108 | choices = [] 109 | for label, text in zip(labels, texts): 110 | choices.append(f"{label}. {text}") 111 | choices = "\n".join(choices) 112 | question = data["question"] + "\n" + choices 113 | datas.append({ 114 | "id": data["id"], 115 | "question": question, 116 | }) 117 | elif config.test_dataset == "coin": 118 | dataset = load_dataset("data/cache/skrishna___json/skrishna--coin_flip-8305ab6800b027bf")["test"] 119 | for did, data in enumerate(dataset): 120 | question = data["inputs"].replace("Q:", "").strip() 121 | datas.append({ 122 | "id": did, 123 | "question": question, 124 | }) 125 | else: 126 | print(f"[WARNING] {config.test_dataset} is not an option!") 127 | raise NotImplementedError 128 | return datas 129 | 130 | def load_prompt_from_config(config): 131 | if config.test_dataset in MATH_DATASET: 132 | prompt = math_prompt 133 | elif config.test_dataset in COMMONSENSE_DATASET: 134 | if config.test_dataset == "csqa": 135 | if "fact" in config.decode_strategy: 136 | prompt = csqa_recall_prompt 137 | else: 138 | prompt = csqa_prompt 139 | else: 140 | if "fact" in config.decode_strategy: 141 | prompt = cs_recall_prompt 142 | else: 143 | prompt = cs_prompt 144 | elif config.test_dataset in SYMBOLIC_DATASET: 145 | if config.test_dataset == "coin": 146 | prompt = coin_prompt 147 | else: 148 | print(f"[WARNING] {config.test_dataset} is not an option!") 149 | raise NotImplementedError 150 | return prompt 151 | 152 | def load_dataset_and_prompt(config): 153 | return load_dataset_from_config(config), load_prompt_from_config(config) 154 | 155 | 156 | @torch.inference_mode() 157 | def main(): 158 | torch.manual_seed(42) 159 | random.seed(42) 160 | 161 | parser = get_parser() 162 | config = parser.parse_args() 163 | 164 | config.device = torch.device("cuda") 165 | 166 | # load dataset, prompt, and decoder 167 | dataset, prompt = load_dataset_and_prompt(config) 168 | config.prompt = prompt 169 | if "gpt" in config.decoder_path: 170 | decoder = ChatGPTDecoder(config) 171 | else: 172 | decoder = FastDecoder(config) 173 | 174 | # load from previous infer checkpoint 175 | if config.resume_path is not None: 176 | ckpt_name = config.resume_path.split(".")[0].split("/")[-1] 177 | else: 178 | ckpt_name = "None" 179 | 180 | for name in MODEL_NAME: 181 | if name in config.decoder_path: 182 | model_abbr_name = name 183 | else: 184 | if "/" in config.decoder_path: 185 | model_abbr_name = config.decoder_path.split("/")[-1] 186 | else: 187 | model_abbr_name = config.decoder_path 188 | 189 | if not os.path.exists(config.output_dir): 190 | os.mkdir(config.output_dir) 191 | config.output_dir = os.path.join(config.output_dir, model_abbr_name) 192 | if not os.path.exists(config.output_dir): 193 | os.mkdir(config.output_dir) 194 | save_path = os.path.join(config.output_dir, f"{config.test_dataset}_{model_abbr_name}_{ckpt_name}_{config.decode_strategy}_N{config.num_sampling}_B{config.num_beams}.txt") 195 | if os.path.exists(save_path): 196 | with open(save_path, "r") as fin: 197 | line = fin.readlines()[-1] 198 | data = json.loads(line.strip()) 199 | for key in data.keys(): 200 | continue 201 | last_sid = int(key) 202 | else: 203 | last_sid = -1 204 | 205 | # inference 206 | progressbar = tqdm(range(len(dataset))) 207 | for sid, sample in enumerate(dataset): 208 | if sid <= last_sid: 209 | progressbar.update(1) 210 | continue 211 | question = sample["question"] 212 | answers = decoder.decode_fn(question) 213 | 214 | # save inference checkpoint 215 | if "beam_search" in config.decode_strategy: 216 | beams, full_rationales = answers 217 | with open(save_path, "a+") as fout: 218 | fout.write(f"{json.dumps({sid: beams[0]})}\n") 219 | with open(f"{save_path}.full", "a+") as fout: 220 | fout.write(f"{json.dumps({sid: full_rationales})}\n") 221 | else: 222 | with open(save_path, "a+") as fout: 223 | fout.write(f"{json.dumps({sid: answers})}\n") 224 | progressbar.update(1) 225 | 226 | if __name__ == "__main__": 227 | main() -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 3 | import torch 4 | 5 | from src.parser_utils import get_parser 6 | from src.data_utils import GSM8KRankingMultiNegativeDataset 7 | from src.trainer import RankingMultipleNegativeTrainer 8 | from src.model import Ranker 9 | 10 | 11 | def ranker_multi_negavie_main(): 12 | parser = get_parser() 13 | config = parser.parse_args() 14 | 15 | if config.train_datapath is not None: 16 | train_dataset = GSM8KRankingMultiNegativeDataset(config.train_datapath) 17 | else: 18 | train_dataset = None 19 | if config.valid_datapath is not None: 20 | valid_dataset = GSM8KRankingMultiNegativeDataset(config.valid_datapath) 21 | else: 22 | valid_dataset = None 23 | if config.test_datapath is not None: 24 | test_dataset = GSM8KRankingMultiNegativeDataset(config.test_datapath) 25 | else: 26 | test_dataset = None 27 | 28 | model = Ranker(config) 29 | 30 | trainer = RankingMultipleNegativeTrainer( 31 | config, 32 | model, 33 | train_dataset, 34 | valid_dataset, 35 | test_dataset 36 | ) 37 | 38 | trainer.train() 39 | 40 | if __name__ == "__main__": 41 | ranker_multi_negavie_main() -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import torch 5 | import logging 6 | import numpy as np 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader, random_split 10 | from torch.optim import AdamW 11 | from transformers import get_linear_schedule_with_warmup 12 | 13 | from src.data_utils import collate_fn 14 | 15 | class RankingTrainer(): 16 | def __init__(self, 17 | config, 18 | model, 19 | train_dataset=None, 20 | valid_dataset=None, 21 | test_dataset=None,): 22 | self.config = config 23 | 24 | # device config 25 | if self.config.use_gpu: 26 | self.config.device = torch.device(config.device) 27 | else: 28 | self.config.device = torch.device("cpu") 29 | 30 | # logger config 31 | logging.basicConfig( 32 | level=logging.INFO, 33 | format="%(asctime)s [%(levelname)s] %(message)s", 34 | handlers=[ 35 | logging.FileHandler(f"logs/{self.config.experiment_name}.log", mode="w"), 36 | logging.StreamHandler() 37 | ] 38 | ) 39 | self.logger = logging.getLogger(__name__) 40 | 41 | # prepare data loader and target documents 42 | if train_dataset is None and test_dataset is None: 43 | self.logger.error("At least one dataset should be passed") 44 | raise FileNotFoundError 45 | if train_dataset is not None and valid_dataset is not None: 46 | self.train_dataloader = DataLoader( 47 | train_dataset, 48 | batch_size=config.batch_size, 49 | shuffle=True, 50 | num_workers=config.num_workers, 51 | collate_fn=collate_fn, 52 | drop_last=True, 53 | ) 54 | self.logger.info(f"Training set length: {len(train_dataset)}") 55 | self.val_dataloader = DataLoader( 56 | valid_dataset, 57 | batch_size=config.batch_size, 58 | shuffle=False, 59 | num_workers=config.num_workers, 60 | collate_fn=collate_fn, 61 | drop_last=True, 62 | ) 63 | self.logger.info(f"Validation set length: {len(valid_dataset)}") 64 | elif train_dataset is not None and valid_dataset is None: 65 | datasets = random_split(train_dataset, [int(0.9 * len(train_dataset)), len(train_dataset) - int(0.9 * len(train_dataset))]) 66 | # datasets = random_split(train_dataset, [int(0.001 * len(train_dataset)), int(0.001 * len(train_dataset)), len(train_dataset) - 2 * int(0.001 * len(train_dataset))]) 67 | train_dataset = datasets[0] 68 | valid_dataset = datasets[1] 69 | self.train_dataloader = DataLoader( 70 | train_dataset, 71 | batch_size=config.batch_size, 72 | shuffle=True, 73 | num_workers=config.num_workers, 74 | collate_fn=collate_fn, 75 | drop_last=True, 76 | ) 77 | self.logger.info(f"Training set length: {len(train_dataset)}") 78 | self.val_dataloader = DataLoader( 79 | valid_dataset, 80 | batch_size=config.batch_size, 81 | shuffle=False, 82 | num_workers=config.num_workers, 83 | collate_fn=collate_fn, 84 | drop_last=True, 85 | ) 86 | self.logger.info(f"Validation set length: {len(valid_dataset)}") 87 | if test_dataset is not None: 88 | self.test_dataloader = DataLoader( 89 | test_dataset, 90 | batch_size=1, 91 | shuffle=True, 92 | num_workers=config.num_workers, 93 | collate_fn=collate_fn, 94 | drop_last=True, 95 | ) 96 | self.logger.info(f"Test set length: {len(test_dataset)}") 97 | self.logger.info("Data init done.") 98 | 99 | # prepare model, preprocessor, loss, optimizer and scheduler 100 | self.model = model 101 | # self.margin = self.config.min_margin 102 | # self.loss = nn.MarginRankingLoss(margin=self.margin) 103 | self.loss = nn.MarginRankingLoss(margin=0.2) 104 | 105 | no_decay = ["bias", "LayerNorm.weight"] 106 | optimizer_grouped_parameters = [ 107 | { 108 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 109 | "weight_decay": self.config.weight_decay, 110 | }, 111 | { 112 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 113 | "weight_decay": 0.0, 114 | }, 115 | ] 116 | self.optimizer = AdamW( 117 | optimizer_grouped_parameters, 118 | lr=self.config.learning_rate, 119 | ) 120 | 121 | if config.resume_path is None: 122 | self.logger.warning("No checkpoint given!") 123 | self.best_val_loss = 10000 124 | self.best_accu = 0 125 | self.last_epoch = -1 126 | num_training_steps = self.config.num_epochs * len(self.train_dataloader) 127 | self.scheduler = get_linear_schedule_with_warmup( 128 | optimizer=self.optimizer, 129 | num_warmup_steps=self.config.num_warmup_steps, 130 | num_training_steps=num_training_steps, 131 | last_epoch=self.last_epoch 132 | ) 133 | else: 134 | self.logger.info(f"Loading model from checkpoint: {self.config.resume_path}.") 135 | checkpoint = torch.load(self.config.resume_path, map_location=self.config.device) 136 | self.model.load_state_dict(checkpoint["model"], strict=False) 137 | self.best_val_loss = checkpoint["best_val_loss"] 138 | self.best_accu = checkpoint["best_accu"] 139 | self.last_epoch = -1 140 | num_training_steps = self.config.num_epochs * len(self.train_dataloader) 141 | self.optimizer.load_state_dict(checkpoint["optimizier"]) 142 | self.scheduler = get_linear_schedule_with_warmup( 143 | optimizer=self.optimizer, 144 | num_warmup_steps=self.config.num_warmup_steps, 145 | num_training_steps=num_training_steps, 146 | last_epoch=self.last_epoch 147 | ) 148 | self.scheduler.load_state_dict(checkpoint["scheduler"]) 149 | self.model.to(self.config.device) 150 | 151 | self.logger.info("Trainer init done.") 152 | 153 | def train_one_epoch(self): 154 | total_loss = 0 155 | num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / self.config.gradient_accumulation_steps) 156 | progress_bar = tqdm(range(num_update_steps_per_epoch)) 157 | for step, batch in enumerate(self.train_dataloader): 158 | contexts, answers, false_answers = batch 159 | logits = self.model(contexts, answers).squeeze() 160 | false_logits = self.model(contexts, false_answers).squeeze() 161 | labels = torch.tensor([1] * len(contexts)).float().to(self.config.device) 162 | loss = self.loss(logits, false_logits, labels) 163 | 164 | total_loss += loss.detach().cpu().item() 165 | loss.backward() 166 | if (step != 0 and step % self.config.gradient_accumulation_steps == 0) or step == len(self.train_dataloader) - 1: 167 | self.optimizer.step() 168 | self.scheduler.step() 169 | self.optimizer.zero_grad() 170 | progress_bar.set_postfix( 171 | { 172 | 'loss': total_loss / (step + 1), 173 | } 174 | ) 175 | progress_bar.update(1) 176 | if self.config.margin_increase_step > 0 and step != 0 and step % self.config.margin_increase_step == 0 and self.margin < self.config.max_margin: 177 | self.margin += 0.1 178 | self.loss = nn.MarginRankingLoss(self.margin) 179 | return total_loss / (step + 1) 180 | 181 | def train(self): 182 | for epoch in range(self.config.num_epochs): 183 | self.logger.info("========================") 184 | self.logger.info("Training...") 185 | epoch_loss = self.train_one_epoch() 186 | self.logger.info(f"Epoch {epoch} training loss: {epoch_loss}") 187 | self.logger.info("Validating...") 188 | val_loss, val_accu = self.validate() 189 | self.logger.info(f"Epoch {epoch} validation loss: {val_loss}") 190 | self.logger.info(f"Epoch {epoch} validation accuracy: {val_accu}") 191 | # if val_loss < self.best_val_loss: 192 | if val_accu > self.best_accu: 193 | self.best_val_loss = val_loss 194 | self.best_accu = val_accu 195 | checkpoint = { 196 | "model": self.model.state_dict(), 197 | "optimizier": self.optimizer.state_dict(), 198 | "scheduler": self.scheduler.state_dict(), 199 | "last_epoch": self.last_epoch + epoch + 1, 200 | "best_val_loss": self.best_val_loss, 201 | "best_accu": self.best_accu, 202 | "config": self.config, 203 | } 204 | save_path = os.path.join(self.config.save_dir, f"{self.config.experiment_name}.pt") 205 | self.logger.info(f"Saving best checkpoint to {save_path}") 206 | torch.save(checkpoint, save_path) 207 | 208 | @torch.no_grad() 209 | def validate(self): 210 | total_loss = 0 211 | cnt_true, cnt = 0, 0 212 | for step, batch in tqdm(enumerate(self.val_dataloader)): 213 | contexts, answers, false_answers = batch 214 | logits = self.model(contexts, answers).squeeze() 215 | false_logits = self.model(contexts, false_answers).squeeze() 216 | labels = torch.tensor([1] * len(contexts)).float().to(self.config.device) 217 | loss = self.loss(logits, false_logits, labels) 218 | total_loss += loss.detach().cpu().item() 219 | 220 | b_cnt_true, b_cnt = self.evaluate(logits.detach().cpu(), false_logits.detach().cpu()) 221 | cnt_true += b_cnt_true 222 | cnt += b_cnt 223 | 224 | return total_loss / (step + 1), cnt_true / cnt 225 | 226 | @torch.no_grad() 227 | def evaluate(self, logits, false_logits): 228 | cnt_true, cnt = 0, 0 229 | for pred, truth in zip(logits, false_logits): 230 | if pred > truth: 231 | cnt_true += 1 232 | cnt += 1 233 | return cnt_true, cnt 234 | 235 | @torch.no_grad() 236 | def metric(self, tp, tn, fp, fn): 237 | try: 238 | accu = (tp + tn) / (tp + tn + fp + fn) 239 | except: 240 | accu = 0 241 | try: 242 | prec = tp / (tp + fp) 243 | except: 244 | prec = 0 245 | try: 246 | reca = tp / (tp + fn) 247 | except: 248 | reca = 0 249 | try: 250 | f1 = 2 * prec * reca / (prec + reca) 251 | except: 252 | f1 = 0 253 | print(f"Accuracy: {accu}") 254 | print(f"Preision: {prec}") 255 | print(f"Recall: {reca}") 256 | print(f"F1: {f1}") 257 | return accu, prec, reca, f1 258 | 259 | class RankingMultipleNegativeTrainer(): 260 | def __init__(self, 261 | config, 262 | model, 263 | train_dataset=None, 264 | valid_dataset=None, 265 | test_dataset=None,): 266 | self.config = config 267 | 268 | # device config 269 | if self.config.use_gpu: 270 | self.config.device = torch.device(config.device) 271 | else: 272 | self.config.device = torch.device("cpu") 273 | 274 | # logger config 275 | logging.basicConfig( 276 | level=logging.INFO, 277 | format="%(asctime)s [%(levelname)s] %(message)s", 278 | handlers=[ 279 | logging.FileHandler(f"logs/{self.config.experiment_name}.log", mode="w"), 280 | logging.StreamHandler() 281 | ] 282 | ) 283 | self.logger = logging.getLogger(__name__) 284 | self.logger.info(self.config) 285 | 286 | # prepare data loader and target documents 287 | if train_dataset is None and test_dataset is None: 288 | self.logger.error("At least one dataset should be passed") 289 | raise FileNotFoundError 290 | if train_dataset is not None and valid_dataset is not None: 291 | self.train_dataloader = DataLoader( 292 | train_dataset, 293 | batch_size=config.batch_size, 294 | shuffle=True, 295 | num_workers=config.num_workers, 296 | collate_fn=collate_fn, 297 | drop_last=True, 298 | ) 299 | self.logger.info(f"Training set length: {len(train_dataset)}") 300 | self.val_dataloader = DataLoader( 301 | valid_dataset, 302 | batch_size=config.batch_size, 303 | shuffle=False, 304 | num_workers=config.num_workers, 305 | collate_fn=collate_fn, 306 | drop_last=True, 307 | ) 308 | self.logger.info(f"Validation set length: {len(valid_dataset)}") 309 | elif train_dataset is not None and valid_dataset is None: 310 | datasets = random_split(train_dataset, [int(0.9 * len(train_dataset)), len(train_dataset) - int(0.9 * len(train_dataset))]) 311 | # datasets = random_split(train_dataset, [int(0.001 * len(train_dataset)), int(0.001 * len(train_dataset)), len(train_dataset) - 2 * int(0.001 * len(train_dataset))]) 312 | train_dataset = datasets[0] 313 | valid_dataset = datasets[1] 314 | self.train_dataloader = DataLoader( 315 | train_dataset, 316 | batch_size=config.batch_size, 317 | shuffle=True, 318 | num_workers=config.num_workers, 319 | collate_fn=collate_fn, 320 | drop_last=True, 321 | ) 322 | self.logger.info(f"Training set length: {len(train_dataset)}") 323 | self.val_dataloader = DataLoader( 324 | valid_dataset, 325 | batch_size=config.batch_size, 326 | shuffle=False, 327 | num_workers=config.num_workers, 328 | collate_fn=collate_fn, 329 | drop_last=True, 330 | ) 331 | self.logger.info(f"Validation set length: {len(valid_dataset)}") 332 | if test_dataset is not None: 333 | self.test_dataloader = DataLoader( 334 | test_dataset, 335 | batch_size=1, 336 | shuffle=True, 337 | num_workers=config.num_workers, 338 | collate_fn=collate_fn, 339 | drop_last=True, 340 | ) 341 | self.logger.info(f"Test set length: {len(test_dataset)}") 342 | self.logger.info("Data init done.") 343 | 344 | # prepare model, preprocessor, loss, optimizer and scheduler 345 | self.model = model 346 | 347 | no_decay = ["bias", "LayerNorm.weight"] 348 | optimizer_grouped_parameters = [ 349 | { 350 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 351 | "weight_decay": self.config.weight_decay, 352 | }, 353 | { 354 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 355 | "weight_decay": 0.0, 356 | }, 357 | ] 358 | self.optimizer = AdamW( 359 | optimizer_grouped_parameters, 360 | lr=self.config.learning_rate, 361 | ) 362 | 363 | if config.resume_path is None: 364 | self.logger.warning("No checkpoint given!") 365 | self.best_val_loss = 10000 366 | # self.best_accu = 0 367 | self.last_epoch = -1 368 | num_training_steps = self.config.num_epochs * len(self.train_dataloader) 369 | self.scheduler = get_linear_schedule_with_warmup( 370 | optimizer=self.optimizer, 371 | num_warmup_steps=self.config.num_warmup_steps, 372 | num_training_steps=num_training_steps, 373 | last_epoch=self.last_epoch 374 | ) 375 | else: 376 | self.logger.info(f"Loading model from checkpoint: {self.config.resume_path}.") 377 | checkpoint = torch.load(self.config.resume_path, map_location=self.config.device) 378 | self.model.to(self.config.device) 379 | self.model.load_state_dict(checkpoint["model"], strict=False) 380 | # self.best_val_loss = checkpoint["best_val_loss"] 381 | self.best_val_loss = 10000 382 | # self.best_accu = checkpoint["best_accu"] 383 | self.last_epoch = checkpoint["last_epoch"] 384 | num_training_steps = self.config.num_epochs * len(self.train_dataloader) 385 | # self.optimizer.load_state_dict(checkpoint["optimizier"]) 386 | self.scheduler = get_linear_schedule_with_warmup( 387 | optimizer=self.optimizer, 388 | num_warmup_steps=self.config.num_warmup_steps, 389 | num_training_steps=num_training_steps, 390 | last_epoch=-1 391 | ) 392 | # self.scheduler.load_state_dict(checkpoint["scheduler"]) 393 | del(checkpoint) 394 | self.model.to(self.config.device) 395 | 396 | self.logger.info("Trainer init done.") 397 | 398 | def train_one_epoch(self): 399 | total_loss = 0 400 | num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / self.config.gradient_accumulation_steps) 401 | progress_bar = tqdm(range(num_update_steps_per_epoch)) 402 | for step, batch in enumerate(self.train_dataloader): 403 | contexts, answers = batch 404 | answers_0 = [answer["0"] for answer in answers] # positive 405 | answers_1 = [answer["1"] for answer in answers] # negative 1, set margin to 0.3 406 | answers_2 = [answer["2"] for answer in answers] # negative 2, set margin to 0.6 407 | answers_3 = [answer["3"] for answer in answers] # negative 3, set margin to 0.9 408 | 409 | logits_0 = self.model(contexts, answers_0) 410 | logits_1 = self.model(contexts, answers_1) 411 | logits_2 = self.model(contexts, answers_2) 412 | logits_3 = self.model(contexts, answers_3) 413 | logits = torch.cat([logits_0, logits_1, logits_2, logits_3], dim=1) 414 | 415 | labels = torch.tensor([1] * len(answers_0)).float().to(self.config.device) 416 | labels = labels.unsqueeze(1) 417 | loss = nn.MarginRankingLoss(0.3)(logits_0, logits_1, labels) + nn.MarginRankingLoss(0.6)(logits_0, logits_2, labels) + nn.MarginRankingLoss(0.9)(logits_0, logits_3, labels) 418 | 419 | total_loss += loss.detach().cpu().item() 420 | loss.backward() 421 | if (step != 0 and step % self.config.gradient_accumulation_steps == 0) or step == len(self.train_dataloader) - 1: 422 | self.optimizer.step() 423 | self.scheduler.step() 424 | self.optimizer.zero_grad() 425 | progress_bar.set_postfix( 426 | { 427 | 'loss': total_loss / (step + 1), 428 | } 429 | ) 430 | progress_bar.update(1) 431 | # if self.config.margin_increase_step > 0 and step != 0 and step % self.config.margin_increase_step == 0 and self.margin < self.config.max_margin: 432 | # self.margin += 0.1 433 | # self.loss = nn.MarginRankingLoss(self.margin) 434 | return total_loss / (step + 1) 435 | 436 | def train(self): 437 | for epoch in range(self.config.num_epochs): 438 | self.logger.info("========================") 439 | self.logger.info("Training...") 440 | epoch_loss = self.train_one_epoch() 441 | self.logger.info(f"Epoch {epoch} training loss: {epoch_loss}") 442 | self.logger.info("Validating...") 443 | val_loss = self.validate() # , val_accu 444 | self.logger.info(f"Epoch {epoch} validation loss: {val_loss}") 445 | # self.logger.info(f"Epoch {epoch} validation accuracy: {val_accu}") 446 | if val_loss < self.best_val_loss: 447 | # if val_accu > self.best_accu: 448 | self.best_val_loss = val_loss 449 | # self.best_accu = val_accu 450 | checkpoint = { 451 | "model": self.model.state_dict(), 452 | "optimizier": self.optimizer.state_dict(), 453 | "scheduler": self.scheduler.state_dict(), 454 | "last_epoch": self.last_epoch + epoch + 1, 455 | "best_val_loss": self.best_val_loss, 456 | # "best_accu": self.best_accu, 457 | "config": self.config, 458 | } 459 | save_path = os.path.join(self.config.save_dir, f"{self.config.experiment_name}.pt") 460 | self.logger.info(f"Saving best checkpoint to {save_path}") 461 | torch.save(checkpoint, save_path) 462 | 463 | @torch.no_grad() 464 | def validate(self): 465 | total_loss = 0 466 | # cnt_true, cnt = 0, 0 467 | for step, batch in tqdm(enumerate(self.val_dataloader)): 468 | contexts, answers = batch 469 | answers_0 = [answer["0"] for answer in answers] # positive 470 | answers_1 = [answer["1"] for answer in answers] # negative 1, set margin to 0.3 471 | answers_2 = [answer["2"] for answer in answers] # negative 2, set margin to 0.6 472 | answers_3 = [answer["3"] for answer in answers] # negative 3, set margin to 0.9 473 | 474 | logits_0 = self.model(contexts, answers_0) 475 | logits_1 = self.model(contexts, answers_1) 476 | logits_2 = self.model(contexts, answers_2) 477 | logits_3 = self.model(contexts, answers_3) 478 | 479 | labels = torch.tensor([1] * len(answers_0)).float().to(self.config.device) 480 | labels = labels.unsqueeze(1) 481 | loss = nn.MarginRankingLoss(0.3)(logits_0, logits_1, labels) + nn.MarginRankingLoss(0.6)(logits_0, logits_2, labels) + nn.MarginRankingLoss(0.9)(logits_0, logits_3, labels) 482 | 483 | total_loss += loss.detach().cpu().item() 484 | 485 | return total_loss / (step + 1) # , cnt_true / cnt 486 | 487 | @torch.no_grad() 488 | def evaluate(self, logits, false_logits): 489 | cnt_true, cnt = 0, 0 490 | for pred, truth in zip(logits, false_logits): 491 | if pred > truth: 492 | cnt_true += 1 493 | cnt += 1 494 | return cnt_true, cnt 495 | 496 | @torch.no_grad() 497 | def metric(self, tp, tn, fp, fn): 498 | try: 499 | accu = (tp + tn) / (tp + tn + fp + fn) 500 | except: 501 | accu = 0 502 | try: 503 | prec = tp / (tp + fp) 504 | except: 505 | prec = 0 506 | try: 507 | reca = tp / (tp + fn) 508 | except: 509 | reca = 0 510 | try: 511 | f1 = 2 * prec * reca / (prec + reca) 512 | except: 513 | f1 = 0 514 | print(f"Accuracy: {accu}") 515 | print(f"Preision: {prec}") 516 | print(f"Recall: {reca}") 517 | print(f"F1: {f1}") 518 | return accu, prec, reca, f1 --------------------------------------------------------------------------------