├── .gitignore ├── LICENSE ├── README.md ├── assets └── ReCEvalOverview.png ├── evaluate_receval.py ├── perturb_EB.py ├── requirements.txt ├── run_flan.py ├── train_infogain_pvi.py └── train_pvi.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Archiki Prasad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReCEval: Evaluating Reasoning Chains via Correctness and Informativeness 2 | * Authors: [Archiki Prasad](https://archiki.github.io), [Swarnadeep Saha](https://swarnahub.github.io/), [Xiang Zhou](https://owenzx.github.io/), and [Mohit Bansal](https://www.cs.unc.edu/~mbansal/) (UNC Chapel Hill) 3 | * [Paper](https://arxiv.org/abs/2304.10703) 4 | * **Note:** This is preliminary version of our code. The complete code to run all experiments in the paper will be added shortly. 5 | 6 | teaser image 7 | 8 | ## Dependencies 9 | This code is written using PyTorch and [HuggingFace's Transformer repo](https://github.com/huggingface/pytorch-transformers). Running ReCEval requires access to GPUs. The evaluation is quite light-weight, so one GPU should suffice. Please install [Entailment Bank](https://allenai.org/data/entailmentbank) and [GSM-8K](https://github.com/openai/grade-school-math) datasets separately. For using human judgements datasets for GSM-8K and running baselines please follow the setup procedure in [ROSCOE](https://github.com/facebookresearch/ParlAI/tree/main/projects/roscoe/) (preferably in a separate environment). 10 | 11 | ## Installation 12 | The simplest way to run our code is to start with a fresh environment. 13 | ``` 14 | conda create -n ReCEval python=3.9 15 | source activate ReCEval 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Running Evaluation 20 | * `evaluate_receval.py` contains the implementation of metrics in ReCEval. 21 | * `train_*_pvi.py` scripts are used to train models for the PVI-based metrics. 22 | * `perturb_EB.py` applies perturbations to the reasoning trees in [Entailment Bank](https://allenai.org/data/entailmentbank). 23 | * `run_flan.py` is used to obtain chain of thought responses from the [GSM-8K](https://github.com/openai/grade-school-math) dataset. 24 | * To compute metrics and evaluate, simply run `python evaluate_receval.py` (with default Entailment Bank). Default model and data directories can directly be changed within the script. These variables include: 25 | * `inp_model_dir`: Model *g* for calculating PVI-based intra-step correctness 26 | * `inp_model_dir`: Model *g'* for calculating PVI-based intra-step correctness 27 | * `info_model_dir`: Model for calculating PVI-based information-gain 28 | * `source_path`: Path containing reasoning chains to be scored or meta-evaluated 29 | * **PVI Models:** Here is a link for trained [PVI models for entailment](https://drive.google.com/drive/folders/1qhWKqEAFAoIar3ydSUtMoh2YyYctutG6?usp=drive_link). For more training details and how we prepare the data refer to Appendix A of our paper and/or consider using off-the-shelf LLMs to compute ReCEval metrics. 30 | 31 | ## Reference 32 | Please cite our paper if you use our repository in your works: 33 | ```bibtex 34 | 35 | @article{Prasad2023ReCEval, 36 | title = {ReCEval: Evaluating Reasoning Chains via Correctness and Informativeness}, 37 | author = {Archiki Prasad and Swarnadeep Saha and Xiang Zhou and Mohit Bansal}, 38 | year = {2023}, 39 | archivePrefix = {arXiv}, 40 | primaryClass = {cs.CL}, 41 | eprint = {2304.10703} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /assets/ReCEvalOverview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/archiki/ReCEval/bc023421f4f8c748111643c1772693cd33686032/assets/ReCEvalOverview.png -------------------------------------------------------------------------------- /evaluate_receval.py: -------------------------------------------------------------------------------- 1 | from supar import Parser 2 | from nltk import word_tokenize, sent_tokenize 3 | from nltk.tokenize.treebank import TreebankWordDetokenizer 4 | import nltk 5 | import pdb 6 | import json 7 | import tqdm 8 | def tqdm_replacement(iterable_object,*args,**kwargs): 9 | return iterable_object 10 | tqdm_copy = tqdm.tqdm 11 | tqdm.tqdm = tqdm_replacement 12 | import os 13 | import torch 14 | import numpy as np 15 | import random 16 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM 17 | from scipy.stats import somersd 18 | from allennlp.predictors.predictor import Predictor 19 | import allennlp_models.tagging 20 | import re, string 21 | import logging 22 | import logging.config 23 | from datasets import load_dataset, load_metric, Dataset 24 | import torch.nn.functional as F 25 | from datasets.utils.logging import disable_progress_bar 26 | import pandas as pd 27 | 28 | 29 | logging.config.dictConfig({ 30 | 'version': 1, 31 | 'disable_existing_loggers': True, 32 | }) 33 | disable_progress_bar() 34 | random.seed(1) 35 | 36 | 37 | srl_predictor = predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz") 38 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 39 | ent_model_name = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli" 40 | ent_tokenizer = AutoTokenizer.from_pretrained(ent_model_name) 41 | ent_model = AutoModelForSequenceClassification.from_pretrained(ent_model_name).to(device) 42 | 43 | # Intra-Step PVI arguments 44 | inp_model_dir = 'PVI/inp_models/' 45 | no_inp_model_dir = 'PVI/noinp_models/' 46 | # Infor-gain PVI arguments 47 | info_gain_model_dir = 'PVI/infogain_models/' 48 | 49 | max_input_length = 512 50 | max_target_length = 64 51 | padding = "max_length" 52 | model_name = "t5-large" 53 | label_pad_token_id = -100 54 | pad_token = '' 55 | prefix = 'Generate entailed sentence: ' 56 | 57 | inp_tokenizer = AutoTokenizer.from_pretrained(inp_model_dir) 58 | inp_config = AutoConfig.from_pretrained(inp_model_dir) 59 | inp_model = AutoModelForSeq2SeqLM.from_pretrained(inp_model_dir, config=inp_config) 60 | inp_model.cuda().eval() 61 | no_inp_tokenizer = AutoTokenizer.from_pretrained(no_inp_model_dir) 62 | no_inp_config = AutoConfig.from_pretrained(no_inp_model_dir) 63 | no_inp_model = AutoModelForSeq2SeqLM.from_pretrained(no_inp_model_dir, config=no_inp_config) 64 | no_inp_model.cuda().eval() 65 | 66 | info_gain_tokenizer = AutoTokenizer.from_pretrained(info_gain_model_dir) 67 | info_gain_config = AutoConfig.from_pretrained(info_gain_model_dir) 68 | info_gain_model = AutoModelForSeq2SeqLM.from_pretrained(info_gain_model_dir, config=info_gain_config) 69 | info_gain_model.cuda().eval() 70 | info_gain_mname = 'gpt2' 71 | 72 | 73 | def init_gpt2(): 74 | ll_tokenizer = AutoTokenizer.from_pretrained('gpt2') 75 | ll_model = AutoModelForCausalLM.from_pretrained('gpt2-xl') 76 | ll_model.eval().cuda() 77 | ll_tokenizer.padding_side = "left" 78 | ll_tokenizer.pad_token = ll_tokenizer.eos_token 79 | ll_model.config.pad_token_id = ll_model.config.eos_token_id 80 | return ll_model, ll_tokenizer 81 | 82 | def inti_t5(): 83 | ll_tokenizer = AutoTokenizer.from_pretrained("t5-large") 84 | config = AutoConfig.from_pretrained("t5-large") 85 | ll_model = AutoModelForSeq2SeqLM.from_pretrained("t5-large", config=config) 86 | ll_model.eval().cuda() 87 | ll_tokenizer.pad_token = pad_token 88 | ll_model.config.pad_token_id = pad_token 89 | return ll_model, ll_tokenizer 90 | 91 | # For Info-Gain PVI 92 | if info_gain_mname == 'gpt2': ll_model, ll_tokenizer = init_gpt2() 93 | elif info_gain_mname == 't5-large': ll_model, ll_tokenizer = inti_t5() 94 | 95 | def obtain_entailment_scores(premise, hypothesis): 96 | input = ent_tokenizer(premise, hypothesis, truncation=True, return_tensors="pt").to(device) 97 | with torch.no_grad(): 98 | output = ent_model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu" 99 | prediction = torch.softmax(output["logits"][0], -1).tolist() 100 | label_names = ["entailment", "neutral", "contradiction"] 101 | prediction = {name: float(pred) for pred, name in zip(prediction, label_names)} 102 | return prediction['entailment'] 103 | 104 | def obtain_contradiction_scores(premise, hypothesis): 105 | input = ent_tokenizer(premise, hypothesis, truncation=True, return_tensors="pt").to(device) 106 | with torch.no_grad(): 107 | output = ent_model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu" 108 | prediction = torch.softmax(output["logits"][0], -1).tolist() 109 | label_names = ["entailment", "neutral", "contradiction"] 110 | prediction = {name: float(pred) for pred, name in zip(prediction, label_names)} 111 | return prediction['contradiction'] 112 | 113 | def obtain_unit_entailment_score(prem_units, conc_units): 114 | if len(prem_units): 115 | premise = ' and '.join(prem_units) 116 | hypothesis = ' and '.join(conc_units) 117 | score = obtain_entailment_scores(premise, hypothesis) 118 | else: 119 | score = 1 120 | return score 121 | 122 | def obtain_contradiction_score(prem_units, conc_units): 123 | pair_scores = [] 124 | hypothesis = ' and '.join(conc_units) 125 | for premise in prem_units: 126 | pair_scores.append(obtain_contradiction_scores(premise, hypothesis)) 127 | if len(pair_scores): 128 | score = 1 - max(pair_scores) 129 | else: 130 | score = 1 131 | 132 | return score 133 | 134 | 135 | def detokenize(tokens): 136 | return TreebankWordDetokenizer().detokenize(tokens) 137 | 138 | def verb_modifiers(desc): 139 | filtered_mods = [] 140 | mods = re.findall(r"\[ARGM.*?\]", desc) 141 | if not len(mods): return filtered_mods 142 | for mod in mods: 143 | phrase = mod.split(': ')[1].rstrip(']') 144 | verb_match = ['VB' in k[1] for k in nltk.pos_tag(word_tokenize(phrase))] 145 | if sum(verb_match) and len(phrase.split()) > 2: filtered_mods.append(phrase) # put in a length criteria 146 | return filtered_mods 147 | 148 | def remove_modifiers(sent, modifiers): 149 | if not len(modifiers): return sent 150 | for mod in modifiers: 151 | sent = sent.replace(mod, "") 152 | sent = re.sub(' +', ' ', sent) # remove any double spaces 153 | sent = sent.strip(string.punctuation + ' ') # remove stray punctuations 154 | return sent 155 | 156 | def extract_frame(tags, words, desc): 157 | prev = 'O' 158 | start, end = None, None 159 | if len(set(tags)) == 1: return '' 160 | tags = [t if 'C-ARG' not in t else 'O' for t in tags] #check if the modifier is a verb phrase 161 | for w in range(len(words)): 162 | if 'B-' in tags[w] and start is None: start = w 163 | if tags[len(words) - w -1]!='O' and end is None: end = len(words) - w -1 164 | 165 | if end is None: end = start 166 | sent = detokenize(words[start: end + 1]).rstrip('.') 167 | return sent 168 | 169 | 170 | def get_phrases(sent): 171 | # Simple RCU extractor without conjunction check for premises 172 | phrases = [] 173 | history = '' 174 | srl_out = predictor.predict(sent) 175 | words = srl_out['words'] 176 | frames = [s['tags'] for s in srl_out['verbs']] 177 | descs = [s['description'] for s in srl_out['verbs']] 178 | mod_sent = detokenize(words).rstrip('.') 179 | for frame, desc in zip(frames, descs): 180 | phrase = extract_frame(frame, words, desc) 181 | if phrase == mod_sent: phrase = remove_modifiers(phrase, verb_modifiers(desc)) 182 | phrases.append(phrase) 183 | phrases.sort(key=lambda s: len(s), reverse=True) 184 | filtered_phrases = [] 185 | for p in phrases: 186 | if p not in history: 187 | history += ' ' + p 188 | filtered_phrases.append(p) 189 | if len(filtered_phrases): 190 | filtered_phrases.sort(key=lambda s: mod_sent.find(s)) 191 | left = mod_sent 192 | mod_filt = False 193 | for fp in filtered_phrases: left = left.replace(fp, '#').strip(string.punctuation + ' ') 194 | for l in left.split('#'): 195 | l = l.strip(string.punctuation + ' ') 196 | if len(l.split()) >=4 and l not in " ".join(filtered_phrases): 197 | verb_match = ['VB' in k[1] for k in nltk.pos_tag(word_tokenize(l))] 198 | if sum(verb_match): 199 | filtered_phrases.append(l) 200 | mod_filt = True 201 | if mod_filt: filtered_phrases.sort(key=lambda s: mod_sent.find(s)) 202 | return filtered_phrases 203 | else: return [sent.rstrip('.')] 204 | 205 | def get_sent_phrases(para): 206 | sentences = sent_tokenize(para) 207 | phrases = [] 208 | for sent in sentences: 209 | phrases.extend(get_phrases(sent)) 210 | return phrases 211 | 212 | def get_reasoning_chain_text(steps, sentences): 213 | # If using the reasoning trees directly 214 | step_texts = [] 215 | covered_nodes = [] 216 | for step in steps: 217 | parent_text = " and ".join([sentences[p] for p in step['parents'] if p not in covered_nodes]) 218 | if len(parent_text): step_text = parent_text + ', so ' + sentences[step['child']] + "." 219 | else: step_text = 'so ' + sentences[step['child']] + '.' 220 | covered_nodes.extend(step['parents']); covered_nodes.append(step['child']) 221 | step_texts.append(step_text) 222 | return step_texts 223 | 224 | 225 | def preprocess_and_convert(premise_units, conc_units): 226 | data = {'inputs': [], 'labels': []} 227 | parent_text = " & ".join(premise_units) + ' ->' 228 | child_text = " " + conc_units[0] # assume just one conc unit 229 | data['inputs'].append(parent_text) 230 | data['labels'].append(child_text) 231 | return data 232 | 233 | def postprocess_test_data(examples): 234 | inputs = [prefix + text for text in examples['inputs']] 235 | model_inputs = inp_tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt") 236 | 237 | # Setup the tokenizer for targets 238 | with inp_tokenizer.as_target_tokenizer(): 239 | targets = [pad_token + label for label in examples['labels']] 240 | labels = inp_tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt") 241 | model_inputs["decoder_input_ids"] = labels["input_ids"] 242 | model_inputs["decoder_attention_mask"] = labels["attention_mask"] 243 | return model_inputs 244 | 245 | def noinp_postprocess_test_data(examples): 246 | inputs = [prefix + 'None ->' for text in examples['inputs']] 247 | model_inputs = no_inp_tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt") 248 | 249 | # Setup the tokenizer for targets 250 | with no_inp_tokenizer.as_target_tokenizer(): 251 | targets = [pad_token + label for label in examples['labels']] 252 | labels = no_inp_tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt") 253 | model_inputs["decoder_input_ids"] = labels["input_ids"] 254 | model_inputs["decoder_attention_mask"] = labels["attention_mask"] 255 | return model_inputs 256 | 257 | def obtain_log_prob(predict_dataset, model, tokenizer): 258 | logits = model(input_ids=torch.Tensor(predict_dataset['input_ids']).long().cuda(), attention_mask=torch.Tensor(predict_dataset['attention_mask']).long().cuda(), 259 | decoder_input_ids=torch.Tensor(predict_dataset['decoder_input_ids']).long().cuda(), decoder_attention_mask = torch.Tensor(predict_dataset['decoder_attention_mask']).long().cuda()).logits 260 | all_logprobs = torch.log(torch.softmax(logits, dim=-1)) 261 | labels = tokenizer(predict_dataset['labels'], max_length=max_target_length).input_ids 262 | filter_sums = [] 263 | for row, label in zip(all_logprobs, labels): 264 | label.pop() 265 | row = row[:len(label), :].detach().cpu().numpy() 266 | vocab_size = row.shape[-1] 267 | loc = F.one_hot(torch.tensor(label), num_classes=vocab_size).numpy().astype(bool) 268 | try: summed_logprob = np.sum(row, where = loc) 269 | except: import pdb; pdb.set_trace() 270 | filter_sums.append(summed_logprob/len(label)) 271 | return np.array(filter_sums) 272 | 273 | def obtain_unit_pvi_score(premise_units, conc_units): 274 | dataset = Dataset.from_dict(preprocess_and_convert(premise_units, conc_units)) 275 | inp_dataset = dataset.map(postprocess_test_data, batched=True, remove_columns=['inputs']) 276 | no_inp_dataset = dataset.map(noinp_postprocess_test_data, batched=True, remove_columns=['inputs']) 277 | inp_logprob = obtain_log_prob(inp_dataset, inp_model, inp_tokenizer)[0] 278 | no_inp_logprob = obtain_log_prob(no_inp_dataset, no_inp_model, no_inp_tokenizer)[0] 279 | return inp_logprob - no_inp_logprob 280 | 281 | 282 | def slice_select_logits(all_logprobs, label): 283 | filter_sums = [] 284 | if info_gain_mname == 'gpt2': row = all_logprobs[-len(label):, :].detach().cpu().numpy() 285 | elif info_gain_mname == 't5-large': row = all_logprobs[:len(label), :].detach().cpu().numpy() 286 | vocab_size = row.shape[-1] 287 | loc = F.one_hot(torch.tensor(label), num_classes=vocab_size).numpy().astype(bool) 288 | try: summed_logprob = np.sum(row, where = loc) 289 | except: import pdb; pdb.set_trace() 290 | filter_sums.append(summed_logprob/len(label)) 291 | return np.array(filter_sums) 292 | 293 | def obtain_info_gain_score(prev_steps, current_step, current_conc, target, info_model, info_tokenizer): 294 | if info_gain_mname == 't5-large': 295 | target = pad_token + target 296 | input = " ".join(prev_steps + [current_step]) + ' Therefore,' + target 297 | if len(prev_steps): ref_input = " ".join(prev_steps) + ' Therefore,' + target 298 | else: ref_input = 'Therefore,' + target 299 | 300 | inputs = info_tokenizer(input, return_tensors="pt") 301 | ref = info_tokenizer(ref_input, return_tensors="pt") 302 | labels = info_tokenizer(target, return_tensors="pt") 303 | inputs["decoder_input_ids"] = labels['input_ids'] 304 | inputs['decoder_attention_mask'] = labels['attention_mask'] 305 | ref["decoder_input_ids"] = labels['input_ids'] 306 | ref['decoder_attention_mask'] = labels['attention_mask'] 307 | 308 | for i in inputs: inputs[i] = inputs[i].cuda() 309 | for i in ref: ref[i] = ref[i].cuda() 310 | with torch.no_grad(): 311 | inp_logits = info_model.forward(**inputs).logits.detach().cpu() 312 | ref_logits = info_model.forward(**ref).logits.detach().cpu() 313 | all_inp_logprobs = torch.log(torch.softmax(inp_logits, dim=-1)) 314 | all_ref_logprobs = torch.log(torch.softmax(ref_logits, dim=-1)) 315 | labels = labels.input_ids.detach().cpu().tolist()[0][1:] 316 | filtered_inp_logprobs = slice_select_logits(all_inp_logprobs[0,:,:], labels)[0] 317 | filtered_ref_logprobs = slice_select_logits(all_ref_logprobs[0,:,:], labels)[0] 318 | 319 | elif info_gain_mname == 'gpt2': 320 | target = " " + target 321 | input = " " + " ".join(prev_steps + [current_step]) + ' Therefore,' + target 322 | if len(prev_steps): ref_input = " " + " ".join(prev_steps) + ' Therefore,' + target 323 | else: ref_input = ' Therefore,' + target 324 | labels = info_tokenizer(target).input_ids 325 | input_ids = info_tokenizer(input, return_tensors="pt").input_ids.cuda() 326 | ref_input_ids = info_tokenizer(ref_input, return_tensors="pt").input_ids.cuda() 327 | with torch.no_grad(): 328 | inp_logits = info_model.forward(input_ids=input_ids, return_dict=True).logits.detach().cpu() 329 | ref_logits = info_model.forward(input_ids=ref_input_ids, return_dict=True).logits.detach().cpu() 330 | all_inp_logprobs = torch.log(torch.softmax(inp_logits, dim=-1)) 331 | all_ref_logprobs = torch.log(torch.softmax(ref_logits, dim=-1)) 332 | filtered_inp_logprobs = slice_select_logits(all_inp_logprobs[0,:-1,:], labels)[0] #shift probability since at idx i produce distribution of tokens at i+1 333 | filtered_ref_logprobs = slice_select_logits(all_ref_logprobs[0,:-1,:], labels)[0] 334 | return (filtered_inp_logprobs - filtered_ref_logprobs) 335 | 336 | 337 | source_path = 'perturbed_trees' 338 | error_types = os.listdir(source_path) 339 | score_keys = ['entail', 'pvi', 'contradict', 'll-info', 'pvi-info'] 340 | score_keys = ['ll-info'] 341 | errors_correl = {k:{'somersd':{}, 'pearson':{}} for k in score_keys} 342 | K = 0 # Set how many past steps to look at. 343 | 344 | for error in error_types: 345 | print(error) 346 | epath = os.path.join(source_path, error) 347 | tree_entry = [json.loads(line) for line in open(epath, 'r')] 348 | local_ent_scores = [] 349 | alt_local_ent_scores = [] 350 | local_pvi_scores = [] 351 | global_contradict_scores = [] 352 | info_ll_scores, info_pvi_scores = [], [], [] 353 | for t, entry in tqdm_copy(enumerate(tree_entry)): 354 | # Tree-based Evaluation from EB 355 | # Otherwise, directly iterate over reasoning problems and directly get sentences 356 | # Hypothesis for GSM-8K is concat question and answer 357 | input_context = entry['question'] 358 | input_context_sentences = sent_tokenize(input_context) 359 | steps, sentences = entry['steps']['perturbed'], entry['sentences']['perturbed'] 360 | reasoning_steps = get_reasoning_chain_text(steps, sentences) 361 | # reasoning_steps = sent_tokenize(entry['steps']) 362 | # Needed keys are: question (input_context), steps, hypothesis 363 | step_ent_scores = [] 364 | alt_step_ent_scores = [] 365 | step_pvi_scores = [] 366 | step_contradict_scores = [] 367 | step_redundancy_scores = [] 368 | step_ll_scores = [] 369 | step_pviinfo_scores = [] 370 | running_conc = [] 371 | for sid, step in enumerate(reasoning_steps): 372 | units = get_phrases(step) 373 | premise_units, conc_units = [], [] 374 | premise_units.extend(units[:-1]) 375 | conc_units.append(units[-1]) 376 | 377 | # Entail Step Calculation 378 | if 'entail' in score_keys: 379 | alt_step_ent_scores.append(obtain_unit_entailment_score(premise_units + running_conc[-1*K:], conc_units)) 380 | # Intra-Step PVI Calculation 381 | if 'pvi' in score_keys: 382 | step_pvi_scores.append(obtain_unit_pvi_score(premise_units + running_conc[-1*K:], conc_units)) 383 | # Global Contradiction Check 384 | if 'contradict' in score_keys: 385 | step_contradict_scores.append(obtain_contradiction_score(input_context_sentences + running_conc, conc_units)) 386 | 387 | # LL Informativeness Check 388 | if 'll-info' in score_keys: 389 | step_ll_scores.append(obtain_info_gain_score(reasoning_steps[:sid], step, conc_units, sentences['hypothesis'], ll_model, ll_tokenizer)) 390 | 391 | # PVI Informativeness Check 392 | if 'pvi-info' in score_keys: 393 | step_pviinfo_scores.append(obtain_info_gain_score(reasoning_steps[:sid], step, conc_units, sentences['hypothesis'], info_gain_model, info_gain_tokenizer)) 394 | running_conc.extend(conc_units) 395 | 396 | if 'entail' in score_keys: alt_local_ent_scores.append(min(alt_step_ent_scores)) 397 | if 'pvi' in score_keys: local_pvi_scores.append(min(step_pvi_scores)) 398 | if 'contradict' in score_keys: global_contradict_scores.append(min(step_contradict_scores)) 399 | if 'll-info' in score_keys: info_ll_scores.append(min(step_ll_scores)) 400 | if 'pvi-info' in score_keys: info_pvi_scores.append(min(step_pviinfo_scores)) 401 | perturbed_ids = [1 - int(e['perturbed']) for e in tree_entry] 402 | 403 | if 'entail' in score_keys: print(somersd(perturbed_ids, alt_local_ent_scores).statistic) 404 | if 'pvi' in score_keys: print(somersd(perturbed_ids, local_pvi_scores).statistic) 405 | if 'contradict' in score_keys: print(somersd(perturbed_ids, global_contradict_scores).statistic) 406 | if 'll-info' in score_keys: print(somersd(perturbed_ids, info_ll_scores).statistic) 407 | if 'pvi-info' in score_keys: print(somersd(perturbed_ids, info_pvi_scores).statistic) 408 | 409 | if 'entail' in score_keys: 410 | errors_correl['entail']['pearson'][error] = np.corrcoef(alt_local_ent_scores, perturbed_ids)[0][1] 411 | errors_correl['entail']['somersd'][error] = somersd(perturbed_ids, alt_local_ent_scores).statistic 412 | if 'pvi' in score_keys: 413 | errors_correl['pvi']['pearson'][error] = np.corrcoef(local_pvi_scores, perturbed_ids)[0][1] 414 | errors_correl['pvi']['somersd'][error] = somersd(perturbed_ids, local_pvi_scores).statistic 415 | if 'contradict' in score_keys: 416 | errors_correl['contradict']['pearson'][error] = np.corrcoef(global_contradict_scores, perturbed_ids)[0][1] 417 | errors_correl['contradict']['somersd'][error] = somersd(perturbed_ids, global_contradict_scores).statistic 418 | if 'pvi-info' in score_keys: 419 | errors_correl['redundancy']['pearson'][error] = np.corrcoef(info_pvi_scores, perturbed_ids)[0][1] 420 | errors_correl['redundancy']['somersd'][error] = somersd(perturbed_ids, info_pvi_scores).statistic 421 | if 'll-info' in score_keys: 422 | errors_correl['ll-info']['pearson'][error] = np.corrcoef(info_ll_scores, perturbed_ids)[0][1] 423 | errors_correl['ll-info']['somersd'][error] = somersd(perturbed_ids, info_ll_scores).statistic 424 | 425 | f = open('ResultLogs/correlations.json', 'w+') 426 | json.dump(errors_correl, f, indent=4) 427 | -------------------------------------------------------------------------------- /perturb_EB.py: -------------------------------------------------------------------------------- 1 | import graphviz 2 | from graphviz import Digraph 3 | import json 4 | import re 5 | import random 6 | from copy import deepcopy 7 | from checklist.perturb import Perturb 8 | import spacy 9 | nlp = spacy.load('en_core_web_sm') 10 | import os 11 | import jsonlines 12 | from tqdm import tqdm 13 | from transformers import PegasusForConditionalGeneration, PegasusTokenizer 14 | 15 | gold_path = 'entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/dataset/task_1/test.jsonl' 16 | more_path = 'entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/dataset/task_2/test.jsonl' 17 | preds_path = 'entailment_bank/data/processed_data/predictions/emnlp_2021/task1/T5_11B/test.16K_steps.predictions.tsv' 18 | 19 | gold_examples = [json.loads(line) for line in open(gold_path, 'r')] 20 | more_examples = [json.loads(line) for line in open(more_path, 'r')] 21 | pred_examples = [line.split('=')[-1].strip() for line in open(preds_path, 'r')] 22 | 23 | gold_examples = [g for g in gold_examples if g['proof'].count(';') > 1] 24 | more_examples = [m for m in more_examples if m['proof'].count(';') > 1] 25 | 26 | para_model_name = 'tuner007/pegasus_paraphrase' 27 | torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' 28 | para_tokenizer = PegasusTokenizer.from_pretrained(para_model_name) 29 | para_model = PegasusForConditionalGeneration.from_pretrained(para_model_name).to(torch_device).eval() 30 | # Can replace with alternate paraphrase models 31 | 32 | def get_response(input_text,num_return_sequences,num_beams): 33 | batch = para_tokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device) 34 | translated = para_model.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5) 35 | tgt_text = para_tokenizer.batch_decode(translated, skip_special_tokens=True) 36 | return tgt_text 37 | 38 | def obtain_paraphrase(phrase): 39 | num_beams = 5 40 | num_return_sequences = 5 41 | paraphrases = get_response(phrase, num_return_sequences, num_beams) 42 | paraphrase = np.random.choice(paraphrases, 1)[0] 43 | paraphrase = paraphrase.strip('.') 44 | return paraphrase 45 | 46 | def process(proof, alt_triples, hypothesis): 47 | proof = proof.rstrip('; ') 48 | steps = proof.split(';') 49 | step_list = [] 50 | triples = deepcopy(alt_triples) 51 | for step in steps: 52 | [parents, leaf] = step.split('->') 53 | parents = parents.split('&') 54 | parents = [p.strip(' ') for p in parents] 55 | leaf = leaf.strip() 56 | leafId = leaf.split(':')[0] 57 | if 'int' in leafId: 58 | leaf_sent = leaf.split(':')[-1].strip() 59 | triples[leafId] = leaf_sent 60 | elif leafId == 'hypothesis': 61 | leaf_sent = hypothesis 62 | triples[leafId] = leaf_sent 63 | step_list.append({'parents':parents, 'child': leafId}) 64 | proof = proof + '; ' 65 | return step_list, triples 66 | 67 | def reconstruct_proof(steps, sentences): 68 | proof = '' 69 | for step in steps: 70 | proof += " & ".join(step['parents']) + ' -> ' + step['child'] 71 | if 'int' in step['child']: 72 | proof += ': ' + sentences[step['child']] 73 | proof += '; ' 74 | return proof 75 | 76 | def repeat_steps(in_steps, in_sentences): 77 | steps = deepcopy(in_steps) 78 | sentences = deepcopy(in_sentences) 79 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']] 80 | idx = random.choice(int_idxs) 81 | assert idx < len(steps) - 1 82 | repeated_node = steps[idx]['child'] 83 | print(idx, repeated_node) 84 | key = 'int' + str(len(int_idxs) + 1) 85 | sentences[key] = sentences[repeated_node] 86 | steps[idx]['child'] = key 87 | steps.insert(idx + 1, {'parents': [key], 'child': repeated_node}) 88 | return steps, sentences 89 | 90 | def delete_steps(in_steps, in_sentences): 91 | steps = deepcopy(in_steps) 92 | sentences = deepcopy(in_sentences) 93 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']] 94 | idx = random.choice(int_idxs) 95 | assert idx < len(steps) - 1 96 | del_node = steps[idx]['child'] 97 | del_parents = steps[idx]['parents'] 98 | del steps[idx] 99 | for step in steps: 100 | if del_node in step['parents']: 101 | step['parents'].extend(del_parents) 102 | step['parents'] = [p for p in step['parents'] if p != del_node] 103 | return steps, sentences 104 | 105 | def swapped_steps(in_steps, in_sentences): 106 | steps = deepcopy(in_steps) 107 | sentences = deepcopy(in_sentences) 108 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']] 109 | idx = random.choice(int_idxs) 110 | assert idx < len(steps) - 1 111 | swap_node = deepcopy(steps[idx]['child']) 112 | swap_parent = random.choice(steps[idx]['parents']) 113 | print(swap_node, swap_parent) 114 | alt_parents = [p for p in steps[idx]['parents'] if p!= swap_parent] 115 | alt_parents.append(swap_node) 116 | steps[idx]['parents'] = alt_parents 117 | steps[idx]['child'] = swap_parent 118 | for step in steps[idx + 1:]: 119 | if swap_node in step['parents']: 120 | step['parents'].append(swap_parent) 121 | step['parents'] = [p for p in step['parents'] if p != swap_node] 122 | return steps, sentences 123 | 124 | def negate_step(in_steps, in_sentences): 125 | steps = deepcopy(in_steps) 126 | sentences = deepcopy(in_sentences) 127 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']] 128 | idx = random.choice(int_idxs) 129 | assert idx < len(steps) - 1 130 | negated_node = steps[idx]['child'] 131 | print(negated_node, sentences[negated_node]) 132 | sentences[negated_node] = Perturb.add_negation(nlp(sentences[negated_node])) 133 | return steps, sentences 134 | 135 | def hallucinate_step(in_steps, in_sentences, extra_sentences): 136 | steps = deepcopy(in_steps) 137 | sentences = deepcopy(in_sentences) 138 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']] 139 | idx = random.choice(int_idxs) 140 | assert idx < len(steps) - 1 141 | hallucinate_node = steps[idx]['child'] 142 | print(hallucinate_node, sentences[hallucinate_node]) 143 | potential_sents = [v for (k,v) in extra_sentences.items() if v not in in_sentences.values()] 144 | sentences[hallucinate_node] = random.choice(potential_sents) 145 | print(sentences[hallucinate_node]) 146 | return steps, sentences 147 | 148 | def paraphrase_steps(in_steps, in_sentences): 149 | steps = deepcopy(in_steps) 150 | sentences = deepcopy(in_sentences) 151 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']] 152 | idx = random.choice(int_idxs) 153 | assert idx < len(steps) - 1 154 | repeated_node = steps[idx]['child'] 155 | print(idx, repeated_node) 156 | key = 'int' + str(len(int_idxs) + 1) 157 | sentences[key] = obtain_paraphrase(sentences[repeated_node]) 158 | steps[idx]['child'] = key 159 | steps.insert(idx + 1, {'parents': [key], 'child': repeated_node}) 160 | return steps, sentences 161 | 162 | def redundant_steps(in_steps, in_sentences, extra_sentences): 163 | steps = deepcopy(in_steps) 164 | sentences = deepcopy(in_sentences) 165 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']] 166 | idx = random.choice(int_idxs) 167 | assert idx < len(steps) - 1 168 | repeated_node = steps[idx]['child'] 169 | print(idx, repeated_node) 170 | key = 'int' + str(len(int_idxs) + 1) 171 | potential_sents = [v for (k,v) in extra_sentences.items() if v not in in_sentences.values()] 172 | sentences[key] = random.choice(potential_sents) 173 | steps[idx]['child'] = key 174 | steps.insert(idx + 1, {'parents': [key], 'child': repeated_node}) 175 | return steps, sentences 176 | 177 | 178 | 179 | tree_dest_path = 'perturbed_trees' 180 | if not os.path.exists(tree_dest_path): os.makedirs(tree_dest_path) 181 | random.seed(0) 182 | 183 | unperturbed_path = 'ParlAI/projects/roscoe/roscoe_data/unperturbed_ids.json' 184 | custom_unperturbed_ids = {'entailment_bank_synthetic':{}} 185 | 186 | perturbation_functions = {'DuplicateOneStep': repeat_steps, 'ParaphraseOneStep': paraphrase_steps, 'RedundantOneStep': redundant_steps, 'RemoveOneStep': delete_steps, 'SwapOneStep': swapped_steps, 'NegateStep': negate_step, 'ExtrinsicHallucinatedStep': hallucinate_step} 187 | 188 | for perturb_type, perturb_func in tqdm(perturbation_functions.items()): 189 | type_unperturbed_ids = random.choices(range(len(gold_examples)), k=126) # unperturbed_ids[perturb_type + "_test.jsonl"] 190 | custom_unperturbed_ids[perturb_type + "_test.jsonl"] = type_unperturbed_ids 191 | fname = '50%_' + perturb_type + '_test.jsonl' 192 | revised_entry = [] 193 | tree_entry = [] 194 | for id, gold_example, more_example in zip(range(len(gold_examples)), gold_examples, more_examples): 195 | steps, sentences = process(gold_example['proof'], gold_example['meta']['triples'], gold_example['hypothesis']) 196 | question = ". ".join(list(gold_example['meta']['triples'].values()) + [gold_example['question']]) 197 | answer = gold_example['answer'] 198 | int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']] 199 | if not len(int_idxs): continue 200 | 201 | if id in type_unperturbed_ids: 202 | perturbed_steps = steps 203 | perturbed_sentences = sentences 204 | perturbed = False 205 | else: 206 | more_sentences = more_example['meta']['triples'] 207 | perturbed_steps, perturbed_sentences = perturb_func(steps, sentences, more_sentences) 208 | if steps == perturbed_steps and sentences == perturbed_sentences: 209 | perturbed = False 210 | custom_unperturbed_ids[perturb_type + "_test.jsonl"].append(id) 211 | else: perturbed = True 212 | tree_entry.append({'perturbed': perturbed, 'perturbations': perturb_type, 'steps':{'original': steps, 'perturbed': perturbed_steps}, 'sentences':{'original': sentences, 'perturbed': perturbed_sentences}, 'written':{'original': original_written_steps, 'perturbed': written_steps}, 'question': question, 'answer': answer}) 213 | with jsonlines.open(os.path.join(tree_dest_path, fname), 'w') as writer: 214 | writer.write_all(tree_entry) 215 | 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | allennlp==2.10.1 3 | allennlp-models==2.10.1 4 | -e git+https://github.com/marcotcr/checklist.git@3edd07c9a84e6c6657333450d4d0e70ecb0c00d9#egg=checklist 5 | datasets==2.7.1 6 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl 7 | evaluate==0.4.0 8 | fairscale==0.4.6 9 | graphviz @ file:///opt/conda/conda-bld/python-graphviz_1660063183935/work 10 | huggingface-hub==0.10.1 11 | jsonlines==3.1.0 12 | matplotlib==3.6.2 13 | nltk==3.8 14 | numpy==1.20.3 15 | pandas==1.5.2 16 | protobuf==3.19.6 17 | py-rouge==1.1 18 | rouge-score==0.1.2 19 | scikit-learn==1.2.0 20 | scipy==1.9.3 21 | sentence-transformers==2.2.2 22 | sentencepiece==0.1.97 23 | spacy==3.4.0 24 | textdistance==4.5.0 25 | tokenizers==0.12.1 26 | torch==1.10.0+cu111 27 | torchaudio==0.10.0+rocm4.1 28 | torchvision==0.11.0+cu111 29 | tqdm==4.64.1 30 | transformers==4.20.1 -------------------------------------------------------------------------------- /run_flan.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import datasets 3 | import json 4 | import numpy as np 5 | import pdb 6 | import torch 7 | from torch.utils.data import DataLoader, Dataset 8 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 9 | import pandas as pd 10 | from datasets import load_dataset 11 | from tqdm import tqdm 12 | from nltk import sent_tokenize 13 | import re 14 | from string import punctuation 15 | 16 | zero_shot_instruction = "Answer the following question by reasoning step-by-step.\n" #Write the answer as a separate sentence starting with 'The answer is'.\n" 17 | few_shot_prefix = """Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? 18 | A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.\n 19 | Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? 20 | A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.\n 21 | Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? 22 | A: Leah had 32 chocolates and Leah’s sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.\n 23 | Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? 24 | A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.\n 25 | """ 26 | # Examples that won't fit in context from CoT paper 27 | # """ 28 | # Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does 29 | # he have now? 30 | # A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.\n 31 | # Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 32 | # A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.\n 33 | # Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 34 | # A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.\n 35 | # Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 36 | # A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.\n""" 37 | 38 | def apply_zs_prompt(example, instruction=zero_shot_instruction): 39 | return instruction + '\nQ: ' + example['question'] + '\n' + 'A:' 40 | 41 | def apply_fs_prompt(example, instruction=zero_shot_instruction): 42 | return few_shot_prefix + '\n' + instruction + '\nQ: ' + example['question'] + '\n' + 'A:' 43 | 44 | def generate_math_prompt(prefix, question, instruction): 45 | prompt = '' 46 | prompt += prefix 47 | if len(instruction): 48 | prompt += instruction + '\n' 49 | prompt += question 50 | return prompt 51 | 52 | def extract_answer(answer_text, answer_prefix='answer is ', no_prefix=False, overfit=False): 53 | sentences = sent_tokenize(answer_text) 54 | ans_candidate = sentences[-1] 55 | found = True 56 | if not no_prefix and answer_prefix in ans_candidate: 57 | answer = ans_candidate.partition(answer_prefix)[2].strip(punctuation) 58 | try: 59 | return float(answer) 60 | except: found = False 61 | else: found = False 62 | if no_prefix: 63 | answers = re.findall(r'\d+', ans_candidate) 64 | if len(answers): 65 | return float(answers[-1]) 66 | else: found = False 67 | if not found: 68 | if not overfit: return None 69 | else: 70 | answer = ans_candidate.partition('=')[2].strip(punctuation) 71 | try: 72 | return float(answer) 73 | except: 74 | return None 75 | 76 | def extract_gold_answer(ans): 77 | return float(ans.partition('###')[2].replace(',', '').strip(punctuation)) 78 | 79 | 80 | dataset = load_dataset("gsm8k", 'main', split='test') 81 | 82 | dataset = dataset.map(lambda example: {'gold': extract_gold_answer(example['answer'])}) 83 | dataset = dataset.map(lambda example: {'input_prompt': apply_zs_prompt(example)}) 84 | data_loader = DataLoader(dataset, batch_size=4, shuffle=False) 85 | 86 | num_gen_tokens = 128 87 | 88 | tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl") 89 | tokenizer.padding_side = "left" 90 | tokenizer.pad_token = tokenizer.eos_token 91 | if torch.cuda.is_available(): 92 | model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xxl", device_map="auto") 93 | else: 94 | model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xxl") 95 | 96 | 97 | count = 0 98 | correct = 0 99 | 100 | for batch in tqdm(data_loader): 101 | input_text = batch['input_prompt'] 102 | if torch.cuda.is_available(): 103 | inputs = tokenizer(input_text, return_tensors="pt", padding=True).to("cuda") 104 | else: 105 | inputs = tokenizer(input_text, return_tensors="pt", padding=True) 106 | outputs = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=num_gen_tokens) 107 | op_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) 108 | op_answers = [extract_answer(op, no_prefix=True, overfit=True) for op in op_texts] 109 | correct += sum(np.array(op_answers) == np.array(batch['gold'])) 110 | count += len(batch['gold']) 111 | print('0-shot') 112 | print(round(100*correct/count, 2)) 113 | 114 | -------------------------------------------------------------------------------- /train_infogain_pvi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import transformers 3 | from datasets import load_dataset, load_metric 4 | import json 5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback 6 | import numpy as np 7 | from transformers.trainer_utils import get_last_checkpoint 8 | import torch 9 | import shutil 10 | 11 | 12 | max_input_length = 512 13 | max_target_length = 64 14 | padding = "max_length" 15 | model_name = "t5-large" 16 | label_pad_token_id = -100 17 | pad_token = '' 18 | tokenizer = AutoTokenizer.from_pretrained(model_name) 19 | config = AutoConfig.from_pretrained(model_name) 20 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config) 21 | batch_size = 16 22 | output_dir = 'PVI/infogain_models/norepNLnomarkov' 23 | do_train = True 24 | do_eval = True 25 | do_predict = True 26 | global no_input 27 | no_input = False 28 | overwrite_output_dir = True 29 | 30 | def postprocess_test_data(examples): 31 | if not no_input: 32 | inputs = [prefix + text for text in examples['inputs']] 33 | else: inputs = [prefix + text for text in examples['prev_inputs']] 34 | model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt") 35 | 36 | # Setup the tokenizer for targets 37 | with tokenizer.as_target_tokenizer(): 38 | targets = [pad_token + label for label in examples['labels']] 39 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt") 40 | model_inputs["decoder_input_ids"] = labels["input_ids"] 41 | model_inputs["decoder_attention_mask"] = labels["attention_mask"] 42 | return model_inputs 43 | 44 | def preprocess_data(examples): 45 | if not no_input: 46 | inputs = [prefix + text for text in examples['inputs']] 47 | else: inputs = [prefix + text for text in examples['prev_inputs']] 48 | model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True) 49 | 50 | # Setup the tokenizer for targets 51 | with tokenizer.as_target_tokenizer(): 52 | labels = tokenizer(examples["labels"], max_length=max_target_length, padding=padding, truncation=True) 53 | 54 | labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]] 55 | model_inputs["labels"] = labels["input_ids"] 56 | return model_inputs 57 | 58 | def compute_metrics(eval_pred): 59 | predictions, labels = eval_pred 60 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 61 | 62 | # Replace -100 in the labels as we can't decode them. 63 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 64 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 65 | # Compute ROUGE scores 66 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 67 | 68 | # Extract ROUGE f1 scores 69 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 70 | 71 | return {k: round(v, 2) for k, v in result.items()} 72 | 73 | 74 | prefix = 'Generate final answer: ' 75 | dataset = load_dataset('json', data_files = {'train': 'PVI/IG_norepNLnomarkov_tree/train.json', 'dev': 'PVI/IG_norepNLnomarkov_tree/dev.json', 'test': 'PVI/IG_norepNLnomarkov_tree/test.json'}, field="data") 76 | tokenized_dataset = dataset.map(preprocess_data, batched=True) 77 | # predict_dataset = dataset['test'].map(postprocess_test_data, batched=True) 78 | 79 | args = Seq2SeqTrainingArguments( 80 | output_dir = output_dir, 81 | evaluation_strategy="epoch", 82 | logging_strategy="epoch", 83 | save_strategy="epoch", 84 | learning_rate=3e-5, 85 | per_device_train_batch_size=batch_size, 86 | per_device_eval_batch_size=batch_size, 87 | weight_decay=0.01, 88 | save_total_limit=1, 89 | num_train_epochs=10, 90 | predict_with_generate=True, 91 | fp16=True, 92 | load_best_model_at_end=True, 93 | metric_for_best_model="eval_rougeL", 94 | overwrite_output_dir=overwrite_output_dir, 95 | ) 96 | 97 | data_collator = DataCollatorForSeq2Seq(tokenizer, model = model, label_pad_token_id=label_pad_token_id) 98 | metric = load_metric("rouge") 99 | 100 | trainer = Seq2SeqTrainer( 101 | model=model, 102 | args=args, 103 | train_dataset=tokenized_dataset["train"], 104 | eval_dataset=tokenized_dataset["dev"], 105 | data_collator=data_collator, 106 | tokenizer=tokenizer, 107 | compute_metrics=compute_metrics, 108 | callbacks = [EarlyStoppingCallback(early_stopping_patience=3)], 109 | ) 110 | 111 | if do_train: 112 | checkpoint = None 113 | last_checkpoint = None 114 | if os.path.isdir(output_dir): 115 | last_checkpoint = get_last_checkpoint(output_dir) 116 | 117 | if last_checkpoint is not None: 118 | checkpoint = last_checkpoint 119 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 120 | trainer.save_model() # Saves the tokenizer too for easy upload 121 | 122 | metrics = train_result.metrics 123 | trainer.log_metrics("train", metrics) 124 | trainer.save_metrics("train", metrics) 125 | trainer.save_state() 126 | 127 | # Evaluation 128 | results = {} 129 | if do_eval: 130 | metrics = trainer.evaluate(max_length=max_target_length, num_beams=8, metric_key_prefix="eval") 131 | trainer.log_metrics("eval", metrics) 132 | trainer.save_metrics("eval", metrics) 133 | 134 | # Prediction 135 | if do_predict: 136 | results = trainer.predict(tokenized_dataset['test'], dataset['test']) 137 | metrics = results.metrics 138 | trainer.log_metrics("predict", metrics) 139 | trainer.save_metrics("predict", metrics) 140 | last_checkpoint = get_last_checkpoint(output_dir) 141 | shutil.rmtree(last_checkpoint) 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /train_pvi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import transformers 3 | from datasets import load_dataset, load_metric 4 | import json 5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback 6 | import numpy as np 7 | from transformers.trainer_utils import get_last_checkpoint 8 | import torch 9 | 10 | max_input_length = 512 11 | max_target_length = 64 12 | padding = "max_length" 13 | model_name = "t5-large" 14 | label_pad_token_id = -100 15 | pad_token = '' 16 | tokenizer = AutoTokenizer.from_pretrained(model_name) 17 | config = AutoConfig.from_pretrained(model_name) 18 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config) 19 | batch_size = 16 20 | output_dir = 'PVI/noinp_models/' 21 | do_train = True 22 | do_eval = True 23 | do_predict = True 24 | global no_input 25 | no_input = True 26 | overwrite_output_dir = True 27 | 28 | def postprocess_test_data(examples): 29 | if not no_input: 30 | inputs = [prefix + text for text in examples['inputs']] 31 | else: inputs = [prefix + 'None ->' for text in examples['inputs']] 32 | model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt") 33 | 34 | # Setup the tokenizer for targets 35 | with tokenizer.as_target_tokenizer(): 36 | targets = [pad_token + label for label in examples['labels']] 37 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt") 38 | model_inputs["decoder_input_ids"] = labels["input_ids"] 39 | model_inputs["decoder_attention_mask"] = labels["attention_mask"] 40 | return model_inputs 41 | 42 | def preprocess_data(examples): 43 | if not no_input: 44 | inputs = [prefix + text for text in examples['inputs']] 45 | else: inputs = [prefix + 'None ->' for text in examples['inputs']] 46 | model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True) 47 | 48 | # Setup the tokenizer for targets 49 | with tokenizer.as_target_tokenizer(): 50 | labels = tokenizer(examples["labels"], max_length=max_target_length, padding=padding, truncation=True) 51 | 52 | labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]] 53 | model_inputs["labels"] = labels["input_ids"] 54 | return model_inputs 55 | 56 | def compute_metrics(eval_pred): 57 | predictions, labels = eval_pred 58 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 59 | 60 | # Replace -100 in the labels as we can't decode them. 61 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 62 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 63 | 64 | # Compute ROUGE scores 65 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 66 | 67 | # Extract ROUGE f1 scores 68 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 69 | return {k: round(v, 2) for k, v in result.items()} 70 | 71 | 72 | prefix = 'Generate entailed sentence: ' 73 | dataset = load_dataset('json', data_files = {'train': 'PVI/train.json', 'dev': 'PVI/dev.json', 'test': 'PVI/test.json'}, field="data") 74 | tokenized_dataset = dataset.map(preprocess_data, batched=True) 75 | predict_dataset = dataset['test'].map(postprocess_test_data, batched=True) 76 | 77 | args = Seq2SeqTrainingArguments( 78 | output_dir = output_dir, 79 | evaluation_strategy="epoch", 80 | logging_strategy="epoch", 81 | save_strategy="epoch", 82 | learning_rate=3e-5, 83 | per_device_train_batch_size=batch_size, 84 | per_device_eval_batch_size=batch_size, 85 | weight_decay=0.01, 86 | save_total_limit=1, 87 | num_train_epochs=10, 88 | predict_with_generate=True, 89 | fp16=True, 90 | load_best_model_at_end=True, 91 | metric_for_best_model="eval_rouge1", 92 | overwrite_output_dir=overwrite_output_dir, 93 | ) 94 | 95 | data_collator = DataCollatorForSeq2Seq(tokenizer, model = model, label_pad_token_id=label_pad_token_id) 96 | metric = load_metric("rouge") 97 | 98 | trainer = Seq2SeqTrainer( 99 | model=model, 100 | args=args, 101 | train_dataset=tokenized_dataset["train"], 102 | eval_dataset=tokenized_dataset["dev"], 103 | data_collator=data_collator, 104 | tokenizer=tokenizer, 105 | compute_metrics=compute_metrics, 106 | callbacks = [EarlyStoppingCallback(early_stopping_patience=3)], 107 | ) 108 | 109 | if do_train: 110 | checkpoint = None 111 | last_checkpoint = None 112 | if os.path.isdir(output_dir): 113 | last_checkpoint = get_last_checkpoint(output_dir) 114 | 115 | if last_checkpoint is not None: 116 | checkpoint = last_checkpoint 117 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 118 | trainer.save_model() # Saves the tokenizer too for easy upload 119 | 120 | metrics = train_result.metrics 121 | trainer.log_metrics("train", metrics) 122 | trainer.save_metrics("train", metrics) 123 | trainer.save_state() 124 | 125 | # Evaluation 126 | results = {} 127 | if do_eval: 128 | 129 | metrics = trainer.evaluate(max_length=max_target_length, num_beams=8, metric_key_prefix="eval") 130 | trainer.log_metrics("eval", metrics) 131 | trainer.save_metrics("eval", metrics) 132 | 133 | # Prediction 134 | if do_predict: 135 | 136 | results = trainer.predict(tokenized_dataset['test'], dataset['test']) 137 | metrics = results.metrics 138 | trainer.log_metrics("predict", metrics) 139 | trainer.save_metrics("predict", metrics) 140 | 141 | 142 | --------------------------------------------------------------------------------