├── .gitignore ├── LICENSE ├── README.md ├── factscore ├── __init__.py ├── abstain_detection.py ├── atomic_facts.py ├── clm.py ├── download_data.py ├── factscorer.py ├── lm.py ├── npm.py ├── openai_lm.py ├── retrieval.py └── utils.py ├── preprocessing └── preprocess_acl.py ├── pyproject.toml ├── requirements.txt └── roberta_stopwords.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | bin/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Installer logs 24 | pip-log.txt 25 | pip-delete-this-directory.txt 26 | 27 | # Unit test / coverage reports 28 | .tox/ 29 | .coverage 30 | .cache 31 | nosetests.xml 32 | coverage.xml 33 | 34 | # Translations 35 | *.mo 36 | 37 | # Mr Developer 38 | .mr.developer.cfg 39 | .project 40 | .pydevproject 41 | 42 | # Rope 43 | .ropeproject 44 | 45 | # Django stuff: 46 | *.log 47 | *.pot 48 | 49 | # Sphinx documentation 50 | docs/_build/ 51 | 52 | demos 53 | original_generation 54 | enwiki-20230401.db 55 | api.key 56 | editing-data 57 | editing-demos 58 | data 59 | 60 | llama-7B 61 | inst-llama-7B 62 | inst-llama-7B.zip 63 | 64 | .cache 65 | fs-venv 66 | 67 | poetry.lock 68 | 69 | acl-publication-info.74k.parquet 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sewon Min 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FActScore 2 | 3 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) 4 | [![arxiv](https://img.shields.io/badge/arXiv-2305.14251-b31b1b.svg)](https://arxiv.org/abs/2305.14251) 5 | [![PyPI version factscore](https://badge.fury.io/py/factscore.svg)](https://pypi.python.org/pypi/factscore/) 6 | [![Downloads](https://pepy.tech/badge/factscore)](https://pepy.tech/project/factscore) 7 | 8 | This is the official release accompanying our EMNLP 2023 paper, [FActScore: Fine-grained Atomic Evaluation of Factual Precision in Long Form Text Generation](https://arxiv.org/abs/2305.14251). FActScore is available as a PIP package as well. 9 | 10 | If you find FActScore useful, please cite: 11 | ``` 12 | @inproceedings{ factscore, 13 | title={ {FActScore}: Fine-grained Atomic Evaluation of Factual Precision in Long Form Text Generation }, 14 | author={ Min, Sewon and Krishna, Kalpesh and Lyu, Xinxi and Lewis, Mike and Yih, Wen-tau and Koh, Pang Wei and Iyyer, Mohit and Zettlemoyer, Luke and Hajishirzi, Hannaneh }, 15 | year={ 2023 }, 16 | booktitle = { EMNLP }, 17 | url={ https://arxiv.org/abs/2305.14251 } 18 | } 19 | ``` 20 | 21 | ## Announcement 22 | * **11/04/2023**: The data we release includes human annotations of factual precision reported in Section 3 of [the paper](https://arxiv.org/abs/2305.14251). If you want to download these human annotated data *only*, without other data, you can download it directly from [this Google Drive link](https://drive.google.com/drive/folders/1kFey69z8hGXScln01mVxrOhrqgM62X7I?usp=sharing). We are also releasing FActScore results of 12 different LMs reported in Section 4.3 of the paper, in case you want to obtain them without running the code. Please refer to [here](#factscore-results-of-the-unlabeled-data). 23 | 24 | ## Install 25 | 30 | 31 | Make a new Python 3.7+ environment using `virtualenv` or `conda`. 32 | 33 | ```bash 34 | pip install --upgrade factscore 35 | python -m spacy download en_core_web_sm 36 | ``` 37 | 38 | ## Download the data 39 | 40 | ```bash 41 | python -m factscore.download_data --llama_7B_HF_path "llama-7B" 42 | ``` 43 | 44 | This command does the following. 45 | 1. Download the knowledge source and example data. 46 | 2. Take the LLAMA 7B model and reconstruct Inst-LLAMA. This requires having access to HuggingFace weights of the LLAMA-7B model, which are added to the `--llama_7B_HF_path` flag. Follow [this guide](https://huggingface.co/docs/transformers/main/model_doc/llama) in order to obtain those weights. Skip the `--llama_7B_HF_path` if you would only like to use the ChatGPT version of FActScore. 47 | 48 | **Optional flags**: 49 | - `--data_dir`: directory to store the knowledge source and example data. `.cache/factscore` by default. 50 | - `--model_dir`: directory to store Inst-LLAMA weights. `.cache/factscore` by default. 51 | 52 | **Troubleshooting**: 53 | - If you get a `ERROR 429: Too Many Requests` error while downloading the DB file, please download the DB from [this Google Drive link](https://drive.google.com/drive/folders/1kFey69z8hGXScln01mVxrOhrqgM62X7I?usp=sharing) and place it under `--data_dir` (`.cache/factscore` by default). 54 | - If everything else fails, consider downloading the files manually from [this link](https://drive.google.com/drive/folders/1kFey69z8hGXScln01mVxrOhrqgM62X7I?usp=sharing) and placing them in `--data_dir` and `--model_dir`, see [`factscore/download_data.py`](factscore/download_data.py) for more details. 55 | 56 | 57 | ## Running FActScore using a command line 58 | 59 | We expect running FActScore costs about $1 of the API cost per 100 sentences. For instance, if you have 100 generations, each with 5 sentences on average, it costs $5 in total. 60 | 61 | ```bash 62 | python -m factscore.factscorer --input_path {input_path} --model_name {estimator_name} --openai_key {openai_key} 63 | ``` 64 | 65 | - `--input_path` can be something like `data/unlabeled/InstructGPT.jsonl`. It should be a `.jsonl` format where each line contains `topic` (a topic entity that corresponds to the Wikipedia title) and `output` (a generation from the model). 66 | - `--model_name`: `retrieval+ChatGPT` and `retrieval+llama+npm` (You can also use `retrieval+ChatGPT+npm` or `retrieval+llama` but we recommend the former two.) 67 | - `--openai_key`: File containing OpenAI API Key. 68 | 69 | **Optional flags**: 70 | - `--data_dir`: Directory containing knowledge source, etc. `.cache/factscore` by default. 71 | - `--model_dir`: Directory containing Inst-LLAMA weights. Skip if your `model_name` doesn't include `llama`. `.cache/factscore` by default. 72 | - `--cache_dir`: Directory containing cache from API/models. `.cache/factscore` by default. 73 | - `--use_atomic_facts`: If specified, it uses model-generated atomic facts released as part of our data instead of running the atomic fact generator. This will allow reproducing our results with no (or little if it still uses ChatGPT) cost. You can't specify it if you are running new model generations. 74 | - `--gamma`: A hyperparameter for length penalty. `10` by default. It penalizes the score if the number of facts is less than `gamma`. `10` roughly corresponds to 2 sentences, so would penalize if the generation has less than 2 sentences. Usually, this would not change the ranking between systems unless some systems generate overly short responses all the time (e.g., models trained on NLP datasets without long-form generation tasks may do so). If you would like to turn off the length penalty completely, specify `--gamma 0`. 75 | - `--n_samples`: If specified, it runs the model on a subset of the data. 76 | - `--verbose`: If specified, it shows the progress bar. 77 | - `--print_rate_limit_error`: It specified, it prints out rate limit errors from OpenAI API. 78 | - `--cost_estimate`: This flag decides the type of OpenAI API cost estimation that we provide before calling it. It can be `"consider_cache"` (default) or `"ignore_cache"`. 79 | - `--abstain_detection`: This flag optionally enables automatic detection of abstained responses. By default this is disabled, but it is recommended to add your own function tailored to your model. The currently supported detectors are `"generic"` and `"perplexity_ai"`, and their implementations can be found in [`factscore/abstain_detection.py`](factscore/abstain_detection.py). There are two methods to add your own abstain function: a) clone our GitHub repository to install `factscore` locally (`pip install --editable .`), and then add your function to [`factscore/abstain_detection.py`](factscore/abstain_detection.py) directly; b) process your abstain detection outside our package, and use empty strings in the `output` key for the JSONL file used in `--input_path`. 80 | - `--knowledge_source`: In case the default knowledge source (Wikipedia - 2023/04/01) will not be used, preprocess it using the [instructions below](#To-use-a-custom-knowledge-source), and then specify the knowledge_source name under this flag. 81 | 82 | ## To evaluate your own LM 83 | 84 | There're two sets of prompt entities, `data/labeled/prompt_entities.txt` (183 entities) and `data/unlabeled/prompt_entities.txt` (500 entities). Each line contains the name of the person (which is also a corresponding Wikipedia title). You can use the labeled version if you want to be compatible with the data under `data/labeled` (Section 3 and Section 4.2 in the paper), and use the unlabeled version if you want to be compatible with the data under `data/unlabeled` (Section 4.3 in the paper). 85 | 86 | You can prompt your LM with your own prompt (we used `Question: Tell me a bio of .`) and use the following code. 87 | 88 | ```python 89 | from factscore.factscorer import FactScorer 90 | 91 | fs = FactScorer(openai_key="...") 92 | 93 | # topics: list of strings (human entities used to generate bios) 94 | # generations: list of strings (model generations) 95 | out = fs.get_score(topics, generations, gamma=10) 96 | print (out["score"]) # FActScore 97 | print (out["init_score"]) # FActScore w/o length penalty 98 | print (out["respond_ratio"]) # % of responding (not abstaining from answering) 99 | print (out["num_facts_per_response"]) # average number of atomic facts per response 100 | ``` 101 | 102 | Alternatively, you can create a .jsonl file, where each line has `topic` (entity name, exactly same as the one from `.txt` file) and `output` (generation from LM), and then use a command line [above](#Running-FActScore-using-a-command-line). 103 | 104 | We recommend using (A) `FactScorer(model_name="retrieval+ChatGPT")` (default) or (B) `FactScorer(model_name="retrieval+llama+npm")`. They have 0.99 Pearson correlation. Here're results of a range of models, which you can easily reproduce through [these command lines](#Running-FActScore-using-a-command-line). 105 | 106 | | Model | % respond | # facts | FActScore from (A) | FActScore from (B) | 107 | |---|---|---|---|---| 108 | | [GPT-4](https://arxiv.org/abs/2303.08774) | 88.2 | 60.8 | 73.1 | 59.9 | 109 | | [ChatGPT](https://openai.com/blog/chatgpt) | 84.2 | 37.0 | 71.6 | 60.4 | 110 | | [Alpaca 65B](https://crfm.stanford.edu/2023/03/13/alpaca.html) | 100.0 | 17.1 | 55.6 | 46.3 | 111 | | [InstructGPT](https://openai.com/research/instruction-following) | 99.8 | 27.7 | 52.8 | 41.7 | 112 | | [Alpaca 13B](https://crfm.stanford.edu/2023/03/13/alpaca.html) | 100.0 | 16.6 | 47.7 | 40.3 | 113 | | [Vicuna 13B](https://lmsys.org/blog/2023-03-30-vicuna/) | 76.6 | 50.9 | 46.6 | 40.7 | 114 | | [Alpaca 7B](https://crfm.stanford.edu/2023/03/13/alpaca.html) | 100.0 | 17.4 | 39.7 | 36.5 | 115 | | [Vicuna 7B](https://lmsys.org/blog/2023-03-30-vicuna/) | 91.0 | 45.6 | 38.9 | 36.9 | 116 | | [MPT Chat 7B](https://www.mosaicml.com/blog/mpt-7b) | 88.8 | 37.3 | 30.1 | 27.9 | 117 | | [Oasst Pythia 12B](https://huggingface.co/OpenAssistant/oasst-sft-1-pythia-12b) | 100.0 | 39.7 | 25.1 | 20.8 | 118 | | [Dolly 12B](https://huggingface.co/databricks/dolly-v2-12b) | 100.0 | 24.6 | 21.7 | 17.1 | 119 | | [StableLM tuned 7B](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b) | 66.6 | 38.0 | 17.3 | 16.3 | 120 | 121 | `% respond` (% of responding instead of abstaining from answering) and `# facts` (# of atomic facts per valid response) indicate "factual recall" (how many pieces of information the model gives) and FActScore indicates "factual precision" (how accurate each piece of information the model gives is). 122 | 123 | ## To use a custom knowledge source 124 | 125 | By default, FActScore uses Wikipedia dump from 2023/04/01. But you can also use your own knowledge source! 126 | 127 | The knolwedge source should be ready in a `.jsonl` format, where each line is a dictionary containing `title` and `text`. `text` can either be a string or a list of strings (e.g., sections). 128 | 129 | ```python 130 | from factscore.factscorer import FactScorer 131 | 132 | fs = FactScorer() 133 | 134 | # this will create a database using your file 135 | # for English Wikipedia (18GB)), it takes ~8 hours 136 | # once DB file is created, you can reuse it by only specifying `db_path` 137 | fs.register_knowledge_source(name_of_your_knowledge_source, 138 | data_path=path_to_jsonl_file, 139 | db_path=path_to_output_db_file) 140 | 141 | # now, when you compute a score, specify knowledge source to use 142 | out = fs.get_score(topics, generations, knowledge_source=name_of_your_knowledge_source) 143 | print (out["score"]) # FActScore 144 | print (out["respond_ratio"]) # % of responding (not abstaining from answering) 145 | print (out["num_facts_per_response"]) # average number of atomic facts per response 146 | ``` 147 | 148 | To see an example of constructing the ACL anthology knowledge source, see [`preprocessing/preprocess_acl.py`](preprocessing/preprocess_acl.py). 149 | 150 | ## FActScore results of the unlabeled data 151 | 152 | You can easily reproduce FActScore results of 12 different LMs reported in Section 4.3 of [the paper](https://arxiv.org/abs/2305.14251) using this code. However, if you would like to obtain their predictions without running the code, you can download it from [this Google Drive link](https://drive.google.com/drive/folders/1kFey69z8hGXScln01mVxrOhrqgM62X7I?usp=sharing). 153 | 154 | Each file corresponds to the subject LM (LM that generates responses that we are validating). Each line is a dictionary: 155 | - `prompt`: the initial prompt fed into the LM 156 | - `facts`: atomic facts decomposed by the model 157 | - `LLAMA+NP_labels`: labels to facts, verified by LLAMA+NP 158 | - `ChatGPT_labels`: labels to facts, verified by ChatGPT 159 | 160 | Note that the number of lines may be less than 500, because it excludes the cases where the model abstains from responding (e.g., it says "I don't know"). You can do `# of lines / 500` to calculate the response ratio. 161 | 162 | If you unzip the data and run the following code for verification, you will be able to get statistics that exactly match the statistics reported in the paper (Table 5 and Figure 3). 163 | ```python 164 | dirname = "factscore-unlabeled-predictions" 165 | for fn in os.listdir(dirname): 166 | chatgpt_fs = [] 167 | llama_fs = [] 168 | n_facts = [] 169 | with open(os.path.join(dirname, fn)) as f: 170 | for line in f: 171 | dp = json.loads(line) 172 | n_facts.append(len(dp["facts"])) 173 | if "ChatGPT_Labels" in dp: 174 | chatgpt_fs.append(np.mean([l=="S" for l in dp["ChatGPT_Labels"]])) 175 | llama_fs.append(np.mean([l=="S" for l in dp["LLAMA+NP_Labels"]])) 176 | print ("Model=%s\t(%.1f%% responding, %.1f facts/response)\tFactScore=%.1f (ChatGPT)\t%.1f (LLAMA)" % ( 177 | fn.split(".")[0], len(n_facts)*100/500, np.mean(n_facts), np.mean(chatgpt_fs)*100, np.mean(llama_fs)*100 178 | )) 179 | ``` 180 | 181 | -------------------------------------------------------------------------------- /factscore/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shmsw25/FActScore/f28272deffcf33efc1f1117d5479c10bb75221a9/factscore/__init__.py -------------------------------------------------------------------------------- /factscore/abstain_detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | invalid_ppl_mentions = [ 5 | "I could not find any information", 6 | "The search results do not provide", 7 | "There is no information", 8 | "There are no search results", 9 | "there are no provided search results", 10 | "not provided in the search results", 11 | "is not mentioned in the provided search results", 12 | "There seems to be a mistake in the question", 13 | "Not sources found", 14 | "No sources found", 15 | "Try a more general question" 16 | ] 17 | 18 | def remove_citation(text): 19 | # text = re.sub(r'\[\d+\]', '', text) 20 | text = re.sub(r"\s*\[\d+\]\s*","", text) 21 | if text.startswith("According to , "): 22 | text = text.replace("According to , ", "According to the search results, ") 23 | return text 24 | 25 | def is_invalid_ppl(text): 26 | return np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions]) 27 | 28 | def is_invalid_paragraph_ppl(text): 29 | return len(text.strip())==0 or np.any([mention.lower() in text.lower() for mention in invalid_ppl_mentions]) 30 | 31 | def perplexity_ai_abstain_detect(generation): 32 | output = remove_citation(generation) 33 | if is_invalid_ppl(output): 34 | return True 35 | valid_paras = [] 36 | for para in output.split("\n\n"): 37 | if is_invalid_paragraph_ppl(para): 38 | break 39 | valid_paras.append(para.strip()) 40 | 41 | if len(valid_paras) == 0: 42 | return True 43 | else: 44 | return False 45 | 46 | def generic_abstain_detect(generation): 47 | return generation.startswith("I'm sorry") or "provide more" in generation 48 | 49 | def is_response_abstained(generation, fn_type): 50 | if fn_type == "perplexity_ai": 51 | return perplexity_ai_abstain_detect(generation) 52 | 53 | elif fn_type == "generic": 54 | return generic_abstain_detect(generation) 55 | 56 | else: 57 | return False 58 | 59 | -------------------------------------------------------------------------------- /factscore/atomic_facts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import re 4 | import functools 5 | import string 6 | import spacy 7 | import sys 8 | import nltk 9 | import openai 10 | from rank_bm25 import BM25Okapi 11 | import os 12 | import time 13 | from nltk.tokenize import sent_tokenize 14 | 15 | from factscore.openai_lm import OpenAIModel 16 | 17 | nltk.download("punkt") 18 | 19 | 20 | class AtomicFactGenerator(object): 21 | def __init__(self, key_path, demon_dir, gpt3_cache_file=None): 22 | self.nlp = spacy.load("en_core_web_sm") 23 | self.is_bio = True 24 | self.demon_path = os.path.join(demon_dir, "demons.json" if self.is_bio else "demons_complex.json") 25 | 26 | self.openai_lm = OpenAIModel("InstructGPT", cache_file=gpt3_cache_file, key_path=key_path) 27 | 28 | # get the demos 29 | with open(self.demon_path, 'r') as f: 30 | self.demons = json.load(f) 31 | 32 | tokenized_corpus = [doc.split(" ") for doc in self.demons.keys()] 33 | self.bm25 = BM25Okapi(tokenized_corpus) 34 | 35 | def save_cache(self): 36 | self.openai_lm.save_cache() 37 | 38 | def run(self, generation, cost_estimate=None): 39 | """Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None.""" 40 | assert isinstance(generation, str), "generation must be a string" 41 | paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0] 42 | return self.get_atomic_facts_from_paragraph(paragraphs, cost_estimate=cost_estimate) 43 | 44 | def get_atomic_facts_from_paragraph(self, paragraphs, cost_estimate=None): 45 | sentences = [] 46 | para_breaks = [] 47 | for para_idx, paragraph in enumerate(paragraphs): 48 | if para_idx > 0 : 49 | para_breaks.append(len(sentences)) 50 | 51 | initials = detect_initials(paragraph) 52 | 53 | curr_sentences = sent_tokenize(paragraph) 54 | curr_sentences_2 = sent_tokenize(paragraph) 55 | 56 | curr_sentences = fix_sentence_splitter(curr_sentences, initials) 57 | curr_sentences_2 = fix_sentence_splitter(curr_sentences_2, initials) 58 | 59 | # checking this, just to ensure the crediability of the sentence splitter fixing algorithm 60 | assert curr_sentences == curr_sentences_2, (paragraph, curr_sentences, curr_sentences_2) 61 | 62 | sentences += curr_sentences 63 | 64 | atoms_or_estimate = self.get_init_atomic_facts_from_sentence([sent for i, sent in enumerate(sentences) if not (not self.is_bio and ( \ 65 | (i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \ 66 | (i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))))], cost_estimate=cost_estimate) 67 | 68 | if cost_estimate: 69 | return atoms_or_estimate 70 | else: 71 | atoms = atoms_or_estimate 72 | 73 | atomic_facts_pairs = [] 74 | for i, sent in enumerate(sentences): 75 | if not self.is_bio and ( \ 76 | (i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \ 77 | (i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))): 78 | atomic_facts_pairs.append((sent, [])) 79 | elif self.is_bio and sent.startswith("This sentence does not contain any facts"): 80 | atomic_facts_pairs.append((sent, [])) 81 | elif sent.startswith("Sure") or sent.startswith("Please") or (i==0 and sent.startswith("Here are")): 82 | atomic_facts_pairs.append((sent, [])) 83 | else: 84 | atomic_facts_pairs.append((sent, atoms[sent])) 85 | 86 | # postprocess_atomic_facts will fix minor issues from InstructGPT 87 | # it is supposed to handle sentence splitter issue too, but since here 88 | # we fixed sentence splitter issue already, 89 | # the new para_breaks should be identical to the original para_breaks 90 | if self.is_bio: 91 | atomic_facts_pairs, para_breaks = postprocess_atomic_facts(atomic_facts_pairs, list(para_breaks), self.nlp) 92 | 93 | return atomic_facts_pairs, para_breaks 94 | 95 | 96 | def get_init_atomic_facts_from_sentence(self, sentences, cost_estimate=None): 97 | """Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None.""" 98 | 99 | is_bio = self.is_bio 100 | demons = self.demons 101 | 102 | k = 1 if is_bio else 0 103 | n = 7 if is_bio else 8 104 | 105 | prompts = [] 106 | prompt_to_sent = {} 107 | atoms = {} 108 | for sentence in sentences: 109 | if sentence in atoms: 110 | continue 111 | top_machings = best_demos(sentence, self.bm25, list(demons.keys()), k) 112 | prompt = "" 113 | 114 | for i in range(n): 115 | prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(list(demons.keys())[i]) 116 | for fact in demons[list(demons.keys())[i]]: 117 | prompt = prompt + "- {}\n".format(fact) 118 | prompt = prompt + "\n" 119 | 120 | for match in top_machings: 121 | prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(match) 122 | for fact in demons[match]: 123 | prompt = prompt + "- {}\n".format(fact) 124 | prompt = prompt + "\n" 125 | prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(sentence) 126 | prompts.append(prompt) 127 | prompt_to_sent[prompt] = sentence 128 | 129 | if cost_estimate: 130 | total_words_estimate = 0 131 | for prompt in prompts: 132 | if cost_estimate == "consider_cache" and (prompt.strip() + "_0") in self.openai_lm.cache_dict: 133 | continue 134 | total_words_estimate += len(prompt.split()) 135 | return total_words_estimate 136 | else: 137 | for prompt in prompts: 138 | output, _ = self.openai_lm.generate(prompt) 139 | atoms[prompt_to_sent[prompt]] = text_to_sentences(output) 140 | 141 | for key, value in demons.items(): 142 | if key not in atoms: 143 | atoms[key] = value 144 | 145 | return atoms 146 | 147 | 148 | def best_demos(query, bm25, demons_sents, k): 149 | tokenized_query = query.split(" ") 150 | top_machings = bm25.get_top_n(tokenized_query, demons_sents, k) 151 | return top_machings 152 | 153 | 154 | # transform InstructGPT output into sentences 155 | def text_to_sentences(text): 156 | sentences = text.split("- ")[1:] 157 | sentences = [sent.strip()[:-1] if sent.strip()[-1] == '\n' else sent.strip() for sent in sentences] 158 | if len(sentences) > 0: 159 | if sentences[-1][-1] != '.': 160 | sentences[-1] = sentences[-1] + '.' 161 | else: 162 | sentences = [] 163 | return sentences 164 | 165 | 166 | def normalize_answer(s): 167 | """Lower text and remove punctuation, articles and extra whitespace.""" 168 | def remove_articles(text): 169 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 170 | return re.sub(regex, ' ', text) 171 | def white_space_fix(text): 172 | return ' '.join(text.split()) 173 | def remove_punc(text): 174 | exclude = set(string.punctuation) 175 | return ''.join(ch for ch in text if ch not in exclude) 176 | def lower(text): 177 | return text.lower() 178 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 179 | 180 | MONTHS = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"] 181 | MONTHS = [m.lower() for m in MONTHS] 182 | 183 | def is_num(text): 184 | try: 185 | text = int(text) 186 | return True 187 | except Exception: 188 | return False 189 | 190 | def is_date(text): 191 | text = normalize_answer(text) 192 | for token in text.split(" "): 193 | if (not is_num(token)) and token not in MONTHS: 194 | return False 195 | return True 196 | 197 | def extract_numeric_values(text): 198 | pattern = r'\b\d+\b' # regular expression pattern for integers 199 | numeric_values = re.findall(pattern, text) # find all numeric values in the text 200 | return set([value for value in numeric_values]) # convert the values to float and return as a list 201 | 202 | 203 | def detect_entities(text, nlp): 204 | doc = nlp(text) 205 | entities = set() 206 | 207 | def _add_to_entities(text): 208 | if "-" in text: 209 | for _text in text.split("-"): 210 | entities.add(_text.strip()) 211 | else: 212 | entities.add(text) 213 | 214 | 215 | for ent in doc.ents: 216 | # spacy often has errors with other types of entities 217 | if ent.label_ in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]: 218 | 219 | if is_date(ent.text): 220 | _add_to_entities(ent.text) 221 | else: 222 | for token in ent.text.split(): 223 | if is_date(token): 224 | _add_to_entities(token) 225 | 226 | for new_ent in extract_numeric_values(text): 227 | if not np.any([new_ent in ent for ent in entities]): 228 | entities.add(new_ent) 229 | 230 | return entities 231 | 232 | def postprocess_atomic_facts(_atomic_facts, para_breaks, nlp): 233 | 234 | verbs = ["born.", " appointed.", " characterized.", " described.", " known.", " member.", " advocate.", "served.", "elected."] 235 | permitted_verbs = ["founding member."] 236 | 237 | atomic_facts = [] 238 | new_atomic_facts = [] 239 | new_para_breaks = [] 240 | 241 | for i, (sent, facts) in enumerate(_atomic_facts): 242 | sent = sent.strip() 243 | if len(sent.split())==1 and i not in para_breaks and i > 0: 244 | assert i not in para_breaks 245 | atomic_facts[-1][0] += " " + sent 246 | atomic_facts[-1][1] += facts 247 | else: 248 | if i in para_breaks: 249 | new_para_breaks.append(len(atomic_facts)) 250 | atomic_facts.append([sent, facts]) 251 | 252 | for i, (sent, facts) in enumerate(atomic_facts): 253 | entities = detect_entities(sent, nlp) 254 | covered_entities = set() 255 | # print (entities) 256 | new_facts = [] 257 | for i, fact in enumerate(facts): 258 | if any([fact.endswith(verb) for verb in verbs]) and not any([fact.endswith(verb) for verb in permitted_verbs]): 259 | if any([fact[:-1] in other_fact for j, other_fact in enumerate(facts) if j != i]): 260 | continue 261 | sent_entities = detect_entities(fact, nlp) 262 | covered_entities |= set([e for e in sent_entities if e in entities]) 263 | new_entities = sent_entities - entities 264 | if len(new_entities) > 0: 265 | do_pass = False 266 | for new_ent in new_entities: 267 | pre_ent = None 268 | for ent in entities: 269 | if ent.startswith(new_ent): 270 | pre_ent = ent 271 | break 272 | if pre_ent is None: 273 | do_pass = True 274 | break 275 | fact = fact.replace(new_ent, pre_ent) 276 | covered_entities.add(pre_ent) 277 | if do_pass: 278 | continue 279 | if fact in new_facts: 280 | continue 281 | new_facts.append(fact) 282 | try: 283 | assert entities==covered_entities 284 | except Exception: 285 | new_facts = facts # there is a bug in spacy entity linker, so just go with the previous facts 286 | 287 | new_atomic_facts.append((sent, new_facts)) 288 | 289 | return new_atomic_facts, new_para_breaks 290 | 291 | def is_integer(s): 292 | try: 293 | s = int(s) 294 | return True 295 | except Exception: 296 | return False 297 | 298 | def detect_initials(text): 299 | pattern = r"[A-Z]\. ?[A-Z]\." 300 | match = re.findall(pattern, text) 301 | return [m for m in match] 302 | 303 | def fix_sentence_splitter(curr_sentences, initials): 304 | for initial in initials: 305 | if not np.any([initial in sent for sent in curr_sentences]): 306 | alpha1, alpha2 = [t.strip() for t in initial.split(".") if len(t.strip())>0] 307 | for i, (sent1, sent2) in enumerate(zip(curr_sentences, curr_sentences[1:])): 308 | if sent1.endswith(alpha1 + ".") and sent2.startswith(alpha2 + "."): 309 | # merge sentence i and i+1 310 | curr_sentences = curr_sentences[:i] + [curr_sentences[i] + " " + curr_sentences[i+1]] + curr_sentences[i+2:] 311 | break 312 | sentences = [] 313 | combine_with_previous = None 314 | for sent_idx, sent in enumerate(curr_sentences): 315 | if len(sent.split())<=1 and sent_idx==0: 316 | assert not combine_with_previous 317 | combine_with_previous = True 318 | sentences.append(sent) 319 | elif len(sent.split())<=1: 320 | assert sent_idx > 0 321 | sentences[-1] += " " + sent 322 | combined_with_previous = False 323 | elif sent[0].isalpha() and not sent[0].isupper() and sent_idx > 0: 324 | assert sent_idx > 0, curr_sentences 325 | sentences[-1] += " " + sent 326 | combine_with_previous = False 327 | elif combine_with_previous: 328 | assert sent_idx > 0 329 | sentences[-1] += " " + sent 330 | combine_with_previous = False 331 | else: 332 | assert not combine_with_previous 333 | sentences.append(sent) 334 | return sentences 335 | 336 | 337 | def main(): 338 | generator = AtomicFactGenerator("api.key", "demos", gpt3_cache_dir=None) 339 | atomic_facts, para_breaks = generator.run("Thierry Henry (born 17 August 1977) is a French professional football coach, pundit, and former player. He is considered one of the greatest strikers of all time, and one the greatest players of the Premier League history. He has been named Arsenal F.C's greatest ever player.\n\nHenry made his professional debut with Monaco in 1994 before signing for defending Serie A champions Juventus. However, limited playing time, coupled with disagreements with the club's hierarchy, led to him signing for Premier League club Arsenal for £11 million in 1999.") 340 | 341 | print(atomic_facts) 342 | print(para_breaks) 343 | 344 | if __name__ == "__main__": 345 | main() -------------------------------------------------------------------------------- /factscore/clm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import math 9 | import time 10 | import json 11 | import numpy as np 12 | import torch 13 | from tqdm import tqdm 14 | from collections import defaultdict 15 | 16 | from transformers import AutoModelForCausalLM 17 | from transformers import LlamaTokenizer 18 | 19 | from factscore.utils import convert_model_to_int8_on_gpu 20 | from factscore.lm import LM 21 | 22 | class CLM(LM): 23 | def __init__(self, model_name, model_dir, cache_file=None): 24 | self.model_name = model_name 25 | self.model_dir = model_dir 26 | if cache_file: 27 | super().__init__(cache_file) 28 | 29 | def load_model(self): 30 | self.model = AutoModelForCausalLM.from_pretrained(self.model_dir) 31 | self.model = convert_model_to_int8_on_gpu(self.model, device='cuda') 32 | self.tokenizer = LlamaTokenizer.from_pretrained(self.model_dir) 33 | 34 | def _generate(self, prompts, max_sequence_length=2048, max_output_length=128, 35 | end_if_newline=False, end_if_second_newline=False, verbose=False): 36 | is_single = type(prompts)==str 37 | if is_single: 38 | prompts = [prompts] 39 | 40 | input_ids = self.tokenizer(prompts).input_ids 41 | if verbose: 42 | input_ids = tqdm(input_ids) 43 | 44 | generations = [] 45 | scores = [] 46 | for curr_input_ids in input_ids: 47 | if len(curr_input_ids) > max_sequence_length - max_output_length: 48 | curr_input_ids = curr_input_ids[-(max_sequence_length - max_output_length):] 49 | curr_input_ids = torch.LongTensor([curr_input_ids]).cuda() 50 | gen_outputs = self.model.generate( 51 | curr_input_ids, 52 | max_length=curr_input_ids.shape[1]+max_output_length, 53 | return_dict_in_generate=True, 54 | output_scores=True 55 | ) 56 | gen_tokens = gen_outputs["sequences"] 57 | # saving the logits for the very first token 58 | gen_scores = gen_outputs["scores"][0][0].detach().cpu().numpy() 59 | gen = self.tokenizer.decode(gen_tokens[0, curr_input_ids.shape[-1]:]) 60 | 61 | if end_if_newline: 62 | gen = gen.split("\n")[0].strip() 63 | elif end_if_second_newline: 64 | gen = "\n".join(gen.split("\n")[:2]).strip() 65 | 66 | if verbose and len(generations)==0: 67 | print ("Input:", prompts[0]) 68 | print ("Prediction:", gen) 69 | 70 | if self.model_name.startswith("llama-sni"): 71 | gen = gen.split("")[0] 72 | 73 | generations.append(gen) 74 | scores.append(gen_scores) 75 | 76 | assert len(generations)==len(prompts)==len(scores) 77 | if is_single: 78 | return generations[0], scores[0] 79 | 80 | return generations, scores 81 | 82 | -------------------------------------------------------------------------------- /factscore/download_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | import torch 5 | import tqdm 6 | import transformers 7 | 8 | 9 | def download_file(_id, dest, cache_dir): 10 | if os.path.exists(dest) or os.path.exists(os.path.join(cache_dir, dest)): 11 | print ("[Already exists] Skipping", dest) 12 | print ("If you want to download the file in another location, please specify a different path") 13 | return 14 | 15 | if os.path.exists(dest.replace(".zip", "")) or os.path.exists(os.path.join(cache_dir, dest.replace(".zip", ""))): 16 | print ("[Already exists] Skipping", dest) 17 | print ("If you want to download the file in another location, please specify a different path") 18 | return 19 | 20 | if "/" in dest: 21 | dest_dir = "/".join(dest.split("/")[:-1]) 22 | if not os.path.isdir(dest_dir): 23 | os.makedirs(dest_dir) 24 | else: 25 | dest_dir = "." 26 | 27 | if _id.startswith("https://"): 28 | command = """wget -O %s %s""" % (dest, _id) 29 | else: 30 | command = """wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=%s' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=%s" -O %s && rm -rf /tmp/cookies.txt""" % (_id, _id, dest) 31 | 32 | ret_code = subprocess.run([command], shell=True) 33 | if ret_code.returncode != 0: 34 | print("Download {} ... [Failed]".format(dest)) 35 | else: 36 | print("Download {} ... [Success]".format(dest)) 37 | 38 | if dest.endswith(".zip"): 39 | command = """unzip %s -d %s && rm %s""" % (dest, dest_dir, dest) 40 | 41 | ret_code = subprocess.run([command], shell=True) 42 | if ret_code.returncode != 0: 43 | print("Unzip {} ... [Failed]".format(dest)) 44 | else: 45 | print("Unzip {} ... [Success]".format(dest)) 46 | 47 | 48 | 49 | def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model): 50 | """Resize tokenizer and embedding. 51 | 52 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 53 | """ 54 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 55 | model.resize_token_embeddings(len(tokenizer)) 56 | 57 | if num_new_tokens > 0: 58 | input_embeddings = model.get_input_embeddings().weight.data 59 | output_embeddings = model.get_output_embeddings().weight.data 60 | 61 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 62 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 63 | 64 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 65 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 66 | 67 | 68 | def recover_instruct_llama(path_raw, output_path, device="cpu", test_recovered_model=False): 69 | """Heavily adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/main/weight_diff.py.""" 70 | 71 | model_raw = transformers.AutoModelForCausalLM.from_pretrained( 72 | path_raw, 73 | device_map={"": torch.device(device)}, 74 | torch_dtype=torch.float32, 75 | low_cpu_mem_usage=True, 76 | ) 77 | model_recovered = transformers.AutoModelForCausalLM.from_pretrained( 78 | "kalpeshk2011/instruct-llama-7b-wdiff", 79 | device_map={"": torch.device(device)}, 80 | torch_dtype=torch.float32, 81 | low_cpu_mem_usage=True, 82 | ) 83 | 84 | tokenizer_raw = transformers.AutoTokenizer.from_pretrained(path_raw) 85 | if tokenizer_raw.pad_token is None: 86 | smart_tokenizer_and_embedding_resize( 87 | special_tokens_dict=dict(pad_token="[PAD]"), 88 | model=model_raw, 89 | tokenizer=tokenizer_raw, 90 | ) 91 | tokenizer_recovered = transformers.AutoTokenizer.from_pretrained("kalpeshk2011/instruct-llama-7b-wdiff") 92 | 93 | state_dict_recovered = model_recovered.state_dict() 94 | state_dict_raw = model_raw.state_dict() 95 | for key in tqdm.tqdm(state_dict_recovered): 96 | state_dict_recovered[key].add_(state_dict_raw[key]) 97 | 98 | if output_path is not None: 99 | model_recovered.save_pretrained(output_path) 100 | tokenizer_recovered.save_pretrained(output_path) 101 | 102 | if test_recovered_model: 103 | input_text = ( 104 | "Below is an instruction that describes a task. " 105 | "Write a response that appropriately completes the request.\r\n\r\n" 106 | "### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:" 107 | ) 108 | inputs = tokenizer_recovered(input_text, return_tensors="pt") 109 | out = model_recovered.generate(inputs=inputs.input_ids, max_new_tokens=100) 110 | output_text = tokenizer_recovered.batch_decode(out, skip_special_tokens=True)[0] 111 | output_text = output_text[len(input_text) :] 112 | print(f"Input: {input_text}\nCompletion: {output_text}") 113 | 114 | return model_recovered, tokenizer_recovered 115 | 116 | if __name__ == '__main__': 117 | 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('--data_dir', 120 | type=str, 121 | default=".cache/factscore") 122 | parser.add_argument('--model_dir', 123 | type=str, 124 | default=".cache/factscore") 125 | parser.add_argument('--llama_7B_HF_path', 126 | type=str, 127 | default=None) 128 | 129 | args = parser.parse_args() 130 | 131 | if not os.path.exists(args.model_dir): 132 | os.makedirs(args.model_dir) 133 | 134 | if not os.path.exists(args.data_dir): 135 | os.makedirs(args.data_dir) 136 | 137 | download_file("1sbW6pkYl6cc9gooD4WLaeoFKcAj3poZu", "demos.zip", args.data_dir) 138 | download_file("155exEdKs7R21gZF4G-x54-XN3qswBcPo", "data.zip", args.data_dir) 139 | download_file("1Qu4JHWjpUKhGPaAW5UHhS5RJ545CVy4I", "enwiki-20230401.db", args.data_dir) 140 | 141 | if args.llama_7B_HF_path: 142 | recover_instruct_llama(args.llama_7B_HF_path, os.path.join(args.model_dir, "inst-llama-7B")) 143 | 144 | # download the roberta_stopwords.txt file 145 | subprocess.run(["wget https://raw.githubusercontent.com/shmsw25/FActScore/main/roberta_stopwords.txt"], shell=True) 146 | 147 | # move the files to the data directory 148 | subprocess.run(["mv demos %s" % args.data_dir], shell=True) 149 | subprocess.run(["mv enwiki-20230401.db %s" % args.data_dir], shell=True) 150 | 151 | -------------------------------------------------------------------------------- /factscore/factscorer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import string 3 | import json 4 | import numpy as np 5 | import os 6 | import logging 7 | 8 | from tqdm import tqdm 9 | from factscore.abstain_detection import is_response_abstained 10 | from factscore.atomic_facts import AtomicFactGenerator 11 | from factscore.clm import CLM 12 | from factscore.npm import NPM 13 | from factscore.openai_lm import OpenAIModel 14 | from factscore.retrieval import DocDB, Retrieval 15 | 16 | class FactScorer(object): 17 | 18 | def __init__(self, 19 | model_name="retrieval+ChatGPT", 20 | data_dir=".cache/factscore", 21 | model_dir=".cache/factscore", 22 | cache_dir=".cache/factscore", 23 | openai_key="api.key", 24 | cost_estimate="consider_cache", 25 | abstain_detection_type=None, 26 | batch_size=256): 27 | assert model_name in ["retrieval+llama", "retrieval+llama+npm", "retrieval+ChatGPT", "npm", "retrieval+ChatGPT+npm"] 28 | self.model_name = model_name 29 | 30 | self.db = {} 31 | self.retrieval = {} 32 | self.npm = {} 33 | self.batch_size = batch_size # batch size for retrieval 34 | self.openai_key = openai_key 35 | self.abstain_detection_type = abstain_detection_type 36 | 37 | self.data_dir = data_dir 38 | self.cache_dir = cache_dir 39 | if not os.path.exists(cache_dir): 40 | os.makedirs(cache_dir) 41 | 42 | self.af_generator = None 43 | self.cost_estimate = cost_estimate 44 | 45 | if "llama" in model_name: 46 | self.lm = CLM("inst-llama-7B", 47 | model_dir=os.path.join(model_dir, "inst-llama-7B"), 48 | cache_file=os.path.join(cache_dir, "inst-llama-7B.pkl")) 49 | elif "ChatGPT" in model_name: 50 | self.lm = OpenAIModel("ChatGPT", 51 | cache_file=os.path.join(cache_dir, "ChatGPT.pkl"), 52 | key_path=openai_key) 53 | else: 54 | self.lm = None 55 | 56 | def save_cache(self): 57 | if self.lm: 58 | self.lm.save_cache() 59 | if "npm" in self.model_name: 60 | for k, v in self.npm.items(): 61 | v.save_cache() 62 | for k, v in self.retrieval.items(): 63 | v.save_cache() 64 | 65 | def register_knowledge_source(self, name="enwiki-20230401", db_path=None, data_path=None): 66 | assert name not in self.retrieval, f"{name} already registered" 67 | if db_path is None: 68 | db_path = os.path.join(self.data_dir, f"{name}.db") 69 | 70 | if data_path is None: 71 | data_path = os.path.join(self.data_dir, f"{name}.jsonl") 72 | 73 | cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.json") 74 | embed_cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.pkl") 75 | 76 | self.db[name] = DocDB(db_path=db_path, data_path=data_path) 77 | self.retrieval[name] = Retrieval(self.db[name], cache_path, embed_cache_path, batch_size=self.batch_size) 78 | if "npm" in self.model_name: 79 | cache_path = os.path.join(self.cache_dir, f"bm25-{name}.json") 80 | embed_cache_path = os.path.join(self.cache_dir, f"bm25-{name}.pkl") 81 | self.npm[name] = NPM(Retrieval(self.db[name], cache_path, embed_cache_path, "bm25"), 82 | "npm-single", 83 | cache_file=os.path.join(self.cache_dir, f"npm-{name}.pkl")) 84 | 85 | 86 | def print_cost_estimates(self, total_words, task, model): 87 | # https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them 88 | # Number of tokens are roughly 4/3 of the number of words 89 | total_tokens = total_words * 4.0 / 3 90 | 91 | # https://openai.com/pricing 92 | # if we use davinci-003, the cost is $0.02 per 1000 tokens 93 | # if we use gpt-3.5-turbo, the cost is $0.002 per 1000 tokens 94 | if model == "davinci-003": 95 | rate = 0.02 96 | elif model == "gpt-3.5-turbo": 97 | rate = 0.002 98 | 99 | total_cost = total_tokens * rate / 1000 100 | 101 | # print the total words, tokens, and cost along with rate 102 | logging.critical("Estimated OpenAI API cost for %s ($%.3f per 1000 tokens): $%.2f for %d words and %d tokens" % (task, rate, total_cost, total_words, total_tokens)) 103 | 104 | def get_score(self, 105 | topics, 106 | generations, 107 | gamma=10, 108 | atomic_facts=None, 109 | knowledge_source=None, 110 | verbose=False): 111 | if knowledge_source is None: 112 | # use the default knowledge source 113 | knowledge_source = "enwiki-20230401" 114 | 115 | if knowledge_source not in self.retrieval: 116 | self.register_knowledge_source(knowledge_source) 117 | 118 | if type(topics)==type(generations)==str: 119 | topics = [topics] 120 | generations = [generations] 121 | else: 122 | assert type(topics)==type(generations)==list, "`topics` and `generations` should be lists." 123 | assert len(topics)==len(generations), "`topics` and `generations` should have the same length" 124 | 125 | if atomic_facts is not None: 126 | assert len(topics)==len(atomic_facts), "`topics` and `atomic_facts` should have the same length" 127 | else: 128 | if self.af_generator is None: 129 | self.af_generator = AtomicFactGenerator(key_path=self.openai_key, 130 | demon_dir=os.path.join(self.data_dir, "demos"), 131 | gpt3_cache_file=os.path.join(self.cache_dir, "InstructGPT.pkl")) 132 | 133 | # estimate the total cost of atomic fact generation 134 | total_words = 0 135 | for gen in generations: 136 | total_words += self.af_generator.run(gen, cost_estimate=self.cost_estimate) 137 | 138 | self.print_cost_estimates(total_words, task="atomic fact generation", model="davinci-003") 139 | 140 | if verbose: 141 | topics = tqdm(topics) 142 | 143 | atomic_facts = [] 144 | for topic, gen in zip(topics, generations): 145 | # optionally, first detect if the response is abstained 146 | response_abstained = is_response_abstained(gen, self.abstain_detection_type) 147 | if response_abstained: 148 | atomic_facts.append(None) 149 | continue 150 | # continue only when the response is not abstained 151 | curr_afs, _ = self.af_generator.run(gen) 152 | curr_afs = [fact for _, facts in curr_afs for fact in facts] 153 | if len(curr_afs)==0: 154 | atomic_facts.append(None) 155 | else: 156 | atomic_facts.append(curr_afs) 157 | if len(atomic_facts) % 10 == 0: 158 | self.af_generator.save_cache() 159 | 160 | assert len(atomic_facts)==len(topics) 161 | self.af_generator.save_cache() 162 | 163 | respond_ratio = np.mean([facts is not None for facts in atomic_facts]) 164 | 165 | if "ChatGPT" in self.model_name: 166 | # estimate the total cost of response generation 167 | total_words = 0 168 | for topic, generation, facts in zip(topics, generations, atomic_facts): 169 | if facts is not None: 170 | total_words += self._get_score(topic, generation, facts, knowledge_source, cost_estimate=self.cost_estimate) 171 | 172 | self.print_cost_estimates(total_words, task="factscore evaluation", model="gpt-3.5-turbo") 173 | 174 | if verbose: 175 | topics = tqdm(topics) 176 | 177 | scores = [] 178 | init_scores = [] 179 | decisions = [] 180 | for topic, generation, facts in zip(topics, generations, atomic_facts): 181 | if facts is None: 182 | decisions.append(None) 183 | else: 184 | decision = self._get_score(topic, generation, facts, knowledge_source) 185 | score = np.mean([d["is_supported"] for d in decision]) 186 | 187 | if gamma: 188 | init_scores.append(score) 189 | penalty = 1.0 if len(facts)>gamma else np.exp(1-gamma/len(facts)) 190 | score = penalty * score 191 | 192 | decisions.append(decision) 193 | scores.append(score) 194 | if len(scores) % 10 == 0: 195 | self.save_cache() 196 | 197 | self.save_cache() 198 | 199 | out = {"score": np.mean(scores), 200 | "respond_ratio": respond_ratio, 201 | "decisions": decisions, 202 | "num_facts_per_response": np.mean([len(d) for d in decisions if d is not None])} 203 | 204 | if gamma: 205 | out["init_score"] = np.mean(init_scores) 206 | 207 | return out 208 | 209 | def _get_score(self, topic, generation, atomic_facts, knowledge_source, cost_estimate=None): 210 | decisions = [] 211 | total_words = 0 212 | for atom in atomic_facts: 213 | atom = atom.strip() 214 | if self.lm: 215 | passages = self.retrieval[knowledge_source].get_passages(topic, atom, k=5) 216 | definition = "Answer the question about {} based on the given context.\n\n".format(topic) 217 | context = "" 218 | for psg_idx, psg in enumerate(reversed(passages)): 219 | context += "Title: {}\nText: {}\n\n".format(psg["title"], psg["text"].replace("", "").replace("", "")) 220 | definition += context.strip() 221 | if not definition[-1] in string.punctuation: 222 | definition += "." 223 | prompt = "{}\n\nInput: {} True or False?\nOutput:".format(definition.strip(), atom.strip()) 224 | 225 | if cost_estimate: 226 | if cost_estimate == "consider_cache" and (prompt.strip() + "_0") not in self.lm.cache_dict: 227 | total_words += len(prompt.split()) 228 | elif cost_estimate == "ignore_cache": 229 | total_words += len(prompt.split()) 230 | continue 231 | 232 | output = self.lm.generate(prompt) 233 | 234 | if type(output[1])==np.ndarray: 235 | # when logits are available 236 | logits = np.array(output[1]) 237 | assert logits.shape[0] in [32000, 32001] 238 | true_score = logits[5852] 239 | false_score = logits[7700] 240 | is_supported = true_score > false_score 241 | else: 242 | # when logits are unavailable 243 | generated_answer = output[0].lower() 244 | if "true" in generated_answer or "false" in generated_answer: 245 | if "true" in generated_answer and "false" not in generated_answer: 246 | is_supported = True 247 | elif "false" in generated_answer and "true" not in generated_answer: 248 | is_supported = False 249 | else: 250 | is_supported = generated_answer.index("true") > generated_answer.index("false") 251 | else: 252 | is_supported = all([keyword not in generated_answer.lower().translate(str.maketrans("", "", string.punctuation)).split() for keyword in ["not", "cannot", "unknown", "information"]]) 253 | 254 | else: 255 | is_supported = True 256 | 257 | if is_supported and "npm" in self.model_name: 258 | npprob = self.npm[knowledge_source].get_probabilty(topic, atom) 259 | is_supported = npprob > 0.3 260 | 261 | decisions.append({"atom": atom, "is_supported": is_supported}) 262 | 263 | if cost_estimate: 264 | return total_words 265 | else: 266 | return decisions 267 | 268 | if __name__ == '__main__': 269 | 270 | parser = argparse.ArgumentParser() 271 | parser.add_argument('--input_path', 272 | type=str, 273 | default="data/labeled/InstructGPT.jsonl") 274 | parser.add_argument('--model_name', 275 | type=str, 276 | default="retrieval+ChatGPT") 277 | parser.add_argument('--gamma', 278 | type=int, 279 | default=10, 280 | help="hyperparameter for length penalty") 281 | 282 | parser.add_argument('--openai_key', 283 | type=str, 284 | default="api.key") 285 | parser.add_argument('--data_dir', 286 | type=str, 287 | default=".cache/factscore/") 288 | parser.add_argument('--model_dir', 289 | type=str, 290 | default=".cache/factscore/") 291 | parser.add_argument('--cache_dir', 292 | type=str, 293 | default=".cache/factscore/") 294 | parser.add_argument('--knowledge_source', 295 | type=str, 296 | default=None) 297 | 298 | 299 | parser.add_argument('--cost_estimate', 300 | type=str, 301 | default="consider_cache", 302 | choices=["consider_cache", "ignore_cache"]) 303 | parser.add_argument('--abstain_detection_type', 304 | type=str, 305 | default=None, 306 | choices=["perplexity_ai", "generic", "none"]) 307 | parser.add_argument('--use_atomic_facts', 308 | action="store_true") 309 | parser.add_argument('--verbose', 310 | action="store_true", 311 | help="for printing out the progress bar") 312 | parser.add_argument('--print_rate_limit_error', 313 | action="store_true", 314 | help="for printing out rate limit error when using OpenAI keys") 315 | parser.add_argument('--n_samples', 316 | type=int, 317 | default=None) 318 | 319 | args = parser.parse_args() 320 | 321 | logging.basicConfig(format='%(asctime)s - %(name)s - %(message)s', 322 | datefmt='%m/%d/%Y %H:%M:%S', 323 | level=logging.ERROR if args.print_rate_limit_error else logging.CRITICAL) 324 | 325 | fs = FactScorer(model_name=args.model_name, 326 | data_dir=args.data_dir, 327 | model_dir=args.model_dir, 328 | cache_dir=args.cache_dir, 329 | openai_key=args.openai_key, 330 | cost_estimate=args.cost_estimate, 331 | abstain_detection_type=args.abstain_detection_type) 332 | 333 | tot = 0 334 | topics, generations, atomic_facts = [], [], [] 335 | with open(args.input_path) as f: 336 | for line in f: 337 | dp = json.loads(line) 338 | tot += 1 339 | if args.use_atomic_facts: 340 | assert "annotations" in dp, "You can specify `--use_atomic_facts` only when atomic facts are available in the input data already." 341 | if dp["annotations"] is None: 342 | continue 343 | topics.append(dp["topic"]) 344 | generations.append(dp["output"]) 345 | atomic_facts.append([atom["text"] for sent in dp["annotations"] for atom in sent["model-atomic-facts"]]) 346 | else: 347 | topics.append(dp["topic"]) 348 | generations.append(dp["output"]) 349 | if args.n_samples is not None and tot==args.n_samples: 350 | break 351 | out = fs.get_score(topics=topics, 352 | generations=generations, 353 | gamma=args.gamma, 354 | atomic_facts=atomic_facts if args.use_atomic_facts else None, 355 | knowledge_source=args.knowledge_source, 356 | verbose=args.verbose) 357 | logging.critical("FActScore = %.1f%%" % (100*out["score"])) 358 | if "init_score" in out: 359 | logging.critical("FActScore w/o length penalty = %.1f%%" % (100*out["init_score"])) 360 | logging.critical("Respond ratio = %.1f%%" % (100*out["respond_ratio"])) 361 | logging.critical("# Atomic facts per valid response = %.1f" % (out["num_facts_per_response"])) 362 | 363 | # Save out as a json file 364 | with open(args.input_path.replace(".jsonl", f"_factscore_output.json"), 'w') as f: 365 | f.write(json.dumps(out) + "\n") 366 | 367 | -------------------------------------------------------------------------------- /factscore/lm.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import time 4 | 5 | class LM(object): 6 | 7 | def __init__(self, cache_file): 8 | self.cache_file = cache_file 9 | self.cache_dict = self.load_cache() 10 | self.model = None 11 | self.add_n = 0 12 | 13 | def load_model(self): 14 | # load the model and put it as self.model 15 | raise NotImplementedError() 16 | 17 | def generate(self, prompt, sample_idx=0, max_sequence_length=2048, max_output_length=128): 18 | prompt = prompt.strip() # it's important not to end with a whitespace 19 | cache_key = f"{prompt}_{sample_idx}" 20 | 21 | if cache_key in self.cache_dict: 22 | return self.cache_dict[cache_key] 23 | 24 | if self.model is None: 25 | self.load_model() 26 | 27 | if prompt.endswith(" True or False?\nAnswer:"): 28 | generated = self._generate(prompt, max_sequence_length=max_sequence_length, max_output_length=1) 29 | else: 30 | generated = self._generate(prompt, max_sequence_length=max_sequence_length, max_output_length=max_output_length) 31 | 32 | self.cache_dict[cache_key] = generated 33 | self.add_n += 1 34 | return generated 35 | 36 | def save_cache(self): 37 | if self.add_n == 0: 38 | return 39 | 40 | # load the latest cache first, since if there were other processes running in parallel, cache might have been updated 41 | for k, v in self.load_cache().items(): 42 | self.cache_dict[k] = v 43 | 44 | with open(self.cache_file, "wb") as f: 45 | pickle.dump(self.cache_dict, f) 46 | 47 | def load_cache(self, allow_retry=True): 48 | if os.path.exists(self.cache_file): 49 | while True: 50 | try: 51 | with open(self.cache_file, "rb") as f: 52 | cache = pickle.load(f) 53 | break 54 | except Exception: 55 | if not allow_retry: 56 | assert False 57 | print ("Pickle Error: Retry in 5sec...") 58 | time.sleep(5) 59 | else: 60 | cache = {} 61 | return cache 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /factscore/npm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | from collections import defaultdict 5 | from transformers import AutoModelForMaskedLM, AutoTokenizer 6 | 7 | from factscore.lm import LM 8 | from factscore.retrieval import Retrieval 9 | 10 | def softmax(x): 11 | return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()) 12 | 13 | class NPM(LM): 14 | 15 | def __init__(self, bm25, model_name, cache_file): 16 | assert model_name.startswith("npm") 17 | self.bm25 = bm25 18 | self.model_name = model_name 19 | self.model = None 20 | 21 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/" + self.model_name) 22 | self.mask_id = self.tokenizer.mask_token_id 23 | 24 | with open("roberta_stopwords.txt", "r") as f: 25 | self.stopwords = set() 26 | for line in f: 27 | self.stopwords.add(int(line.strip())) 28 | 29 | super().__init__(cache_file=cache_file) 30 | 31 | def load_model(self): 32 | self.model = AutoModelForMaskedLM.from_pretrained("facebook/" + self.model_name) 33 | self.model.cuda() 34 | self.model.eval() 35 | 36 | def save_cache(self): 37 | super().save_cache() 38 | self.bm25.save_cache() 39 | 40 | def tokenize(self, texts, skip_special_tokens=False, padding=True): 41 | assert type(texts)==list 42 | all_input_ids = self.tokenizer(texts)["input_ids"] 43 | if skip_special_tokens: 44 | for i, input_ids in enumerate(all_input_ids): 45 | assert input_ids[0]==0 and input_ids[-1]==2 46 | all_input_ids[i] = input_ids[1:-1] 47 | if not padding: 48 | return all_input_ids 49 | max_length = np.max([len(_ids) for _ids in all_input_ids]) 50 | _all_input_ids = [] 51 | _all_attention_mask = [] 52 | for i, input_ids in enumerate(all_input_ids): 53 | n_valid = len(input_ids) 54 | n_masks = max_length - n_valid 55 | _all_input_ids.append(input_ids + [0 for _ in range(n_masks)]) 56 | _all_attention_mask.append([1 for _ in range(n_valid)] + [0 for _ in range(n_masks)]) 57 | return torch.LongTensor(_all_input_ids), torch.LongTensor(_all_attention_mask) 58 | 59 | def decode(self, input_ids): 60 | return self.tokenizer.decode(input_ids) 61 | 62 | def encode(self, texts, skip_special_tokens=False, gt_input_ids=None): 63 | assert type(texts)==list 64 | if self.model is None: 65 | self.load_model() 66 | if gt_input_ids is not None: 67 | assert len(texts)==len(gt_input_ids) 68 | all_input_ids, all_attention_mask = self.tokenize(texts, skip_special_tokens=skip_special_tokens) 69 | 70 | with torch.no_grad(): 71 | outputs = self.model(all_input_ids.cuda(), 72 | all_attention_mask.cuda(), 73 | output_hidden_states=True, 74 | return_dict=True) 75 | all_logits = outputs["logits"].detach().cpu().numpy() 76 | all_hidden_states = outputs["hidden_states"][-1].detach().cpu().numpy() 77 | 78 | results = [] 79 | for i, (text, input_ids, logits, hidden_states) in enumerate(zip(texts, all_input_ids, all_logits, all_hidden_states)): 80 | input_ids = input_ids.numpy().tolist() 81 | if self.mask_id in input_ids: 82 | idx = input_ids.index(self.mask_id) 83 | assert gt_input_ids is not None 84 | prob = softmax(logits[idx])[gt_input_ids[i]] 85 | results.append((prob, hidden_states[idx])) 86 | else: 87 | _input_ids = [_id for _id in input_ids if _id not in [0, 2]] 88 | _hidden_states = [h for _id, h in zip(input_ids, hidden_states) if _id not in [0, 2]] 89 | results.append((_input_ids, _hidden_states)) 90 | 91 | return results 92 | 93 | def get_probabilty(self, topic, question): 94 | passages = self.bm25.get_passages(topic, question, k=3) 95 | passages = [p["text"].strip() for p in passages] 96 | cache_key = question + "#" + "#".join(passages) 97 | 98 | if cache_key not in self.cache_dict: 99 | encoded = self.encode(passages, skip_special_tokens=True) 100 | stacked_passage_tokens, stacked_passage_vectors = [], [] 101 | for input_ids, vectors in encoded: 102 | stacked_passage_tokens += input_ids 103 | if len(vectors)>0: 104 | stacked_passage_vectors.append(vectors) 105 | stacked_passage_vectors = np.concatenate(stacked_passage_vectors, 0) 106 | 107 | question_input_ids = self.tokenize(["Fact: " + question], skip_special_tokens=False, padding=False)[0] 108 | if 2 in question_input_ids: 109 | question_input_ids = question_input_ids[:question_input_ids.index(2)] 110 | question_input_ids = question_input_ids[1:] 111 | 112 | ''' 113 | triples = [] 114 | prefix = True 115 | for i, input_id in enumerate(question_input_ids): 116 | if prefix: 117 | if input_id==35: # the end of prefix 118 | prefix = False 119 | continue 120 | if input_id in [0, 2] or input_id in self.stopwords: 121 | continue 122 | new_question = self.decode(question_input_ids[:i] + [self.mask_id] + question_input_ids[i+1:]) 123 | prob, vector = self.encode(new_question, gt_input_id=input_id) 124 | triples.append((prob, vector, input_id)) 125 | ''' 126 | triples = [] 127 | batch = [] 128 | gt_input_ids = [] 129 | prefix = True 130 | for i, input_id in enumerate(question_input_ids): 131 | if prefix: 132 | if input_id==35: # the end of prefix 133 | prefix = False 134 | continue 135 | if input_id in [0, 2] or input_id in self.stopwords: 136 | continue 137 | batch.append(self.decode(question_input_ids[:i] + [self.mask_id] + question_input_ids[i+1:])) 138 | gt_input_ids.append(input_id) 139 | for (prob, vector), gt_input_id in zip(self.encode(batch, gt_input_ids=gt_input_ids), gt_input_ids): 140 | triples.append((prob, vector, gt_input_id)) 141 | 142 | stacked_question_vectors = np.stack([v for _, v, _ in triples], 0) 143 | all_scores = np.exp(np.inner(stacked_question_vectors, stacked_passage_vectors) / np.sqrt(stacked_passage_vectors.shape[-1])) 144 | 145 | probs = [] 146 | for (softmax_prob, vector, input_id), scores in zip(triples, all_scores): 147 | assert len(stacked_passage_tokens)==len(scores) 148 | if input_id not in stacked_passage_tokens: 149 | probs.append(0) 150 | else: 151 | aggregated_scores = defaultdict(list) 152 | for token, score in zip(stacked_passage_tokens, scores): 153 | aggregated_scores[token].append(score) 154 | tot = np.sum([np.sum(v) for v in aggregated_scores.values()]) 155 | prob = np.sum(aggregated_scores[input_id]) / tot 156 | probs.append(prob) 157 | 158 | self.cache_dict[cache_key] = np.mean(probs) 159 | self.add_n += 1 160 | 161 | return self.cache_dict[cache_key] 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /factscore/openai_lm.py: -------------------------------------------------------------------------------- 1 | from factscore.lm import LM 2 | import openai 3 | import sys 4 | import time 5 | import os 6 | import numpy as np 7 | import logging 8 | 9 | class OpenAIModel(LM): 10 | 11 | def __init__(self, model_name, cache_file=None, key_path="api.key"): 12 | self.model_name = model_name 13 | self.key_path = key_path 14 | self.temp = 0.7 15 | self.save_interval = 100 16 | super().__init__(cache_file) 17 | 18 | def load_model(self): 19 | # load api key 20 | key_path = self.key_path 21 | assert os.path.exists(key_path), f"Please place your OpenAI APT Key in {key_path}." 22 | with open(key_path, 'r') as f: 23 | api_key = f.readline() 24 | openai.api_key = api_key.strip() 25 | self.model = self.model_name 26 | 27 | def _generate(self, prompt, max_sequence_length=2048, max_output_length=128): 28 | if self.add_n % self.save_interval == 0: 29 | self.save_cache() 30 | # return a tuple of string (generated text) and metadata (any format) 31 | # This should be about generating a response from the prompt, no matter what the application is 32 | if self.model_name == "ChatGPT": 33 | # Construct the prompt send to ChatGPT 34 | message = [{"role": "user", "content": prompt}] 35 | # Call API 36 | response = call_ChatGPT(message, temp=self.temp, max_len=max_sequence_length) 37 | # Get the output from the response 38 | output = response["choices"][0]["message"]["content"] 39 | return output, response 40 | elif self.model_name == "InstructGPT": 41 | # Call API 42 | response = call_GPT3(prompt, temp=self.temp) 43 | # Get the output from the response 44 | output = response["choices"][0]["text"] 45 | return output, response 46 | else: 47 | raise NotImplementedError() 48 | 49 | def call_ChatGPT(message, model_name="gpt-3.5-turbo", max_len=1024, temp=0.7, verbose=False): 50 | # call GPT-3 API until result is provided and then return it 51 | response = None 52 | received = False 53 | num_rate_errors = 0 54 | while not received: 55 | try: 56 | response = openai.ChatCompletion.create(model=model_name, 57 | messages=message, 58 | max_tokens=max_len, 59 | temperature=temp) 60 | received = True 61 | except: 62 | # print(message) 63 | num_rate_errors += 1 64 | error = sys.exc_info()[0] 65 | if error == openai.error.InvalidRequestError: 66 | # something is wrong: e.g. prompt too long 67 | logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{message}\n\n") 68 | assert False 69 | 70 | logging.error("API error: %s (%d). Waiting %dsec" % (error, num_rate_errors, np.power(2, num_rate_errors))) 71 | time.sleep(np.power(2, num_rate_errors)) 72 | return response 73 | 74 | 75 | def call_GPT3(prompt, model_name="text-davinci-003", max_len=512, temp=0.7, num_log_probs=0, echo=False, verbose=False): 76 | # call GPT-3 API until result is provided and then return it 77 | response = None 78 | received = False 79 | num_rate_errors = 0 80 | while not received: 81 | try: 82 | response = openai.Completion.create(model=model_name, 83 | prompt=prompt, 84 | max_tokens=max_len, 85 | temperature=temp, 86 | logprobs=num_log_probs, 87 | echo=echo) 88 | received = True 89 | except: 90 | error = sys.exc_info()[0] 91 | num_rate_errors += 1 92 | if error == openai.error.InvalidRequestError: 93 | # something is wrong: e.g. prompt too long 94 | logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n") 95 | assert False 96 | logging.error("API error: %s (%d)" % (error, num_rate_errors)) 97 | time.sleep(np.power(2, num_rate_errors)) 98 | return response 99 | -------------------------------------------------------------------------------- /factscore/retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import os 4 | 5 | import sqlite3 6 | import numpy as np 7 | import pickle as pkl 8 | 9 | from rank_bm25 import BM25Okapi 10 | 11 | SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####" 12 | MAX_LENGTH = 256 13 | 14 | class DocDB(object): 15 | """Sqlite backed document storage. 16 | 17 | Implements get_doc_text(doc_id). 18 | """ 19 | 20 | def __init__(self, db_path=None, data_path=None): 21 | self.db_path = db_path 22 | self.connection = sqlite3.connect(self.db_path, check_same_thread=False) 23 | 24 | cursor = self.connection.cursor() 25 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 26 | 27 | if len(cursor.fetchall())==0: 28 | assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB." 29 | print (f"{self.db_path} is empty. start building DB from {data_path}...") 30 | self.build_db(self.db_path, data_path) 31 | 32 | def __enter__(self): 33 | return self 34 | 35 | def __exit__(self, *args): 36 | self.close() 37 | 38 | def path(self): 39 | """Return the path to the file that backs this database.""" 40 | return self.path 41 | 42 | def close(self): 43 | """Close the connection to the database.""" 44 | self.connection.close() 45 | 46 | def build_db(self, db_path, data_path): 47 | from transformers import RobertaTokenizer 48 | tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 49 | 50 | titles = set() 51 | output_lines = [] 52 | tot = 0 53 | start_time = time.time() 54 | c = self.connection.cursor() 55 | c.execute("CREATE TABLE documents (title PRIMARY KEY, text);") 56 | 57 | with open(data_path, "r") as f: 58 | for line in f: 59 | dp = json.loads(line) 60 | title = dp["title"] 61 | text = dp["text"] 62 | if title in titles: 63 | continue 64 | titles.add(title) 65 | if type(text)==str: 66 | text = [text] 67 | passages = [[]] 68 | for sent_idx, sent in enumerate(text): 69 | assert len(sent.strip())>0 70 | tokens = tokenizer(sent)["input_ids"] 71 | max_length = MAX_LENGTH - len(passages[-1]) 72 | if len(tokens) <= max_length: 73 | passages[-1].extend(tokens) 74 | else: 75 | passages[-1].extend(tokens[:max_length]) 76 | offset = max_length 77 | while offset < len(tokens): 78 | passages.append(tokens[offset:offset+MAX_LENGTH]) 79 | offset += MAX_LENGTH 80 | 81 | psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0] 82 | text = SPECIAL_SEPARATOR.join(psgs) 83 | output_lines.append((title, text)) 84 | tot += 1 85 | 86 | if len(output_lines) == 1000000: 87 | c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) 88 | output_lines = [] 89 | print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) 90 | 91 | if len(output_lines) > 0: 92 | c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) 93 | print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) 94 | 95 | self.connection.commit() 96 | self.connection.close() 97 | 98 | def get_text_from_title(self, title): 99 | """Fetch the raw text of the doc for 'doc_id'.""" 100 | cursor = self.connection.cursor() 101 | cursor.execute("SELECT text FROM documents WHERE title = ?", (title,)) 102 | results = cursor.fetchall() 103 | results = [r for r in results] 104 | cursor.close() 105 | assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." 106 | results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)] 107 | assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." 108 | return results 109 | 110 | class Retrieval(object): 111 | 112 | def __init__(self, db, cache_path, embed_cache_path, 113 | retrieval_type="gtr-t5-large", batch_size=None): 114 | self.db = db 115 | self.cache_path = cache_path 116 | self.embed_cache_path = embed_cache_path 117 | self.retrieval_type = retrieval_type 118 | self.batch_size = batch_size 119 | assert retrieval_type=="bm25" or retrieval_type.startswith("gtr-") 120 | 121 | self.encoder = None 122 | self.load_cache() 123 | self.add_n = 0 124 | self.add_n_embed = 0 125 | 126 | def load_encoder(self): 127 | from sentence_transformers import SentenceTransformer 128 | encoder = SentenceTransformer("sentence-transformers/" + self.retrieval_type) 129 | encoder = encoder.cuda() 130 | encoder = encoder.eval() 131 | self.encoder = encoder 132 | assert self.batch_size is not None 133 | 134 | def load_cache(self): 135 | if os.path.exists(self.cache_path): 136 | with open(self.cache_path, "r") as f: 137 | self.cache = json.load(f) 138 | else: 139 | self.cache = {} 140 | if os.path.exists(self.embed_cache_path): 141 | with open(self.embed_cache_path, "rb") as f: 142 | self.embed_cache = pkl.load(f) 143 | else: 144 | self.embed_cache = {} 145 | 146 | def save_cache(self): 147 | if self.add_n > 0: 148 | if os.path.exists(self.cache_path): 149 | with open(self.cache_path, "r") as f: 150 | new_cache = json.load(f) 151 | self.cache.update(new_cache) 152 | 153 | with open(self.cache_path, "w") as f: 154 | json.dump(self.cache, f) 155 | 156 | if self.add_n_embed > 0: 157 | if os.path.exists(self.embed_cache_path): 158 | with open(self.embed_cache_path, "rb") as f: 159 | new_cache = pkl.load(f) 160 | self.embed_cache.update(new_cache) 161 | 162 | with open(self.embed_cache_path, "wb") as f: 163 | pkl.dump(self.embed_cache, f) 164 | 165 | def get_bm25_passages(self, topic, query, passages, k): 166 | if topic in self.embed_cache: 167 | bm25 = self.embed_cache[topic] 168 | else: 169 | bm25 = BM25Okapi([psg["text"].replace("", "").replace("", "").split() for psg in passages]) 170 | self.embed_cache[topic] = bm25 171 | self.add_n_embed += 1 172 | scores = bm25.get_scores(query.split()) 173 | indices = np.argsort(-scores)[:k] 174 | return [passages[i] for i in indices] 175 | 176 | def get_gtr_passages(self, topic, retrieval_query, passages, k): 177 | if self.encoder is None: 178 | self.load_encoder() 179 | if topic in self.embed_cache: 180 | passage_vectors = self.embed_cache[topic] 181 | else: 182 | inputs = [psg["title"] + " " + psg["text"].replace("", "").replace("", "") for psg in passages] 183 | passage_vectors = self.encoder.encode(inputs, batch_size=self.batch_size, device=self.encoder.device) 184 | self.embed_cache[topic] = passage_vectors 185 | self.add_n_embed += 1 186 | query_vectors = self.encoder.encode([retrieval_query], 187 | batch_size=self.batch_size, 188 | device=self.encoder.device)[0] 189 | scores = np.inner(query_vectors, passage_vectors) 190 | indices = np.argsort(-scores)[:k] 191 | return [passages[i] for i in indices] 192 | 193 | def get_passages(self, topic, question, k): 194 | retrieval_query = topic + " " + question.strip() 195 | cache_key = topic + "#" + retrieval_query 196 | 197 | if cache_key not in self.cache: 198 | passages = self.db.get_text_from_title(topic) 199 | if self.retrieval_type=="bm25": 200 | self.cache[cache_key] = self.get_bm25_passages(topic, retrieval_query, passages, k) 201 | else: 202 | self.cache[cache_key] = self.get_gtr_passages(topic, retrieval_query, passages, k) 203 | assert len(self.cache[cache_key]) in [k, len(passages)] 204 | self.add_n += 1 205 | 206 | 207 | return self.cache[cache_key] 208 | 209 | 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /factscore/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | def assert_all_approx_close(a, b, rtol, atol, count): 9 | 10 | idx = torch.isclose(a.float(), b.float(), rtol, atol) 11 | sumval = (idx==0).sum().item() 12 | if sumval > count: 13 | print(f'Too many values not close: assert {sumval} < {count}') 14 | try: 15 | torch.testing.assert_allclose(a, b, rtol, atol) 16 | except Exception as e: 17 | print(e) 18 | 19 | 20 | def get_memory_footprint(model, return_buffers=True): 21 | """ 22 | Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. 23 | Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the 24 | PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 25 | Arguments: 26 | return_buffers (`bool`, *optional*, defaults to `True`): 27 | Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers 28 | are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch 29 | norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 30 | """ 31 | mem = sum([param.nelement() * param.element_size() for param in model.parameters()]) 32 | if return_buffers: 33 | mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) 34 | mem = mem + mem_bufs 35 | return mem 36 | 37 | 38 | def ـreplace_linear_with_int8linear(model, modules_to_not_convert="lm_head"): 39 | for name, module in model.named_children(): 40 | ـreplace_linear_with_int8linear(module, modules_to_not_convert) 41 | 42 | if isinstance(module, torch.nn.Linear) and name != modules_to_not_convert: 43 | model._modules[name] = QuantizedLinearInt8(linear_layer=module) 44 | return 45 | 46 | 47 | class QuantizedLinearInt8(torch.nn.Module): 48 | ''' 49 | A simple but effictive implmenetion of Int8 quantization for linear layers. 50 | The weights are quantized and stored as Int8, which saves ~50% of the gpu memory. 51 | During the forwared pass, the weights are de-quantized back to fp16 to do multiplication. 52 | Pros: 53 | - saves ~50% of the gpu memory 54 | - accurate quantization because only the weights are quantized, and the weights don't suffer 55 | from the "outliers" issue mentioned in the LLM.int8 paper; only the activations do. 56 | - high precision results beacuse the multiplication is done in fp16 57 | - much faster than LLM.int8 58 | Cons: 59 | - a bit slower because of the added computation of dequantization in each forward pass. In practice, the slowdown 60 | is not large because in the generation application, gpu utilization is not very high. 61 | ''' 62 | def __init__(self, linear_layer): 63 | super().__init__() 64 | self.bias = linear_layer.bias 65 | 66 | weight_bit_width = 8 67 | weight = linear_layer.weight 68 | 69 | self.weight_scale = torch.nn.Parameter( 70 | (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half(), 71 | ) 72 | # print(self.weight_scale.max().item(), self.weight_scale.min().item(), self.weight_scale.mean().item()) 73 | # if self.weight_scale.max().item() > 0.002: 74 | # print(self.weight_scale.max().item()) 75 | self.weight = torch.nn.Parameter( 76 | torch.round(weight.float() / self.weight_scale[:, None]).char(), 77 | requires_grad=False 78 | ) 79 | 80 | def forward(self, x): 81 | weight = self.weight.half() * self.weight_scale[:, None] 82 | return torch.nn.functional.linear(x, weight, self.bias) 83 | 84 | 85 | def convert_model_to_int8_on_gpu(model, device): 86 | """ 87 | Quantize a model to int8 and move it to GPU using a simple method. 88 | """ 89 | if 'cuda' not in device: 90 | raise ValueError(f"Target device should be a gpu. Device {device} is not supported") 91 | 92 | model.half() 93 | 94 | memory_before_quantization = get_memory_footprint(model) # without lm_head 95 | 96 | ـreplace_linear_with_int8linear(model) # replace `Linear` with `QuantizedLinearInt8` 97 | 98 | model.to(device=device) 99 | memory_after_quantization = get_memory_footprint(model) # without lm_head 100 | 101 | saving = round(100 * memory_after_quantization/memory_before_quantization) 102 | memory_before_quantization = round(memory_before_quantization / 2**30, 2) # rounding for printing 103 | memory_after_quantization = round(memory_after_quantization / 2**30, 2) # rounding for printing 104 | 105 | print(f'Quantization memory - before: {memory_before_quantization} GB, after: {memory_after_quantization} GB ({saving}% of the size before)') 106 | return model 107 | -------------------------------------------------------------------------------- /preprocessing/preprocess_acl.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import tqdm 3 | import json 4 | import openai 5 | from factscore.openai_lm import call_ChatGPT 6 | from factscore.factscorer import FactScorer 7 | 8 | # File downloaded from https://github.com/shauryr/ACL-anthology-corpus 9 | # https://drive.google.com/file/d/1CFCzNGlTls0H-Zcaem4Hg_ETj4ebhcDO/view?usp=sharing 10 | df = pd.read_parquet('acl-publication-info.74k.parquet') 11 | titles = df['title'].tolist() 12 | full_text = df['full_text'].tolist() 13 | 14 | acl_corpus = [] 15 | for x, y in zip(titles, full_text): 16 | if x.strip() == "" or y.strip() == "": 17 | continue 18 | acl_corpus.append({"title": x, "text": y}) 19 | 20 | with open("acl_corpus.jsonl", 'w') as f: 21 | for line in acl_corpus: 22 | f.write(json.dumps(line) + "\n") 23 | 24 | fs = FactScorer() 25 | # this will create a database using your file 26 | # once DB file is created, you can reuse it by only specifying `db_path` 27 | fs.register_knowledge_source("acl_corpus", 28 | data_path="acl_corpus.jsonl", 29 | db_path=None) 30 | 31 | 32 | prompt_titles = [ 33 | "Dense Passage Retrieval for Open-Domain Question Answering", 34 | "AmbigQA: Answering Ambiguous Open-domain Questions", 35 | "MetaICL: Learning to Learn In Context", 36 | "Noisy Channel Language Model Prompting for Few-Shot Text Classification", 37 | "Joint Passage Ranking for Diverse Multi-Answer Retrieval", 38 | "Reformulating Unsupervised Style Transfer as Paraphrase Generation", 39 | "Syntactically Supervised Transformers for Faster Neural Machine Translation", 40 | "Hurdles to Progress in Long-form Question Answering", 41 | "Generating Question-Answer Hierarchies", 42 | "Do Long-Range Language Models Actually Use Long-Range Context?" 43 | ] 44 | 45 | prompts_list = [] 46 | 47 | for title in prompt_titles: 48 | prompts_list.append(f"Give me a summary of the research paper titled \"{title}\".") 49 | 50 | with open("api.key", 'r') as f: 51 | api_key = f.readline() 52 | openai.api_key = api_key.strip() 53 | 54 | responses = [] 55 | for ptitle, prompt in tqdm.tqdm(zip(prompt_titles, prompts_list)): 56 | message = [{"role": "user", "content": prompt}] 57 | response = call_ChatGPT(message, model_name="gpt-3.5-turbo-0301") 58 | responses.append({ 59 | "topic": ptitle, 60 | "output": response["choices"][0]["message"]["content"] 61 | }) 62 | 63 | # # write the corpus to a jsonl file 64 | with open("acl_chatgpt_outputs.jsonl", 'w') as f: 65 | for line in responses: 66 | f.write(json.dumps(line) + "\n") 67 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "factscore" 3 | version = "0.2.0" 4 | description = "FactScore is an automatic evaluation metric for factual precision in long-form text generation. It uses large language models and retrieval to break down generations into atomic facts and then measure the correctness with respect to a knowledge source (like Wikipedia)." 5 | authors = ["Sewon Min ", "Kalpesh Krishna ", "Xinxi Lyu "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.7.1" 11 | torch = "^1.13.0" 12 | sentence-transformers = "^2.2.2" 13 | transformers = "^4.29.2" 14 | openai = "^0.27.7" 15 | rank-bm25 = "^0.2.2" 16 | spacy = "^3.5.3" 17 | pysqlite-binary = "^0.5.0" 18 | nltk = "^3.8.1" 19 | 20 | [build-system] 21 | requires = ["poetry-core"] 22 | build-backend = "poetry.core.masonry.api" 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | sentence_transformers 3 | transformers 4 | pysqlite3 5 | openai 6 | rank_bm25 7 | spacy 8 | 9 | # sentence_transformers>=2.2.2 10 | # transformers>=4.29.2 11 | # torch>=>=2.0.0 -------------------------------------------------------------------------------- /roberta_stopwords.txt: -------------------------------------------------------------------------------- 1 | 4 2 | 5 3 | 6 4 | 7 5 | 8 6 | 9 7 | 10 8 | 11 9 | 12 10 | 13 11 | 14 12 | 15 13 | 16 14 | 19 15 | 21 16 | 22 17 | 23 18 | 24 19 | 25 20 | 28 21 | 30 22 | 31 23 | 32 24 | 33 25 | 34 26 | 35 27 | 36 28 | 37 29 | 39 30 | 40 31 | 41 32 | 42 33 | 43 34 | 45 35 | 47 36 | 49 37 | 50 38 | 51 39 | 52 40 | 53 41 | 54 42 | 55 43 | 56 44 | 57 45 | 58 46 | 59 47 | 61 48 | 62 49 | 63 50 | 64 51 | 66 52 | 68 53 | 69 54 | 70 55 | 71 56 | 73 57 | 77 58 | 79 59 | 81 60 | 84 61 | 87 62 | 88 63 | 89 64 | 95 65 | 97 66 | 98 67 | 99 68 | 103 69 | 106 70 | 108 71 | 109 72 | 110 73 | 111 74 | 113 75 | 114 76 | 116 77 | 122 78 | 123 79 | 127 80 | 128 81 | 129 82 | 131 83 | 136 84 | 137 85 | 141 86 | 142 87 | 143 88 | 144 89 | 145 90 | 147 91 | 148 92 | 149 93 | 150 94 | 159 95 | 160 96 | 162 97 | 167 98 | 172 99 | 182 100 | 197 101 | 207 102 | 209 103 | 215 104 | 218 105 | 222 106 | 223 107 | 227 108 | 258 109 | 259 110 | 276 111 | 308 112 | 328 113 | 349 114 | 350 115 | 351 116 | 359 117 | 367 118 | 385 119 | 399 120 | 454 121 | 456 122 | 473 123 | 475 124 | 479 125 | 519 126 | 524 127 | 579 128 | 596 129 | 608 130 | 617 131 | 630 132 | 646 133 | 683 134 | 742 135 | 769 136 | 787 137 | 849 138 | 874 139 | 938 140 | 939 141 | 947 142 | 965 143 | 1003 144 | 1009 145 | 1021 146 | 1039 147 | 1065 148 | 1215 149 | 1235 150 | 1423 151 | 1495 152 | 1589 153 | 1629 154 | 1640 155 | 1705 156 | 1721 157 | 1979 158 | 2025 159 | 2055 160 | 2156 161 | 2185 162 | 2220 163 | 2282 164 | 2512 165 | 2661 166 | 2744 167 | 2864 168 | 3226 169 | 3486 170 | 3559 171 | 4288 172 | 4395 173 | 4832 174 | 4839 175 | 5030 176 | 5214 177 | 5457 178 | 5844 179 | 7606 180 | 8061 181 | 9131 182 | 10431 183 | 10975 184 | 12905 185 | 14314 186 | 14434 187 | 15157 188 | 15483 189 | 15698 190 | 17487 191 | 18134 192 | 18212 193 | 19385 194 | 20343 195 | 22209 196 | 23367 197 | 24303 198 | 25522 199 | 25606 200 | 27779 201 | 27785 202 | 28696 203 | 31954 204 | 34437 205 | 35227 206 | 35524 207 | 37249 208 | 37457 209 | 41552 210 | 44128 211 | 45152 212 | --------------------------------------------------------------------------------