├── .gitattributes
├── .gitignore
├── README.md
├── data
├── stage1
│ └── train.json
└── stage2
│ ├── arithmetic
│ ├── train.split.1.json
│ └── train.split.2.json
│ └── commonsense
│ └── train.json
├── prompts
├── __init__.py
├── coin.py
├── csqa.py
├── gsm8k.py
└── strategyqa.py
├── requirements.txt
└── src
├── data_utils.py
├── decoder.py
├── generate_data.py
├── model.py
├── openai_api_mp.py
├── parser_utils.py
├── test.py
├── train.py
└── trainer.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.json filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Deductive Beam Search
Decoding Deducible Rationale for Chain-of-Thought Reasoning
2 |
3 | 
4 |
5 |
6 | Source Code of paper [`Deductive Beam Search: Decoding Deducible Rationale for Chain-of-Thought Reasoning`](https://arxiv.org/abs/2401.17686).
7 |
8 | # Quick Start
9 | Clone repo:
10 | ```bash
11 | git clone https://github.com/OSU-NLP-Group/Deductive-Beam-Search.git
12 | cd Deductive-Beam-Search
13 | mkdir outputs
14 | ```
15 |
16 | Install requirements:
17 | ```bash
18 | conda create --name dbs python=3.10
19 | conda activate dbs
20 | pip install -r requirements.txt
21 | ```
22 |
23 | Training:
24 |
25 | - Stage 1:
26 |
27 | ```bash
28 | python src/train.py \
29 | --train_datapath your_data_path_stage1 \
30 | --experiment_name your_experiment_name_stage1 \
31 | --batch_size 8 --gradient_accumulation_steps 16 \
32 | --learning_rate 1e-5
33 | ```
34 |
35 | - Stage 2:
36 |
37 | ```bash
38 | python src/train.py \
39 | --train_datapath your_data_path_stage2 \
40 | --experiment_name your_experiment_name_stage2 \
41 | --batch_size 8 --gradient_accumulation_steps 16 \
42 | --learning_rate 1e-7
43 | ```
44 |
45 | Inference:
46 | ```bash
47 | DECODER_PATH=your_model_path
48 | DATASETS=("gsm8k" "svamp") # add datasets you want to test
49 |
50 | for TEST_DATASET in ${DATASETS[@]};
51 | do
52 | python src/test.py \
53 | --test_dataset $TEST_DATASET \
54 | --output_dir outputs/$TEST_DATASET \
55 | --decoder_path $DECODER_PATH \
56 | --decode_strategy beam_search \
57 | --num_beams 5 \
58 | --num_sampling 10 \
59 | --num_gpus_decode 4 # how many gpus for inference
60 | done
61 | ```
62 |
63 | # Data
64 |
65 | ## Training Data
66 |
67 | All training data synthesized is in `data/` folder. The training file `data/stage1/train.json` is used for training a general deductive verifier. For stage 2, the arithmetic and symbolic verifier is trained on `data/stage2/arithmetic/train.split.*.json`, and the commonsense verifier is trained on `data/stage2/commonsense/train.json`.
68 |
69 | ## Checkpoints
70 |
71 | We provide the checkpoint of a general deductive verifier, please download from this [link](https://drive.google.com/drive/folders/1GbnAiX160Cz63zAbr2FAgB0QFySfM2Vn?usp=sharing).
72 | You can use it to continue-train on our data or train on your own data.
73 |
74 | ## Data Generation
75 | The complete process of data construction is in `src/generate_data.py`.
76 | If you want to generate data on your own, please modify the data loading part.
77 | After modifying, you can run `python src/generate_data.py` to generate data for your own domains.
78 |
79 | **\[TODO\]** We will improve the code for easier modification and usage.
80 |
81 | # Contact
82 |
83 | If you have any problems, please contact
84 | [Tinghui Zhu](mailto:darthzhu@gmail.com) and
85 | [Kai Zhang](mailto:zhang.13253@osu.edu).
86 |
87 | # Citation Information
88 |
89 | If you find our codes and data useful, please consider citing our paper:
90 |
91 | ```
92 | @article{zhu2024deductive,
93 | title={Deductive Beam Search: Decoding Deducible Rationale for Chain-of-Thought Reasoning},
94 | author={Zhu, Tinghui and Zhang, Kai and Xie, Jian and Su, Yu},
95 | journal={arXiv preprint arXiv:2401.17686},
96 | year={2024}
97 | }
98 | ```
--------------------------------------------------------------------------------
/data/stage1/train.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:48c570fc3f18ee95499fa664d97fe7fbbca5ace9cf9c0383aa4e7c49eeb854f8
3 | size 16897210
4 |
--------------------------------------------------------------------------------
/data/stage2/arithmetic/train.split.1.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:458035c7327280eaa66536d8195507f4dac600735b41cf809f0c9bb93e1491db
3 | size 69144494
4 |
--------------------------------------------------------------------------------
/data/stage2/arithmetic/train.split.2.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:458035c7327280eaa66536d8195507f4dac600735b41cf809f0c9bb93e1491db
3 | size 69144494
4 |
--------------------------------------------------------------------------------
/data/stage2/commonsense/train.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:b42f1ca5739bf7517310776aeadc8335e2813631be5a676c7db3375cfa449888
3 | size 3443251
4 |
--------------------------------------------------------------------------------
/prompts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSU-NLP-Group/Deductive-Beam-Search/67ed155699cf3ae966d44793f92fbf10379bceca/prompts/__init__.py
--------------------------------------------------------------------------------
/prompts/coin.py:
--------------------------------------------------------------------------------
1 | prompt = """Given the question, output the rationale step by step and give the final answer.
2 |
3 | Example 1
4 | Question:
5 | A coin is heads up. sager does not flip the coin. zyheir flips the coin. Is the coin still heads up?
6 | Answer:
7 | sager does not flip the coin, so the coin is heads up.
8 | zyheir flips the coins, so the coin is tails up.
9 | Final Answer: no
10 |
11 | Example 2
12 | Question:
13 | A coin is heads up. mailey does not flip the coin. maurisa does not flip the coin. Is the coin still heads up?
14 | Answer:
15 | mailye does not flip the coin, so the coin is heads up.
16 | maurisa does not flip the coin, so the coin is heads up.
17 | Final Answer: yes
18 | """
--------------------------------------------------------------------------------
/prompts/csqa.py:
--------------------------------------------------------------------------------
1 | prompt = """Given the question, output the rationale step by step and give the final answer. You should choose the best answer.
2 |
3 | Example 1
4 | Question:
5 | Sammy wanted to go to where the people were. Where might he go?
6 | A. race track
7 | B. populated area
8 | C. the desert
9 | D. apartment
10 | E. roadblock
11 | Answer:
12 | Sammy wanted to go to places with many people.
13 | Race track and apartment do not have many people.
14 | The desert and roadblock have few people.
15 | And, the populated area means that it is the place with many people.
16 | Thus, Sammy should go to populated area.
17 | Final Answer: B
18 |
19 | Example 2
20 | Question:
21 | The fox walked from the city into the forest, what was it looking for?
22 | A. pretty flowers
23 | B. hen house
24 | C. natural habitat
25 | D. storybook
26 | E. dense forest
27 | Answer:
28 | The forest does not have hen house or storybook.
29 | The fox is a carnivore that does not look for flowers and forest.
30 | The forest is a natural habitat for foxes.
31 | Thus, it was looking for a natural habitat.
32 | Final Answer: C
33 | """
34 |
35 | recall_prompt = """Given the question, output the rationale step by step and give the final answer. You should choose the best answer.
36 |
37 | Example 1
38 | Question:
39 | Sammy wanted to go to where the people were. Where might he go?
40 | A. race track
41 | B. populated area
42 | C. the desert
43 | D. apartment
44 | E. roadblock
45 | Answer:
46 | Fact:
47 | Sammy wanted to go to places with many people.
48 | Race track and apartment do not have many people.
49 | The desert and roadblock have few people.
50 | And, the populated area means that it is the place with many people.
51 | Reasoning:
52 | Thus, Sammy should go to populated area.
53 | Final Answer: B
54 |
55 | Example 2
56 | Question:
57 | The fox walked from the city into the forest, what was it looking for?
58 | A. pretty flowers
59 | B. hen house
60 | C. natural habitat
61 | D. storybook
62 | E. dense forest
63 | Answer:
64 | Fact:
65 | The forest does not have hen house or storybook.
66 | The fox is a carnivore that does not look for flowers and forest.
67 | The forest is a natural habitat for foxes.
68 | Reasoning:
69 | Thus, it was looking for a natural habitat.
70 | Final Answer: C
71 | """
--------------------------------------------------------------------------------
/prompts/gsm8k.py:
--------------------------------------------------------------------------------
1 | prompt = """You are a good math solver. Based on the question, please give the rationales step by step and give a final answer.
2 |
3 | Example 1:
4 | Question:
5 | Kate's hair is half as long as Emily's hair. Emily's hair is 6 inches longer than Logan's hair. If Logan hair is 20 inches, how many inches is Kate's hair?
6 | Answer:
7 | Emily's hair is 20-6 = 14 inches long.
8 | Kate's hair 14/2= 7 inches long.
9 | Final Answer:7
10 |
11 | Example 2:
12 | Question:
13 | John puts $25 in his piggy bank every month for 2 years to save up for a vacation. He had to spend $400 from his piggy bank savings last week to repair his car. How many dollars are left in his piggy bank?
14 | Answer:
15 | He saved money for 2 years, which is equal to 12 x 2 = 24 months.
16 | The amount of money he saved is $25*24 = $600.
17 | But he spent some money so there is $600 - $400 = 200 left.
18 | Final Answer:200
19 | """
20 |
21 | """
22 | Example 3:
23 | Question:
24 | After complaints from the residents of Tatoosh about the number of cats on the island, the wildlife service carried out a relocation mission that saw the number of cats on the island drastically reduced. On the first relocation mission, 600 cats were relocated from the island to a neighboring island. On the second mission, half of the remaining cats were relocated to a rescue center inland. If the number of cats originally on the island was 1800, how many cats remained on the island after the rescue mission?
25 | Answer:
26 | After the first mission, the number of cats remaining on the island was 1800-600 = <<1800-600=1200>>1200.
27 | If half of the remaining cats on the island were relocated to a rescue center inland, the number of cats taken by the wildlife service on the second mission is 1200/2 = <<1200/2=600>>600 cats.
28 | The number of cats remaining on the island is 1200-600 = <<1200-600=600>>600
29 | Final Answer:600
30 |
31 | Example 4:
32 | Question:
33 | Paul, Amoura, and Ingrid were to go to a friend's party planned to start at 8:00 a.m. Paul arrived at 8:25. Amoura arrived 30 minutes later than Paul, and Ingrid was three times later than Amoura. How late, in minutes, was Ingrid to the party?
34 | Answer:
35 | If the party was planned to start at 8:00 am, Paul was 8:25-8:00 = <<825-800=25>>25 minutes late.
36 | If Paul was 25 minutes late, Amoura was 25+30 = <<25+30=55>>55 minutes late.
37 | If Ingrid was three times late than Amoura was, she was 3*55 = <<3*55=165>>165 minutes late
38 | Final Answer:165
39 |
40 | Example 5:
41 | Question:
42 | Alicia has to buy some books for the new school year. She buys 2 math books, 3 art books, and 6 science books, for a total of $30. If both the math and science books cost $3 each, what was the cost of each art book?
43 | Answer:
44 | The total cost of maths books is 2*3 = <<2*3=6>>6 dollars
45 | The total cost of science books is 6*3 = <<6*3=18>>18 dollars
46 | The total cost for maths and science books is 6+18 = <<6+18=24>>24 dollars
47 | The cost for art books is 30-24 = <<30-24=6>>6 dollars.
48 | Since he bought 3 art books, the cost for each art book will be 6/3 = 2 dollars
49 | Final Answer:2
50 | """
--------------------------------------------------------------------------------
/prompts/strategyqa.py:
--------------------------------------------------------------------------------
1 | prompt = """Given the question, output the rationale step by step and give the final answer (yes or no).
2 |
3 | Example 1
4 | Question:
5 | Do hamsters provide food for any animals?
6 | Answer:
7 | Hamsters are prey animals.
8 | Prey are food for predators.
9 | Final answer: yes
10 |
11 | Example 2
12 | Question:
13 | Could a llama birth twice during War in Vietnam (1945-46)?
14 | Answer:
15 | The War in Vietnam was 6 months.
16 | The gestation period for a llama is 11 months, which is more than 6 months.
17 | Final answer: no
18 | """
19 |
20 | recall_prompt = """Given the question, output the rationale step by step and give the final answer (yes or no).
21 |
22 | Example 1
23 | Question:
24 | Do hamsters provide food for any animals?
25 | Answer:
26 | Fact:
27 | Hamsters are prey animals.
28 | Prey are food for predators.
29 | Reasoning:
30 | Hamsters are food for some predators.
31 | Final answer: yes
32 |
33 | Example 2
34 | Question:
35 | Could a llama birth twice during War in Vietnam (1945-46)?
36 | Answer:
37 | Fact:
38 | The War in Vietnam was 6 months.
39 | The gestation period for a llama is 11 months, which is more than 6 months.
40 | Reasoning:
41 | A llama could not birth twice during War in Vietnam.
42 | Final answer: no
43 | """
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | datasets==2.13.1
3 | multiprocess==0.70.14
4 | openai==0.28.0
5 | tiktoken==0.4.0
6 | tokenizers==0.13.3
7 | torch==2.0.1
8 | torchaudio==2.0.2
9 | torchvision==0.15.2
10 | transformers==4.33.2
11 | vllm==0.1.4
--------------------------------------------------------------------------------
/src/data_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from torch.utils.data import Dataset
3 |
4 | def collate_fn(batch):
5 | return tuple(zip(*batch))
6 |
7 | class GSM8KDataset(Dataset):
8 | def __init__(self, datapath, mode="train") -> None:
9 | if mode == "train":
10 | with open(datapath, "r") as fin:
11 | self.datas = json.load(fin)
12 |
13 | def __len__(self):
14 | return len(self.datas)
15 |
16 | def __getitem__(self, index):
17 | """
18 | {
19 | "context": ,
20 | "answer": ,
21 | "label": 1 for entail; 0 for neutral
22 | }
23 | """
24 | data = self.datas[index]
25 | return data["context"], data["answer"], data["label"]
26 |
27 | class GSM8KRankingDataset(Dataset):
28 | def __init__(self, datapath, mode="train") -> None:
29 | if mode == "train":
30 | with open(datapath, "r") as fin:
31 | self.datas = json.load(fin)
32 |
33 | def __len__(self):
34 | return len(self.datas)
35 |
36 | def __getitem__(self, index):
37 | """
38 | {
39 | "context": ,
40 | "answer": ,
41 | "false_answer": 1 for entail; 0 for neutral
42 | }
43 | """
44 | data = self.datas[index]
45 | return data["context"], data["answer"], data["false_answer"]
46 |
47 | class GSM8KRankingMultiNegativeDataset(Dataset):
48 | def __init__(self, datapath, mode="train") -> None:
49 | if mode == "train":
50 | with open(datapath, "r") as fin:
51 | self.datas = json.load(fin)
52 |
53 | def __len__(self):
54 | return len(self.datas)
55 |
56 | def __getitem__(self, index):
57 | """
58 | {
59 | "context": ,
60 | "answers": , "0", "1", "2", "3"
61 | }
62 | """
63 | data = self.datas[index]
64 | return data["context"], data["answers"]
--------------------------------------------------------------------------------
/src/decoder.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
5 | from vllm import LLM, SamplingParams
6 | from collections import Counter
7 |
8 | from src.model import Ranker
9 | from src.openai_api_mp import ChatGPTPool
10 |
11 | class Decoder():
12 | def __init__(self, config) -> None:
13 | self.device = config.device
14 | self.prompt = config.prompt
15 | self.num_sampling = config.num_sampling
16 | self.num_beams = config.num_beams
17 |
18 | if config.decode_strategy == "rank":
19 | self.ranker = Ranker(config)
20 | # self.ranker.from_pretrained(config.resume_path)
21 | checkpoint = torch.load(config.resume_path, map_location="cpu")
22 | checkpoint = checkpoint["model"]
23 | self.ranker.load_state_dict(checkpoint)
24 | self.ranker.to(self.device)
25 | self.decode_fn = self.decode_with_ranking
26 | self.stop = "\n"
27 | elif config.decode_strategy == "beam_search":
28 | self.ranker = Ranker(config)
29 | # self.ranker.from_pretrained(config.resume_path)
30 | checkpoint = torch.load(config.resume_path, map_location="cpu")
31 | checkpoint = checkpoint["model"]
32 | self.ranker.load_state_dict(checkpoint)
33 | self.ranker.to(self.device)
34 | self.decode_fn = self.beam_search
35 | self.stop = "\n"
36 | elif config.decode_strategy == "beam_search_with_fact":
37 | self.ranker = Ranker(config)
38 | # self.ranker.from_pretrained(config.resume_path)
39 | checkpoint = torch.load(config.resume_path, map_location="cpu")
40 | checkpoint = checkpoint["model"]
41 | self.ranker.load_state_dict(checkpoint)
42 | self.ranker.to(self.device)
43 | self.decode_fn = self.beam_search_with_fact
44 | self.stop = "\n"
45 | else:
46 | self.decode_fn = self.decode_direct
47 | self.stop = None
48 |
49 | if "gpt" in config.decoder_path:
50 | self.model = ChatGPTPool(
51 | num_sampling = config.num_sampling,
52 | temperature = 1.0,
53 | max_tokens = 256,
54 | stop = self.stop,
55 | )
56 | else:
57 | if config.is_greedy:
58 | self.sampling_params = SamplingParams(n=1, temperature=0, top_p=1, max_tokens=512)
59 | else:
60 | self.sampling_params = SamplingParams(n=self.num_sampling, temperature=1, top_p=0.95, max_tokens=512, stop=self.stop)
61 | self.model = LLM(model=config.decoder_path, seed=42, max_num_batched_tokens=4096, tensor_parallel_size=config.num_gpus_decode, gpu_memory_utilization=0.85) # , dtype="float32" , tensor_parallel_size=2
62 |
63 | # TODO: batch inference
64 | def decode(self, question):
65 | return self.decode_fn(question)
66 |
67 | @torch.inference_mode()
68 | def decode_with_ranking(self, question):
69 | raise NotImplementedError
70 |
71 | @torch.inference_mode()
72 | def decode_direct(self, question):
73 | raise NotImplementedError
74 |
75 | @torch.inference_mode()
76 | def beam_search(self, question):
77 | raise NotImplementedError
78 |
79 | @torch.inference_mode()
80 | def beam_search_with_fact(self, question):
81 | raise NotImplementedError
82 |
83 | class FastDecoder(Decoder):
84 | def __init__(self, config) -> None:
85 | super().__init__(config)
86 |
87 | @torch.inference_mode()
88 | def decode_with_ranking(self, question):
89 | answers = []
90 | context = "\nQuestion:\n" + question + "\nAnswer:\n"
91 | terminated = False
92 | step = 0
93 | rationales = []
94 | while not terminated and step < 20:
95 | outputs = self.model.generate(self.prompt + context, self.sampling_params, use_tqdm=False)[0].outputs
96 | answers = []
97 | for output in outputs:
98 | answers.append(output.text.strip().split("\n")[0].strip())
99 | contexts = [context] * len(answers)
100 | scores = self.ranker(contexts, answers).cpu().squeeze()
101 | sorted, indices = torch.sort(scores, descending=True)
102 | rationale = answers[indices[0]]
103 | rationales.append(rationale)
104 | context += f"{rationale}\n"
105 | if "Final Answer" in rationale:
106 | terminated = True
107 | step += 1
108 | return rationales
109 |
110 | @torch.inference_mode()
111 | def decode_direct(self, question):
112 | context = "\nQuestion:\n" + question + "\nAnswer:\n"
113 | outputs = self.model.generate(self.prompt + context, self.sampling_params, use_tqdm=False)[0].outputs
114 | answers = []
115 | for output in outputs:
116 | answers.append(output.text.strip())
117 | return answers
118 |
119 | @torch.inference_mode()
120 | def beam_search(self, question):
121 | self.sampling_params = SamplingParams(n=self.num_beams, temperature=1, top_p=0.95, max_tokens=512)
122 | answers = []
123 | context = "\nQuestion:\n" + question + "\nAnswer:\n"
124 | global_terminated = False
125 | step = 0
126 | beams = [(context, 1, False)] * self.num_beams # (current rationale, score, terminated)
127 | completed_rationales = []
128 | while not global_terminated and step < 20:
129 | current_beams = []
130 | for beam, score, terminated in beams:
131 | # if terminated, leave it alone
132 | if terminated:
133 | current_beams.append((beam, score, terminated))
134 | continue
135 |
136 | # otherwise, generate next rationale
137 | outputs = self.model.generate(self.prompt + beam, self.sampling_params, use_tqdm=False)[0].outputs
138 | answers = []
139 | for output in outputs:
140 | answers.append(output.text.strip().split("\n")[0].strip())
141 | contexts_for_ranker = [beam] * len(answers)
142 | scores = self.ranker(contexts_for_ranker, answers).cpu().squeeze()
143 | sorted_scores, indices = torch.sort(scores, descending=True)
144 |
145 | # calculate current score
146 | for _ in range(self.num_beams):
147 | current_beam = beam + answers[indices[_]] + "\n"
148 | current_score = score * scores[indices[_]]
149 | # if termintated, add to completed rationales
150 | if "Final Answer" in answers[indices[_]] or "Final answer" in answers[indices[_]]:
151 | terminated = True
152 | completed_rationales.append((current_beam, current_score.item()))
153 | current_beams.append((current_beam, current_score.item(), terminated))
154 | sorted_beams = sorted(current_beams, key=lambda x: x[1], reverse=True)
155 | beams = sorted_beams[:self.num_beams]
156 | flag = True
157 | for _ , _, terminated in beams:
158 | if not terminated:
159 | flag = False
160 | break
161 | global_terminated = flag
162 | step += 1
163 |
164 | return beams, completed_rationales
165 |
166 | @torch.inference_mode()
167 | def beam_search_with_fact(self, question):
168 | answers = []
169 | context = "\nQuestion:\n" + question + "\nAnswer:\n"
170 | if_verify = False
171 | step = 0
172 | while not if_verify and step < 10:
173 | fact_sampling_params = SamplingParams(n=1, temperature=0, top_p=1, max_tokens=512, stop="\n")
174 | outputs = self.model.generate(self.prompt + context, fact_sampling_params, use_tqdm=False)[0].outputs
175 | context += outputs[0].text
176 | context += "\n"
177 | if "Reasoning" in outputs[0].text:
178 | if_verify = True
179 | step += 1
180 |
181 | global_terminated = False
182 | beams = [(context, 1, False)] * self.num_beams # (current rationale, score, terminated)
183 | completed_rationales = []
184 |
185 | while not global_terminated and step < 10:
186 | current_beams = []
187 | for beam, score, terminated in beams:
188 | # if terminated, leave it alone
189 | if terminated:
190 | current_beams.append((beam, score, terminated))
191 | continue
192 |
193 | # otherwise, generate next rationale
194 | outputs = self.model.generate(self.prompt + beam, self.sampling_params, use_tqdm=False)[0].outputs
195 | answers = []
196 | for output in outputs:
197 | answers.append(output.text.strip().split("\n")[0].strip())
198 | contexts_for_ranker = [beam] * len(answers)
199 | scores = self.ranker(contexts_for_ranker, answers).cpu().squeeze()
200 | sorted_scores, indices = torch.sort(scores, descending=True)
201 |
202 | # calculate current score
203 | for _ in range(self.num_beams):
204 | current_beam = beam + answers[indices[_]] + "\n"
205 | current_score = score * scores[indices[_]]
206 | # if termintated, add to completed rationales
207 | if "Final Answer" in answers[indices[_]] or "Final answer" in answers[indices[_]]:
208 | terminated = True
209 | completed_rationales.append((current_beam, current_score.item()))
210 | current_beams.append((current_beam, current_score.item(), terminated))
211 |
212 | sorted_beams = sorted(current_beams, key=lambda x: x[1], reverse=True)
213 | beams = sorted_beams[:self.num_beams]
214 | flag = True
215 | for _ , _, terminated in beams:
216 | if not terminated:
217 | flag = False
218 | break
219 | global_terminated = flag
220 | step += 1
221 |
222 | return beams, completed_rationales
223 |
224 | class ChatGPTDecoder(Decoder):
225 | def __init__(self, config) -> None:
226 | super().__init__(config)
227 |
228 | @torch.inference_mode()
229 | def decode_direct(self, question):
230 | context = "\nQuestion:\n" + question + "\nAnswer:\n"
231 | outputs = self.model.chat_single_round(self.prompt + context)
232 | answers = []
233 | for output in outputs:
234 | answers.append(output.strip())
235 | return answers
236 |
237 | @torch.inference_mode()
238 | def beam_search(self, question):
239 | answers = []
240 | context = "\nQuestion:\n" + question + "\nAnswer:\n"
241 | global_terminated = False
242 | step = 0
243 | beams = [(context, 1, False)] * self.num_beams # (current rationale, score, terminated)
244 | completed_rationales = []
245 | while not global_terminated and step < 50:
246 | current_beams = []
247 | for beam, score, terminated in beams:
248 | # if terminated, leave it alone
249 | if terminated:
250 | current_beams.append((beam, score, terminated))
251 | continue
252 |
253 | # otherwise, generate next rationale
254 | outputs = self.model.chat_single_round(self.prompt + beam)
255 | answers = []
256 | for output in outputs:
257 | answer = output.strip().split("\n")[0].strip()
258 | if len(answer) < 1:
259 | continue
260 | answers.append(answer)
261 | contexts_for_ranker = [beam] * len(answers)
262 | scores = self.ranker(contexts_for_ranker, answers).cpu().squeeze()
263 | sorted_scores, indices = torch.sort(scores, descending=True)
264 |
265 | # calculate current score
266 | for _ in range(min(len(answers), self.num_beams)):
267 | current_beam = beam + answers[indices[_]] + "\n"
268 | current_score = score * scores[indices[_]]
269 | # if termintated, add to completed rationales
270 | if "Final Answer" in answers[indices[_]] or "Final answer" in answers[indices[_]]:
271 | terminated = True
272 | completed_rationales.append((current_beam, current_score.item()))
273 | current_beams.append((current_beam, current_score.item(), terminated))
274 | sorted_beams = sorted(current_beams, key=lambda x: x[1], reverse=True)
275 | beams = sorted_beams[:self.num_beams]
276 | flag = True
277 | for _ , _, terminated in beams:
278 | if not terminated:
279 | flag = False
280 | break
281 | global_terminated = flag
282 | step += 1
283 | return (beams, completed_rationales)
284 |
285 | @torch.inference_mode()
286 | def beam_search_with_fact(self, question):
287 | answers = []
288 | context = "\nQuestion:\n" + question + "\nAnswer:\n"
289 | if_verify = False
290 | step = 0
291 | self.model.num_sampling = 1
292 | while not if_verify and step < 10:
293 | self.model.num_sampling = 1
294 | outputs = self.model.chat_single_round(self.prompt + context)
295 | context += outputs[0]
296 | context += "\n"
297 | if "Reasoning" in outputs[0]:
298 | if_verify = True
299 |
300 | self.model.num_sampling = self.num_sampling
301 | global_terminated = False
302 | beams = [(context, 1, False)] * self.num_beams # (current rationale, score, terminated)
303 | completed_rationales = []
304 | while not global_terminated and step < 20:
305 | current_beams = []
306 | for beam, score, terminated in beams:
307 | # if terminated, leave it alone
308 | if terminated:
309 | current_beams.append((beam, score, terminated))
310 | continue
311 |
312 | # otherwise, generate next rationale
313 | outputs = self.model.chat_single_round(self.prompt + beam)
314 | answers = []
315 | for output in outputs:
316 | answers.append(output.strip().split("\n")[0].strip())
317 | contexts_for_ranker = [beam] * len(answers)
318 | scores = self.ranker(contexts_for_ranker, answers).cpu().squeeze()
319 | sorted_scores, indices = torch.sort(torch.tensor(scores), descending=True)
320 |
321 | # calculate current score
322 | for _ in range(self.num_beams):
323 | current_beam = beam + answers[indices[_]] + "\n"
324 | current_score = score * scores[indices[_]]
325 | # if termintated, add to completed rationales
326 | if "Final Answer" in answers[indices[_]] or "Final answer" in answers[indices[_]]:
327 | terminated = True
328 | completed_rationales.append((current_beam, current_score.item()))
329 | current_beams.append((current_beam, current_score.item(), terminated))
330 | sorted_beams = sorted(current_beams, key=lambda x: x[1], reverse=True)
331 | beams = sorted_beams[:self.num_beams]
332 | flag = True
333 | for _ , _, terminated in beams:
334 | if not terminated:
335 | flag = False
336 | break
337 | global_terminated = flag
338 | step += 1
339 | return (beams, completed_rationales)
--------------------------------------------------------------------------------
/src/generate_data.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import torch
4 | import random
5 | import argparse
6 | from tqdm import tqdm
7 | from datasets import load_dataset
8 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, GenerationConfig
9 | from vllm import LLM, SamplingParams
10 |
11 | from src.decoder import Decoder
12 | from src.parser_utils import get_parser
13 | from prompts.gsm8k import prompt
14 | from src.model import Ranker
15 |
16 | def generate_correct_rationales():
17 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
18 | model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to(torch.device("cuda:0"))
19 | datasets = load_dataset("gsm8k", "main")
20 | dataset = datasets["train"]
21 | pattern = r"<<(.*?)>>"
22 | datas = []
23 | prompt = "Please rephrase the following sentence and keep the equation between <<>> unchanged. For exaxmple:\n{example}"
24 | # model = ChatGPTAPI(prompt)
25 | count_done = 0
26 | with open("data/gsm8k/gsm8k.txt", "r", encoding="utf-8") as fin:
27 | for _ in fin.readlines():
28 | count_done += 1
29 | sid = 0
30 | for sample in tqdm(dataset):
31 | if sid < count_done:
32 | sid += 1
33 | continue
34 | # model.reset()
35 | question = sample["question"]
36 | gold_answer = sample["answer"]
37 | final_answer = gold_answer.split("####")[-1].strip()
38 | # equations = re.findall(pattern, gold_answer)
39 | rationales = gold_answer.split("####")[0].split("\n")[:-1]
40 | diverse_rationales = []
41 | for r in rationales:
42 | model.reset()
43 | # print(r)
44 | rs = [r]
45 | state = f"{prompt}\n{r}"
46 | for _ in range(4):
47 | answer = model.chat_multi_round(state)
48 | # print(answer)
49 | rs.append(answer)
50 | state = None
51 | diverse_rationales.append(rs)
52 | write_data = {
53 | "question": question,
54 | "rationales": diverse_rationales,
55 | "chain_length": len(rationales),
56 | "final_answer": final_answer,
57 | }
58 | with open("data/gsm8k/gsm8k.txt", "a+", encoding="utf-8") as fout:
59 | fout.write(f"{json.dumps(write_data)}\n")
60 |
61 | def confuse(rationales):
62 | """
63 | Generate false samples for given rationales
64 | """
65 |
66 | pattern = r"([\$A-Za-z0-9\%\.]+\s*[\+\-\*\/x])+\s*[\$A-Za-z0-9\%\.]+\s*(=\s*[\$A-Za-z0-9\%\.]+)"
67 | number_set = set()
68 | for rs in rationales:
69 | rationale = rs[0]
70 | if rationale.endswith("."):
71 | rationale = rationale[:-1]
72 | equation = re.search(pattern, rationale)
73 | if equation is None:
74 | continue
75 | equation = equation.group()
76 | eq = equation.split("=")[0].strip()
77 | elements = re.split(r"[\+\-\/\*x]", eq)
78 | for e in elements:
79 | number_set.add(e.strip())
80 | number_set.add(equation.split("=")[-1].strip())
81 | number_set = list(number_set)
82 | print(number_set)
83 | false_rationales = {}
84 | for rid, rs in enumerate(rationales):
85 | f_rs = []
86 | for r in rs:
87 | print(r)
88 | if r.endswith("."):
89 | r = r[:-1]
90 | equation = re.search(pattern, r)
91 | # get some following rationales
92 | following_rationales = []
93 | for fid, tmp_rs in enumerate(rationales):
94 | if fid <= rid:
95 | continue
96 | following_rationales.extend(tmp_rs)
97 | print(len(following_rationales))
98 | if equation is None: # if there is no eqaution, continue
99 | continue
100 | # if there is an equation, replace some elements in it
101 | e_start = equation.start()
102 | e_end = equation.end()
103 | equation = equation.group()
104 | eq = equation.split("=")[0].strip()
105 | symbols = re.findall(r"[\+\-\/\*x]", eq)
106 | elements = re.split(r"[\+\-\/\*x]", eq)
107 | elements = [ele.strip() for ele in elements]
108 | count = 0
109 | while count < 10:
110 | try:
111 | replaced_index = random.choice(range(len(elements)))
112 | print(elements[replaced_index])
113 | tmp_number_set = [n for n in number_set]
114 | tmp_number_set.remove(elements[replaced_index])
115 | elements[replaced_index] = random.choice(tmp_number_set)
116 | false_eq = f"{elements[0]}"
117 | for sid, symbol in enumerate(symbols):
118 | if symbol == "x":
119 | symbol = "*"
120 | false_eq += f"{symbol}{elements[sid + 1]}"
121 | try:
122 | print(false_eq)
123 | value = eval(false_eq)
124 | f_r = r[:e_start - 1] + f" {false_eq} = {value} " + r[e_end:]
125 | f_rs.append(f_r)
126 | except:
127 | print(false_eq)
128 | value = equation.split("=")[-1].strip()
129 | f_r = r[:e_start - 1] + f" {false_eq} = {value} " + r[e_end:]
130 | f_rs.append(f_r)
131 | break
132 | except:
133 | count += 1
134 | continue
135 | false_rationales[rid] = f_rs
136 | return false_rationales
137 |
138 | def generate_ranking_negative_samples():
139 | """
140 | 1. gold answer
141 | 2. gold answer but change one number
142 | 2. similar answer that change one number
143 | 3. other rationales from different question
144 | """
145 | def get_some_random_samples(excluded_id):
146 | id_list = list(range(len(id2false_rationales)))
147 | id_list.remove(excluded_id)
148 | id_list = random.sample(id_list, 1)
149 | selected_false_rationales = []
150 | for sid in id_list:
151 | false_rationales = None
152 | while false_rationales is None or len(false_rationales) == 0:
153 | false_rationales = id2false_rationales.get(sid)
154 | sid += 1
155 | s_false_rationales = []
156 | for v in false_rationales.values():
157 | s_false_rationales.extend(v)
158 | selected_false_rationales.append(random.choice(s_false_rationales))
159 | return selected_false_rationales
160 |
161 |
162 | id2false_rationales = {}
163 | with open("data/gsm8k/false_rationales.txt", "r") as fin:
164 | for line in fin.readlines():
165 | data = json.loads(line.strip())
166 | id2false_rationales.update({data["id"]: data["false_rationales"]})
167 | samples = {}
168 | with open("data/gsm8k/gsm8k_clean.txt", "r") as fin:
169 | for line in fin.readlines():
170 | data = json.loads(line.strip())
171 | metadata = {**data}
172 | samples.update({data["id"]: metadata})
173 | with open("data/gsm8k/train_short.json", "r") as fin:
174 | shortened_contexts = json.load(fin)
175 | answer2context = {}
176 | for sample in shortened_contexts:
177 | answer2context.update({sample["answer"]: sample["context"]})
178 | datasets = load_dataset("gsm8k", "main")
179 | dataset = datasets["train"]
180 | write_datas = []
181 | for sid, sample in enumerate(dataset):
182 | false_rationales = id2false_rationales[sid]
183 | correct_metadata = samples[sid]
184 | correct_rationales = correct_metadata["rationales"]
185 | chain_length = correct_metadata["chain_length"]
186 | context = correct_metadata["question"]
187 | for i in range(chain_length):
188 | i_correct_rationales = correct_rationales[i]
189 | i_false_rationales = false_rationales.get(str(i))
190 | if i_false_rationales is None or len(i_false_rationales) <= 1:
191 | continue
192 | cr = i_correct_rationales[0]
193 | answers = {
194 | 0: cr,
195 | 1: i_false_rationales[0],
196 | 2: random.choice(i_false_rationales[1:]),
197 | 3: get_some_random_samples(sid)[0]
198 | }
199 | write_datas.append({
200 | "context": context,
201 | "answers": answers,
202 | })
203 | context += f" {cr}"
204 | print(len(write_datas)) # 111280
205 | with open("data/gsm8k/train_ranking_small_full_multi_negative.json", "w") as fout:
206 | json.dump(write_datas, fout, indent=1, ensure_ascii=False)
207 |
208 |
209 | @torch.inference_mode()
210 | def generate_hard_negative():
211 | torch.manual_seed(42)
212 | random.seed(42)
213 |
214 | parser = get_parser()
215 | config = parser.parse_args()
216 |
217 | config.device = torch.device("cuda")
218 | config.generation_config = GenerationConfig(
219 | do_sample=True,
220 | temperature=1,
221 | top_p=0.95,
222 | max_new_tokens=512,
223 | )
224 | config.prompt = prompt
225 |
226 | sampling_params = SamplingParams(n=10, temperature=1, top_p=0.95, max_tokens=512)
227 |
228 | # decoder = Decoder(config)
229 |
230 | dataset = load_dataset("data/cache/meta-math___json/meta-math--MetaMathQA-b6af0a8ce3115a0e")["train"]
231 | datas = []
232 | all_rationales = []
233 | for sample in dataset:
234 | if sample["type"].startswith("GSM"):
235 | response = sample["response"].split("\n")[:-1]
236 | final_rationale = response[-1].replace("####", "").strip()
237 | response[-1] = f"Final Answer: {final_rationale}"
238 | sample["response"] = response
239 | datas.append(sample)
240 | all_rationales.extend(response)
241 |
242 | llm = LLM(model="~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/08751db2aca9bf2f7f80d2e516117a53d7450235/", seed=42) # , dtype="float32"
243 | ranker = Ranker(config)
244 | checkpoint = torch.load("ckpts/gsm8k/gsm8k_ranking_multi_neg_margin.pt", map_location="cpu")
245 | checkpoint = checkpoint["model"]
246 | del(checkpoint["model.embeddings.position_ids"])
247 | ranker.load_state_dict(checkpoint)
248 | ranker.cuda()
249 | ranker.eval()
250 |
251 | progressbar = tqdm(range(len(datas)))
252 | for did, data in enumerate(datas):
253 | question = data["query"]
254 | answer = data["response"]
255 | context = "\nQuestion:\n" + question + "\nAnswer:\n"
256 | for _ in range(len(answer)):
257 | outputs = llm.generate(prompt + context, sampling_params, use_tqdm=False)[0].outputs
258 | generated_answers = []
259 | for output in outputs:
260 | generated_answers.append(output.text.strip().split("\n")[0].strip())
261 | contexts = [context] * len(generated_answers)
262 | scores = ranker(contexts, generated_answers).cpu().squeeze()
263 | # scores = torch.tensor(scores).squeeze()
264 | sorted, indices = torch.sort(scores, descending=True)
265 | # print(answers)
266 | # input()
267 | write_data = {
268 | "sid": did,
269 | "context": context.strip(),
270 | "answers": {
271 | "0": answer[_],
272 | "1": generated_answers[indices[-1]],
273 | "2": random.choice(all_rationales),
274 | }
275 | }
276 | with open("data/gsm8k/hard_negative_metamathqa.txt", "a+") as fout:
277 | fout.write(f"{json.dumps(write_data)}\n")
278 | context += answer[_] + "\n"
279 | # print(context)
280 | # input()
281 | progressbar.update(1)
282 |
283 | if __name__ == "__main__":
284 | generate_hard_negative()
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import AutoModel, AutoTokenizer
4 |
5 | class Verifier(nn.Module):
6 | def __init__(self, config) -> None:
7 | super().__init__()
8 | self.config = config
9 | self.tokenizer = AutoTokenizer.from_pretrained(config.verifier_model_path)
10 | self.model = AutoModel.from_pretrained(config.verifier_model_path)
11 |
12 | self.decision_layers = nn.Sequential(
13 | nn.Linear(self.model.config.hidden_size, config.hidden_size),
14 | nn.ReLU(),
15 | nn.Linear(config.hidden_size, 1), # 0 for neutral, 1 for entail
16 | # nn.Softmax(),
17 | nn.Sigmoid(),
18 | )
19 |
20 | def forward(self, contexts, rationales):
21 | inputs = [f"[CLS]{c}[SEP]{r}[SEP]" for c, r in zip(contexts, rationales)]
22 | tokenized = self.tokenizer(
23 | inputs,
24 | add_special_tokens=False,
25 | max_length=self.config.max_length,
26 | padding="max_length",
27 | truncation=True,
28 | return_tensors="pt",
29 | )
30 | input_ids = tokenized.input_ids.to(self.config.device)
31 | attention_mask = tokenized.attention_mask.to(self.config.device)
32 | encoded = self.model(input_ids, attention_mask).last_hidden_state[:, 0, :]
33 | # print(encoded)
34 | logits = self.decision_layers(encoded)
35 | return logits
36 |
37 | def train(self, contexts, rationales, labels):
38 | logits = self.forward(contexts, rationales).squeeze()
39 | # print(logits)
40 | # print(labels)
41 | # input()
42 | return logits, torch.tensor(labels).float().to(logits.device)
43 |
44 | def from_pretrained(self, model_path):
45 | checkpoint = torch.load(model_path, map_location="cpu")
46 | pretrained_dict = checkpoint["model"]
47 | model_dict = self.model.state_dict()
48 | pretrained_dict = {k.replace("model.", ""): v for k, v in pretrained_dict.items() if k.replace("model.", "") in model_dict}
49 | self.model.load_state_dict(pretrained_dict)
50 |
51 | pretrained_dict = checkpoint["model"]
52 | model_dict = self.decision_layers.state_dict()
53 | pretrained_dict = {k.replace("decision_layers.", ""): v for k, v in pretrained_dict.items() if k.replace("decision_layers.", "") in model_dict}
54 | self.decision_layers.load_state_dict(pretrained_dict)
55 | return self
56 |
57 | class Ranker(nn.Module):
58 | def __init__(self, config) -> None:
59 | super().__init__()
60 | self.config = config
61 | self.tokenizer = AutoTokenizer.from_pretrained(config.verifier_model_path)
62 | self.model = AutoModel.from_pretrained(config.verifier_model_path)
63 |
64 | self.decision_layers = nn.Sequential(
65 | nn.Linear(self.model.config.hidden_size, config.hidden_size),
66 | nn.ReLU(),
67 | nn.Linear(config.hidden_size, 1), # 0 for neutral, 1 for entail
68 | nn.Sigmoid(),
69 | )
70 |
71 | def forward(self, contexts, rationales):
72 | inputs = [f"[CLS]{c}[SEP]{r}[SEP]" for c, r in zip(contexts, rationales)]
73 | tokenized = self.tokenizer(
74 | inputs,
75 | add_special_tokens=False,
76 | max_length=self.config.max_length,
77 | padding="max_length",
78 | truncation=True,
79 | return_tensors="pt",
80 | )
81 | input_ids = tokenized.input_ids.to(self.config.device)
82 | attention_mask = tokenized.attention_mask.to(self.config.device)
83 | encoded = self.model(input_ids, attention_mask).last_hidden_state[:, 0, :]
84 | # print(encoded)
85 | logits = self.decision_layers(encoded)
86 | return logits
87 |
88 | def forward_train(self, contexts, rationales, labels):
89 | logits = self.forward(contexts, rationales).squeeze()
90 | # print(logits)
91 | # print(labels)
92 | # input()
93 | return logits, torch.tensor(labels).to(logits.device)
94 |
95 | def from_pretrained(self, model_path):
96 | checkpoint = torch.load(model_path, map_location="cpu")
97 | pretrained_dict = checkpoint["model"]
98 | model_dict = self.model.state_dict()
99 | pretrained_dict = {k.replace("model.", ""): v for k, v in pretrained_dict.items() if k.replace("model.", "") in model_dict}
100 | self.model.load_state_dict(pretrained_dict)
101 |
102 | pretrained_dict = checkpoint["model"]
103 | model_dict = self.decision_layers.state_dict()
104 | pretrained_dict = {k.replace("decision_layers.", ""): v for k, v in pretrained_dict.items() if k.replace("decision_layers.", "") in model_dict}
105 | self.decision_layers.load_state_dict(pretrained_dict)
106 | return self
--------------------------------------------------------------------------------
/src/openai_api_mp.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
3 | import openai
4 | import tiktoken
5 | from openai.error import InvalidRequestError
6 | from typing import List, Dict, AnyStr
7 | from tenacity import stop_after_attempt, retry, wait_random_exponential, retry_if_not_exception_type
8 | from multiprocessing import Pool, Value
9 |
10 | def openai_setup(api_key = None):
11 | if api_key:
12 | openai.api_key = api_key
13 | else:
14 | openai.api_key = os.environ['openai_api_key']
15 |
16 | class ChatGPTPool:
17 | def __init__(
18 | self,
19 | api_key = "sk-your_api_key",
20 | prompt: AnyStr = None,
21 | history_messages: List[Dict[AnyStr, AnyStr]] = None,
22 | num_sampling = 1,
23 | temperature = 1.0,
24 | max_tokens = 256,
25 | stop = "\n",
26 | ) -> None:
27 | openai_setup(api_key=api_key)
28 | if history_messages:
29 | self.messages = history_messages
30 | else:
31 | self.messages = []
32 | if prompt:
33 | msg = {"role": "system","content": prompt}
34 | self.messages.append(msg)
35 |
36 | self.money = 0
37 | self.num_sampling = num_sampling
38 | self.temperature = temperature
39 | self.max_tokens = max_tokens
40 | self.stop = stop
41 | self.prompt = prompt
42 | self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-instruct")
43 |
44 | def check_cost(self):
45 | return self.money
46 |
47 | def chat_single_round(
48 | self,
49 | message,
50 | ):
51 | num_tokens = 0
52 | messages = []
53 | if self.prompt:
54 | messages.append({"role": "system", "content": self.prompt})
55 | num_tokens += len(self.encoding.encode(self.prompt))
56 | e = {"role": "user", "content": message}
57 | messages.append(e)
58 | num_tokens += len(self.encoding.encode(message))
59 | self.money += 0.0015 / 1000 * num_tokens * self.num_sampling
60 |
61 | with Pool(self.num_sampling) as p:
62 | contents = p.map(chat_single_round, [(messages, self.temperature, self.max_tokens, self.stop)] * self.num_sampling)
63 |
64 | num_tokens = sum([len(self.encoding.encode(content)) for content in contents])
65 | self.money += 0.002 / 1000 * num_tokens
66 | return contents
67 |
68 |
69 | @retry(stop=stop_after_attempt(5), wait=wait_random_exponential(min=1, max=60), retry=retry_if_not_exception_type(InvalidRequestError))
70 | def chat_single_round(
71 | messages,
72 | ):
73 | temperature = messages[1]
74 | max_tokens = messages[2]
75 | stop = messages[3]
76 | messages = messages[0]
77 |
78 | prompt = []
79 | for message in messages:
80 | prompt.append(message["content"])
81 | prompt = "\n".join(prompt)
82 | completion = openai.Completion.create(
83 | model = 'gpt-3.5-turbo-instruct',
84 | prompt = prompt,
85 | temperature = temperature,
86 | max_tokens = max_tokens,
87 | stop = stop,
88 | )
89 | content = completion.choices[0].text
90 | return content
91 |
--------------------------------------------------------------------------------
/src/parser_utils.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | def get_parser():
4 | parser = ArgumentParser()
5 | parser.add_argument("--experiment_name", default="gsm8k")
6 |
7 | parser.add_argument("--train_datapath", default="data/gsm8k/train_ranking_small_full.json")
8 | parser.add_argument("--valid_datapath", default=None)
9 | parser.add_argument("--test_datapath", default=None)
10 |
11 | parser.add_argument("--save_dir", default="ckpts/gsm8k")
12 | parser.add_argument("--resume_path", default=None)
13 |
14 | parser.add_argument("--verifier_model_path", default="data/cache/deberta/64a8c8eab3e352a784c658aef62be1662607476f")
15 | parser.add_argument("--hidden_size", type=int, default=1024)
16 |
17 | parser.add_argument("--num_workers", type=int, default=8)
18 | parser.add_argument("--max_length", type=int, default=512)
19 | parser.add_argument("--dense_dim", type=int, default=768)
20 | parser.add_argument("--batch_size", type=int, default=16)
21 | parser.add_argument("--num_epochs", type=int, default=10)
22 | parser.add_argument("--num_warmup_steps", type=int, default=200)
23 | parser.add_argument("--learning_rate", type=float, default=5e-5)
24 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
25 | parser.add_argument("--warmup_epochs", type=int, default=20)
26 | parser.add_argument("--lr_step_size", type=int, default=5)
27 | parser.add_argument("--lr_gamma", type=float, default=0.1)
28 | parser.add_argument("--weight_decay", type=int, default=0.0005)
29 |
30 | parser.add_argument("--min_margin", type=float, default=0.1)
31 | parser.add_argument("--max_margin", type=float, default=0.3)
32 | parser.add_argument("--margin_increase_step", type=int, default=2000)
33 |
34 | parser.add_argument("--use_gpu", action="store_true", default=True)
35 | parser.add_argument("--device")
36 |
37 | parser.add_argument("--test_dataset", type=str, choices=["gsm8k", "svamp", "aqua", "addsub", "singleeq", "multiarith", "strategyqa", "csqa", "llc", "coin"])
38 | parser.add_argument("--prompt", type=str, default="")
39 | parser.add_argument("--decoder_path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
40 | parser.add_argument("--num_sampling", type=int, default=1)
41 | parser.add_argument("--num_beams", type=int, default=1)
42 | parser.add_argument("--decode_strategy", type=str, default="direct")
43 | parser.add_argument("--shorten_context", action="store_true")
44 | parser.add_argument("--is_greedy", action="store_true", default=False)
45 | parser.add_argument("--num_gpus_decode", type=int, default=1)
46 | parser.add_argument("--output_dir", default="outputs/")
47 |
48 | return parser
49 |
50 | # python train_gsm8k.py --batch_size 3 --save_dir ckpts/debug --train_datapath data/gsm8k/train_ranking_small_full.json --experiment_name gsm8k_ranking_small_full_low_lr
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import random
5 | from datasets import load_dataset
6 |
7 | from src.decoder import FastDecoder, ChatGPTDecoder
8 | from src.parser_utils import get_parser
9 | from prompts.gsm8k import prompt as math_prompt
10 | from prompts.strategyqa import prompt as cs_prompt
11 | from prompts.csqa import prompt as csqa_prompt
12 | from prompts.coin import prompt as coin_prompt
13 | from prompts.strategyqa import recall_prompt as cs_recall_prompt
14 | from prompts.csqa import recall_prompt as csqa_recall_prompt
15 | from tqdm import tqdm
16 |
17 | MATH_DATASET = [
18 | "gsm8k",
19 | "svamp",
20 | "aqua",
21 | "singleeq",
22 | "multiarith",
23 | ]
24 | COMMONSENSE_DATASET = [
25 | "strategyqa",
26 | "csqa",
27 | ]
28 | SYMBOLIC_DATASET = [
29 | "coin",
30 | ]
31 |
32 | MODEL_NAME = ["7b", "13b", "70b", "gpt"]
33 |
34 |
35 | def load_dataset_from_config(config):
36 | """
37 | The returned datas should be a list and each element should be a dict containing keyword "quesiton".
38 | [
39 | {
40 | "id": id,
41 | "question": question,
42 | }
43 | ]
44 | """
45 | datas = []
46 | if config.test_dataset == "gsm8k":
47 | dataset = load_dataset("data/cache/gsm8k", "main")["test"]
48 | for did, data in enumerate(dataset):
49 | datas.append({
50 | "id": did,
51 | "question": data["question"],
52 | })
53 | elif config.test_dataset == "svamp":
54 | dataset = load_dataset("data/cache/ChilleD___json/ChilleD--SVAMP-4bd8179a65d5f05b")["test"]
55 | for data in dataset:
56 | datas.append({
57 | "id": data["ID"],
58 | "question": data["Body"] + " " + data["Question"],
59 | })
60 | elif config.test_dataset == "aqua":
61 | dataset = []
62 | with open("data/cache/AQuA/test.json", "r") as fin:
63 | for line in fin.readlines():
64 | dataset.append(json.loads(line.strip()))
65 | # dataset = json.load(fin)
66 | for did, data in enumerate(dataset):
67 | datas.append({
68 | "id": did,
69 | "question": data["question"],
70 | })
71 | elif config.test_dataset == "addsub":
72 | dataset = load_dataset("data/cache/allenai___lila", "addsub")["test"]
73 | for did, data in enumerate(dataset):
74 | datas.append({
75 | "id": did,
76 | "question": data["input"],
77 | })
78 | elif config.test_dataset == "singleeq":
79 | data_dir = "data/cache/SingleEq"
80 | for i in range(5):
81 | with open(os.path.join(data_dir, f"test{i}")) as fin:
82 | for lid, line in enumerate(fin.readlines()):
83 | if lid % 3 == 0:
84 | datas.append({
85 | "id": int(lid / 3),
86 | "question": line.strip(),
87 | })
88 | elif config.test_dataset == "multiarith":
89 | dataset = load_dataset("data/cache/ChilleD___json/ChilleD--MultiArith-2e3d95e2a4ce9083")["test"]
90 | for did, data in enumerate(dataset):
91 | datas.append({
92 | "id": did,
93 | "question": data["question"],
94 | })
95 | elif config.test_dataset == "strategyqa":
96 | with open("data/cache/strategyqa/test_set.json", "r") as fin:
97 | dataset = json.load(fin)
98 | for data in dataset:
99 | datas.append({
100 | "id": data["qid"],
101 | "question": data["question"],
102 | })
103 | elif config.test_dataset == "csqa":
104 | dataset = load_dataset("data/cache/commonsense_qa")["validation"]
105 | for data in dataset:
106 | labels = data["choices"]["label"]
107 | texts = data["choices"]["text"]
108 | choices = []
109 | for label, text in zip(labels, texts):
110 | choices.append(f"{label}. {text}")
111 | choices = "\n".join(choices)
112 | question = data["question"] + "\n" + choices
113 | datas.append({
114 | "id": data["id"],
115 | "question": question,
116 | })
117 | elif config.test_dataset == "coin":
118 | dataset = load_dataset("data/cache/skrishna___json/skrishna--coin_flip-8305ab6800b027bf")["test"]
119 | for did, data in enumerate(dataset):
120 | question = data["inputs"].replace("Q:", "").strip()
121 | datas.append({
122 | "id": did,
123 | "question": question,
124 | })
125 | else:
126 | print(f"[WARNING] {config.test_dataset} is not an option!")
127 | raise NotImplementedError
128 | return datas
129 |
130 | def load_prompt_from_config(config):
131 | if config.test_dataset in MATH_DATASET:
132 | prompt = math_prompt
133 | elif config.test_dataset in COMMONSENSE_DATASET:
134 | if config.test_dataset == "csqa":
135 | if "fact" in config.decode_strategy:
136 | prompt = csqa_recall_prompt
137 | else:
138 | prompt = csqa_prompt
139 | else:
140 | if "fact" in config.decode_strategy:
141 | prompt = cs_recall_prompt
142 | else:
143 | prompt = cs_prompt
144 | elif config.test_dataset in SYMBOLIC_DATASET:
145 | if config.test_dataset == "coin":
146 | prompt = coin_prompt
147 | else:
148 | print(f"[WARNING] {config.test_dataset} is not an option!")
149 | raise NotImplementedError
150 | return prompt
151 |
152 | def load_dataset_and_prompt(config):
153 | return load_dataset_from_config(config), load_prompt_from_config(config)
154 |
155 |
156 | @torch.inference_mode()
157 | def main():
158 | torch.manual_seed(42)
159 | random.seed(42)
160 |
161 | parser = get_parser()
162 | config = parser.parse_args()
163 |
164 | config.device = torch.device("cuda")
165 |
166 | # load dataset, prompt, and decoder
167 | dataset, prompt = load_dataset_and_prompt(config)
168 | config.prompt = prompt
169 | if "gpt" in config.decoder_path:
170 | decoder = ChatGPTDecoder(config)
171 | else:
172 | decoder = FastDecoder(config)
173 |
174 | # load from previous infer checkpoint
175 | if config.resume_path is not None:
176 | ckpt_name = config.resume_path.split(".")[0].split("/")[-1]
177 | else:
178 | ckpt_name = "None"
179 |
180 | for name in MODEL_NAME:
181 | if name in config.decoder_path:
182 | model_abbr_name = name
183 | else:
184 | if "/" in config.decoder_path:
185 | model_abbr_name = config.decoder_path.split("/")[-1]
186 | else:
187 | model_abbr_name = config.decoder_path
188 |
189 | if not os.path.exists(config.output_dir):
190 | os.mkdir(config.output_dir)
191 | config.output_dir = os.path.join(config.output_dir, model_abbr_name)
192 | if not os.path.exists(config.output_dir):
193 | os.mkdir(config.output_dir)
194 | save_path = os.path.join(config.output_dir, f"{config.test_dataset}_{model_abbr_name}_{ckpt_name}_{config.decode_strategy}_N{config.num_sampling}_B{config.num_beams}.txt")
195 | if os.path.exists(save_path):
196 | with open(save_path, "r") as fin:
197 | line = fin.readlines()[-1]
198 | data = json.loads(line.strip())
199 | for key in data.keys():
200 | continue
201 | last_sid = int(key)
202 | else:
203 | last_sid = -1
204 |
205 | # inference
206 | progressbar = tqdm(range(len(dataset)))
207 | for sid, sample in enumerate(dataset):
208 | if sid <= last_sid:
209 | progressbar.update(1)
210 | continue
211 | question = sample["question"]
212 | answers = decoder.decode_fn(question)
213 |
214 | # save inference checkpoint
215 | if "beam_search" in config.decode_strategy:
216 | beams, full_rationales = answers
217 | with open(save_path, "a+") as fout:
218 | fout.write(f"{json.dumps({sid: beams[0]})}\n")
219 | with open(f"{save_path}.full", "a+") as fout:
220 | fout.write(f"{json.dumps({sid: full_rationales})}\n")
221 | else:
222 | with open(save_path, "a+") as fout:
223 | fout.write(f"{json.dumps({sid: answers})}\n")
224 | progressbar.update(1)
225 |
226 | if __name__ == "__main__":
227 | main()
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
3 | import torch
4 |
5 | from src.parser_utils import get_parser
6 | from src.data_utils import GSM8KRankingMultiNegativeDataset
7 | from src.trainer import RankingMultipleNegativeTrainer
8 | from src.model import Ranker
9 |
10 |
11 | def ranker_multi_negavie_main():
12 | parser = get_parser()
13 | config = parser.parse_args()
14 |
15 | if config.train_datapath is not None:
16 | train_dataset = GSM8KRankingMultiNegativeDataset(config.train_datapath)
17 | else:
18 | train_dataset = None
19 | if config.valid_datapath is not None:
20 | valid_dataset = GSM8KRankingMultiNegativeDataset(config.valid_datapath)
21 | else:
22 | valid_dataset = None
23 | if config.test_datapath is not None:
24 | test_dataset = GSM8KRankingMultiNegativeDataset(config.test_datapath)
25 | else:
26 | test_dataset = None
27 |
28 | model = Ranker(config)
29 |
30 | trainer = RankingMultipleNegativeTrainer(
31 | config,
32 | model,
33 | train_dataset,
34 | valid_dataset,
35 | test_dataset
36 | )
37 |
38 | trainer.train()
39 |
40 | if __name__ == "__main__":
41 | ranker_multi_negavie_main()
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | import torch
5 | import logging
6 | import numpy as np
7 | import torch.nn as nn
8 | from tqdm import tqdm
9 | from torch.utils.data import DataLoader, random_split
10 | from torch.optim import AdamW
11 | from transformers import get_linear_schedule_with_warmup
12 |
13 | from src.data_utils import collate_fn
14 |
15 | class RankingTrainer():
16 | def __init__(self,
17 | config,
18 | model,
19 | train_dataset=None,
20 | valid_dataset=None,
21 | test_dataset=None,):
22 | self.config = config
23 |
24 | # device config
25 | if self.config.use_gpu:
26 | self.config.device = torch.device(config.device)
27 | else:
28 | self.config.device = torch.device("cpu")
29 |
30 | # logger config
31 | logging.basicConfig(
32 | level=logging.INFO,
33 | format="%(asctime)s [%(levelname)s] %(message)s",
34 | handlers=[
35 | logging.FileHandler(f"logs/{self.config.experiment_name}.log", mode="w"),
36 | logging.StreamHandler()
37 | ]
38 | )
39 | self.logger = logging.getLogger(__name__)
40 |
41 | # prepare data loader and target documents
42 | if train_dataset is None and test_dataset is None:
43 | self.logger.error("At least one dataset should be passed")
44 | raise FileNotFoundError
45 | if train_dataset is not None and valid_dataset is not None:
46 | self.train_dataloader = DataLoader(
47 | train_dataset,
48 | batch_size=config.batch_size,
49 | shuffle=True,
50 | num_workers=config.num_workers,
51 | collate_fn=collate_fn,
52 | drop_last=True,
53 | )
54 | self.logger.info(f"Training set length: {len(train_dataset)}")
55 | self.val_dataloader = DataLoader(
56 | valid_dataset,
57 | batch_size=config.batch_size,
58 | shuffle=False,
59 | num_workers=config.num_workers,
60 | collate_fn=collate_fn,
61 | drop_last=True,
62 | )
63 | self.logger.info(f"Validation set length: {len(valid_dataset)}")
64 | elif train_dataset is not None and valid_dataset is None:
65 | datasets = random_split(train_dataset, [int(0.9 * len(train_dataset)), len(train_dataset) - int(0.9 * len(train_dataset))])
66 | # datasets = random_split(train_dataset, [int(0.001 * len(train_dataset)), int(0.001 * len(train_dataset)), len(train_dataset) - 2 * int(0.001 * len(train_dataset))])
67 | train_dataset = datasets[0]
68 | valid_dataset = datasets[1]
69 | self.train_dataloader = DataLoader(
70 | train_dataset,
71 | batch_size=config.batch_size,
72 | shuffle=True,
73 | num_workers=config.num_workers,
74 | collate_fn=collate_fn,
75 | drop_last=True,
76 | )
77 | self.logger.info(f"Training set length: {len(train_dataset)}")
78 | self.val_dataloader = DataLoader(
79 | valid_dataset,
80 | batch_size=config.batch_size,
81 | shuffle=False,
82 | num_workers=config.num_workers,
83 | collate_fn=collate_fn,
84 | drop_last=True,
85 | )
86 | self.logger.info(f"Validation set length: {len(valid_dataset)}")
87 | if test_dataset is not None:
88 | self.test_dataloader = DataLoader(
89 | test_dataset,
90 | batch_size=1,
91 | shuffle=True,
92 | num_workers=config.num_workers,
93 | collate_fn=collate_fn,
94 | drop_last=True,
95 | )
96 | self.logger.info(f"Test set length: {len(test_dataset)}")
97 | self.logger.info("Data init done.")
98 |
99 | # prepare model, preprocessor, loss, optimizer and scheduler
100 | self.model = model
101 | # self.margin = self.config.min_margin
102 | # self.loss = nn.MarginRankingLoss(margin=self.margin)
103 | self.loss = nn.MarginRankingLoss(margin=0.2)
104 |
105 | no_decay = ["bias", "LayerNorm.weight"]
106 | optimizer_grouped_parameters = [
107 | {
108 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
109 | "weight_decay": self.config.weight_decay,
110 | },
111 | {
112 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
113 | "weight_decay": 0.0,
114 | },
115 | ]
116 | self.optimizer = AdamW(
117 | optimizer_grouped_parameters,
118 | lr=self.config.learning_rate,
119 | )
120 |
121 | if config.resume_path is None:
122 | self.logger.warning("No checkpoint given!")
123 | self.best_val_loss = 10000
124 | self.best_accu = 0
125 | self.last_epoch = -1
126 | num_training_steps = self.config.num_epochs * len(self.train_dataloader)
127 | self.scheduler = get_linear_schedule_with_warmup(
128 | optimizer=self.optimizer,
129 | num_warmup_steps=self.config.num_warmup_steps,
130 | num_training_steps=num_training_steps,
131 | last_epoch=self.last_epoch
132 | )
133 | else:
134 | self.logger.info(f"Loading model from checkpoint: {self.config.resume_path}.")
135 | checkpoint = torch.load(self.config.resume_path, map_location=self.config.device)
136 | self.model.load_state_dict(checkpoint["model"], strict=False)
137 | self.best_val_loss = checkpoint["best_val_loss"]
138 | self.best_accu = checkpoint["best_accu"]
139 | self.last_epoch = -1
140 | num_training_steps = self.config.num_epochs * len(self.train_dataloader)
141 | self.optimizer.load_state_dict(checkpoint["optimizier"])
142 | self.scheduler = get_linear_schedule_with_warmup(
143 | optimizer=self.optimizer,
144 | num_warmup_steps=self.config.num_warmup_steps,
145 | num_training_steps=num_training_steps,
146 | last_epoch=self.last_epoch
147 | )
148 | self.scheduler.load_state_dict(checkpoint["scheduler"])
149 | self.model.to(self.config.device)
150 |
151 | self.logger.info("Trainer init done.")
152 |
153 | def train_one_epoch(self):
154 | total_loss = 0
155 | num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / self.config.gradient_accumulation_steps)
156 | progress_bar = tqdm(range(num_update_steps_per_epoch))
157 | for step, batch in enumerate(self.train_dataloader):
158 | contexts, answers, false_answers = batch
159 | logits = self.model(contexts, answers).squeeze()
160 | false_logits = self.model(contexts, false_answers).squeeze()
161 | labels = torch.tensor([1] * len(contexts)).float().to(self.config.device)
162 | loss = self.loss(logits, false_logits, labels)
163 |
164 | total_loss += loss.detach().cpu().item()
165 | loss.backward()
166 | if (step != 0 and step % self.config.gradient_accumulation_steps == 0) or step == len(self.train_dataloader) - 1:
167 | self.optimizer.step()
168 | self.scheduler.step()
169 | self.optimizer.zero_grad()
170 | progress_bar.set_postfix(
171 | {
172 | 'loss': total_loss / (step + 1),
173 | }
174 | )
175 | progress_bar.update(1)
176 | if self.config.margin_increase_step > 0 and step != 0 and step % self.config.margin_increase_step == 0 and self.margin < self.config.max_margin:
177 | self.margin += 0.1
178 | self.loss = nn.MarginRankingLoss(self.margin)
179 | return total_loss / (step + 1)
180 |
181 | def train(self):
182 | for epoch in range(self.config.num_epochs):
183 | self.logger.info("========================")
184 | self.logger.info("Training...")
185 | epoch_loss = self.train_one_epoch()
186 | self.logger.info(f"Epoch {epoch} training loss: {epoch_loss}")
187 | self.logger.info("Validating...")
188 | val_loss, val_accu = self.validate()
189 | self.logger.info(f"Epoch {epoch} validation loss: {val_loss}")
190 | self.logger.info(f"Epoch {epoch} validation accuracy: {val_accu}")
191 | # if val_loss < self.best_val_loss:
192 | if val_accu > self.best_accu:
193 | self.best_val_loss = val_loss
194 | self.best_accu = val_accu
195 | checkpoint = {
196 | "model": self.model.state_dict(),
197 | "optimizier": self.optimizer.state_dict(),
198 | "scheduler": self.scheduler.state_dict(),
199 | "last_epoch": self.last_epoch + epoch + 1,
200 | "best_val_loss": self.best_val_loss,
201 | "best_accu": self.best_accu,
202 | "config": self.config,
203 | }
204 | save_path = os.path.join(self.config.save_dir, f"{self.config.experiment_name}.pt")
205 | self.logger.info(f"Saving best checkpoint to {save_path}")
206 | torch.save(checkpoint, save_path)
207 |
208 | @torch.no_grad()
209 | def validate(self):
210 | total_loss = 0
211 | cnt_true, cnt = 0, 0
212 | for step, batch in tqdm(enumerate(self.val_dataloader)):
213 | contexts, answers, false_answers = batch
214 | logits = self.model(contexts, answers).squeeze()
215 | false_logits = self.model(contexts, false_answers).squeeze()
216 | labels = torch.tensor([1] * len(contexts)).float().to(self.config.device)
217 | loss = self.loss(logits, false_logits, labels)
218 | total_loss += loss.detach().cpu().item()
219 |
220 | b_cnt_true, b_cnt = self.evaluate(logits.detach().cpu(), false_logits.detach().cpu())
221 | cnt_true += b_cnt_true
222 | cnt += b_cnt
223 |
224 | return total_loss / (step + 1), cnt_true / cnt
225 |
226 | @torch.no_grad()
227 | def evaluate(self, logits, false_logits):
228 | cnt_true, cnt = 0, 0
229 | for pred, truth in zip(logits, false_logits):
230 | if pred > truth:
231 | cnt_true += 1
232 | cnt += 1
233 | return cnt_true, cnt
234 |
235 | @torch.no_grad()
236 | def metric(self, tp, tn, fp, fn):
237 | try:
238 | accu = (tp + tn) / (tp + tn + fp + fn)
239 | except:
240 | accu = 0
241 | try:
242 | prec = tp / (tp + fp)
243 | except:
244 | prec = 0
245 | try:
246 | reca = tp / (tp + fn)
247 | except:
248 | reca = 0
249 | try:
250 | f1 = 2 * prec * reca / (prec + reca)
251 | except:
252 | f1 = 0
253 | print(f"Accuracy: {accu}")
254 | print(f"Preision: {prec}")
255 | print(f"Recall: {reca}")
256 | print(f"F1: {f1}")
257 | return accu, prec, reca, f1
258 |
259 | class RankingMultipleNegativeTrainer():
260 | def __init__(self,
261 | config,
262 | model,
263 | train_dataset=None,
264 | valid_dataset=None,
265 | test_dataset=None,):
266 | self.config = config
267 |
268 | # device config
269 | if self.config.use_gpu:
270 | self.config.device = torch.device(config.device)
271 | else:
272 | self.config.device = torch.device("cpu")
273 |
274 | # logger config
275 | logging.basicConfig(
276 | level=logging.INFO,
277 | format="%(asctime)s [%(levelname)s] %(message)s",
278 | handlers=[
279 | logging.FileHandler(f"logs/{self.config.experiment_name}.log", mode="w"),
280 | logging.StreamHandler()
281 | ]
282 | )
283 | self.logger = logging.getLogger(__name__)
284 | self.logger.info(self.config)
285 |
286 | # prepare data loader and target documents
287 | if train_dataset is None and test_dataset is None:
288 | self.logger.error("At least one dataset should be passed")
289 | raise FileNotFoundError
290 | if train_dataset is not None and valid_dataset is not None:
291 | self.train_dataloader = DataLoader(
292 | train_dataset,
293 | batch_size=config.batch_size,
294 | shuffle=True,
295 | num_workers=config.num_workers,
296 | collate_fn=collate_fn,
297 | drop_last=True,
298 | )
299 | self.logger.info(f"Training set length: {len(train_dataset)}")
300 | self.val_dataloader = DataLoader(
301 | valid_dataset,
302 | batch_size=config.batch_size,
303 | shuffle=False,
304 | num_workers=config.num_workers,
305 | collate_fn=collate_fn,
306 | drop_last=True,
307 | )
308 | self.logger.info(f"Validation set length: {len(valid_dataset)}")
309 | elif train_dataset is not None and valid_dataset is None:
310 | datasets = random_split(train_dataset, [int(0.9 * len(train_dataset)), len(train_dataset) - int(0.9 * len(train_dataset))])
311 | # datasets = random_split(train_dataset, [int(0.001 * len(train_dataset)), int(0.001 * len(train_dataset)), len(train_dataset) - 2 * int(0.001 * len(train_dataset))])
312 | train_dataset = datasets[0]
313 | valid_dataset = datasets[1]
314 | self.train_dataloader = DataLoader(
315 | train_dataset,
316 | batch_size=config.batch_size,
317 | shuffle=True,
318 | num_workers=config.num_workers,
319 | collate_fn=collate_fn,
320 | drop_last=True,
321 | )
322 | self.logger.info(f"Training set length: {len(train_dataset)}")
323 | self.val_dataloader = DataLoader(
324 | valid_dataset,
325 | batch_size=config.batch_size,
326 | shuffle=False,
327 | num_workers=config.num_workers,
328 | collate_fn=collate_fn,
329 | drop_last=True,
330 | )
331 | self.logger.info(f"Validation set length: {len(valid_dataset)}")
332 | if test_dataset is not None:
333 | self.test_dataloader = DataLoader(
334 | test_dataset,
335 | batch_size=1,
336 | shuffle=True,
337 | num_workers=config.num_workers,
338 | collate_fn=collate_fn,
339 | drop_last=True,
340 | )
341 | self.logger.info(f"Test set length: {len(test_dataset)}")
342 | self.logger.info("Data init done.")
343 |
344 | # prepare model, preprocessor, loss, optimizer and scheduler
345 | self.model = model
346 |
347 | no_decay = ["bias", "LayerNorm.weight"]
348 | optimizer_grouped_parameters = [
349 | {
350 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
351 | "weight_decay": self.config.weight_decay,
352 | },
353 | {
354 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
355 | "weight_decay": 0.0,
356 | },
357 | ]
358 | self.optimizer = AdamW(
359 | optimizer_grouped_parameters,
360 | lr=self.config.learning_rate,
361 | )
362 |
363 | if config.resume_path is None:
364 | self.logger.warning("No checkpoint given!")
365 | self.best_val_loss = 10000
366 | # self.best_accu = 0
367 | self.last_epoch = -1
368 | num_training_steps = self.config.num_epochs * len(self.train_dataloader)
369 | self.scheduler = get_linear_schedule_with_warmup(
370 | optimizer=self.optimizer,
371 | num_warmup_steps=self.config.num_warmup_steps,
372 | num_training_steps=num_training_steps,
373 | last_epoch=self.last_epoch
374 | )
375 | else:
376 | self.logger.info(f"Loading model from checkpoint: {self.config.resume_path}.")
377 | checkpoint = torch.load(self.config.resume_path, map_location=self.config.device)
378 | self.model.to(self.config.device)
379 | self.model.load_state_dict(checkpoint["model"], strict=False)
380 | # self.best_val_loss = checkpoint["best_val_loss"]
381 | self.best_val_loss = 10000
382 | # self.best_accu = checkpoint["best_accu"]
383 | self.last_epoch = checkpoint["last_epoch"]
384 | num_training_steps = self.config.num_epochs * len(self.train_dataloader)
385 | # self.optimizer.load_state_dict(checkpoint["optimizier"])
386 | self.scheduler = get_linear_schedule_with_warmup(
387 | optimizer=self.optimizer,
388 | num_warmup_steps=self.config.num_warmup_steps,
389 | num_training_steps=num_training_steps,
390 | last_epoch=-1
391 | )
392 | # self.scheduler.load_state_dict(checkpoint["scheduler"])
393 | del(checkpoint)
394 | self.model.to(self.config.device)
395 |
396 | self.logger.info("Trainer init done.")
397 |
398 | def train_one_epoch(self):
399 | total_loss = 0
400 | num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / self.config.gradient_accumulation_steps)
401 | progress_bar = tqdm(range(num_update_steps_per_epoch))
402 | for step, batch in enumerate(self.train_dataloader):
403 | contexts, answers = batch
404 | answers_0 = [answer["0"] for answer in answers] # positive
405 | answers_1 = [answer["1"] for answer in answers] # negative 1, set margin to 0.3
406 | answers_2 = [answer["2"] for answer in answers] # negative 2, set margin to 0.6
407 | answers_3 = [answer["3"] for answer in answers] # negative 3, set margin to 0.9
408 |
409 | logits_0 = self.model(contexts, answers_0)
410 | logits_1 = self.model(contexts, answers_1)
411 | logits_2 = self.model(contexts, answers_2)
412 | logits_3 = self.model(contexts, answers_3)
413 | logits = torch.cat([logits_0, logits_1, logits_2, logits_3], dim=1)
414 |
415 | labels = torch.tensor([1] * len(answers_0)).float().to(self.config.device)
416 | labels = labels.unsqueeze(1)
417 | loss = nn.MarginRankingLoss(0.3)(logits_0, logits_1, labels) + nn.MarginRankingLoss(0.6)(logits_0, logits_2, labels) + nn.MarginRankingLoss(0.9)(logits_0, logits_3, labels)
418 |
419 | total_loss += loss.detach().cpu().item()
420 | loss.backward()
421 | if (step != 0 and step % self.config.gradient_accumulation_steps == 0) or step == len(self.train_dataloader) - 1:
422 | self.optimizer.step()
423 | self.scheduler.step()
424 | self.optimizer.zero_grad()
425 | progress_bar.set_postfix(
426 | {
427 | 'loss': total_loss / (step + 1),
428 | }
429 | )
430 | progress_bar.update(1)
431 | # if self.config.margin_increase_step > 0 and step != 0 and step % self.config.margin_increase_step == 0 and self.margin < self.config.max_margin:
432 | # self.margin += 0.1
433 | # self.loss = nn.MarginRankingLoss(self.margin)
434 | return total_loss / (step + 1)
435 |
436 | def train(self):
437 | for epoch in range(self.config.num_epochs):
438 | self.logger.info("========================")
439 | self.logger.info("Training...")
440 | epoch_loss = self.train_one_epoch()
441 | self.logger.info(f"Epoch {epoch} training loss: {epoch_loss}")
442 | self.logger.info("Validating...")
443 | val_loss = self.validate() # , val_accu
444 | self.logger.info(f"Epoch {epoch} validation loss: {val_loss}")
445 | # self.logger.info(f"Epoch {epoch} validation accuracy: {val_accu}")
446 | if val_loss < self.best_val_loss:
447 | # if val_accu > self.best_accu:
448 | self.best_val_loss = val_loss
449 | # self.best_accu = val_accu
450 | checkpoint = {
451 | "model": self.model.state_dict(),
452 | "optimizier": self.optimizer.state_dict(),
453 | "scheduler": self.scheduler.state_dict(),
454 | "last_epoch": self.last_epoch + epoch + 1,
455 | "best_val_loss": self.best_val_loss,
456 | # "best_accu": self.best_accu,
457 | "config": self.config,
458 | }
459 | save_path = os.path.join(self.config.save_dir, f"{self.config.experiment_name}.pt")
460 | self.logger.info(f"Saving best checkpoint to {save_path}")
461 | torch.save(checkpoint, save_path)
462 |
463 | @torch.no_grad()
464 | def validate(self):
465 | total_loss = 0
466 | # cnt_true, cnt = 0, 0
467 | for step, batch in tqdm(enumerate(self.val_dataloader)):
468 | contexts, answers = batch
469 | answers_0 = [answer["0"] for answer in answers] # positive
470 | answers_1 = [answer["1"] for answer in answers] # negative 1, set margin to 0.3
471 | answers_2 = [answer["2"] for answer in answers] # negative 2, set margin to 0.6
472 | answers_3 = [answer["3"] for answer in answers] # negative 3, set margin to 0.9
473 |
474 | logits_0 = self.model(contexts, answers_0)
475 | logits_1 = self.model(contexts, answers_1)
476 | logits_2 = self.model(contexts, answers_2)
477 | logits_3 = self.model(contexts, answers_3)
478 |
479 | labels = torch.tensor([1] * len(answers_0)).float().to(self.config.device)
480 | labels = labels.unsqueeze(1)
481 | loss = nn.MarginRankingLoss(0.3)(logits_0, logits_1, labels) + nn.MarginRankingLoss(0.6)(logits_0, logits_2, labels) + nn.MarginRankingLoss(0.9)(logits_0, logits_3, labels)
482 |
483 | total_loss += loss.detach().cpu().item()
484 |
485 | return total_loss / (step + 1) # , cnt_true / cnt
486 |
487 | @torch.no_grad()
488 | def evaluate(self, logits, false_logits):
489 | cnt_true, cnt = 0, 0
490 | for pred, truth in zip(logits, false_logits):
491 | if pred > truth:
492 | cnt_true += 1
493 | cnt += 1
494 | return cnt_true, cnt
495 |
496 | @torch.no_grad()
497 | def metric(self, tp, tn, fp, fn):
498 | try:
499 | accu = (tp + tn) / (tp + tn + fp + fn)
500 | except:
501 | accu = 0
502 | try:
503 | prec = tp / (tp + fp)
504 | except:
505 | prec = 0
506 | try:
507 | reca = tp / (tp + fn)
508 | except:
509 | reca = 0
510 | try:
511 | f1 = 2 * prec * reca / (prec + reca)
512 | except:
513 | f1 = 0
514 | print(f"Accuracy: {accu}")
515 | print(f"Preision: {prec}")
516 | print(f"Recall: {reca}")
517 | print(f"F1: {f1}")
518 | return accu, prec, reca, f1
--------------------------------------------------------------------------------