├── .gitignore
├── LICENSE
├── README.md
├── assets
└── ReCEvalOverview.png
├── evaluate_receval.py
├── perturb_EB.py
├── requirements.txt
├── run_flan.py
├── train_infogain_pvi.py
└── train_pvi.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Archiki Prasad
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ReCEval: Evaluating Reasoning Chains via Correctness and Informativeness
2 | * Authors: [Archiki Prasad](https://archiki.github.io), [Swarnadeep Saha](https://swarnahub.github.io/), [Xiang Zhou](https://owenzx.github.io/), and [Mohit Bansal](https://www.cs.unc.edu/~mbansal/) (UNC Chapel Hill)
3 | * [Paper](https://arxiv.org/abs/2304.10703)
4 | * **Note:** This is preliminary version of our code. The complete code to run all experiments in the paper will be added shortly.
5 |
6 |
7 |
8 | ## Dependencies
9 | This code is written using PyTorch and [HuggingFace's Transformer repo](https://github.com/huggingface/pytorch-transformers). Running ReCEval requires access to GPUs. The evaluation is quite light-weight, so one GPU should suffice. Please install [Entailment Bank](https://allenai.org/data/entailmentbank) and [GSM-8K](https://github.com/openai/grade-school-math) datasets separately. For using human judgements datasets for GSM-8K and running baselines please follow the setup procedure in [ROSCOE](https://github.com/facebookresearch/ParlAI/tree/main/projects/roscoe/) (preferably in a separate environment).
10 |
11 | ## Installation
12 | The simplest way to run our code is to start with a fresh environment.
13 | ```
14 | conda create -n ReCEval python=3.9
15 | source activate ReCEval
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ## Running Evaluation
20 | * `evaluate_receval.py` contains the implementation of metrics in ReCEval.
21 | * `train_*_pvi.py` scripts are used to train models for the PVI-based metrics.
22 | * `perturb_EB.py` applies perturbations to the reasoning trees in [Entailment Bank](https://allenai.org/data/entailmentbank).
23 | * `run_flan.py` is used to obtain chain of thought responses from the [GSM-8K](https://github.com/openai/grade-school-math) dataset.
24 | * To compute metrics and evaluate, simply run `python evaluate_receval.py` (with default Entailment Bank). Default model and data directories can directly be changed within the script. These variables include:
25 | * `inp_model_dir`: Model *g* for calculating PVI-based intra-step correctness
26 | * `inp_model_dir`: Model *g'* for calculating PVI-based intra-step correctness
27 | * `info_model_dir`: Model for calculating PVI-based information-gain
28 | * `source_path`: Path containing reasoning chains to be scored or meta-evaluated
29 | * **PVI Models:** Here is a link for trained [PVI models for entailment](https://drive.google.com/drive/folders/1qhWKqEAFAoIar3ydSUtMoh2YyYctutG6?usp=drive_link). For more training details and how we prepare the data refer to Appendix A of our paper and/or consider using off-the-shelf LLMs to compute ReCEval metrics.
30 |
31 | ## Reference
32 | Please cite our paper if you use our repository in your works:
33 | ```bibtex
34 |
35 | @article{Prasad2023ReCEval,
36 | title = {ReCEval: Evaluating Reasoning Chains via Correctness and Informativeness},
37 | author = {Archiki Prasad and Swarnadeep Saha and Xiang Zhou and Mohit Bansal},
38 | year = {2023},
39 | archivePrefix = {arXiv},
40 | primaryClass = {cs.CL},
41 | eprint = {2304.10703}
42 | }
43 | ```
44 |
--------------------------------------------------------------------------------
/assets/ReCEvalOverview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/archiki/ReCEval/bc023421f4f8c748111643c1772693cd33686032/assets/ReCEvalOverview.png
--------------------------------------------------------------------------------
/evaluate_receval.py:
--------------------------------------------------------------------------------
1 | from supar import Parser
2 | from nltk import word_tokenize, sent_tokenize
3 | from nltk.tokenize.treebank import TreebankWordDetokenizer
4 | import nltk
5 | import pdb
6 | import json
7 | import tqdm
8 | def tqdm_replacement(iterable_object,*args,**kwargs):
9 | return iterable_object
10 | tqdm_copy = tqdm.tqdm
11 | tqdm.tqdm = tqdm_replacement
12 | import os
13 | import torch
14 | import numpy as np
15 | import random
16 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM
17 | from scipy.stats import somersd
18 | from allennlp.predictors.predictor import Predictor
19 | import allennlp_models.tagging
20 | import re, string
21 | import logging
22 | import logging.config
23 | from datasets import load_dataset, load_metric, Dataset
24 | import torch.nn.functional as F
25 | from datasets.utils.logging import disable_progress_bar
26 | import pandas as pd
27 |
28 |
29 | logging.config.dictConfig({
30 | 'version': 1,
31 | 'disable_existing_loggers': True,
32 | })
33 | disable_progress_bar()
34 | random.seed(1)
35 |
36 |
37 | srl_predictor = predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz")
38 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
39 | ent_model_name = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"
40 | ent_tokenizer = AutoTokenizer.from_pretrained(ent_model_name)
41 | ent_model = AutoModelForSequenceClassification.from_pretrained(ent_model_name).to(device)
42 |
43 | # Intra-Step PVI arguments
44 | inp_model_dir = 'PVI/inp_models/'
45 | no_inp_model_dir = 'PVI/noinp_models/'
46 | # Infor-gain PVI arguments
47 | info_gain_model_dir = 'PVI/infogain_models/'
48 |
49 | max_input_length = 512
50 | max_target_length = 64
51 | padding = "max_length"
52 | model_name = "t5-large"
53 | label_pad_token_id = -100
54 | pad_token = ''
55 | prefix = 'Generate entailed sentence: '
56 |
57 | inp_tokenizer = AutoTokenizer.from_pretrained(inp_model_dir)
58 | inp_config = AutoConfig.from_pretrained(inp_model_dir)
59 | inp_model = AutoModelForSeq2SeqLM.from_pretrained(inp_model_dir, config=inp_config)
60 | inp_model.cuda().eval()
61 | no_inp_tokenizer = AutoTokenizer.from_pretrained(no_inp_model_dir)
62 | no_inp_config = AutoConfig.from_pretrained(no_inp_model_dir)
63 | no_inp_model = AutoModelForSeq2SeqLM.from_pretrained(no_inp_model_dir, config=no_inp_config)
64 | no_inp_model.cuda().eval()
65 |
66 | info_gain_tokenizer = AutoTokenizer.from_pretrained(info_gain_model_dir)
67 | info_gain_config = AutoConfig.from_pretrained(info_gain_model_dir)
68 | info_gain_model = AutoModelForSeq2SeqLM.from_pretrained(info_gain_model_dir, config=info_gain_config)
69 | info_gain_model.cuda().eval()
70 | info_gain_mname = 'gpt2'
71 |
72 |
73 | def init_gpt2():
74 | ll_tokenizer = AutoTokenizer.from_pretrained('gpt2')
75 | ll_model = AutoModelForCausalLM.from_pretrained('gpt2-xl')
76 | ll_model.eval().cuda()
77 | ll_tokenizer.padding_side = "left"
78 | ll_tokenizer.pad_token = ll_tokenizer.eos_token
79 | ll_model.config.pad_token_id = ll_model.config.eos_token_id
80 | return ll_model, ll_tokenizer
81 |
82 | def inti_t5():
83 | ll_tokenizer = AutoTokenizer.from_pretrained("t5-large")
84 | config = AutoConfig.from_pretrained("t5-large")
85 | ll_model = AutoModelForSeq2SeqLM.from_pretrained("t5-large", config=config)
86 | ll_model.eval().cuda()
87 | ll_tokenizer.pad_token = pad_token
88 | ll_model.config.pad_token_id = pad_token
89 | return ll_model, ll_tokenizer
90 |
91 | # For Info-Gain PVI
92 | if info_gain_mname == 'gpt2': ll_model, ll_tokenizer = init_gpt2()
93 | elif info_gain_mname == 't5-large': ll_model, ll_tokenizer = inti_t5()
94 |
95 | def obtain_entailment_scores(premise, hypothesis):
96 | input = ent_tokenizer(premise, hypothesis, truncation=True, return_tensors="pt").to(device)
97 | with torch.no_grad():
98 | output = ent_model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
99 | prediction = torch.softmax(output["logits"][0], -1).tolist()
100 | label_names = ["entailment", "neutral", "contradiction"]
101 | prediction = {name: float(pred) for pred, name in zip(prediction, label_names)}
102 | return prediction['entailment']
103 |
104 | def obtain_contradiction_scores(premise, hypothesis):
105 | input = ent_tokenizer(premise, hypothesis, truncation=True, return_tensors="pt").to(device)
106 | with torch.no_grad():
107 | output = ent_model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
108 | prediction = torch.softmax(output["logits"][0], -1).tolist()
109 | label_names = ["entailment", "neutral", "contradiction"]
110 | prediction = {name: float(pred) for pred, name in zip(prediction, label_names)}
111 | return prediction['contradiction']
112 |
113 | def obtain_unit_entailment_score(prem_units, conc_units):
114 | if len(prem_units):
115 | premise = ' and '.join(prem_units)
116 | hypothesis = ' and '.join(conc_units)
117 | score = obtain_entailment_scores(premise, hypothesis)
118 | else:
119 | score = 1
120 | return score
121 |
122 | def obtain_contradiction_score(prem_units, conc_units):
123 | pair_scores = []
124 | hypothesis = ' and '.join(conc_units)
125 | for premise in prem_units:
126 | pair_scores.append(obtain_contradiction_scores(premise, hypothesis))
127 | if len(pair_scores):
128 | score = 1 - max(pair_scores)
129 | else:
130 | score = 1
131 |
132 | return score
133 |
134 |
135 | def detokenize(tokens):
136 | return TreebankWordDetokenizer().detokenize(tokens)
137 |
138 | def verb_modifiers(desc):
139 | filtered_mods = []
140 | mods = re.findall(r"\[ARGM.*?\]", desc)
141 | if not len(mods): return filtered_mods
142 | for mod in mods:
143 | phrase = mod.split(': ')[1].rstrip(']')
144 | verb_match = ['VB' in k[1] for k in nltk.pos_tag(word_tokenize(phrase))]
145 | if sum(verb_match) and len(phrase.split()) > 2: filtered_mods.append(phrase) # put in a length criteria
146 | return filtered_mods
147 |
148 | def remove_modifiers(sent, modifiers):
149 | if not len(modifiers): return sent
150 | for mod in modifiers:
151 | sent = sent.replace(mod, "")
152 | sent = re.sub(' +', ' ', sent) # remove any double spaces
153 | sent = sent.strip(string.punctuation + ' ') # remove stray punctuations
154 | return sent
155 |
156 | def extract_frame(tags, words, desc):
157 | prev = 'O'
158 | start, end = None, None
159 | if len(set(tags)) == 1: return ''
160 | tags = [t if 'C-ARG' not in t else 'O' for t in tags] #check if the modifier is a verb phrase
161 | for w in range(len(words)):
162 | if 'B-' in tags[w] and start is None: start = w
163 | if tags[len(words) - w -1]!='O' and end is None: end = len(words) - w -1
164 |
165 | if end is None: end = start
166 | sent = detokenize(words[start: end + 1]).rstrip('.')
167 | return sent
168 |
169 |
170 | def get_phrases(sent):
171 | # Simple RCU extractor without conjunction check for premises
172 | phrases = []
173 | history = ''
174 | srl_out = predictor.predict(sent)
175 | words = srl_out['words']
176 | frames = [s['tags'] for s in srl_out['verbs']]
177 | descs = [s['description'] for s in srl_out['verbs']]
178 | mod_sent = detokenize(words).rstrip('.')
179 | for frame, desc in zip(frames, descs):
180 | phrase = extract_frame(frame, words, desc)
181 | if phrase == mod_sent: phrase = remove_modifiers(phrase, verb_modifiers(desc))
182 | phrases.append(phrase)
183 | phrases.sort(key=lambda s: len(s), reverse=True)
184 | filtered_phrases = []
185 | for p in phrases:
186 | if p not in history:
187 | history += ' ' + p
188 | filtered_phrases.append(p)
189 | if len(filtered_phrases):
190 | filtered_phrases.sort(key=lambda s: mod_sent.find(s))
191 | left = mod_sent
192 | mod_filt = False
193 | for fp in filtered_phrases: left = left.replace(fp, '#').strip(string.punctuation + ' ')
194 | for l in left.split('#'):
195 | l = l.strip(string.punctuation + ' ')
196 | if len(l.split()) >=4 and l not in " ".join(filtered_phrases):
197 | verb_match = ['VB' in k[1] for k in nltk.pos_tag(word_tokenize(l))]
198 | if sum(verb_match):
199 | filtered_phrases.append(l)
200 | mod_filt = True
201 | if mod_filt: filtered_phrases.sort(key=lambda s: mod_sent.find(s))
202 | return filtered_phrases
203 | else: return [sent.rstrip('.')]
204 |
205 | def get_sent_phrases(para):
206 | sentences = sent_tokenize(para)
207 | phrases = []
208 | for sent in sentences:
209 | phrases.extend(get_phrases(sent))
210 | return phrases
211 |
212 | def get_reasoning_chain_text(steps, sentences):
213 | # If using the reasoning trees directly
214 | step_texts = []
215 | covered_nodes = []
216 | for step in steps:
217 | parent_text = " and ".join([sentences[p] for p in step['parents'] if p not in covered_nodes])
218 | if len(parent_text): step_text = parent_text + ', so ' + sentences[step['child']] + "."
219 | else: step_text = 'so ' + sentences[step['child']] + '.'
220 | covered_nodes.extend(step['parents']); covered_nodes.append(step['child'])
221 | step_texts.append(step_text)
222 | return step_texts
223 |
224 |
225 | def preprocess_and_convert(premise_units, conc_units):
226 | data = {'inputs': [], 'labels': []}
227 | parent_text = " & ".join(premise_units) + ' ->'
228 | child_text = " " + conc_units[0] # assume just one conc unit
229 | data['inputs'].append(parent_text)
230 | data['labels'].append(child_text)
231 | return data
232 |
233 | def postprocess_test_data(examples):
234 | inputs = [prefix + text for text in examples['inputs']]
235 | model_inputs = inp_tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt")
236 |
237 | # Setup the tokenizer for targets
238 | with inp_tokenizer.as_target_tokenizer():
239 | targets = [pad_token + label for label in examples['labels']]
240 | labels = inp_tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt")
241 | model_inputs["decoder_input_ids"] = labels["input_ids"]
242 | model_inputs["decoder_attention_mask"] = labels["attention_mask"]
243 | return model_inputs
244 |
245 | def noinp_postprocess_test_data(examples):
246 | inputs = [prefix + 'None ->' for text in examples['inputs']]
247 | model_inputs = no_inp_tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt")
248 |
249 | # Setup the tokenizer for targets
250 | with no_inp_tokenizer.as_target_tokenizer():
251 | targets = [pad_token + label for label in examples['labels']]
252 | labels = no_inp_tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt")
253 | model_inputs["decoder_input_ids"] = labels["input_ids"]
254 | model_inputs["decoder_attention_mask"] = labels["attention_mask"]
255 | return model_inputs
256 |
257 | def obtain_log_prob(predict_dataset, model, tokenizer):
258 | logits = model(input_ids=torch.Tensor(predict_dataset['input_ids']).long().cuda(), attention_mask=torch.Tensor(predict_dataset['attention_mask']).long().cuda(),
259 | decoder_input_ids=torch.Tensor(predict_dataset['decoder_input_ids']).long().cuda(), decoder_attention_mask = torch.Tensor(predict_dataset['decoder_attention_mask']).long().cuda()).logits
260 | all_logprobs = torch.log(torch.softmax(logits, dim=-1))
261 | labels = tokenizer(predict_dataset['labels'], max_length=max_target_length).input_ids
262 | filter_sums = []
263 | for row, label in zip(all_logprobs, labels):
264 | label.pop()
265 | row = row[:len(label), :].detach().cpu().numpy()
266 | vocab_size = row.shape[-1]
267 | loc = F.one_hot(torch.tensor(label), num_classes=vocab_size).numpy().astype(bool)
268 | try: summed_logprob = np.sum(row, where = loc)
269 | except: import pdb; pdb.set_trace()
270 | filter_sums.append(summed_logprob/len(label))
271 | return np.array(filter_sums)
272 |
273 | def obtain_unit_pvi_score(premise_units, conc_units):
274 | dataset = Dataset.from_dict(preprocess_and_convert(premise_units, conc_units))
275 | inp_dataset = dataset.map(postprocess_test_data, batched=True, remove_columns=['inputs'])
276 | no_inp_dataset = dataset.map(noinp_postprocess_test_data, batched=True, remove_columns=['inputs'])
277 | inp_logprob = obtain_log_prob(inp_dataset, inp_model, inp_tokenizer)[0]
278 | no_inp_logprob = obtain_log_prob(no_inp_dataset, no_inp_model, no_inp_tokenizer)[0]
279 | return inp_logprob - no_inp_logprob
280 |
281 |
282 | def slice_select_logits(all_logprobs, label):
283 | filter_sums = []
284 | if info_gain_mname == 'gpt2': row = all_logprobs[-len(label):, :].detach().cpu().numpy()
285 | elif info_gain_mname == 't5-large': row = all_logprobs[:len(label), :].detach().cpu().numpy()
286 | vocab_size = row.shape[-1]
287 | loc = F.one_hot(torch.tensor(label), num_classes=vocab_size).numpy().astype(bool)
288 | try: summed_logprob = np.sum(row, where = loc)
289 | except: import pdb; pdb.set_trace()
290 | filter_sums.append(summed_logprob/len(label))
291 | return np.array(filter_sums)
292 |
293 | def obtain_info_gain_score(prev_steps, current_step, current_conc, target, info_model, info_tokenizer):
294 | if info_gain_mname == 't5-large':
295 | target = pad_token + target
296 | input = " ".join(prev_steps + [current_step]) + ' Therefore,' + target
297 | if len(prev_steps): ref_input = " ".join(prev_steps) + ' Therefore,' + target
298 | else: ref_input = 'Therefore,' + target
299 |
300 | inputs = info_tokenizer(input, return_tensors="pt")
301 | ref = info_tokenizer(ref_input, return_tensors="pt")
302 | labels = info_tokenizer(target, return_tensors="pt")
303 | inputs["decoder_input_ids"] = labels['input_ids']
304 | inputs['decoder_attention_mask'] = labels['attention_mask']
305 | ref["decoder_input_ids"] = labels['input_ids']
306 | ref['decoder_attention_mask'] = labels['attention_mask']
307 |
308 | for i in inputs: inputs[i] = inputs[i].cuda()
309 | for i in ref: ref[i] = ref[i].cuda()
310 | with torch.no_grad():
311 | inp_logits = info_model.forward(**inputs).logits.detach().cpu()
312 | ref_logits = info_model.forward(**ref).logits.detach().cpu()
313 | all_inp_logprobs = torch.log(torch.softmax(inp_logits, dim=-1))
314 | all_ref_logprobs = torch.log(torch.softmax(ref_logits, dim=-1))
315 | labels = labels.input_ids.detach().cpu().tolist()[0][1:]
316 | filtered_inp_logprobs = slice_select_logits(all_inp_logprobs[0,:,:], labels)[0]
317 | filtered_ref_logprobs = slice_select_logits(all_ref_logprobs[0,:,:], labels)[0]
318 |
319 | elif info_gain_mname == 'gpt2':
320 | target = " " + target
321 | input = " " + " ".join(prev_steps + [current_step]) + ' Therefore,' + target
322 | if len(prev_steps): ref_input = " " + " ".join(prev_steps) + ' Therefore,' + target
323 | else: ref_input = ' Therefore,' + target
324 | labels = info_tokenizer(target).input_ids
325 | input_ids = info_tokenizer(input, return_tensors="pt").input_ids.cuda()
326 | ref_input_ids = info_tokenizer(ref_input, return_tensors="pt").input_ids.cuda()
327 | with torch.no_grad():
328 | inp_logits = info_model.forward(input_ids=input_ids, return_dict=True).logits.detach().cpu()
329 | ref_logits = info_model.forward(input_ids=ref_input_ids, return_dict=True).logits.detach().cpu()
330 | all_inp_logprobs = torch.log(torch.softmax(inp_logits, dim=-1))
331 | all_ref_logprobs = torch.log(torch.softmax(ref_logits, dim=-1))
332 | filtered_inp_logprobs = slice_select_logits(all_inp_logprobs[0,:-1,:], labels)[0] #shift probability since at idx i produce distribution of tokens at i+1
333 | filtered_ref_logprobs = slice_select_logits(all_ref_logprobs[0,:-1,:], labels)[0]
334 | return (filtered_inp_logprobs - filtered_ref_logprobs)
335 |
336 |
337 | source_path = 'perturbed_trees'
338 | error_types = os.listdir(source_path)
339 | score_keys = ['entail', 'pvi', 'contradict', 'll-info', 'pvi-info']
340 | score_keys = ['ll-info']
341 | errors_correl = {k:{'somersd':{}, 'pearson':{}} for k in score_keys}
342 | K = 0 # Set how many past steps to look at.
343 |
344 | for error in error_types:
345 | print(error)
346 | epath = os.path.join(source_path, error)
347 | tree_entry = [json.loads(line) for line in open(epath, 'r')]
348 | local_ent_scores = []
349 | alt_local_ent_scores = []
350 | local_pvi_scores = []
351 | global_contradict_scores = []
352 | info_ll_scores, info_pvi_scores = [], [], []
353 | for t, entry in tqdm_copy(enumerate(tree_entry)):
354 | # Tree-based Evaluation from EB
355 | # Otherwise, directly iterate over reasoning problems and directly get sentences
356 | # Hypothesis for GSM-8K is concat question and answer
357 | input_context = entry['question']
358 | input_context_sentences = sent_tokenize(input_context)
359 | steps, sentences = entry['steps']['perturbed'], entry['sentences']['perturbed']
360 | reasoning_steps = get_reasoning_chain_text(steps, sentences)
361 | # reasoning_steps = sent_tokenize(entry['steps'])
362 | # Needed keys are: question (input_context), steps, hypothesis
363 | step_ent_scores = []
364 | alt_step_ent_scores = []
365 | step_pvi_scores = []
366 | step_contradict_scores = []
367 | step_redundancy_scores = []
368 | step_ll_scores = []
369 | step_pviinfo_scores = []
370 | running_conc = []
371 | for sid, step in enumerate(reasoning_steps):
372 | units = get_phrases(step)
373 | premise_units, conc_units = [], []
374 | premise_units.extend(units[:-1])
375 | conc_units.append(units[-1])
376 |
377 | # Entail Step Calculation
378 | if 'entail' in score_keys:
379 | alt_step_ent_scores.append(obtain_unit_entailment_score(premise_units + running_conc[-1*K:], conc_units))
380 | # Intra-Step PVI Calculation
381 | if 'pvi' in score_keys:
382 | step_pvi_scores.append(obtain_unit_pvi_score(premise_units + running_conc[-1*K:], conc_units))
383 | # Global Contradiction Check
384 | if 'contradict' in score_keys:
385 | step_contradict_scores.append(obtain_contradiction_score(input_context_sentences + running_conc, conc_units))
386 |
387 | # LL Informativeness Check
388 | if 'll-info' in score_keys:
389 | step_ll_scores.append(obtain_info_gain_score(reasoning_steps[:sid], step, conc_units, sentences['hypothesis'], ll_model, ll_tokenizer))
390 |
391 | # PVI Informativeness Check
392 | if 'pvi-info' in score_keys:
393 | step_pviinfo_scores.append(obtain_info_gain_score(reasoning_steps[:sid], step, conc_units, sentences['hypothesis'], info_gain_model, info_gain_tokenizer))
394 | running_conc.extend(conc_units)
395 |
396 | if 'entail' in score_keys: alt_local_ent_scores.append(min(alt_step_ent_scores))
397 | if 'pvi' in score_keys: local_pvi_scores.append(min(step_pvi_scores))
398 | if 'contradict' in score_keys: global_contradict_scores.append(min(step_contradict_scores))
399 | if 'll-info' in score_keys: info_ll_scores.append(min(step_ll_scores))
400 | if 'pvi-info' in score_keys: info_pvi_scores.append(min(step_pviinfo_scores))
401 | perturbed_ids = [1 - int(e['perturbed']) for e in tree_entry]
402 |
403 | if 'entail' in score_keys: print(somersd(perturbed_ids, alt_local_ent_scores).statistic)
404 | if 'pvi' in score_keys: print(somersd(perturbed_ids, local_pvi_scores).statistic)
405 | if 'contradict' in score_keys: print(somersd(perturbed_ids, global_contradict_scores).statistic)
406 | if 'll-info' in score_keys: print(somersd(perturbed_ids, info_ll_scores).statistic)
407 | if 'pvi-info' in score_keys: print(somersd(perturbed_ids, info_pvi_scores).statistic)
408 |
409 | if 'entail' in score_keys:
410 | errors_correl['entail']['pearson'][error] = np.corrcoef(alt_local_ent_scores, perturbed_ids)[0][1]
411 | errors_correl['entail']['somersd'][error] = somersd(perturbed_ids, alt_local_ent_scores).statistic
412 | if 'pvi' in score_keys:
413 | errors_correl['pvi']['pearson'][error] = np.corrcoef(local_pvi_scores, perturbed_ids)[0][1]
414 | errors_correl['pvi']['somersd'][error] = somersd(perturbed_ids, local_pvi_scores).statistic
415 | if 'contradict' in score_keys:
416 | errors_correl['contradict']['pearson'][error] = np.corrcoef(global_contradict_scores, perturbed_ids)[0][1]
417 | errors_correl['contradict']['somersd'][error] = somersd(perturbed_ids, global_contradict_scores).statistic
418 | if 'pvi-info' in score_keys:
419 | errors_correl['redundancy']['pearson'][error] = np.corrcoef(info_pvi_scores, perturbed_ids)[0][1]
420 | errors_correl['redundancy']['somersd'][error] = somersd(perturbed_ids, info_pvi_scores).statistic
421 | if 'll-info' in score_keys:
422 | errors_correl['ll-info']['pearson'][error] = np.corrcoef(info_ll_scores, perturbed_ids)[0][1]
423 | errors_correl['ll-info']['somersd'][error] = somersd(perturbed_ids, info_ll_scores).statistic
424 |
425 | f = open('ResultLogs/correlations.json', 'w+')
426 | json.dump(errors_correl, f, indent=4)
427 |
--------------------------------------------------------------------------------
/perturb_EB.py:
--------------------------------------------------------------------------------
1 | import graphviz
2 | from graphviz import Digraph
3 | import json
4 | import re
5 | import random
6 | from copy import deepcopy
7 | from checklist.perturb import Perturb
8 | import spacy
9 | nlp = spacy.load('en_core_web_sm')
10 | import os
11 | import jsonlines
12 | from tqdm import tqdm
13 | from transformers import PegasusForConditionalGeneration, PegasusTokenizer
14 |
15 | gold_path = 'entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/dataset/task_1/test.jsonl'
16 | more_path = 'entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/dataset/task_2/test.jsonl'
17 | preds_path = 'entailment_bank/data/processed_data/predictions/emnlp_2021/task1/T5_11B/test.16K_steps.predictions.tsv'
18 |
19 | gold_examples = [json.loads(line) for line in open(gold_path, 'r')]
20 | more_examples = [json.loads(line) for line in open(more_path, 'r')]
21 | pred_examples = [line.split('=')[-1].strip() for line in open(preds_path, 'r')]
22 |
23 | gold_examples = [g for g in gold_examples if g['proof'].count(';') > 1]
24 | more_examples = [m for m in more_examples if m['proof'].count(';') > 1]
25 |
26 | para_model_name = 'tuner007/pegasus_paraphrase'
27 | torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
28 | para_tokenizer = PegasusTokenizer.from_pretrained(para_model_name)
29 | para_model = PegasusForConditionalGeneration.from_pretrained(para_model_name).to(torch_device).eval()
30 | # Can replace with alternate paraphrase models
31 |
32 | def get_response(input_text,num_return_sequences,num_beams):
33 | batch = para_tokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
34 | translated = para_model.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
35 | tgt_text = para_tokenizer.batch_decode(translated, skip_special_tokens=True)
36 | return tgt_text
37 |
38 | def obtain_paraphrase(phrase):
39 | num_beams = 5
40 | num_return_sequences = 5
41 | paraphrases = get_response(phrase, num_return_sequences, num_beams)
42 | paraphrase = np.random.choice(paraphrases, 1)[0]
43 | paraphrase = paraphrase.strip('.')
44 | return paraphrase
45 |
46 | def process(proof, alt_triples, hypothesis):
47 | proof = proof.rstrip('; ')
48 | steps = proof.split(';')
49 | step_list = []
50 | triples = deepcopy(alt_triples)
51 | for step in steps:
52 | [parents, leaf] = step.split('->')
53 | parents = parents.split('&')
54 | parents = [p.strip(' ') for p in parents]
55 | leaf = leaf.strip()
56 | leafId = leaf.split(':')[0]
57 | if 'int' in leafId:
58 | leaf_sent = leaf.split(':')[-1].strip()
59 | triples[leafId] = leaf_sent
60 | elif leafId == 'hypothesis':
61 | leaf_sent = hypothesis
62 | triples[leafId] = leaf_sent
63 | step_list.append({'parents':parents, 'child': leafId})
64 | proof = proof + '; '
65 | return step_list, triples
66 |
67 | def reconstruct_proof(steps, sentences):
68 | proof = ''
69 | for step in steps:
70 | proof += " & ".join(step['parents']) + ' -> ' + step['child']
71 | if 'int' in step['child']:
72 | proof += ': ' + sentences[step['child']]
73 | proof += '; '
74 | return proof
75 |
76 | def repeat_steps(in_steps, in_sentences):
77 | steps = deepcopy(in_steps)
78 | sentences = deepcopy(in_sentences)
79 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
80 | idx = random.choice(int_idxs)
81 | assert idx < len(steps) - 1
82 | repeated_node = steps[idx]['child']
83 | print(idx, repeated_node)
84 | key = 'int' + str(len(int_idxs) + 1)
85 | sentences[key] = sentences[repeated_node]
86 | steps[idx]['child'] = key
87 | steps.insert(idx + 1, {'parents': [key], 'child': repeated_node})
88 | return steps, sentences
89 |
90 | def delete_steps(in_steps, in_sentences):
91 | steps = deepcopy(in_steps)
92 | sentences = deepcopy(in_sentences)
93 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
94 | idx = random.choice(int_idxs)
95 | assert idx < len(steps) - 1
96 | del_node = steps[idx]['child']
97 | del_parents = steps[idx]['parents']
98 | del steps[idx]
99 | for step in steps:
100 | if del_node in step['parents']:
101 | step['parents'].extend(del_parents)
102 | step['parents'] = [p for p in step['parents'] if p != del_node]
103 | return steps, sentences
104 |
105 | def swapped_steps(in_steps, in_sentences):
106 | steps = deepcopy(in_steps)
107 | sentences = deepcopy(in_sentences)
108 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
109 | idx = random.choice(int_idxs)
110 | assert idx < len(steps) - 1
111 | swap_node = deepcopy(steps[idx]['child'])
112 | swap_parent = random.choice(steps[idx]['parents'])
113 | print(swap_node, swap_parent)
114 | alt_parents = [p for p in steps[idx]['parents'] if p!= swap_parent]
115 | alt_parents.append(swap_node)
116 | steps[idx]['parents'] = alt_parents
117 | steps[idx]['child'] = swap_parent
118 | for step in steps[idx + 1:]:
119 | if swap_node in step['parents']:
120 | step['parents'].append(swap_parent)
121 | step['parents'] = [p for p in step['parents'] if p != swap_node]
122 | return steps, sentences
123 |
124 | def negate_step(in_steps, in_sentences):
125 | steps = deepcopy(in_steps)
126 | sentences = deepcopy(in_sentences)
127 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
128 | idx = random.choice(int_idxs)
129 | assert idx < len(steps) - 1
130 | negated_node = steps[idx]['child']
131 | print(negated_node, sentences[negated_node])
132 | sentences[negated_node] = Perturb.add_negation(nlp(sentences[negated_node]))
133 | return steps, sentences
134 |
135 | def hallucinate_step(in_steps, in_sentences, extra_sentences):
136 | steps = deepcopy(in_steps)
137 | sentences = deepcopy(in_sentences)
138 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
139 | idx = random.choice(int_idxs)
140 | assert idx < len(steps) - 1
141 | hallucinate_node = steps[idx]['child']
142 | print(hallucinate_node, sentences[hallucinate_node])
143 | potential_sents = [v for (k,v) in extra_sentences.items() if v not in in_sentences.values()]
144 | sentences[hallucinate_node] = random.choice(potential_sents)
145 | print(sentences[hallucinate_node])
146 | return steps, sentences
147 |
148 | def paraphrase_steps(in_steps, in_sentences):
149 | steps = deepcopy(in_steps)
150 | sentences = deepcopy(in_sentences)
151 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
152 | idx = random.choice(int_idxs)
153 | assert idx < len(steps) - 1
154 | repeated_node = steps[idx]['child']
155 | print(idx, repeated_node)
156 | key = 'int' + str(len(int_idxs) + 1)
157 | sentences[key] = obtain_paraphrase(sentences[repeated_node])
158 | steps[idx]['child'] = key
159 | steps.insert(idx + 1, {'parents': [key], 'child': repeated_node})
160 | return steps, sentences
161 |
162 | def redundant_steps(in_steps, in_sentences, extra_sentences):
163 | steps = deepcopy(in_steps)
164 | sentences = deepcopy(in_sentences)
165 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
166 | idx = random.choice(int_idxs)
167 | assert idx < len(steps) - 1
168 | repeated_node = steps[idx]['child']
169 | print(idx, repeated_node)
170 | key = 'int' + str(len(int_idxs) + 1)
171 | potential_sents = [v for (k,v) in extra_sentences.items() if v not in in_sentences.values()]
172 | sentences[key] = random.choice(potential_sents)
173 | steps[idx]['child'] = key
174 | steps.insert(idx + 1, {'parents': [key], 'child': repeated_node})
175 | return steps, sentences
176 |
177 |
178 |
179 | tree_dest_path = 'perturbed_trees'
180 | if not os.path.exists(tree_dest_path): os.makedirs(tree_dest_path)
181 | random.seed(0)
182 |
183 | unperturbed_path = 'ParlAI/projects/roscoe/roscoe_data/unperturbed_ids.json'
184 | custom_unperturbed_ids = {'entailment_bank_synthetic':{}}
185 |
186 | perturbation_functions = {'DuplicateOneStep': repeat_steps, 'ParaphraseOneStep': paraphrase_steps, 'RedundantOneStep': redundant_steps, 'RemoveOneStep': delete_steps, 'SwapOneStep': swapped_steps, 'NegateStep': negate_step, 'ExtrinsicHallucinatedStep': hallucinate_step}
187 |
188 | for perturb_type, perturb_func in tqdm(perturbation_functions.items()):
189 | type_unperturbed_ids = random.choices(range(len(gold_examples)), k=126) # unperturbed_ids[perturb_type + "_test.jsonl"]
190 | custom_unperturbed_ids[perturb_type + "_test.jsonl"] = type_unperturbed_ids
191 | fname = '50%_' + perturb_type + '_test.jsonl'
192 | revised_entry = []
193 | tree_entry = []
194 | for id, gold_example, more_example in zip(range(len(gold_examples)), gold_examples, more_examples):
195 | steps, sentences = process(gold_example['proof'], gold_example['meta']['triples'], gold_example['hypothesis'])
196 | question = ". ".join(list(gold_example['meta']['triples'].values()) + [gold_example['question']])
197 | answer = gold_example['answer']
198 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
199 | if not len(int_idxs): continue
200 |
201 | if id in type_unperturbed_ids:
202 | perturbed_steps = steps
203 | perturbed_sentences = sentences
204 | perturbed = False
205 | else:
206 | more_sentences = more_example['meta']['triples']
207 | perturbed_steps, perturbed_sentences = perturb_func(steps, sentences, more_sentences)
208 | if steps == perturbed_steps and sentences == perturbed_sentences:
209 | perturbed = False
210 | custom_unperturbed_ids[perturb_type + "_test.jsonl"].append(id)
211 | else: perturbed = True
212 | tree_entry.append({'perturbed': perturbed, 'perturbations': perturb_type, 'steps':{'original': steps, 'perturbed': perturbed_steps}, 'sentences':{'original': sentences, 'perturbed': perturbed_sentences}, 'written':{'original': original_written_steps, 'perturbed': written_steps}, 'question': question, 'answer': answer})
213 | with jsonlines.open(os.path.join(tree_dest_path, fname), 'w') as writer:
214 | writer.write_all(tree_entry)
215 |
216 |
217 |
218 |
219 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.18.0
2 | allennlp==2.10.1
3 | allennlp-models==2.10.1
4 | -e git+https://github.com/marcotcr/checklist.git@3edd07c9a84e6c6657333450d4d0e70ecb0c00d9#egg=checklist
5 | datasets==2.7.1
6 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
7 | evaluate==0.4.0
8 | fairscale==0.4.6
9 | graphviz @ file:///opt/conda/conda-bld/python-graphviz_1660063183935/work
10 | huggingface-hub==0.10.1
11 | jsonlines==3.1.0
12 | matplotlib==3.6.2
13 | nltk==3.8
14 | numpy==1.20.3
15 | pandas==1.5.2
16 | protobuf==3.19.6
17 | py-rouge==1.1
18 | rouge-score==0.1.2
19 | scikit-learn==1.2.0
20 | scipy==1.9.3
21 | sentence-transformers==2.2.2
22 | sentencepiece==0.1.97
23 | spacy==3.4.0
24 | textdistance==4.5.0
25 | tokenizers==0.12.1
26 | torch==1.10.0+cu111
27 | torchaudio==0.10.0+rocm4.1
28 | torchvision==0.11.0+cu111
29 | tqdm==4.64.1
30 | transformers==4.20.1
--------------------------------------------------------------------------------
/run_flan.py:
--------------------------------------------------------------------------------
1 | import transformers
2 | import datasets
3 | import json
4 | import numpy as np
5 | import pdb
6 | import torch
7 | from torch.utils.data import DataLoader, Dataset
8 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9 | import pandas as pd
10 | from datasets import load_dataset
11 | from tqdm import tqdm
12 | from nltk import sent_tokenize
13 | import re
14 | from string import punctuation
15 |
16 | zero_shot_instruction = "Answer the following question by reasoning step-by-step.\n" #Write the answer as a separate sentence starting with 'The answer is'.\n"
17 | few_shot_prefix = """Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
18 | A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.\n
19 | Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
20 | A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.\n
21 | Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
22 | A: Leah had 32 chocolates and Leah’s sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.\n
23 | Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
24 | A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.\n
25 | """
26 | # Examples that won't fit in context from CoT paper
27 | # """
28 | # Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does
29 | # he have now?
30 | # A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.\n
31 | # Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
32 | # A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.\n
33 | # Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
34 | # A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.\n
35 | # Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
36 | # A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.\n"""
37 |
38 | def apply_zs_prompt(example, instruction=zero_shot_instruction):
39 | return instruction + '\nQ: ' + example['question'] + '\n' + 'A:'
40 |
41 | def apply_fs_prompt(example, instruction=zero_shot_instruction):
42 | return few_shot_prefix + '\n' + instruction + '\nQ: ' + example['question'] + '\n' + 'A:'
43 |
44 | def generate_math_prompt(prefix, question, instruction):
45 | prompt = ''
46 | prompt += prefix
47 | if len(instruction):
48 | prompt += instruction + '\n'
49 | prompt += question
50 | return prompt
51 |
52 | def extract_answer(answer_text, answer_prefix='answer is ', no_prefix=False, overfit=False):
53 | sentences = sent_tokenize(answer_text)
54 | ans_candidate = sentences[-1]
55 | found = True
56 | if not no_prefix and answer_prefix in ans_candidate:
57 | answer = ans_candidate.partition(answer_prefix)[2].strip(punctuation)
58 | try:
59 | return float(answer)
60 | except: found = False
61 | else: found = False
62 | if no_prefix:
63 | answers = re.findall(r'\d+', ans_candidate)
64 | if len(answers):
65 | return float(answers[-1])
66 | else: found = False
67 | if not found:
68 | if not overfit: return None
69 | else:
70 | answer = ans_candidate.partition('=')[2].strip(punctuation)
71 | try:
72 | return float(answer)
73 | except:
74 | return None
75 |
76 | def extract_gold_answer(ans):
77 | return float(ans.partition('###')[2].replace(',', '').strip(punctuation))
78 |
79 |
80 | dataset = load_dataset("gsm8k", 'main', split='test')
81 |
82 | dataset = dataset.map(lambda example: {'gold': extract_gold_answer(example['answer'])})
83 | dataset = dataset.map(lambda example: {'input_prompt': apply_zs_prompt(example)})
84 | data_loader = DataLoader(dataset, batch_size=4, shuffle=False)
85 |
86 | num_gen_tokens = 128
87 |
88 | tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
89 | tokenizer.padding_side = "left"
90 | tokenizer.pad_token = tokenizer.eos_token
91 | if torch.cuda.is_available():
92 | model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xxl", device_map="auto")
93 | else:
94 | model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xxl")
95 |
96 |
97 | count = 0
98 | correct = 0
99 |
100 | for batch in tqdm(data_loader):
101 | input_text = batch['input_prompt']
102 | if torch.cuda.is_available():
103 | inputs = tokenizer(input_text, return_tensors="pt", padding=True).to("cuda")
104 | else:
105 | inputs = tokenizer(input_text, return_tensors="pt", padding=True)
106 | outputs = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=num_gen_tokens)
107 | op_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
108 | op_answers = [extract_answer(op, no_prefix=True, overfit=True) for op in op_texts]
109 | correct += sum(np.array(op_answers) == np.array(batch['gold']))
110 | count += len(batch['gold'])
111 | print('0-shot')
112 | print(round(100*correct/count, 2))
113 |
114 |
--------------------------------------------------------------------------------
/train_infogain_pvi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import transformers
3 | from datasets import load_dataset, load_metric
4 | import json
5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
6 | import numpy as np
7 | from transformers.trainer_utils import get_last_checkpoint
8 | import torch
9 | import shutil
10 |
11 |
12 | max_input_length = 512
13 | max_target_length = 64
14 | padding = "max_length"
15 | model_name = "t5-large"
16 | label_pad_token_id = -100
17 | pad_token = ''
18 | tokenizer = AutoTokenizer.from_pretrained(model_name)
19 | config = AutoConfig.from_pretrained(model_name)
20 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config)
21 | batch_size = 16
22 | output_dir = 'PVI/infogain_models/norepNLnomarkov'
23 | do_train = True
24 | do_eval = True
25 | do_predict = True
26 | global no_input
27 | no_input = False
28 | overwrite_output_dir = True
29 |
30 | def postprocess_test_data(examples):
31 | if not no_input:
32 | inputs = [prefix + text for text in examples['inputs']]
33 | else: inputs = [prefix + text for text in examples['prev_inputs']]
34 | model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt")
35 |
36 | # Setup the tokenizer for targets
37 | with tokenizer.as_target_tokenizer():
38 | targets = [pad_token + label for label in examples['labels']]
39 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt")
40 | model_inputs["decoder_input_ids"] = labels["input_ids"]
41 | model_inputs["decoder_attention_mask"] = labels["attention_mask"]
42 | return model_inputs
43 |
44 | def preprocess_data(examples):
45 | if not no_input:
46 | inputs = [prefix + text for text in examples['inputs']]
47 | else: inputs = [prefix + text for text in examples['prev_inputs']]
48 | model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True)
49 |
50 | # Setup the tokenizer for targets
51 | with tokenizer.as_target_tokenizer():
52 | labels = tokenizer(examples["labels"], max_length=max_target_length, padding=padding, truncation=True)
53 |
54 | labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]]
55 | model_inputs["labels"] = labels["input_ids"]
56 | return model_inputs
57 |
58 | def compute_metrics(eval_pred):
59 | predictions, labels = eval_pred
60 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
61 |
62 | # Replace -100 in the labels as we can't decode them.
63 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
64 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
65 | # Compute ROUGE scores
66 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
67 |
68 | # Extract ROUGE f1 scores
69 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
70 |
71 | return {k: round(v, 2) for k, v in result.items()}
72 |
73 |
74 | prefix = 'Generate final answer: '
75 | dataset = load_dataset('json', data_files = {'train': 'PVI/IG_norepNLnomarkov_tree/train.json', 'dev': 'PVI/IG_norepNLnomarkov_tree/dev.json', 'test': 'PVI/IG_norepNLnomarkov_tree/test.json'}, field="data")
76 | tokenized_dataset = dataset.map(preprocess_data, batched=True)
77 | # predict_dataset = dataset['test'].map(postprocess_test_data, batched=True)
78 |
79 | args = Seq2SeqTrainingArguments(
80 | output_dir = output_dir,
81 | evaluation_strategy="epoch",
82 | logging_strategy="epoch",
83 | save_strategy="epoch",
84 | learning_rate=3e-5,
85 | per_device_train_batch_size=batch_size,
86 | per_device_eval_batch_size=batch_size,
87 | weight_decay=0.01,
88 | save_total_limit=1,
89 | num_train_epochs=10,
90 | predict_with_generate=True,
91 | fp16=True,
92 | load_best_model_at_end=True,
93 | metric_for_best_model="eval_rougeL",
94 | overwrite_output_dir=overwrite_output_dir,
95 | )
96 |
97 | data_collator = DataCollatorForSeq2Seq(tokenizer, model = model, label_pad_token_id=label_pad_token_id)
98 | metric = load_metric("rouge")
99 |
100 | trainer = Seq2SeqTrainer(
101 | model=model,
102 | args=args,
103 | train_dataset=tokenized_dataset["train"],
104 | eval_dataset=tokenized_dataset["dev"],
105 | data_collator=data_collator,
106 | tokenizer=tokenizer,
107 | compute_metrics=compute_metrics,
108 | callbacks = [EarlyStoppingCallback(early_stopping_patience=3)],
109 | )
110 |
111 | if do_train:
112 | checkpoint = None
113 | last_checkpoint = None
114 | if os.path.isdir(output_dir):
115 | last_checkpoint = get_last_checkpoint(output_dir)
116 |
117 | if last_checkpoint is not None:
118 | checkpoint = last_checkpoint
119 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
120 | trainer.save_model() # Saves the tokenizer too for easy upload
121 |
122 | metrics = train_result.metrics
123 | trainer.log_metrics("train", metrics)
124 | trainer.save_metrics("train", metrics)
125 | trainer.save_state()
126 |
127 | # Evaluation
128 | results = {}
129 | if do_eval:
130 | metrics = trainer.evaluate(max_length=max_target_length, num_beams=8, metric_key_prefix="eval")
131 | trainer.log_metrics("eval", metrics)
132 | trainer.save_metrics("eval", metrics)
133 |
134 | # Prediction
135 | if do_predict:
136 | results = trainer.predict(tokenized_dataset['test'], dataset['test'])
137 | metrics = results.metrics
138 | trainer.log_metrics("predict", metrics)
139 | trainer.save_metrics("predict", metrics)
140 | last_checkpoint = get_last_checkpoint(output_dir)
141 | shutil.rmtree(last_checkpoint)
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------
/train_pvi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import transformers
3 | from datasets import load_dataset, load_metric
4 | import json
5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
6 | import numpy as np
7 | from transformers.trainer_utils import get_last_checkpoint
8 | import torch
9 |
10 | max_input_length = 512
11 | max_target_length = 64
12 | padding = "max_length"
13 | model_name = "t5-large"
14 | label_pad_token_id = -100
15 | pad_token = ''
16 | tokenizer = AutoTokenizer.from_pretrained(model_name)
17 | config = AutoConfig.from_pretrained(model_name)
18 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config)
19 | batch_size = 16
20 | output_dir = 'PVI/noinp_models/'
21 | do_train = True
22 | do_eval = True
23 | do_predict = True
24 | global no_input
25 | no_input = True
26 | overwrite_output_dir = True
27 |
28 | def postprocess_test_data(examples):
29 | if not no_input:
30 | inputs = [prefix + text for text in examples['inputs']]
31 | else: inputs = [prefix + 'None ->' for text in examples['inputs']]
32 | model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt")
33 |
34 | # Setup the tokenizer for targets
35 | with tokenizer.as_target_tokenizer():
36 | targets = [pad_token + label for label in examples['labels']]
37 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt")
38 | model_inputs["decoder_input_ids"] = labels["input_ids"]
39 | model_inputs["decoder_attention_mask"] = labels["attention_mask"]
40 | return model_inputs
41 |
42 | def preprocess_data(examples):
43 | if not no_input:
44 | inputs = [prefix + text for text in examples['inputs']]
45 | else: inputs = [prefix + 'None ->' for text in examples['inputs']]
46 | model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True)
47 |
48 | # Setup the tokenizer for targets
49 | with tokenizer.as_target_tokenizer():
50 | labels = tokenizer(examples["labels"], max_length=max_target_length, padding=padding, truncation=True)
51 |
52 | labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]]
53 | model_inputs["labels"] = labels["input_ids"]
54 | return model_inputs
55 |
56 | def compute_metrics(eval_pred):
57 | predictions, labels = eval_pred
58 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
59 |
60 | # Replace -100 in the labels as we can't decode them.
61 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
62 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
63 |
64 | # Compute ROUGE scores
65 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
66 |
67 | # Extract ROUGE f1 scores
68 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
69 | return {k: round(v, 2) for k, v in result.items()}
70 |
71 |
72 | prefix = 'Generate entailed sentence: '
73 | dataset = load_dataset('json', data_files = {'train': 'PVI/train.json', 'dev': 'PVI/dev.json', 'test': 'PVI/test.json'}, field="data")
74 | tokenized_dataset = dataset.map(preprocess_data, batched=True)
75 | predict_dataset = dataset['test'].map(postprocess_test_data, batched=True)
76 |
77 | args = Seq2SeqTrainingArguments(
78 | output_dir = output_dir,
79 | evaluation_strategy="epoch",
80 | logging_strategy="epoch",
81 | save_strategy="epoch",
82 | learning_rate=3e-5,
83 | per_device_train_batch_size=batch_size,
84 | per_device_eval_batch_size=batch_size,
85 | weight_decay=0.01,
86 | save_total_limit=1,
87 | num_train_epochs=10,
88 | predict_with_generate=True,
89 | fp16=True,
90 | load_best_model_at_end=True,
91 | metric_for_best_model="eval_rouge1",
92 | overwrite_output_dir=overwrite_output_dir,
93 | )
94 |
95 | data_collator = DataCollatorForSeq2Seq(tokenizer, model = model, label_pad_token_id=label_pad_token_id)
96 | metric = load_metric("rouge")
97 |
98 | trainer = Seq2SeqTrainer(
99 | model=model,
100 | args=args,
101 | train_dataset=tokenized_dataset["train"],
102 | eval_dataset=tokenized_dataset["dev"],
103 | data_collator=data_collator,
104 | tokenizer=tokenizer,
105 | compute_metrics=compute_metrics,
106 | callbacks = [EarlyStoppingCallback(early_stopping_patience=3)],
107 | )
108 |
109 | if do_train:
110 | checkpoint = None
111 | last_checkpoint = None
112 | if os.path.isdir(output_dir):
113 | last_checkpoint = get_last_checkpoint(output_dir)
114 |
115 | if last_checkpoint is not None:
116 | checkpoint = last_checkpoint
117 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
118 | trainer.save_model() # Saves the tokenizer too for easy upload
119 |
120 | metrics = train_result.metrics
121 | trainer.log_metrics("train", metrics)
122 | trainer.save_metrics("train", metrics)
123 | trainer.save_state()
124 |
125 | # Evaluation
126 | results = {}
127 | if do_eval:
128 |
129 | metrics = trainer.evaluate(max_length=max_target_length, num_beams=8, metric_key_prefix="eval")
130 | trainer.log_metrics("eval", metrics)
131 | trainer.save_metrics("eval", metrics)
132 |
133 | # Prediction
134 | if do_predict:
135 |
136 | results = trainer.predict(tokenized_dataset['test'], dataset['test'])
137 | metrics = results.metrics
138 | trainer.log_metrics("predict", metrics)
139 | trainer.save_metrics("predict", metrics)
140 |
141 |
142 |
--------------------------------------------------------------------------------