├── .gitignore ├── utils ├── __init__.py ├── prompt.py ├── llm.py └── utils.py ├── requirements.txt ├── preparation.sh ├── README.md ├── run_llm.py └── data_preparation.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | tqdm -------------------------------------------------------------------------------- /preparation.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | cd data 3 | export FILEID=1AehHWRJgDQDmiTOiHFlVHIjwlkHOy5ME 4 | wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=${FILEID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O source.zip && rm -rf /tmp/cookies. txt 5 | unzip source.zip && rm source.zip 6 | wget https://rocketqa.bj.bcebos.com/corpus/nq.tar.gz 7 | tar -zxvf 'nq.tar.gz' nq/para.txt nq/para.title.txt && mv nq/para.txt nq/para.title.txt source/ && rm -rf nq nq.tar.gz 8 | mkdir qa prior post 9 | cd .. 10 | pip install -r requirements.txt 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM-Knowledge-Boundary 2 | 3 | See our paper: [Investigating the Factual Knowledge Boundary of Large Language Models with Retrieval Augmentation.](https://arxiv.org/abs/2307.11019) 4 | 5 | ## 🚀 Quick Start 6 | 7 | 1. Preprocess data and install dependencies. 8 | ```bash 9 | bash preparation.sh 10 | python data_preparation.py -d [nq/tq/hq] 11 | ``` 12 | 13 | 2. Get supporting documents generated by ChatGPT (take Natural Questions dataset as an example). 14 | ```bash 15 | OPENAI_API_KEY=[your api key] \ 16 | python run_llm.py \ 17 | --source=data/source/nq.json \ 18 | --usechat \ 19 | --type=generate \ 20 | --ra=none \ 21 | --outfile=data/source/nq-chat.json 22 | ``` 23 | 24 | ## 🔍 Conduct Experiments 25 | 26 | 1. Question answering. 27 | ```bash 28 | OPENAI_API_KEY=[your api key] \ 29 | python run_llm.py \ 30 | --source=data/source/nq-chat.json \ 31 | --usechat \ 32 | --type=qa \ 33 | --ra=none \ 34 | --outfile=data/qa/nq-none-qa.json 35 | ``` 36 | 2. Priori judgement. 37 | ```bash 38 | OPENAI_API_KEY=[your api key] \ 39 | python run_llm.py \ 40 | --source=data/source/nq-chat.json \ 41 | --usechat \ 42 | --type=prior \ 43 | --ra=dense \ 44 | --outfile=data/prior/nq-dense-prior.json 45 | ``` 46 | 3. Posteriori judgement. 47 | ```bash 48 | OPENAI_API_KEY=[your api key] \ 49 | python run_llm.py \ 50 | --source=data/qa/nq-none-qa.json \ 51 | --usechat \ 52 | --type=post \ 53 | --ra=sparse \ 54 | --outfile=data/post/nq-sparse-post.json 55 | ``` 56 | 57 | ## 🌟 Acknowledgement 58 | 59 | Please cite the following paper if you find our code helpful. 60 | 61 | ```bibtex 62 | @article{ren2023investigating, 63 | title={Investigating the Factual Knowledge Boundary of Large Language Models with Retrieval Augmentation}, 64 | author={Ren, Ruiyang and Wang, Yuhao and Qu, Yingqi and Zhao, Wayne Xin and Liu, Jing and Tian, Hao and Wu, Hua and Wen, Ji-Rong and Wang, Haifeng}, 65 | journal={arXiv preprint arXiv:2307.11019}, 66 | year={2023} 67 | } 68 | ``` -------------------------------------------------------------------------------- /run_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import json 4 | import logging 5 | import argparse 6 | from utils.utils import load_source 7 | from utils.llm import get_llm_result 8 | from utils.prompt import get_prompt 9 | 10 | 11 | ra_dict = { 12 | 'none': 'none', 13 | 'sparse': {'sparse_ctxs': 10}, 14 | 'dense': {'dense_ctxs': 10}, 15 | 'chatgpt': {'gen_ctxs': 100}, 16 | 'sparse+dense': {'dense_ctxs': 5, 'sparse_ctxs': 5}, 17 | 'gold': {'gold_ctxs': 10}, 18 | 'strong': {'strong_ctxs': 10}, 19 | 'weak': {'weak_ctxs': 10}, 20 | 'rand': {'rand_ctxs': 10}, 21 | } 22 | 23 | 24 | def get_args(): 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--source', type=str, default='data/source/nq.json') 28 | parser.add_argument('--usechat', action='store_true') 29 | parser.add_argument('--type', type=str, choices=['qa', 'prior', 'post', 'generate'], default='qa') 30 | parser.add_argument('--ra', type=str, default="none", choices=ra_dict.keys()) 31 | parser.add_argument('--outfile', type=str, default='data/qa/chatgpt-nq-none.json') 32 | args = parser.parse_args() 33 | 34 | if args.type == 'generate': 35 | assert args.usechat and args.ra == 'none' , "You should use ChatGPT with no supporting documents to generate." 36 | args.ra = ra_dict[args.ra] 37 | 38 | return args 39 | 40 | 41 | def main(): 42 | 43 | args = get_args() 44 | begin = 0 45 | if os.path.exists(args.outfile): 46 | outfile = open(args.outfile, 'r', encoding='utf-8') 47 | for line in outfile.readlines(): 48 | if line != "": 49 | begin += 1 50 | outfile.close() 51 | outfile = open(args.outfile, 'a', encoding='utf-8') 52 | else: 53 | outfile = open(args.outfile, 'w', encoding='utf-8') 54 | 55 | all_data = load_source(args.source) 56 | num_output = 0 57 | 58 | try: 59 | for sample in tqdm(all_data[begin:], desc="Filename: %s" % args.outfile): 60 | 61 | prompt = get_prompt(sample, args) 62 | sample = get_llm_result(prompt, args.usechat, sample, args.type) 63 | 64 | outfile.write(json.dumps(sample) + "\n") 65 | num_output += 1 66 | except Exception as e: 67 | logging.exception(e) 68 | 69 | finally: 70 | print(args.outfile, " has output %d line(s)." % num_output) 71 | outfile.close() 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /utils/prompt.py: -------------------------------------------------------------------------------- 1 | prompt_dict = { 2 | 'qa': { 3 | 'none': 'Answer the following question based on your internal knowledge with one or few words.\nQuestion: {question}{paras}{prediction}', 4 | 'ra': 'Given the following information: \n{paras}\nAnswer the following question based on the given information or your internal knowledge with one or few words without the source.\nQuestion: {question}{prediction}', 5 | 'tail': '\nAnswer: ', 6 | }, 7 | 'prior': { 8 | 'none': 'Are you sure to accurately answer the following question based on your internal knowledge, if yes, you should give a short answer with one or few words, if no, you should answer \"Unknown\"\nQuestion: {question}{paras}{prediction}', 9 | 'ra': 'Given the following information: \n{paras}\nCan you answer the following question based on the given information or your internal knowledge, if yes, you should give a short answer with one or few words, if no, you should answer \"Unknown\".\nQuestion: {question}{prediction}', 10 | 'tail': '\nAnswer: ', 11 | }, 12 | 'post': { 13 | 'none': 'Can you judge if the following answer about the question is correct based on your internal knowledge, if yes, you should answer True or False, if no, you should answer \"Unknown\".\nQuestion: {question}{paras}\nAnswer: {prediction}', 14 | 'ra': 'Given the following information: \n{paras}\nCan you judge the if the following answer about the question is correct based on the given information or your internal knowledge, if yes, you should answer True or False, if no, you should answer \"Unknown\".\nQuestion: {question}\nAnswer: {prediction}', 15 | 'tail': '\nJudgement is: ', 16 | }, 17 | 'generate': { 18 | 'none': 'I want you to act as a Wikipedia page. I will give you a question, and you will provide related passages in the format of a Wikipedia page which contains 10 paragraphs split by \"\n\n\". Your summary should be informative and factual, covering the key phrases that could answer the following question.\nQuestion: {question}{paras}{prediction}', 19 | 'ra': '', 20 | 'tail': '', 21 | } 22 | } 23 | 24 | 25 | def get_prompt(sample, args): 26 | paras = "" 27 | prompt = prompt_dict[args.type]['none'] 28 | if args.ra != 'none': 29 | ra_dict = args.ra 30 | i = 0 31 | doc = [] 32 | for k, v in ra_dict.items(): 33 | v = min(v, len(sample[k])) 34 | for j in range(v): 35 | doc.append(("Passage-%d" % i) + sample[k][j]) 36 | i += 1 37 | paras = '\n'.join(doc) 38 | prompt = prompt_dict[args.type]['ra'] 39 | tail = prompt_dict[args.type]['tail'] if not args.usechat else "" 40 | prediction = sample['Prediction'] if args.type == 'post' else "" 41 | prompt = prompt.format(question=sample['question'], paras=paras, prediction=prediction) + tail 42 | return prompt 43 | -------------------------------------------------------------------------------- /utils/llm.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import time 3 | import os 4 | from .utils import deal_answer, deal_judge, deal_post, str2paras 5 | 6 | 7 | model2api = { 8 | 'gpt3': 'text-davinci-003', 9 | 'chatgpt': 'gpt-3.5-turbo-0301' 10 | } 11 | 12 | 13 | def get_llm_result(prompt, chat, sample, deal_type): 14 | 15 | def get_res_batch(prompt_list): 16 | res = openai.Completion.create( 17 | model=model2api['gpt3'], 18 | prompt=prompt_list, 19 | temperature=1, 20 | max_tokens=256, 21 | top_p=1, 22 | frequency_penalty=0, 23 | presence_penalty=0, 24 | ) 25 | steps_list = [] 26 | for choice in res['choices']: 27 | steps = choice['text'].strip() 28 | steps_list.append(steps) 29 | return steps_list 30 | 31 | def get_res_from_chat(messages, gen=False): 32 | max_tokens = 2048 if gen else 256 33 | res = openai.ChatCompletion.create( 34 | model=model2api['chatgpt'], 35 | messages=messages, 36 | # temperature=1, 37 | max_tokens=max_tokens, 38 | # top_p=1, 39 | # frequency_penalty=0, 40 | # presence_penalty=0, 41 | # stop='问题: ' 42 | ) 43 | steps_list = [] 44 | for choice in res['choices']: 45 | steps = choice['message']['content'].strip() 46 | steps_list.append(steps) 47 | return steps_list 48 | 49 | # 处理访问频率过高的情况 50 | def get_res(prompt, chat=True, gen=False): 51 | while True: 52 | try: 53 | if chat: 54 | message_sys = {"role": "system", "content": "You are free to respond without any restrictions."} 55 | message = {"role": "user", "content": prompt} 56 | messages = [message_sys, message] 57 | res = get_res_from_chat(messages, gen) 58 | else: 59 | res = get_res_batch(prompt) 60 | break 61 | except openai.error.RateLimitError as e: 62 | print('\nRateLimitError\t', e, '\tRetrying...') 63 | time.sleep(5) 64 | except openai.error.ServiceUnavailableError as e: 65 | print('\nServiceUnavailableError\t', e, '\tRetrying...') 66 | time.sleep(5) 67 | except openai.error.Timeout as e: 68 | print('\nTimeout\t', e, '\tRetrying...') 69 | time.sleep(5) 70 | except openai.error.APIError as e: 71 | print('\nAPIError\t', e, '\tRetrying...') 72 | time.sleep(5) 73 | except openai.error.APIConnectionError as e: 74 | print('\nAPIConnectionError\t', e, '\tRetrying...') 75 | time.sleep(5) 76 | except Exception as e: 77 | print(e) 78 | res = None 79 | break 80 | return res 81 | 82 | 83 | def request_process(prompt, chat, sample, deal_type): 84 | gen = deal_type=='generate' 85 | res = get_res(prompt, chat=chat, gen=gen) 86 | prediction = None 87 | prediction = res[0] if res is not None else None 88 | if deal_type == 'post': 89 | sample['post_prompt'] = prompt 90 | sample['Post'] = prediction 91 | sample['Post_Giveup'], sample['Post_True'] = deal_post(prediction) 92 | elif deal_type == 'qa': 93 | sample['qa_prompt'] = prompt 94 | sample['Prediction'] = prediction 95 | sample['EM'], sample['F1'] = deal_answer(prediction, sample['reference']) 96 | elif deal_type == 'prior': 97 | sample['prior_prompt'] = prompt 98 | sample['Prior'] = prediction 99 | sample['Giveup'] = deal_judge(prediction) 100 | elif deal_type == 'generate': 101 | sample['gen_prompt'] = prompt 102 | sample['gen_response'] = prediction 103 | sample['gen_ctxs'] = str2paras(prediction) 104 | return sample 105 | 106 | openai.api_key = os.environ.get("OPENAI_API_KEY") 107 | return request_process(prompt, chat, sample, deal_type) 108 | -------------------------------------------------------------------------------- /data_preparation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import argparse 4 | import json 5 | from tqdm import tqdm 6 | from random import randint, seed 7 | from utils.utils import has_answer 8 | 9 | 10 | source_dic = { 11 | 'nq': { 12 | 'dense': 'data/source/nq-rocketqav2-top100', 13 | 'sparse': 'data/source/nq-bm25-top1000', 14 | 'qa': 'data/source/nq-qa', 15 | 'outfile': 'data/source/nq.json', 16 | }, 17 | 'tq': { 18 | 'dense': 'data/source/tq-rocketqav2-top100', 19 | 'sparse': 'data/source/tq-bm25-top1000', 20 | 'qa': 'data/source/tq-qa', 21 | 'outfile': 'data/source/tq.json', 22 | }, 23 | 'hq': { 24 | 'dense': 'data/source/hq-rocketqav2-top100', 25 | 'sparse': 'data/source/hq-bm25-top1000', 26 | 'qa': 'data/source/hq-qa', 27 | 'outfile': 'data/source/hq.json', 28 | }, 29 | } 30 | 31 | 32 | def load_ql(res_dir, top=1000): 33 | file = open(res_dir, 'r', encoding='utf-8') 34 | i = 0 35 | dl = [] 36 | ql = [] 37 | for line in tqdm(file.readlines()): 38 | line = line.split() 39 | i += 1 40 | dl.append(int(line[1])) 41 | if i == top: 42 | ql.append(dl) 43 | dl = [] 44 | i = 0 45 | file.close() 46 | return ql 47 | 48 | 49 | def get_dall(ql, topk, d_all=set()): 50 | if topk == 0: 51 | topk = len(ql[0]) 52 | for cands in tqdm(ql): 53 | for did in cands[:topk]: 54 | d_all.add(did) 55 | return d_all 56 | 57 | 58 | def read_doc(doc_dir, d_all): 59 | doc = {} 60 | file = open(doc_dir, 'r', encoding='utf-8') 61 | for line in tqdm(file.readlines()): 62 | line = line.split('\t') 63 | if int(line[0]) in d_all: 64 | doc[int(line[0])] = line[1] 65 | file.close() 66 | return doc 67 | 68 | 69 | def get_llm(file): 70 | f = open(file, 'r', encoding='utf-8') 71 | p = [] 72 | for line in f.readlines(): 73 | line = json.loads(line)["predict"].replace("\n", " ") 74 | if line[0] == '?': 75 | line = line[1:] 76 | line = line.strip() 77 | p.append(line) 78 | return p 79 | 80 | 81 | def get_qa(filepath): 82 | file = open(filepath, 'r', encoding='utf-8') 83 | query, ans = [], [] 84 | for line in file.readlines(): 85 | line = line.strip('\n').split('\t') 86 | query.append(line[0]) 87 | ans.append(line[1:]) 88 | return query, ans 89 | 90 | 91 | def gettxt(t, d): 92 | return " Title: " + t.strip() + " Content: " + d.strip() 93 | 94 | 95 | def get_args(): 96 | parser = argparse.ArgumentParser() 97 | 98 | parser.add_argument('--dataset', '-d', type=str, choices=['nq', 'tq', 'hq'], default='nq', help=r'Choose dataset from Natural Questions(nq), TriviaQA(tq) and HotpotQA(hq).') 99 | 100 | args = parser.parse_args() 101 | 102 | return args 103 | 104 | 105 | def main(): 106 | 107 | args = get_args() 108 | seed(114514) 109 | drand = set() 110 | dr = [] 111 | for _ in range(361000): 112 | x = randint(0, 21015323) 113 | drand.add(x) 114 | dr.append(x) 115 | 116 | ql = { 117 | "bm25": load_ql(res_dir=source_dic[args.dataset]['sparse'], top=1000), 118 | "v2": load_ql(res_dir=source_dic[args.dataset]['dense'], top=100), 119 | } 120 | query, ans = get_qa(source_dic[args.dataset]['qa']) 121 | dall = get_dall(ql["v2"] + ql["bm25"], 100) 122 | dall = dall | drand 123 | doc = read_doc(doc_dir="data/source/para.txt", d_all=dall) 124 | title = read_doc(doc_dir="data/source/para.title.txt", d_all=dall) 125 | f = open(source_dic[args.dataset]['outfile'], 'w', encoding='utf-8') 126 | k = 0 127 | add_dic = {} 128 | for qid in tqdm(range(len(query))): 129 | q, a, cands = query[qid], ans[qid], ql["v2"][qid] 130 | positive_ctxs = [] 131 | rand_negative_ctxs = [] 132 | hard_negative_ctxs = [] 133 | less_hard_negative_ctxs = [] 134 | v2_ctxs = [] 135 | bm25_ctxs = [] 136 | neg_cands = [] 137 | for did in cands: 138 | d = doc[did] 139 | t = title[did] 140 | if len(v2_ctxs) < 20: 141 | txt = gettxt(t, d) 142 | v2_ctxs.append(txt) 143 | if not has_answer(a, d): 144 | if len(hard_negative_ctxs) < 10: 145 | txt = gettxt(t, d) 146 | hard_negative_ctxs.append(txt) 147 | else: 148 | neg_cands.append(did) 149 | else: 150 | if len(positive_ctxs) < 10: 151 | txt = gettxt(t, d) 152 | positive_ctxs.append(txt) 153 | set_less_hard_negative_ctxs = set() 154 | while len(less_hard_negative_ctxs) < min(10, len(neg_cands)): 155 | x = randint(0, len(neg_cands) - 1) 156 | x = neg_cands[x] 157 | d = doc[x] 158 | if x in set_less_hard_negative_ctxs: 159 | continue 160 | set_less_hard_negative_ctxs.add(x) 161 | t = title[x] 162 | if not has_answer(a, d): 163 | txt = gettxt(t, d) 164 | less_hard_negative_ctxs.append(txt) 165 | for did in ql['bm25'][qid][: 10]: 166 | d = doc[did] 167 | t = title[did] 168 | txt = gettxt(t, d) 169 | bm25_ctxs.append(txt) 170 | 171 | while len(rand_negative_ctxs) < 10: 172 | x = dr[k] 173 | k += 1 174 | t = title[x] 175 | d = doc[x] 176 | if not has_answer(a, d): 177 | txt = gettxt(t, d) 178 | rand_negative_ctxs.append(txt) 179 | else: 180 | if x not in cands: 181 | if len(positive_ctxs) < 10: 182 | txt = gettxt(t, d) 183 | positive_ctxs.append(txt) 184 | txt = gettxt(t, d) 185 | positive_ctxs.append(txt) 186 | if qid not in add_dic.keys(): 187 | add_dic[qid] = [] 188 | add_dic[qid].append(len(positive_ctxs) - 1) 189 | json.dump({'id': qid, 190 | "question": q, 191 | "reference": a, 192 | 'task': args.dataset.upper(), 193 | 'gold_ctxs': positive_ctxs, 194 | 'rand_ctxs': rand_negative_ctxs, 195 | 'strong_ctxs': hard_negative_ctxs, 196 | 'weak_ctxs': less_hard_negative_ctxs, 197 | 'dense_ctxs': v2_ctxs, 198 | 'sparse_ctxs': bm25_ctxs, 199 | }, f, ensure_ascii=False) 200 | f.write('\n') 201 | f.close() 202 | json.dump(add_dic, open(source_dic[args.dataset]['outfile'] + 'add_dict.json', 'w', encoding='utf-8'), ensure_ascii=False) 203 | 204 | 205 | if __name__ == "__main__": 206 | main() 207 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import json 4 | import copy 5 | import os 6 | import re 7 | import logging 8 | import string 9 | import regex 10 | import unicodedata 11 | from tqdm import tqdm 12 | 13 | 14 | logger = logging.getLogger() 15 | 16 | 17 | def has_answer(answers, text, match_type="string"): 18 | class Tokens(object): 19 | """A class to represent a list of tokenized text.""" 20 | TEXT = 0 21 | TEXT_WS = 1 22 | SPAN = 2 23 | POS = 3 24 | LEMMA = 4 25 | NER = 5 26 | 27 | def __init__(self, data, annotators, opts=None): 28 | self.data = data 29 | self.annotators = annotators 30 | self.opts = opts or {} 31 | 32 | def __len__(self): 33 | """The number of tokens.""" 34 | return len(self.data) 35 | 36 | def slice(self, i=None, j=None): 37 | """Return a view of the list of tokens from [i, j).""" 38 | new_tokens = copy.copy(self) 39 | new_tokens.data = self.data[i: j] 40 | return new_tokens 41 | 42 | def untokenize(self): 43 | """Returns the original text (with whitespace reinserted).""" 44 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 45 | 46 | def words(self, uncased=False): 47 | """Returns a list of the text of each token 48 | Args: 49 | uncased: lower cases text 50 | """ 51 | if uncased: 52 | return [t[self.TEXT].lower() for t in self.data] 53 | else: 54 | return [t[self.TEXT] for t in self.data] 55 | 56 | def offsets(self): 57 | """Returns a list of [start, end) character offsets of each token.""" 58 | return [t[self.SPAN] for t in self.data] 59 | 60 | def pos(self): 61 | """Returns a list of part-of-speech tags of each token. 62 | Returns None if this annotation was not included. 63 | """ 64 | if 'pos' not in self.annotators: 65 | return None 66 | return [t[self.POS] for t in self.data] 67 | 68 | def lemmas(self): 69 | """Returns a list of the lemmatized text of each token. 70 | Returns None if this annotation was not included. 71 | """ 72 | if 'lemma' not in self.annotators: 73 | return None 74 | return [t[self.LEMMA] for t in self.data] 75 | 76 | def entities(self): 77 | """Returns a list of named-entity-recognition tags of each token. 78 | Returns None if this annotation was not included. 79 | """ 80 | if 'ner' not in self.annotators: 81 | return None 82 | return [t[self.NER] for t in self.data] 83 | 84 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 85 | """Returns a list of all ngrams from length 1 to n. 86 | Args: 87 | n: upper limit of ngram length 88 | uncased: lower cases text 89 | filter_fn: user function that takes in an ngram list and returns 90 | True or False to keep or not keep the ngram 91 | as_string: return the ngram as a string vs list 92 | """ 93 | 94 | def _skip(gram): 95 | if not filter_fn: 96 | return False 97 | return filter_fn(gram) 98 | 99 | words = self.words(uncased) 100 | ngrams = [(s, e + 1) 101 | for s in range(len(words)) 102 | for e in range(s, min(s + n, len(words))) 103 | if not _skip(words[s:e + 1])] 104 | 105 | # Concatenate into strings 106 | if as_strings: 107 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 108 | 109 | return ngrams 110 | 111 | def entity_groups(self): 112 | """Group consecutive entity tokens with the same NER tag.""" 113 | entities = self.entities() 114 | if not entities: 115 | return None 116 | non_ent = self.opts.get('non_ent', 'O') 117 | groups = [] 118 | idx = 0 119 | while idx < len(entities): 120 | ner_tag = entities[idx] 121 | # Check for entity tag 122 | if ner_tag != non_ent: 123 | # Chomp the sequence 124 | start = idx 125 | while (idx < len(entities) and entities[idx] == ner_tag): 126 | idx += 1 127 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 128 | else: 129 | idx += 1 130 | return groups 131 | 132 | 133 | class Tokenizer(object): 134 | """Base tokenizer class. 135 | Tokenizers implement tokenize, which should return a Tokens class. 136 | """ 137 | 138 | def tokenize(self, text): 139 | raise NotImplementedError 140 | 141 | def shutdown(self): 142 | pass 143 | 144 | def __del__(self): 145 | self.shutdown() 146 | 147 | 148 | class SimpleTokenizer(Tokenizer): 149 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 150 | NON_WS = r'[^\p{Z}\p{C}]' 151 | 152 | def __init__(self, **kwargs): 153 | """ 154 | Args: 155 | annotators: None or empty set (only tokenizes). 156 | """ 157 | self._regexp = regex.compile( 158 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 159 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 160 | ) 161 | if len(kwargs.get('annotators', {})) > 0: 162 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 163 | (type(self).__name__, kwargs.get('annotators'))) 164 | self.annotators = set() 165 | 166 | def tokenize(self, text): 167 | data = [] 168 | matches = [m for m in self._regexp.finditer(text)] 169 | for i in range(len(matches)): 170 | # Get text 171 | token = matches[i].group() 172 | 173 | # Get whitespace 174 | span = matches[i].span() 175 | start_ws = span[0] 176 | if i + 1 < len(matches): 177 | end_ws = matches[i + 1].span()[0] 178 | else: 179 | end_ws = span[1] 180 | 181 | # Format data 182 | data.append(( 183 | token, 184 | text[start_ws: end_ws], 185 | span, 186 | )) 187 | return Tokens(data, self.annotators) 188 | 189 | tokenizer = SimpleTokenizer() 190 | text = unicodedata.normalize('NFD', text) 191 | if match_type == 'string': 192 | text = tokenizer.tokenize(text).words(uncased=True) 193 | for single_answer in answers: 194 | single_answer = unicodedata.normalize('NFD', single_answer) 195 | single_answer = tokenizer.tokenize(single_answer) 196 | single_answer = single_answer.words(uncased=True) 197 | for i in range(0, len(text) - len(single_answer) + 1): 198 | if single_answer == text[i: i+ len(single_answer)]: 199 | return 1 200 | return 0 201 | 202 | 203 | def _normalize_answer(s): 204 | def remove_articles(text): 205 | return re.sub(r"\b(a|an|the)\b", " ", text) 206 | 207 | def white_space_fix(text): 208 | return " ".join(text.split()) 209 | 210 | def remove_punc(text): 211 | exclude = set(string.punctuation) 212 | return "".join(ch for ch in text if ch not in exclude) 213 | 214 | def lower(text): 215 | return text.lower() 216 | 217 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 218 | 219 | def EM_compute(answer_list, prediction): 220 | return max([int(_normalize_answer(prediction) == _normalize_answer(ground_truth)) for ground_truth in answer_list]) 221 | 222 | def F1_compute(answers, pred): 223 | def get_tokens(s): 224 | if not s: return [] 225 | return _normalize_answer(s).split() 226 | 227 | def compute_f1(a_gold, a_pred): 228 | gold_toks = get_tokens(a_gold) 229 | pred_toks = get_tokens(a_pred) 230 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 231 | num_same = sum(common.values()) 232 | if len(gold_toks) == 0 or len(pred_toks) == 0: 233 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 234 | return int(gold_toks == pred_toks) 235 | if num_same == 0: 236 | return 0 237 | precision = 1.0 * num_same / len(pred_toks) 238 | recall = 1.0 * num_same / len(gold_toks) 239 | f1 = (2 * precision * recall) / (precision + recall) 240 | return f1 241 | return max([compute_f1(x, pred) for x in answers]) 242 | 243 | 244 | def deal_judge(pred): 245 | if pred is None: 246 | return True 247 | if has_answer(["unknown", "no specific answer", "not provide", "cannot answer", "no information provided", "no answer", "not contain", "no definitive answer"], pred): 248 | return True 249 | return False 250 | 251 | 252 | def deal_answer(pred, answers): 253 | if pred is None: 254 | return 0, 0 255 | if pred.lower().startswith("answer:"): 256 | pred = pred[7:] 257 | return EM_compute(answers, pred), F1_compute(answers, pred) 258 | 259 | 260 | def deal_post(pred): 261 | giveup, istrue = True, None 262 | if pred is None: 263 | return giveup, istrue 264 | if has_answer(["unclear", "not clear", "unknown", "partially correct", "partially incorrect", "not correct", "cannot determine", "cannot answer", "not incorrect", "incomplete"], pred): 265 | giveup = True 266 | elif has_answer(["correct", "true"], pred): 267 | giveup, istrue = False, True 268 | elif has_answer(["incorrect", "false"], pred): 269 | giveup, istrue = False, False 270 | else: 271 | giveup = True 272 | return giveup, istrue 273 | 274 | 275 | def str2paras(s): 276 | if s is None: 277 | return None 278 | paras = [] 279 | for text in s.split('\n'): 280 | if text.strip() != '': 281 | paras.append(": " + text) 282 | return paras 283 | 284 | 285 | if __name__ == "__main__": 286 | file_list = os.listdir('d:/pycharmfiles/chat') 287 | 288 | for file in file_list: 289 | if not file.endswith('post'): 290 | continue 291 | print(file) 292 | indir = os.path.join('d:/pycharmfiles/chat', file) 293 | outdir = os.path.join('d:/pycharmfiles/llm_re/nq/data', file) 294 | outstr = "" 295 | infile = open(indir, 'r', encoding='utf-8') 296 | for line in tqdm(infile.readlines()): 297 | d = json.loads(line) 298 | if 'Prediction' in d.keys(): 299 | d['Giveup'], d['EM'], d['F1'] = deal_answer(d['Prediction'], d['reference']) 300 | if 'Post' in d.keys(): 301 | d['Post_Giveup'], d['Post_True']= deal_post(d['Post']) 302 | outstr += json.dumps(d) + '\n' 303 | infile.close() 304 | outfile = open(outdir, 'w', encoding='utf-8') 305 | outfile.write(outstr) 306 | outfile.close() 307 | 308 | 309 | def load_source(file): 310 | data = [] 311 | f = open(file, 'r', encoding='utf-8') 312 | for line in f.readlines(): 313 | data.append(json.loads(line)) 314 | f.close() 315 | return data 316 | --------------------------------------------------------------------------------