├── api_secrets.py ├── requirements.txt ├── evaluation.py ├── README.md ├── engine.py ├── knowledge_conflict.py └── abstention.py /api_secrets.py: -------------------------------------------------------------------------------- 1 | def get_api_key(): 2 | return "MY-OPENAI-KEY" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | openai 3 | scipy 4 | tiktoken 5 | tqdm 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | def normalize_answer(s): 5 | """Lower text and remove punctuation, articles and extra whitespace.""" 6 | def remove_articles(text): 7 | return re.sub(r'\b(a|an|the)\b', ' ', text) 8 | def white_space_fix(text): 9 | return ' '.join(text.split()) 10 | def remove_punc(text): 11 | exclude = set(string.punctuation) 12 | return ''.join(ch for ch in text if ch not in exclude) 13 | def lower(text): 14 | return text.lower() 15 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 16 | 17 | def exact_match_score(prediction, ground_truth): 18 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 19 | 20 | def recall_score(prediction, ground_truth): 21 | prediction = normalize_answer(prediction) 22 | ground_truth = normalize_answer(ground_truth) 23 | return (ground_truth in prediction) 24 | 25 | def get_score(preds, golds): 26 | em, recall = 0, 0 27 | for pred, gold in zip(preds, golds): 28 | if isinstance(gold, list): 29 | _em, _recall = 0, 0 30 | for g in gold: 31 | _em = max(exact_match_score(pred, g), _em) 32 | _recall = max(recall_score(pred, g), _recall) 33 | em += _em 34 | recall += _recall 35 | else: 36 | em += exact_match_score(pred, gold) 37 | recall += recall_score(pred, gold) 38 | em = em * 100 / (len(preds) + 1e-5) 39 | recall = recall * 100 / (len(preds) + 1e-5) 40 | return em, recall 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Context-faithful Prompting for Large Language Models 2 | 3 | Code and data for paper [Context-faithful Prompting for Large Language Models](https://arxiv.org/abs/2303.11315). 4 | 5 | ## How to Use 6 | 7 | ### Step 1: Install Required Packages 8 | 9 | Before you begin, make sure to install the necessary packages: ``openai``, ``scipy``, ``numpy``, ``tiktoken``, ``tqdm``, and ``scikit-learn``. To do so, run the following command: 10 | ``pip install -r requirements.txt``. 11 | 12 | ### Step 2: Download the Datasets 13 | 14 | Download the NQ and RealtimeQA datasets from [Google Drive](https://drive.google.com/file/d/1DJ1ajmLNAKVTBWnM7SkP93EYQ2cav3Mk/view?usp=sharing) and extract them to the repository folder. Please note that the TACRED dataset is not included due to its LDC license. 15 | 16 | ### Step 3: Add Your OpenAI API Key 17 | 18 | Insert your OpenAI API key to ``api_secrets.py``. 19 | 20 | ### Step 4: Run Experiments 21 | 22 | Run experiments on the NQ dataset in the knowledge conflict setting using the following command: 23 | `` python knowledge_conflict.py --schema ${SCHEMA} --demo_mode ${DEMO_MODE}`` 24 | 25 | To perform experiments on the RealTime QA dataset in the abstention setting, use this command: 26 | ``python abstention.py --schema ${SCHEMA} --demo_mode ${DEMO_MODE}`` 27 | 28 | The ``SCHEMA`` parameter refers to the prompting templates described in the paper and can take the following values: ``base``, ``attr``, ``instr``, ``opin``, or ``instr+opin``. The ``DEMO_MODE`` parameter represents the demonstration method, with possible values being ``none`` (zero-shot), ``counter`` (counterfactual demonstrations, applicable only in the knowledge conflict setting), and ``original`` (original demonstrations). 29 | 30 | **Please be aware that running experiments can be costly. Few-shot evaluation on the full dataset is estimated to cost around $150 for NQ and $30 for RealTime QA when using the ``text-davinci-003`` engine for each prompting templates.** 31 | 32 | ## Citation 33 | ```bibtex 34 | @article{zhou2023context, 35 | title={Context-faithful Prompting for Large Language Models}, 36 | author={Zhou, Wenxuan and Zhang, Sheng and Poon, Hoifung and Chen, Muhao}, 37 | journal={arXiv preprint arXiv:2303.11315}, 38 | year={2023} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from api_secrets import get_api_key 3 | from time import sleep 4 | import tiktoken 5 | 6 | 7 | openai.api_key = get_api_key() 8 | 9 | length_limit = { 10 | 'text-davinci-003': 4096, 11 | 'text-curie-001': 2048, 12 | 'text-babbage-001': 2048, 13 | 'text-ada-001': 2048, 14 | } 15 | 16 | class Engine: 17 | def __init__(self, engine='text-davinci-003'): 18 | self.engine = engine 19 | self.tokenizer = tiktoken.encoding_for_model(engine) 20 | 21 | def check_prompt_length(self, prompt, max_tokens=64): 22 | prompt_length = len(self.tokenizer.encode(prompt)) 23 | if prompt_length + max_tokens >= length_limit[self.engine]: # Prompt is too long 24 | return True 25 | return False 26 | 27 | def complete(self, prompt, max_tokens=64): 28 | num_retry = 0 29 | while True: 30 | try: 31 | response = openai.Completion.create( 32 | engine=self.engine, 33 | prompt=prompt, 34 | max_tokens=max_tokens, 35 | ) 36 | except Exception as e: 37 | print(e) 38 | if num_retry >= 5: # Retried too many times 39 | print('Retried too many times, skip this instance.') 40 | return None 41 | sleep(2) 42 | num_retry += 1 43 | continue 44 | break 45 | answer = response.choices[0].text 46 | return answer 47 | 48 | def get_prob(self, prompt, num_tokens): 49 | num_retry = 0 50 | while True: 51 | try: 52 | response = openai.Completion.create( 53 | engine=self.engine, 54 | prompt=prompt, 55 | max_tokens=0, 56 | logprobs=1, 57 | echo=True, 58 | ) 59 | token_logprobs = response.choices[0].logprobs.token_logprobs[-num_tokens:] 60 | seq_prob = sum(token_logprobs) 61 | except Exception as e: 62 | print(e) 63 | if num_retry >= 5: # Retried too many times 64 | print('Retried too many times, skip this instance.') 65 | return None 66 | sleep(2) 67 | num_retry += 1 68 | continue 69 | break 70 | return seq_prob 71 | -------------------------------------------------------------------------------- /knowledge_conflict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import argparse 4 | from tqdm import tqdm 5 | from engine import Engine 6 | from evaluation import get_score 7 | 8 | 9 | def qa_to_prompt(query, context, schema, demos=[], num_demos=16): 10 | def get_prompt(query, context, schema, answer=''): 11 | if schema == 'base': 12 | prompt = '{}\nQ:{}\nA:{}'.format(context, query, answer) 13 | elif schema == 'opin': 14 | context = context.replace('"', "") 15 | prompt = 'Bob said "{}"\nQ: {} in Bob\'s opinion?\nA:{}'.format(context, query[:-1], answer) 16 | elif schema == 'instr+opin': 17 | context = context.replace('"', "") 18 | prompt = 'Bob said "{}"\nQ: {} in Bob\'s opinion?\nA:{}'.format(context, query[:-1], answer) 19 | elif schema == 'attr': 20 | prompt = '{}\nQ:{} based on the given tex?\nA:{}'.format(context, query[:-1], answer) 21 | elif schema == 'instr': 22 | prompt = '{}\nQ:{}\nA:{}'.format(context, query, answer) 23 | return prompt 24 | prompt = '' 25 | if schema in ('instr', 'instr+opin'): 26 | prompt = 'Instruction: read the given information and answer the corresponding question.\n\n' 27 | for demo in demos[-num_demos:]: 28 | answer = demo['answer'] if isinstance(demo['answer'], str) else demo['answer'][0] 29 | demo_prompt = get_prompt(demo['question'], demo['context'], schema=schema, answer=answer) 30 | prompt = prompt + demo_prompt + '\n\n' 31 | prompt = prompt + get_prompt(query, context, schema=schema) 32 | return prompt 33 | 34 | def eval(pred_answers, orig_answers, gold_answers): 35 | em, ps = get_score(pred_answers, gold_answers) 36 | _, po = get_score(pred_answers, orig_answers) 37 | mr = po / (ps + po + 1e-10) * 100 38 | print('ps {}, po {}, mr {}, em {}.'.format(ps, po, mr, em)) 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--orig_path", default="./datasets/nq/orig_dev_filtered.json", type=str) 43 | parser.add_argument("--counter_path", default="./datasets/nq/conflict_dev_filtered.json", type=str) 44 | parser.add_argument("--engine", default="text-davinci-003", type=str) 45 | parser.add_argument("--schema", default="base", type=str, help="Choose from the following prompting templates: base, attr, instr, opin, instr+opin.") 46 | parser.add_argument("--demo_mode", default="none", help="Choose from the following demonstrations: none, original, counter.") 47 | parser.add_argument("--num_demos", default=16, type=int) 48 | parser.add_argument("--log_path", default='', type=str) 49 | args = parser.parse_args() 50 | with open(args.orig_path, 'r') as fh: 51 | orig_examples = json.load(fh) 52 | with open(args.counter_path, 'r') as fh: 53 | counter_examples = json.load(fh) 54 | print('Loaded {} instances.'.format(len(counter_examples))) 55 | engine = Engine(args.engine) 56 | 57 | step = 0 58 | gold_answers, pred_answers, orig_answers = [], [], [] 59 | for oe, ce in tqdm(zip(orig_examples, counter_examples), total=len(orig_examples)): 60 | if step % 100 == 0: 61 | eval(pred_answers, orig_answers, gold_answers) 62 | step += 1 63 | query, context, answer = ce['question'], ce['context'], ce['answer'] 64 | orig_answer = oe['answer'] 65 | if orig_answer is None: 66 | continue 67 | if args.demo_mode == 'none': 68 | demos = [] 69 | elif args.demo_mode == 'counter': 70 | demos = ce['ic_examples'] 71 | elif args.demo_mode == 'original': 72 | demos = ce['ico_examples'] 73 | for num_demos in range(args.num_demos, 1, -1): # Use fewer demos if prompt is too long 74 | prompt = qa_to_prompt(query, context, schema=args.schema, demos=demos, num_demos=num_demos) 75 | if not engine.check_prompt_length(prompt): 76 | break 77 | if engine.check_prompt_length(prompt): 78 | continue 79 | pred = engine.complete(prompt) 80 | if pred is None: 81 | continue 82 | pred_answers.append(pred) 83 | gold_answers.append(answer) 84 | orig_answers.append(orig_answer) 85 | # Logs 86 | ce['prediction'] = pred 87 | ce['orig_answer'] = orig_answer 88 | ce['schema'] = args.schema 89 | ce['demo_mode'] = args.demo_mode 90 | if args.log_path: 91 | with open(args.log_path, 'w') as fh: 92 | json.dump(counter_examples, fh) 93 | eval(pred_answers, orig_answers, gold_answers) 94 | 95 | if __name__ == '__main__': 96 | main() -------------------------------------------------------------------------------- /abstention.py: -------------------------------------------------------------------------------- 1 | import json 2 | from engine import Engine 3 | from tqdm import tqdm 4 | import numpy as np 5 | from scipy.special import softmax 6 | import tiktoken 7 | import argparse 8 | from sklearn.metrics import brier_score_loss 9 | 10 | def qa_to_prompt(query, context, choices, schema, answer=''): 11 | context = context.replace('“', '"').replace('”', '"').replace('’', "'") 12 | if schema == 'base': 13 | prompt = '{}\n\nQ: {}\nChoices: {}\nA: {}'.format(context, query, choices, answer) 14 | elif schema == 'opin': 15 | context = context.replace('"', "") 16 | prompt = 'Bob said, "{}"\n\nQ: {} in Bob\'s opinion?\nChoices: {}\nA: {}'.format(context, query[:-1], choices, answer) 17 | elif schema == 'attr': 18 | prompt = '{}\n\nQ:{} based on the given text?\nChoices: {}\nA: {}'.format(context, query[:-1], choices, answer) 19 | elif schema == 'instr': 20 | prompt = '{}\n\nQ: {}\nChoices: {}\nA: {}'.format(context, query, choices, answer) 21 | elif schema == 'instr+opin': 22 | context = context.replace('"', "") 23 | prompt = 'Bob said, "{}"\n\nQ: {} in Bob\'s opinion?\nChoices: {}\nA: {}'.format(context, query[:-1], choices, answer) 24 | return prompt 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--data_path", default="./datasets/realtime_qa/realtime_qa_data.json", type=str) 29 | parser.add_argument("--demo_path", default="./datasets/realtime_qa/realtime_qa_demo_data.json", type=str) 30 | parser.add_argument("--engine", default="text-davinci-003", type=str) 31 | parser.add_argument("--schema", default="base", type=str, help="Choose from the following prompting templates: base, attr, instr, opin, instr+opin.") 32 | parser.add_argument("--demo_mode", default="none", help="Choose from the following demonstrations: none, original.") 33 | parser.add_argument("--log_path", default='', type=str) 34 | args = parser.parse_args() 35 | with open(args.data_path, 'r') as fh: 36 | test_data = json.load(fh) 37 | with open(args.demo_path, 'r') as fh: 38 | demo_data = json.load(fh) 39 | engine = Engine(args.engine) 40 | tokenizer = tiktoken.encoding_for_model(args.engine) 41 | abs_golds, abs_probs, preds, golds = [], [], [], [] 42 | for d in tqdm(test_data): 43 | context, question, choices, answer = d['context'], d['question'], d['choices'], d['answer'] 44 | probs = [] 45 | for choice in choices.split(';'): 46 | choice = choice.strip() 47 | assert len(choice) > 0 48 | prompt = '' 49 | if args.schema in ('instr', 'instr+opin'): 50 | prompt = 'Instruction: answer a question based on the provided input-output pairs.\n\n' 51 | if args.demo_mode == 'original': 52 | for demo in demo_data: 53 | prompt += (qa_to_prompt(demo['question'], demo['context'], demo['choices'], args.schema, answer=demo['answer']) + '\n\n') 54 | choice = choice.strip() + '.' 55 | prompt += qa_to_prompt(question, context, choices, args.schema) 56 | prompt = prompt + choice 57 | if engine.check_prompt_length(prompt): 58 | continue 59 | num_tokens = len(tokenizer.encode(' ' + choice)) 60 | prob = engine.get_prob(prompt, num_tokens) 61 | if prob is not None: 62 | probs.append(prob) 63 | if len(probs) != len(choices.split(';')): 64 | continue 65 | choice_probs = softmax(np.array(probs)) 66 | choices = [s.strip() for s in choices.split(';')] 67 | pred = choices[probs.index(max(probs))] 68 | d['pred'] = pred 69 | d['probs'] = choice_probs.tolist() 70 | abs_gold = 1 if answer == 'I don\'t know' else 0 71 | abs_golds.append(abs_gold) 72 | abs_probs.append(choice_probs.tolist()[-1]) 73 | preds.append(pred) 74 | golds.append(answer) 75 | # Evaluation 76 | has_ans_correct, no_ans_correct, has_ans_wrong, no_ans_wrong = 0, 0, 0, 0 77 | for pred, gold in zip(preds, golds): 78 | if pred == gold: 79 | if gold != 'I don\'t know': 80 | has_ans_correct += 1 81 | else: 82 | no_ans_correct += 1 83 | else: 84 | if gold != 'I don\'t know': 85 | has_ans_wrong += 1 86 | else: 87 | no_ans_wrong += 1 88 | hasans_acc = has_ans_correct / (has_ans_correct + has_ans_wrong) 89 | noans_acc = no_ans_correct / (no_ans_correct + no_ans_wrong) 90 | acc = (has_ans_correct + no_ans_correct) / (has_ans_correct + has_ans_wrong + no_ans_correct + no_ans_wrong) 91 | brier = brier_score_loss(np.array(abs_golds), np.array(abs_probs)) 92 | print("HasAns Acc {}, NoAns Acc {}, Acc {}, Brier {}.".format(hasans_acc, noans_acc, acc, brier)) 93 | if args.log_path: 94 | with open(args.log_path, 'w') as fh: 95 | json.dump(test_data, fh) 96 | 97 | if __name__ == '__main__': 98 | main() --------------------------------------------------------------------------------