├── semantic_uncertainty ├── uncertainty │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ └── huggingface_models.py │ ├── utils │ │ ├── openai.py │ │ ├── eval_utils.py │ │ └── utils.py │ ├── uncertainty_measures │ │ ├── p_ik.py │ │ ├── p_true.py │ │ └── semantic_entropy.py │ └── data │ │ └── data_utils.py ├── analyze_results.py ├── generate_answers.py └── compute_uncertainty_measures.py ├── semantic_entropy_probes ├── models │ └── Llama2-7b_inference.pkl └── README.md ├── LICENSE ├── slurm └── run.sh ├── .gitignore ├── README.md └── sep_enviroment.yaml /semantic_uncertainty/uncertainty/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /semantic_entropy_probes/models/Llama2-7b_inference.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/semantic-entropy-probes/HEAD/semantic_entropy_probes/models/Llama2-7b_inference.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 OATML 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/utils/openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import hashlib 4 | from tenacity import (retry, stop_after_attempt, # for exponential backoff 5 | wait_random_exponential) 6 | 7 | from openai import OpenAI 8 | 9 | 10 | CLIENT = OpenAI(api_key=os.environ['OPENAI_API_KEY']) 11 | 12 | 13 | @retry(wait=wait_random_exponential(min=1, max=10)) 14 | def predict(prompt, temperature=1.0, model='gpt-4'): 15 | """Predict with GPT-4 model.""" 16 | 17 | if isinstance(prompt, str): 18 | messages = [ 19 | {"role": "user", "content": prompt}, 20 | ] 21 | else: 22 | messages = prompt 23 | 24 | if model == 'gpt-4': 25 | model = 'gpt-4-turbo' # or 'gpt-4o' 26 | elif model == 'gpt-3.5': 27 | model = 'gpt-3.5-turbo' 28 | 29 | output = CLIENT.chat.completions.create( 30 | model=model, 31 | messages=messages, 32 | max_tokens=200, 33 | temperature=temperature, 34 | ) 35 | response = output.choices[0].message.content 36 | return response 37 | 38 | 39 | def md5hash(string): 40 | return int(hashlib.md5(string.encode('utf-8')).hexdigest(), 16) 41 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/models/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Text 3 | 4 | 5 | STOP_SEQUENCES = ['\n\n\n\n', '\n\n\n', '\n\n', '\n', 'Question:', 'Context:'] 6 | 7 | 8 | class BaseModel(ABC): 9 | 10 | stop_sequences: List[Text] 11 | 12 | @abstractmethod 13 | def predict(self, input_data, temperature): 14 | pass 15 | 16 | @abstractmethod 17 | def get_p_true(self, input_data): 18 | pass 19 | 20 | def get_character_start_stop_indices(self, input_data_offset, answer): 21 | """Remove any output following (and including) a stop_sequence. 22 | 23 | Some outputs start with newlines (unfortunately). We strip these, in 24 | order to ensure generations with greater-than-zero length. 25 | """ 26 | start_index = input_data_offset 27 | 28 | # Strip zero-length generations from beginning and add to `input_data_offset`. 29 | newline = '\n' 30 | while answer[start_index:].startswith(newline): 31 | start_index += len(newline) 32 | 33 | # Get character index of first stop sequence 34 | stop_index = len(answer) 35 | for word in self.stop_sequences: 36 | index = answer[start_index:].find(word) 37 | if index != -1 and index + start_index < stop_index: 38 | stop_index = index + start_index 39 | 40 | return start_index, stop_index 41 | -------------------------------------------------------------------------------- /slurm/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # SBATCH --cpus-per-task=24 3 | # SBATCH --partition=your_partition_name 4 | # SBATCH --gres=gpu:a100:2 5 | # SBATCH --job-name="nlg_uncertainty_linearprobe" 6 | 7 | # Update conda environment (adjust as appropriate) 8 | ~/miniconda3/bin/conda-env update -f ../sep_environment.yaml 9 | source ~/miniconda3/bin/activate se_probes 10 | 11 | datasets=("squad" "nq" "trivia_qa" "bioasq") 12 | 13 | # Short-form generation. Run the scripts with specified parameters. 14 | for dataset in "${datasets[@]}"; do 15 | srun python ../semantic_uncertainty/generate_answers.py \ 16 | --model_name=Llama-2-7b-chat \ 17 | --dataset=$dataset \ 18 | --num_samples=2000 \ 19 | --random_seed=20 \ 20 | --no-compute_p_ik \ 21 | --no-compute_p_ik_answerable 22 | # e.g. Mistral-7B-Instruct-v0.1, Llama-2-7b-chat, Phi-3-mini-128k-instruct, Meta-Llama-3-8B-Instruct, etc. 23 | done 24 | 25 | # Long-form generation. Run the scripts with specified parameters. 26 | for dataset in "${datasets[@]}"; do 27 | srun python ../semantic_uncertainty/generate_answers.py \ 28 | --model_name=Llama-2-70b-chat \ 29 | --dataset=$dataset \ 30 | --num_samples=1000 \ 31 | --random_seed=20 \ 32 | --no-compute_p_ik \ 33 | --no-compute_p_ik_answerable \ 34 | --p_true_num_fewshot=10 \ 35 | --num_generations=10 \ 36 | --num_few_shot=0 \ 37 | --model_max_new_tokens=100 \ 38 | --brief_prompt=chat \ 39 | --metric=llm_gpt-4 \ 40 | --entailment_model=llm_gpt-3.5 41 | # e.g. Meta-Llama-3-70B-Instruct, Llama-2-70b-chat, etc. 42 | done 43 | 44 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | """Functions for performance evaluation, mainly used in analyze_results.py.""" 2 | import numpy as np 3 | import scipy 4 | from sklearn import metrics 5 | 6 | 7 | # pylint: disable=missing-function-docstring 8 | 9 | 10 | def bootstrap(function, rng, n_resamples=1000): 11 | def inner(data): 12 | bs = scipy.stats.bootstrap( 13 | (data, ), function, n_resamples=n_resamples, confidence_level=0.9, 14 | random_state=rng) 15 | return { 16 | 'std_err': bs.standard_error, 17 | 'low': bs.confidence_interval.low, 18 | 'high': bs.confidence_interval.high 19 | } 20 | return inner 21 | 22 | 23 | def auroc(y_true, y_score): 24 | fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score) 25 | del thresholds 26 | return metrics.auc(fpr, tpr) 27 | 28 | 29 | def accuracy_at_quantile(accuracies, uncertainties, quantile): 30 | cutoff = np.quantile(uncertainties, quantile) 31 | select = uncertainties <= cutoff 32 | return np.mean(accuracies[select]) 33 | 34 | 35 | def area_under_thresholded_accuracy(accuracies, uncertainties): 36 | quantiles = np.linspace(0.1, 1, 20) 37 | select_accuracies = np.array([accuracy_at_quantile(accuracies, uncertainties, q) for q in quantiles]) 38 | dx = quantiles[1] - quantiles[0] 39 | area = (select_accuracies * dx).sum() 40 | return area 41 | 42 | 43 | # Need wrappers because scipy expects 1D data. 44 | def compatible_bootstrap(func, rng): 45 | def helper(y_true_y_score): 46 | # this function is called in the bootstrap 47 | y_true = np.array([i['y_true'] for i in y_true_y_score]) 48 | y_score = np.array([i['y_score'] for i in y_true_y_score]) 49 | out = func(y_true, y_score) 50 | return out 51 | 52 | def wrap_inputs(y_true, y_score): 53 | return [{'y_true': i, 'y_score': j} for i, j in zip(y_true, y_score)] 54 | 55 | def converted_func(y_true, y_score): 56 | y_true_y_score = wrap_inputs(y_true, y_score) 57 | return bootstrap(helper, rng=rng)(y_true_y_score) 58 | return converted_func 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | figures/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | 133 | # Slurm 134 | *.out 135 | 136 | # VSCODE 137 | .vscode 138 | 139 | # WanDB 140 | */wandb/ 141 | 142 | .env 143 | -------------------------------------------------------------------------------- /semantic_entropy_probes/README.md: -------------------------------------------------------------------------------- 1 | # Probe Semantic Entropy in Latent Space 2 | 3 | ## Overview 4 | 5 | We present our probe training in the format of [notebook](./train-latent-probe.ipynb) that execute while logging and making visualizations (e.g. loss curves, performance comparisons) to enhance understanding. 6 | 7 | ## Tutorial 8 | 9 | The approach involves retrieving the model hidden states for two token positions (TBG, SLT) on which we train linear probes to determine model semantic uncertainty or correctness. 10 | 11 | We save model hidden states from SE generation runs (as in [model implementation](../semantic_uncertainty/uncertainty/models/huggingface_models.py)), and if you have finished SE runs using `wandb`, the model hidden states (in `validation_generations.pkl`) and uncertainty measures such as `p_true`, token `log likelihoods`, and `semantic entropy` should already be in place. And these serve as the only prerequisites of running the training notebook. 12 | 13 | We also support saving probes (essentially a trained logistic regression model) as a pickle file to the `models` (created upon running) folder. You may run inference with the probe as you wish - it should just be a minor adaptation from the notebook that you should run the probe (SEP or Acc. Pr.) on concatenated hidden states on some particular token positions (e.g. SLT or TBG) and it will output labels (or logits) predicting how semantically certain a model is and how likely a model outputs faithful answers. 14 | 15 | For tutorial purposes, we have provided [example runs](https://wandb.ai/jiatongg/public_semantic_uncertainty) for Llama-2-7B model (short-form generations), which is the same as in our paper. 16 | 17 | Kindly refer to [our paper](https://arxiv.org/abs/2406.15927) for terminologies and other technical details. 18 | 19 | ## Notebook Structure 20 | 21 | This notebook is arranged in sections: 22 | 23 | * `Imports and Downloads` helps you load wandb runs into local storage; 24 | * `Data Preparation` section prepares the training data, encapsulates the training and evaluation codes, and contains some visualization tools; 25 | * `Probing Acc/SE from Hidden States (IID)` section binarizes SE and carries out actual training of SEPs and Acc. Pr. in the In-Distribution setting, where we train and test on the same dataset yet on different splits; 26 | * `Test probes trained with one dataset on others` section tests SEPs and Acc. Pr. performances in predicting model correctness on other datasets; 27 | * The rest sections are for performance comparisons with baselines and model saving. 28 | 29 | ## Citation 30 | 31 | ``` 32 | @misc{kossen2024semanticentropyprobesrobust, 33 | title={Semantic Entropy Probes: Robust and Cheap Hallucination Detection in LLMs}, 34 | author={Jannik Kossen and Jiatong Han and Muhammed Razzak and Lisa Schut and Shreshth Malik and Yarin Gal}, 35 | year={2024}, 36 | eprint={2406.15927}, 37 | archivePrefix={arXiv}, 38 | primaryClass={cs.CL}, 39 | url={https://arxiv.org/abs/2406.15927}, 40 | } 41 | ``` 42 | 43 | 44 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/uncertainty_measures/p_ik.py: -------------------------------------------------------------------------------- 1 | """Predict model correctness from linear classifier.""" 2 | import os 3 | import logging 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import wandb 9 | 10 | from sklearn.linear_model import LogisticRegression 11 | from sklearn.metrics import accuracy_score 12 | from sklearn.metrics import roc_auc_score 13 | from sklearn.model_selection import train_test_split 14 | 15 | 16 | def get_p_ik(train_embeddings, is_false, eval_embeddings=None, eval_is_false=None): 17 | """Fit linear classifier to embeddings to predict model correctness.""" 18 | 19 | logging.info('Accuracy of model on Task: %f.', 1 - torch.tensor(is_false).mean()) # pylint: disable=no-member 20 | 21 | # Convert the list of tensors to a 2D tensor. 22 | train_embeddings_tensor = torch.cat(train_embeddings, dim=0) # pylint: disable=no-member 23 | # Convert the tensor to a numpy array. 24 | embeddings_array = train_embeddings_tensor.cpu().numpy() 25 | 26 | # Split the data into training and test sets. 27 | X_train, X_test, y_train, y_test = train_test_split( # pylint: disable=invalid-name 28 | embeddings_array, is_false, test_size=0.2, random_state=42) # pylint: disable=invalid-name 29 | 30 | # Fit a logistic regression model. 31 | model = LogisticRegression() 32 | model.fit(X_train, y_train) 33 | 34 | # Predict deterministically and probabilistically and compute accuracy and auroc for all splits. 35 | X_eval = torch.cat(eval_embeddings, dim=0).cpu().numpy() # pylint: disable=no-member,invalid-name 36 | y_eval = eval_is_false 37 | 38 | Xs = [X_train, X_test, X_eval] # pylint: disable=invalid-name 39 | ys = [y_train, y_test, y_eval] # pylint: disable=invalid-name 40 | suffixes = ['train_train', 'train_test', 'eval'] 41 | 42 | metrics, y_preds_proba = {}, {} 43 | fig, axes = plt.subplots(1, 3, figsize=(20, 8)) 44 | fig.suptitle('Log Probabilities Predicted by p_ik when the true label is [False/True]') 45 | 46 | for ax, suffix, X, y_true in zip(axes, suffixes, Xs, ys): # pylint: disable=invalid-name 47 | 48 | # If suffix is eval, we fit a new model on the entire training data set rather than just a split of the 49 | # training data set. 50 | if suffix == 'eval': 51 | model = LogisticRegression() 52 | model.fit(embeddings_array, is_false) 53 | convergence = {'n_iter': model.n_iter_[0], 'converged': (model.n_iter_ < model.max_iter)[0]} 54 | 55 | y_pred = model.predict(X) 56 | y_pred_proba = model.predict_proba(X) 57 | y_preds_proba[suffix] = y_pred_proba 58 | acc_p_ik_train = accuracy_score(y_true, y_pred) 59 | auroc_p_ik_train = roc_auc_score(y_true, y_pred_proba[:, 1]) 60 | split_metrics = { 61 | f'acc_p_ik_{suffix}': acc_p_ik_train, 62 | f'auroc_p_ik_{suffix}': auroc_p_ik_train} 63 | metrics.update(split_metrics) 64 | 65 | # Plotting. 66 | probabilities_of_false_points = y_pred_proba[:, 1][np.array(y_true) == 1.0] 67 | probabilities_of_true_points = y_pred_proba[:, 1][np.array(y_true) == 0.0] 68 | ax.hist(probabilities_of_false_points, bins=20, alpha=0.5, label='False') 69 | ax.hist(probabilities_of_true_points, bins=20, alpha=0.5, label='True') 70 | ax.legend(loc='upper right', title='True Label') 71 | fmt = {k: f"{v:.2f}" for k, v in split_metrics.items()} 72 | ax.set_title(f'Set: {suffix} \n {fmt}') 73 | 74 | # Plotting. 75 | axes[0].set_ylabel('Counts') 76 | axes[1].set_xlabel('Predicted Probabilities') 77 | os.system('mkdir -p figures') 78 | plt.savefig('figures/p_ik.png') # Can be viewed in vscode w/o plugins. 79 | plt.savefig('figures/p_ik.pdf') # Vector graphics are nice. 80 | 81 | logging.info('Metrics for p_ik classifier: %s.', metrics) 82 | wandb.log({**metrics, **convergence}) 83 | 84 | # Return model predictions on the eval set. 85 | return y_preds_proba['eval'][:, 1] 86 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/uncertainty_measures/p_true.py: -------------------------------------------------------------------------------- 1 | """Compute p_true uncertainty metric.""" 2 | import logging 3 | from evaluate import load 4 | 5 | 6 | squad_metric = load("squad_v2") 7 | 8 | 9 | def construct_few_shot_prompt( 10 | *, model, dataset, indices, prompt, brief, brief_always, make_prompt, 11 | num_generations, metric): 12 | """Construct few shot prompt for p_true uncertainty metric.""" 13 | 14 | # Call model n_shots many times 15 | few_shot_prompt = [] 16 | all_responses = dict() 17 | for it, i in enumerate(indices): 18 | prompt_candidate = [] 19 | example = dataset[i] 20 | question = example["question"] 21 | context = example["context"] 22 | if it != 0: 23 | prompt_candidate += ['\n'] 24 | prompt_candidate += ['Question: ' + question] 25 | prompt_candidate += ['\nBrainstormed Answers: '] 26 | current_question = make_prompt(context, question, None, brief, brief_always) 27 | local_prompt = prompt + current_question 28 | logging.info('P_TRUE >> Current Question: '.ljust(25) + current_question) 29 | 30 | responses = [] 31 | for j in range(num_generations + 1): 32 | 33 | if j == 0: 34 | temperature = 0.1 35 | else: 36 | temperature = 1.0 37 | 38 | response, _, _ = model.predict(local_prompt, temperature) 39 | logging.info('P_TRUE >> Current Response: '.ljust(25) + response) 40 | 41 | responses.append(response) 42 | prompt_candidate += [f'{response.strip()} \n'] 43 | if j == 0: 44 | # Save most likely response and compute correctness metric for it. 45 | most_likely_response = response 46 | is_correct = metric(response, example, model) 47 | answers = [answer for answer in example['answers']['text']] 48 | logging.info('P_TRUE >> LOW-T >> true answer: '.ljust(35) + str(answers)) 49 | logging.info('P_TRUE >> LOW-T >> acc: '.ljust(35) + str(is_correct)) 50 | 51 | all_responses[i] = dict( 52 | responses=responses, most_likely_response=most_likely_response, 53 | is_correct=is_correct) 54 | 55 | prompt_candidate += ['Possible answer: ' + most_likely_response + '\n'] 56 | prompt_candidate += ['Is the possible answer:\n'] 57 | prompt_candidate += ['A) True\n'] 58 | prompt_candidate += ['B) False\n'] 59 | prompt_candidate += ['The possible answer is:'] 60 | prompt_candidate += [' A' if is_correct else ' B'] 61 | 62 | prompt_len = len(model.tokenizer.encode(''.join(few_shot_prompt + prompt_candidate))) 63 | # At test time, get a maximum of `num_generations * model.token_limit` extra tokens 64 | # 200 buffer for question and 'Possible Answer'. 65 | max_input_len = prompt_len + num_generations * model.max_new_tokens + 200 66 | 67 | if max_input_len < model.token_limit: 68 | few_shot_prompt.extend(prompt_candidate) 69 | else: 70 | logging.warning('Cutting of p_true prompt at length %d.', it) 71 | break 72 | 73 | return ''.join(few_shot_prompt), all_responses, it 74 | 75 | 76 | def calculate_p_true( 77 | model, question, most_probable_answer, brainstormed_answers, 78 | few_shot_prompt, hint=False): 79 | """Calculate p_true uncertainty metric.""" 80 | 81 | if few_shot_prompt: 82 | prompt = few_shot_prompt + '\n' 83 | else: 84 | prompt = '' 85 | 86 | prompt += 'Question: ' + question 87 | prompt += '\nBrainstormed Answers: ' 88 | for answer in brainstormed_answers + [most_probable_answer]: 89 | prompt += answer.strip() + '\n' 90 | prompt += 'Possible answer: ' + most_probable_answer + '\n' 91 | if not hint: 92 | prompt += 'Is the possible answer:\n' 93 | prompt += 'A) True\n' 94 | prompt += 'B) False\n' 95 | prompt += 'The possible answer is:' 96 | else: 97 | prompt += 'Do the brainstormed answers match the possible answer? Respond with A if they do, if they do not respond with B. Answer:' 98 | 99 | log_prob = model.get_p_true(prompt) 100 | 101 | return log_prob 102 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/data/data_utils.py: -------------------------------------------------------------------------------- 1 | """Data Loading Utilities.""" 2 | import logging 3 | import os 4 | import json 5 | import hashlib 6 | import datasets 7 | 8 | 9 | def load_ds(dataset_name, seed, add_options=None): 10 | """Load dataset.""" 11 | train_dataset, validation_dataset = None, None 12 | if dataset_name == "squad": 13 | dataset = datasets.load_dataset("squad_v2") 14 | train_dataset = dataset["train"] 15 | validation_dataset = dataset["validation"] 16 | 17 | elif dataset_name == 'svamp': 18 | dataset = datasets.load_dataset('ChilleD/SVAMP') 19 | 20 | train_dataset = dataset["train"] 21 | validation_dataset = dataset["test"] 22 | 23 | reformat = lambda x: { 24 | 'question': x['Question'], 'context': x['Body'], 'type': x['Type'], 25 | 'equation': x['Equation'], 'id': x['ID'], 26 | 'answers': {'text': [str(x['Answer'])]}} 27 | 28 | train_dataset = [reformat(d) for d in train_dataset] 29 | _validation_dataset = [reformat(d) for d in validation_dataset] 30 | # For semantic entropy generation: merge training with test set for more samples. 31 | validation_dataset = _validation_dataset + train_dataset 32 | 33 | elif dataset_name == 'nq': 34 | dataset = datasets.load_dataset("nq_open") 35 | train_dataset = dataset["train"] 36 | validation_dataset = dataset["validation"] 37 | md5hash = lambda s: str(int(hashlib.md5(s.encode('utf-8')).hexdigest(), 16)) 38 | 39 | reformat = lambda x: { 40 | 'question': x['question']+'?', 41 | 'answers': {'text': x['answer']}, 42 | 'context': '', 43 | 'id': md5hash(str(x['question'])), 44 | } 45 | 46 | train_dataset = [reformat(d) for d in train_dataset] 47 | validation_dataset = [reformat(d) for d in validation_dataset] 48 | 49 | elif dataset_name == "trivia_qa": 50 | dataset = datasets.load_dataset('TimoImhof/TriviaQA-in-SQuAD-format')['unmodified'] 51 | 52 | dataset = dataset.train_test_split(test_size=0.2, seed=seed) 53 | train_dataset = dataset['train'] 54 | validation_dataset = dataset['test'] 55 | 56 | elif dataset_name == "med_qa": 57 | dataset = datasets.load_dataset("bigbio/med_qa") 58 | logging.info('Dataset: %s', dataset) 59 | for key in 'train', 'validation': 60 | ids = ['train' + str(i) for i in range(len(dataset[key]))] 61 | dataset[key] = dataset[key].add_column("id", ids) 62 | 63 | new_column = [None] * len(dataset[key]) 64 | dataset[key] = dataset[key].add_column("context", new_column) 65 | 66 | answers = [ 67 | {'text': [answer], 'answer_start': [0]} 68 | for answer in dataset[key][:]['answer'] 69 | ] 70 | dataset[key] = dataset[key].add_column("answers", answers) 71 | 72 | if add_options: 73 | options = dataset[key][:]['options'] 74 | options_string = [ 75 | [option['value'] + '\n' for option in option_list] 76 | for option_list in options 77 | ] 78 | questions = dataset[key][:]['question'] 79 | # zip questions and options 80 | questions_options = [ 81 | question + '\n' + ''.join(option_list) 82 | for question, option_list in zip(questions, options_string) 83 | ] 84 | 85 | dataset[key] = dataset[key].remove_columns(['question']) 86 | dataset[key] = dataset[key].add_column( 87 | "question", questions_options) 88 | 89 | train_dataset = dataset["train"] 90 | validation_dataset = dataset["validation"] 91 | 92 | elif dataset_name == "bioasq": 93 | # http://participants-area.bioasq.org/datasets/ we are using training 11b 94 | # could also download from here https://zenodo.org/records/7655130 95 | # scratch_dir = os.getenv('SCRATCH_DIR', '.') 96 | path = "~/uncertainty/data/bioasq/training11b.json" 97 | with open(path, "rb") as file: 98 | data = json.load(file) 99 | 100 | questions = data["questions"] 101 | dataset_dict = { 102 | "question": [], 103 | "answers": [], 104 | "id": [] 105 | } 106 | 107 | for question in questions: 108 | if "exact_answer" not in question: 109 | continue 110 | dataset_dict["question"].append(question["body"]) 111 | if "exact_answer" in question: 112 | 113 | if isinstance(question['exact_answer'], list): 114 | exact_answers = [ 115 | ans[0] if isinstance(ans, list) else ans 116 | for ans in question['exact_answer'] 117 | ] 118 | else: 119 | exact_answers = [question['exact_answer']] 120 | 121 | dataset_dict["answers"].append({ 122 | "text": exact_answers, 123 | "answer_start": [0] * len(question["exact_answer"]) 124 | }) 125 | else: 126 | dataset_dict["answers"].append({ 127 | "text": question["ideal_answer"], 128 | "answer_start": [0] 129 | }) 130 | dataset_dict["id"].append(question["id"]) 131 | 132 | dataset_dict["context"] = [None] * len(dataset_dict["id"]) 133 | 134 | dataset = datasets.Dataset.from_dict(dataset_dict) 135 | 136 | # split into training and validation set 137 | dataset = dataset.train_test_split(test_size=0.8, seed=seed) 138 | train_dataset = dataset['train'] 139 | validation_dataset = dataset['test'] 140 | 141 | elif dataset_name == "record": 142 | # Load the JSON file 143 | for split in ["train", "dev"]: 144 | dataset_dictionary = { 145 | "id": [], "question": [], "context": [], "answers": []} 146 | path = f"~/uncertainty/data/record/{split}.json" 147 | with open(path, "rb") as file: 148 | data = json.load(file) 149 | 150 | # Extract the relevant information and create a dictionary 151 | for item in data["data"]: 152 | for qa in item["qas"]: # pylint: disable=invalid-name 153 | dataset_dictionary["id"].append(qa["id"]) 154 | dataset_dictionary["question"].append(qa["query"]) 155 | dataset_dictionary["context"].append( 156 | item["passage"]["text"]) 157 | list_of_answer_strings = [] 158 | list_of_answer_starts = [] 159 | for answer in qa["answers"]: 160 | list_of_answer_strings.append(answer["text"]) 161 | list_of_answer_starts.append(answer["start"]) 162 | dataset_dictionary["answers"].append({ 163 | "text": list_of_answer_strings, 164 | "answer_start": list_of_answer_starts}) 165 | 166 | # Create the Hugging Face dataset 167 | if split == "train": 168 | train_dataset = datasets.Dataset.from_dict(dataset_dictionary) 169 | logging.info('train_dataset[0]: %s', train_dataset[0]) 170 | else: 171 | validation_dataset = datasets.Dataset.from_dict(dataset_dictionary) 172 | logging.info('validation_dataset[0]: %s', validation_dataset[0]) 173 | 174 | return train_dataset, validation_dataset 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Entropy Probes: Robust and Cheap Hallucination Detection in LLMs 2 | 3 | Jannik Kossen*, Jiatong Han*, Muhammed Razzak*, Lisa Schut, Shreshth Malik, Yarin Gal 4 | 5 | | **[Abstract](#Abstract)** 6 | | **[Citation](#Citation)** 7 | | **[Requirements](#Requirements)** 8 | | **[Installation](#Installation)** 9 | | **[Tutorial](#Tutorial)** 10 | | **[Codebase](#Codebase)** 11 | 12 | [![arXiv](https://img.shields.io/badge/arXiv-2406.15927-b31b1b.svg)](https://arxiv.org/abs/2406.15927) 13 | [![Python 3.11](https://img.shields.io/badge/python-3.11-blue.svg)](https://www.python.org/downloads/release/python-3110/) 14 | [![PyTorch](https://img.shields.io/badge/PyTorch-2.1-red.svg)](https://pytorch.org/get-started/locally/) 15 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 16 | [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://GitHub.com/Naereen/StrapDown.js/graphs/commit-activity) 17 | 18 | ## Abstract 19 | We propose semantic entropy probes (SEPs), a cheap and reliable method for uncertainty quantification in Large Language Models (LLMs). Hallucinations, which are plausible-sounding but factually incorrect and arbitrary model generations, present a major challenge to the practical adoption of LLMs. Recent work by [Farquhar et al. (2024)](https://www.nature.com/articles/s41586-024-07421-0) proposes semantic entropy (SE), which can detect hallucinations by estimating uncertainty in the space semantic meaning for a set of model generations. However, the 5-to-10-fold increase in computation cost associated with SE computation hinders practical adoption. To address this, we propose SEPs, which directly approximate SE from the hidden states of a single generation. SEPs are simple to train and do not require sampling multiple model generations at test time, reducing the overhead of semantic uncertainty quantification to almost zero. We show that SEPs retain high performance for hallucination detection and generalize better to out-of-distribution data than previous probing methods that directly predict model accuracy. Our results across models and tasks suggest that model hidden states capture SE, and our ablation studies give further insights into the token positions and model layers for which this is the case. 20 | 21 | ## Citation 22 | ``` 23 | @misc{kossen2024semanticentropyprobesrobust, 24 | title={Semantic Entropy Probes: Robust and Cheap Hallucination Detection in LLMs}, 25 | author={Jannik Kossen and Jiatong Han and Muhammed Razzak and Lisa Schut and Shreshth Malik and Yarin Gal}, 26 | year={2024}, 27 | eprint={2406.15927}, 28 | archivePrefix={arXiv}, 29 | primaryClass={cs.CL}, 30 | url={https://arxiv.org/abs/2406.15927}, 31 | } 32 | ``` 33 | 34 | ## Requirements 35 | 36 | ### Hardware Dependencies 37 | 38 | To obtain the hidden states of the large language model, you are required to do forward passes through the model on the relevant prompts/answers. Our code makes use of GPUs doing inference at FP16. 39 | 40 | Common memory requirements per model size: 41 | - 7B models: ~24 GB 42 | - 13B models: ~48GB 43 | - 70B model: ~160GB 44 | 45 | 46 | ### Software Dependencies 47 | 48 | Dependecies for this code include Python 3.11 and PyTorch 2.1. 49 | 50 | In `environment_export.yaml`, we list the precise versions for all Python packages. 51 | 52 | ## Installation 53 | 54 | 55 | To install Python with all necessary dependencies, we recommend you use conda. 56 | 57 | We refer to [https://conda.io/](https://conda.io/) for an installation guide. 58 | 59 | After installing conda, you can set up and activate the conda environment by executing the following commands at the root folder of this repository: 60 | 61 | ``` 62 | conda-env update -f sep_enviroment.yaml 63 | conda activate se_probes 64 | ``` 65 | 66 | 67 | Our experiments rely on [Weights & Biases](https://wandb.ai/) to log results. You may need to log in with your wandb API key upon initial execution. 68 | 69 | Our experiments rely on HuggingFace for all LLM models and most of the datasets. Set the environment variable `HUGGING_FACE_HUB_TOKEN` to the token associated with your Hugging Face account. For Llama models, [apply for access](https://huggingface.co/meta-llama) to use the official repository of Meta's LLaMa-2 models. 70 | 71 | 72 | Our experiments with sentence-length generation use GPT models from the OpenAI API. 73 | Please set the environment variable `OPENAI_API_KEY` to your OpenAI API key in order to use these models. 74 | Costs for reproducing our results vary depending on experiment configuration, but, without any guarantee, should lie somewhere between 10 and 100 USD. 75 | 76 | 77 | For almost all tasks, the dataset is downloaded automatically from HuggingFace Datasets library upon first execution. 78 | Only for bioasq, data needs to be [downloaded](http://participants-area.bioasq.org/datasets) manually. 79 | 80 | 81 | ## Tutorial 82 | 83 | ### Generate Semantic Entropy Probes Dataset 84 | 85 | Execute 86 | 87 | ``` 88 | python generate_answers.py --model_name=Llama-2-7b-chat --dataset=trivia_qa 89 | ``` 90 | 91 | to reproduce results for short-phrase generation with LLaMa-2 Chat (7B) on the TriviaQA dataset. 92 | 93 | The expected runtime of this demo is 1 hour using an A100 GPU, 24 cores of a Intel(R) Xeon(R) Gold 6248R CPU @ 3.00GHz, and 192 GB of RAM. 94 | Runtime may be longer upon first execution, as models need to be downloaded first. 95 | 96 | Note down the wandb id assigned to your demo run. 97 | 98 | To obtain a barplot similar to those of the paper, open the the iPython notebook in `semantic_entropy_probes/train-latent-probe.ipynb`, populate `wandb_id` with the id of your demo run, and execute all cells. 99 | 100 | ### Training Semantic Entropy Probes 101 | 102 | We retrieve saved model hidden states on two token positions (TBG, SLT) with which we train linear probes to predict model semantic uncertainty and further predict correctness. 103 | 104 | See [this notebook](./semantic_entropy_probes/latent-probe.ipynb) for step-by-step guide on training SEPs, which also contains handy tools for data loading, visualizations, and computing baselines. 105 | 106 | ## Codebase 107 | ### Repository Structure 108 | 109 | * Code to generate the semantic entropy is contained in the semantic uncertainty folder, and adapted from the repo for [semantic uncertainty](https://github.com/jlko/semantic_uncertainty). With in this, a standard SE generation run executes the following three scripts in order: 110 | 111 | 1. `generate_answers.py`: Sample responses (and their likelihods/hidden states) from the models for the questions. 112 | 2. `compute_uncertainties.py`: Compute uncertainty metrics given responses. 113 | 3. `analyze_results.py`: Compute aggregate performance metrics. 114 | * Once this is calculated, you can use the train_latent-probe.ipynb notebook, contained in the semantic_entropy_probes folder to train your SEPs. 115 | 116 | ### Reproducing the Experiments 117 | 118 | To reproduce the experiments of the paper, one just needs to run the above demo for the various combinations of models and datasets. 119 | 120 | The simplest way is to execute `slurm/run.sh` (with commands to generate both short-form and long-form answers to all datasets) if you are using `slurm`. 121 | 122 | Or you may directly execute iteratively 123 | 124 | ``` 125 | python generate_answers.py --model_name=$MODEL --dataset=$DATASET $EXTRA_CFG 126 | ``` 127 | 128 | where 129 | 130 | * `$MODEL` is one of: [`Llama-2-7b, Llama-2-13b, Llama-2-70b, Llama-2-7b-chat, Llama-2-13b-chat, Llama-2-70b-chat, falcon-7b, falcon-40b, falcon-7b-instruct, falcon-40b-instruct, Mistral-7B-v0.1, Mistral-7B-Instruct-v0.1, Phi-3-mini-128k-instruct`], 131 | * `$DATASET` is one of [`trivia_qa, squad, med_qa, bioasq, record, nq, svamp`], 132 | * and `$EXTRA_CFG` is empty for short-phrase generation and for sentence-length generation, `EXTRA_CFG=--num_few_shot=0 --model_max_new_tokens=100 --brief_prompt=chat --metric=llm_gpt-4 --entailment_model=gpt-3.5 --no-compute_accuracy_at_all_temps`. 133 | 134 | The results for any run can be obtained by passing their `wandb_id` to an evaluation notebook identical to the demonstration in `semantic_entropy_probes/train-latent-probe.ipynb`. 135 | -------------------------------------------------------------------------------- /semantic_uncertainty/analyze_results.py: -------------------------------------------------------------------------------- 1 | """Analyze uncertainty predictions.""" 2 | import argparse 3 | import functools 4 | import logging 5 | import os 6 | import pickle 7 | 8 | import numpy as np 9 | import wandb 10 | 11 | from uncertainty.utils import utils 12 | from uncertainty.utils.eval_utils import ( 13 | bootstrap, compatible_bootstrap, auroc, accuracy_at_quantile, 14 | area_under_thresholded_accuracy) 15 | 16 | 17 | utils.setup_logger() 18 | 19 | result_dict = {} 20 | 21 | UNC_MEAS = 'uncertainty_measures.pkl' 22 | 23 | 24 | def init_wandb(wandb_runid, assign_new_wandb_id, experiment_lot, entity): 25 | '''Initialize wandb session.''' 26 | user = os.environ['USER'] 27 | slurm_jobid = os.getenv('SLURM_JOB_ID') 28 | scratch_dir = os.getenv('SCRATCH_DIR', '.') 29 | kwargs = dict( 30 | entity=entity, 31 | project='semantic_uncertainty', 32 | dir=f'{scratch_dir}/{user}/uncertainty', 33 | notes=f'slurm_id: {slurm_jobid}, experiment_lot: {experiment_lot}', 34 | ) 35 | if not assign_new_wandb_id: 36 | # Restore wandb session. 37 | wandb.init( 38 | id=wandb_runid, 39 | resume=True, 40 | **kwargs) 41 | wandb.restore(UNC_MEAS) 42 | else: 43 | api = wandb.Api() 44 | wandb.init(**kwargs) 45 | 46 | old_run = api.run(f'{entity}/semantic_uncertainty/{wandb_runid}') 47 | old_run.file(UNC_MEAS).download( 48 | replace=True, exist_ok=False, root=wandb.run.dir) 49 | 50 | 51 | def analyze_run( 52 | wandb_runid, assign_new_wandb_id=False, answer_fractions_mode='default', 53 | experiment_lot=None, entity=None): 54 | 55 | '''Analyze the uncertainty measures for a given wandb run id.''' 56 | logging.info('Analyzing wandb_runid `%s`.', wandb_runid) 57 | 58 | # Set up evaluation metrics. 59 | if answer_fractions_mode == 'default': 60 | answer_fractions = [0.8, 0.9, 0.95, 1.0] 61 | elif answer_fractions_mode == 'finegrained': 62 | answer_fractions = [round(i, 3) for i in np.linspace(0, 1, 20+1)] 63 | else: 64 | raise ValueError 65 | 66 | rng = np.random.default_rng(41) 67 | eval_metrics = dict(zip( 68 | ['AUROC', 'area_under_thresholded_accuracy', 'mean_uncertainty'], 69 | list(zip( 70 | [auroc, area_under_thresholded_accuracy, np.mean], 71 | [compatible_bootstrap, compatible_bootstrap, bootstrap] 72 | )), 73 | )) 74 | for answer_fraction in answer_fractions: 75 | key = f'accuracy_at_{answer_fraction}_answer_fraction' 76 | eval_metrics[key] = [ 77 | functools.partial(accuracy_at_quantile, quantile=answer_fraction), 78 | compatible_bootstrap] 79 | 80 | if wandb.run is None: 81 | init_wandb( 82 | wandb_runid, assign_new_wandb_id=assign_new_wandb_id, 83 | experiment_lot=experiment_lot, entity=entity) 84 | 85 | elif wandb.run.id != wandb_runid: 86 | raise 87 | 88 | # Load the results dictionary from a pickle file. 89 | with open(f'{wandb.run.dir}/{UNC_MEAS}', 'rb') as file: 90 | results_old = pickle.load(file) 91 | 92 | result_dict = {'performance': {}, 'uncertainty': {}} 93 | 94 | # First: Compute Simple Accuracy metrics of the model predictions. 95 | all_accuracies = dict() 96 | if 'alt_validation_is_false' in results_old: 97 | all_accuracies.update({name: 1 - np.array(data) for name, data in results_old['alt_validation_is_false'].items()}) 98 | all_accuracies['accuracy'] = 1 - np.array(results_old['validation_is_false']) 99 | 100 | for name, target in all_accuracies.items(): 101 | result_dict['performance'][name] = {} 102 | result_dict['performance'][name]['mean'] = np.mean(target) 103 | result_dict['performance'][name]['bootstrap'] = bootstrap(np.mean, rng)(target) 104 | 105 | rum = results_old['uncertainty_measures'] 106 | if 'p_false' in rum and 'p_false_fixed' not in rum: 107 | # Restore log probs true: y = 1 - x --> x = 1 - y. 108 | # Convert to probs --> np.exp(1 - y). 109 | # Convert to p_false --> 1 - np.exp(1 - y). 110 | rum['p_false_fixed'] = [1 - np.exp(1 - x) for x in rum['p_false']] 111 | 112 | # Next: Uncertainty Measures. 113 | # Iterate through the dictionary and compute additional metrics for each measure. 114 | for measure_name, measure_values in rum.items(): 115 | logging.info('Computing for uncertainty measure `%s`.', measure_name) 116 | 117 | # Validation accuracy. 118 | validation_is_falses = [ 119 | results_old['validation_is_false'], 120 | results_old['validation_unanswerable'] 121 | ] 122 | 123 | logging_names = ['', '_UNANSWERABLE'] 124 | 125 | # Check if we have additional predictions for this measure. 126 | if 'alt_validation_is_false' in results_old: 127 | if measure_name in (u_m := results_old['alt_validation_is_false']): 128 | validation_is_falses.append(u_m[measure_name]) 129 | logging_names.append(f'_max_from_{measure_name}') 130 | 131 | # Iterate over predictions of 'falseness' or 'answerability'. 132 | for validation_is_false, logging_name in zip(validation_is_falses, logging_names): 133 | name = measure_name + logging_name 134 | result_dict['uncertainty'][name] = {} 135 | 136 | validation_is_false = np.array(validation_is_false) 137 | validation_accuracy = 1 - validation_is_false 138 | if len(measure_values) > len(validation_is_false): 139 | # This can happen, but only for p_false. 140 | if 'p_false' not in measure_name: 141 | raise ValueError 142 | logging.warning( 143 | 'More measure values for %s than in validation_is_false. Len(measure values): %d, Len(validation_is_false): %d', 144 | measure_name, len(measure_values), len(validation_is_false)) 145 | measure_values = measure_values[:len(validation_is_false)] 146 | 147 | fargs = { 148 | 'AUROC': [validation_is_false, measure_values], 149 | 'area_under_thresholded_accuracy': [validation_accuracy, measure_values], 150 | 'mean_uncertainty': [measure_values]} 151 | 152 | for answer_fraction in answer_fractions: 153 | fargs[f'accuracy_at_{answer_fraction}_answer_fraction'] = [validation_accuracy, measure_values] 154 | 155 | for fname, (function, bs_function) in eval_metrics.items(): 156 | metric_i = function(*fargs[fname]) 157 | result_dict['uncertainty'][name][fname] = {} 158 | result_dict['uncertainty'][name][fname]['mean'] = metric_i 159 | logging.info("%s for measure name `%s`: %f", fname, name, metric_i) 160 | result_dict['uncertainty'][name][fname]['bootstrap'] = bs_function( 161 | function, rng)(*fargs[fname]) 162 | 163 | wandb.log(result_dict) 164 | logging.info( 165 | 'Analysis for wandb_runid `%s` finished. Full results dict: %s', 166 | wandb_runid, result_dict 167 | ) 168 | 169 | 170 | if __name__ == '__main__': 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument('--wandb_runids', nargs='+', type=str, 173 | help='Wandb run ids of the datasets to evaluate on.') 174 | parser.add_argument('--assign_new_wandb_id', default=True, 175 | action=argparse.BooleanOptionalAction) 176 | parser.add_argument('--answer_fractions_mode', type=str, default='default') 177 | parser.add_argument( 178 | "--experiment_lot", type=str, default='Unnamed Experiment', 179 | help="Keep default wandb clean.") 180 | parser.add_argument( 181 | "--entity", type=str, help="Wandb entity.") 182 | 183 | args, unknown = parser.parse_known_args() 184 | if unknown: 185 | raise ValueError(f'Unkown args: {unknown}') 186 | 187 | wandb_runids = args.wandb_runids 188 | for wid in wandb_runids: 189 | logging.info('Evaluating wandb_runid `%s`.', wid) 190 | analyze_run( 191 | wid, args.assign_new_wandb_id, args.answer_fractions_mode, 192 | experiment_lot=args.experiment_lot, entity=args.entity) 193 | -------------------------------------------------------------------------------- /sep_enviroment.yaml: -------------------------------------------------------------------------------- 1 | name: se_probes 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - abseil-cpp=20211102.0=hd4dd3e8_0 10 | - aiohttp=3.8.5=py311h5eee18b_0 11 | - aiosignal=1.2.0=pyhd3eb1b0_0 12 | - arrow=1.2.3=py311h06a4308_1 13 | - arrow-cpp=11.0.0=h374c478_2 14 | - async-timeout=4.0.2=py311h06a4308_0 15 | - attrs=23.1.0=py311h06a4308_0 16 | - aws-c-common=0.6.8=h5eee18b_1 17 | - aws-c-event-stream=0.1.6=h6a678d5_6 18 | - aws-checksums=0.1.11=h5eee18b_2 19 | - aws-sdk-cpp=1.8.185=h721c034_1 20 | - binaryornot=0.4.4=pyhd3eb1b0_1 21 | - blas=1.0=mkl 22 | - boost-cpp=1.82.0=hdb19cb5_2 23 | - bottleneck=1.3.5=py311hbed6279_0 24 | - brotlipy=0.7.0=py311h5eee18b_1002 25 | - bzip2=1.0.8=h7b6447c_0 26 | - c-ares=1.19.1=h5eee18b_0 27 | - ca-certificates=2023.08.22=h06a4308_0 28 | - certifi=2023.11.17=py311h06a4308_0 29 | - cffi=1.15.1=py311h5eee18b_3 30 | - chardet=4.0.0=py311h06a4308_1003 31 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 32 | - click=8.0.4=py311h06a4308_0 33 | - cookiecutter=1.7.3=pyhd3eb1b0_0 34 | - cryptography=41.0.3=py311hdda0065_0 35 | - cuda-cudart=11.8.89=0 36 | - cuda-cupti=11.8.87=0 37 | - cuda-libraries=11.8.0=0 38 | - cuda-nvrtc=11.8.89=0 39 | - cuda-nvtx=11.8.86=0 40 | - cuda-runtime=11.8.0=0 41 | - datasets=2.12.0=py311h06a4308_0 42 | - dill=0.3.6=py311h06a4308_0 43 | - evaluate=0.4.0=py311h06a4308_0 44 | - ffmpeg=4.3=hf484d3e_0 45 | - filelock=3.9.0=py311h06a4308_0 46 | - freetype=2.12.1=h4a9f257_0 47 | - frozenlist=1.3.3=py311h5eee18b_0 48 | - fsspec=2023.9.2=py311h06a4308_0 49 | - gflags=2.2.2=he6710b0_0 50 | - giflib=5.2.1=h5eee18b_3 51 | - glog=0.5.0=h2531618_0 52 | - gmp=6.2.1=h295c915_3 53 | - gmpy2=2.1.2=py311hc9b5ff0_0 54 | - gnutls=3.6.15=he1e5248_0 55 | - grpc-cpp=1.48.2=he1ff14a_1 56 | - huggingface_hub=0.17.3=py311h06a4308_0 57 | - icu=73.1=h6a678d5_0 58 | - idna=3.4=py311h06a4308_0 59 | - intel-openmp=2023.1.0=hdb19cb5_46305 60 | - jinja2=3.1.2=py311h06a4308_0 61 | - jinja2-time=0.2.0=pyhd3eb1b0_3 62 | - jpeg=9e=h5eee18b_1 63 | - krb5=1.20.1=h143b758_1 64 | - lame=3.100=h7b6447c_0 65 | - lcms2=2.12=h3be6417_0 66 | - ld_impl_linux-64=2.38=h1181459_1 67 | - lerc=3.0=h295c915_0 68 | - libboost=1.82.0=h109eef0_2 69 | - libbrotlicommon=1.0.9=h5eee18b_7 70 | - libbrotlidec=1.0.9=h5eee18b_7 71 | - libbrotlienc=1.0.9=h5eee18b_7 72 | - libcublas=11.11.3.6=0 73 | - libcufft=10.9.0.58=0 74 | - libcufile=1.8.0.34=0 75 | - libcurand=10.3.4.52=0 76 | - libcurl=8.4.0=h251f7ec_0 77 | - libcusolver=11.4.1.48=0 78 | - libcusparse=11.7.5.86=0 79 | - libdeflate=1.17=h5eee18b_1 80 | - libedit=3.1.20221030=h5eee18b_0 81 | - libev=4.33=h7f8727e_1 82 | - libevent=2.1.12=hdbd6064_1 83 | - libffi=3.4.4=h6a678d5_0 84 | - libgcc-ng=11.2.0=h1234567_1 85 | - libgomp=11.2.0=h1234567_1 86 | - libiconv=1.16=h7f8727e_2 87 | - libidn2=2.3.4=h5eee18b_0 88 | - libjpeg-turbo=2.0.0=h9bf148f_0 89 | - libnghttp2=1.57.0=h2d74bed_0 90 | - libnpp=11.8.0.86=0 91 | - libnvjpeg=11.9.0.86=0 92 | - libpng=1.6.39=h5eee18b_0 93 | - libprotobuf=3.20.3=he621ea3_0 94 | - libssh2=1.10.0=hdbd6064_2 95 | - libstdcxx-ng=11.2.0=h1234567_1 96 | - libtasn1=4.19.0=h5eee18b_0 97 | - libthrift=0.15.0=h1795dd8_2 98 | - libtiff=4.5.1=h6a678d5_0 99 | - libunistring=0.9.10=h27cfd23_0 100 | - libuuid=1.41.5=h5eee18b_0 101 | - libwebp=1.3.2=h11a3e52_0 102 | - libwebp-base=1.3.2=h5eee18b_0 103 | - llvm-openmp=14.0.6=h9e868ea_0 104 | - lz4-c=1.9.4=h6a678d5_0 105 | - markupsafe=2.1.1=py311h5eee18b_0 106 | - mkl=2023.1.0=h213fc3f_46343 107 | - mkl-service=2.4.0=py311h5eee18b_1 108 | - mkl_fft=1.3.8=py311h5eee18b_0 109 | - mkl_random=1.2.4=py311hdb19cb5_0 110 | - mpc=1.1.0=h10f8cd9_1 111 | - mpfr=4.0.2=hb69a4c5_1 112 | - mpmath=1.3.0=py311h06a4308_0 113 | - multidict=6.0.2=py311h5eee18b_0 114 | - multiprocess=0.70.14=py311h06a4308_0 115 | - ncurses=6.4=h6a678d5_0 116 | - nettle=3.7.3=hbbd107a_1 117 | - networkx=3.1=py311h06a4308_0 118 | - numexpr=2.8.7=py311h65dcdc2_0 119 | - numpy=1.26.0=py311h08b1b3b_0 120 | - numpy-base=1.26.0=py311hf175353_0 121 | - openh264=2.1.1=h4ff587b_0 122 | - openjpeg=2.4.0=h3ad879b_0 123 | - openssl=3.0.12=h7f8727e_0 124 | - orc=1.7.4=hb3bc3d3_1 125 | - packaging=23.1=py311h06a4308_0 126 | - pillow=10.0.1=py311ha6cbd5a_0 127 | - pip=23.3.1=py311h06a4308_0 128 | - poyo=0.5.0=pyhd3eb1b0_0 129 | - pyarrow=11.0.0=py311hd8e8d9b_1 130 | - pycparser=2.21=pyhd3eb1b0_0 131 | - pyopenssl=23.2.0=py311h06a4308_0 132 | - pysocks=1.7.1=py311h06a4308_0 133 | - python=3.11.5=h955ad1f_0 134 | - python-dateutil=2.8.2=pyhd3eb1b0_0 135 | - python-slugify=5.0.2=pyhd3eb1b0_0 136 | - python-tzdata=2023.3=pyhd3eb1b0_0 137 | - python-xxhash=2.0.2=py311h5eee18b_1 138 | - pytorch=2.1.1=py3.11_cuda11.8_cudnn8.7.0_0 139 | - pytorch-cuda=11.8=h7e8668a_5 140 | - pytorch-mutex=1.0=cuda 141 | - pytz=2023.3.post1=py311h06a4308_0 142 | - pyyaml=6.0=py311h5eee18b_1 143 | - re2=2022.04.01=h295c915_0 144 | - readline=8.2=h5eee18b_0 145 | - requests=2.31.0=py311h06a4308_0 146 | - responses=0.13.3=pyhd3eb1b0_0 147 | - setuptools=68.0.0=py311h06a4308_0 148 | - six=1.16.0=pyhd3eb1b0_1 149 | - snappy=1.1.9=h295c915_0 150 | - sqlite=3.41.2=h5eee18b_0 151 | - sympy=1.11.1=py311h06a4308_0 152 | - tbb=2021.8.0=hdb19cb5_0 153 | - text-unidecode=1.3=pyhd3eb1b0_0 154 | - tk=8.6.12=h1ccaba5_0 155 | - torchaudio=2.1.1=py311_cu118 156 | - torchtriton=2.1.0=py311 157 | - torchvision=0.16.1=py311_cu118 158 | - typing-extensions=4.7.1=py311h06a4308_0 159 | - typing_extensions=4.7.1=py311h06a4308_0 160 | - tzdata=2023c=h04d1e81_0 161 | - unidecode=1.2.0=pyhd3eb1b0_0 162 | - urllib3=1.26.16=py311h06a4308_0 163 | - utf8proc=2.6.1=h27cfd23_0 164 | - wheel=0.41.2=py311h06a4308_0 165 | - xxhash=0.8.0=h7f8727e_3 166 | - xz=5.4.2=h5eee18b_0 167 | - yaml=0.2.5=h7b6447c_0 168 | - yarl=1.8.1=py311h5eee18b_0 169 | - zlib=1.2.13=h5eee18b_0 170 | - zstd=1.5.5=hc292b87_0 171 | - pip: 172 | - absl-py==2.0.0 173 | - accelerate==0.25.0 174 | - annotated-types==0.6.0 175 | - antlr4-python3-runtime==4.9.3 176 | - anyio==3.7.1 177 | - appdirs==1.4.4 178 | - argon2-cffi==23.1.0 179 | - argon2-cffi-bindings==21.2.0 180 | - asttokens==2.4.0 181 | - async-lru==2.0.4 182 | - babel==2.13.0 183 | - backcall==0.2.0 184 | - beautifulsoup4==4.12.2 185 | - bitsandbytes==0.41.2.post2 186 | - bleach==6.1.0 187 | - comm==0.1.4 188 | - contextlib2==21.6.0 189 | - contourpy==1.1.1 190 | - cycler==0.12.1 191 | - debugpy==1.8.0 192 | - decorator==5.1.1 193 | - defusedxml==0.7.1 194 | - distro==1.8.0 195 | - docker-pycreds==0.4.0 196 | - einops==0.7.0 197 | - executing==2.0.0 198 | - fastjsonschema==2.18.1 199 | - flake8==6.1.0 200 | - fonttools==4.43.1 201 | - fqdn==1.5.1 202 | - gitdb==4.0.11 203 | - gitpython==3.1.40 204 | - h11==0.14.0 205 | - httpcore==1.0.1 206 | - httpx==0.25.1 207 | - ipykernel==6.25.2 208 | - ipython==8.16.1 209 | - ipywidgets==8.1.1 210 | - isoduration==20.11.0 211 | - jedi==0.19.1 212 | - joblib==1.3.2 213 | - json5==0.9.14 214 | - jsonpointer==2.4 215 | - jsonschema==4.19.1 216 | - jsonschema-specifications==2023.7.1 217 | - jupyter-client==8.4.0 218 | - jupyter-core==5.4.0 219 | - jupyter-events==0.8.0 220 | - jupyter-lsp==2.2.0 221 | - jupyter-server==2.8.0 222 | - jupyter-server-terminals==0.4.4 223 | - jupyterlab==4.0.9 224 | - jupyterlab-pygments==0.2.2 225 | - jupyterlab-server==2.25.0 226 | - jupyterlab-widgets==3.0.9 227 | - kiwisolver==1.4.5 228 | - lightning-utilities==0.9.0 229 | - matplotlib==3.8.2 230 | - matplotlib-inline==0.1.6 231 | - mccabe==0.7.0 232 | - mistune==3.0.2 233 | - ml-collections==0.1.1 234 | - nbclient==0.8.0 235 | - nbconvert==7.9.2 236 | - nbformat==5.9.2 237 | - nest-asyncio==1.5.8 238 | - nltk==3.8.1 239 | - notebook==7.0.6 240 | - notebook-shim==0.2.3 241 | - omegaconf==2.3.0 242 | - openai==1.3.7 243 | - overrides==7.4.0 244 | - pandas==2.1.3 245 | - pandocfilters==1.5.0 246 | - parso==0.8.3 247 | - pathtools==0.1.2 248 | - pexpect==4.8.0 249 | - pickleshare==0.7.5 250 | - platformdirs==3.11.0 251 | - prometheus-client==0.17.1 252 | - prompt-toolkit==3.0.39 253 | - protobuf==4.24.4 254 | - psutil==5.9.6 255 | - ptyprocess==0.7.0 256 | - pure-eval==0.2.2 257 | - pycodestyle==2.11.1 258 | - pydantic==2.4.2 259 | - pydantic-core==2.10.1 260 | - pyflakes==3.1.0 261 | - pygments==2.16.1 262 | - pyparsing==3.1.1 263 | - python-json-logger==2.0.7 264 | - pyzmq==25.1.1 265 | - referencing==0.30.2 266 | - regex==2023.10.3 267 | - rfc3339-validator==0.1.4 268 | - rfc3986-validator==0.1.1 269 | - rpds-py==0.10.6 270 | - safetensors==0.4.1 271 | - scikit-learn==1.3.2 272 | - scipy==1.11.4 273 | - seaborn==0.13.0 274 | - send2trash==1.8.2 275 | - sentencepiece==0.1.99 276 | - sentry-sdk==1.32.0 277 | - setproctitle==1.3.3 278 | - smmap==5.0.1 279 | - sniffio==1.3.0 280 | - soupsieve==2.5 281 | - stack-data==0.6.3 282 | - tenacity==8.2.3 283 | - terminado==0.17.1 284 | - threadpoolctl==3.2.0 285 | - tiktoken==0.5.2 286 | - tinycss2==1.2.1 287 | - tokenizers==0.15.0 288 | - torchmetrics==1.2.1 289 | - tornado==6.3.3 290 | - tqdm==4.66.1 291 | - traitlets==5.11.2 292 | - transformers==4.35.2 293 | - uri-template==1.3.0 294 | - wandb==0.16.0 295 | - wcwidth==0.2.8 296 | - webcolors==1.13 297 | - webencodings==0.5.1 298 | - websocket-client==1.6.4 299 | - widgetsnbextension==4.0.9 300 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/uncertainty_measures/semantic_entropy.py: -------------------------------------------------------------------------------- 1 | """Implement semantic entropy.""" 2 | import os 3 | import pickle 4 | import logging 5 | 6 | import random 7 | import numpy as np 8 | import wandb 9 | import openai 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 14 | 15 | from uncertainty.models.huggingface_models import HuggingfaceModel 16 | from uncertainty.utils import openai as oai 17 | from uncertainty.utils import utils 18 | 19 | 20 | random.seed(10) 21 | 22 | # Set up OpenAI API credentials 23 | openai.api_key = os.getenv("OPENAI_API_KEY") 24 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 25 | 26 | 27 | class BaseEntailment: 28 | 29 | def save_prediction_cache(self): 30 | pass 31 | 32 | 33 | class EntailmentDeberta(BaseEntailment): 34 | def __init__(self): 35 | self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xlarge-mnli") 36 | self.model = AutoModelForSequenceClassification.from_pretrained( 37 | "microsoft/deberta-v2-xlarge-mnli").to(DEVICE) 38 | 39 | def check_implication(self, text1, text2, *args, **kwargs): 40 | inputs = self.tokenizer(text1, text2, return_tensors="pt").to(DEVICE) 41 | # The model checks if text1 -> text2, i.e. if text2 follows from text1. 42 | # check_implication('The weather is good', 'The weather is good and I like you') --> 1 43 | # check_implication('The weather is good and I like you', 'The weather is good') --> 2 44 | outputs = self.model(**inputs) 45 | logits = outputs.logits 46 | # Deberta-mnli returns `neutral` and `entailment` classes at indices 1 and 2. 47 | largest_index = torch.argmax(F.softmax(logits, dim=1)) # pylint: disable=no-member 48 | prediction = largest_index.cpu().item() 49 | if os.environ.get('DEBERTA_FULL_LOG', False): 50 | logging.info('Deberta Input: %s -> %s', text1, text2) 51 | logging.info('Deberta Prediction: %s', prediction) 52 | 53 | return prediction 54 | 55 | 56 | class EntailmentLLM(BaseEntailment): 57 | 58 | entailment_file = 'entailment_cache.pkl' 59 | 60 | def __init__(self, entailment_cache_id, entailment_cache_only): 61 | self.prediction_cache = self.init_prediction_cache(entailment_cache_id) 62 | self.entailment_cache_only = entailment_cache_only 63 | 64 | def init_prediction_cache(self, entailment_cache_id): 65 | if entailment_cache_id is None: 66 | return dict() 67 | 68 | logging.info('Restoring prediction cache from %s', entailment_cache_id) 69 | 70 | api = wandb.Api() 71 | run = api.run(entailment_cache_id) 72 | run.file(self.entailment_file).download( 73 | replace=True, exist_ok=False, root=wandb.run.dir) 74 | 75 | with open(f'{wandb.run.dir}/{self.entailment_file}', "rb") as infile: 76 | return pickle.load(infile) 77 | 78 | def save_prediction_cache(self): 79 | # write the dictionary to a pickle file 80 | utils.save(self.prediction_cache, self.entailment_file) 81 | 82 | def check_implication(self, text1, text2, example=None): 83 | if example is None: 84 | raise ValueError 85 | prompt = self.equivalence_prompt(text1, text2, example['question']) 86 | 87 | logging.info('%s input: %s', self.name, prompt) 88 | 89 | hashed = oai.md5hash(prompt) 90 | if hashed in self.prediction_cache: 91 | logging.info('Restoring hashed instead of predicting with model.') 92 | response = self.prediction_cache[hashed] 93 | else: 94 | if self.entailment_cache_only: 95 | raise ValueError 96 | response = self.predict(prompt, temperature=0.02) 97 | self.prediction_cache[hashed] = response 98 | 99 | logging.info('%s prediction: %s', self.name, response) 100 | 101 | binary_response = response.lower()[:30] 102 | if 'entailment' in binary_response: 103 | return 2 104 | elif 'neutral' in binary_response: 105 | return 1 106 | elif 'contradiction' in binary_response: 107 | return 0 108 | else: 109 | logging.warning('MANUAL NEUTRAL!') 110 | return 1 111 | 112 | 113 | class EntailmentGPT4(EntailmentLLM): 114 | 115 | def __init__(self, entailment_cache_id, entailment_cache_only): 116 | super().__init__(entailment_cache_id, entailment_cache_only) 117 | self.name = 'gpt-4' 118 | 119 | def equivalence_prompt(self, text1, text2, question): 120 | 121 | prompt = f"""We are evaluating answers to the question \"{question}\"\n""" 122 | 123 | # To precise. 124 | prompt += "Here are two possible answers:\n" 125 | # Ah! This is much closer to what we are doing! 126 | # prompt = prompt + f"""Does at least one of the following two possible answers entail the other? 127 | # Still to precise. 128 | prompt += f"Possible Answer 1: {text1}\nPossible Answer 2: {text2}\n" 129 | prompt += "Does Possible Answer 1 semantically entail Possible Answer 2? Respond with entailment, contradiction, or neutral.""" 130 | 131 | return prompt 132 | 133 | def predict(self, prompt, temperature): 134 | return oai.predict(prompt, temperature, model=self.name) 135 | 136 | 137 | class EntailmentGPT35(EntailmentGPT4): 138 | 139 | def __init__(self, entailment_cache_id, entailment_cache_only): 140 | super().__init__(entailment_cache_id, entailment_cache_only) 141 | self.name = 'gpt-3.5' 142 | 143 | 144 | class EntailmentLlama(EntailmentLLM): 145 | 146 | def __init__(self, entailment_cache_id, entailment_cache_only, name): 147 | super().__init__(entailment_cache_id, entailment_cache_only) 148 | self.name = name 149 | self.model = HuggingfaceModel( 150 | name, stop_sequences='default', max_new_tokens=30) 151 | 152 | def equivalence_prompt(self, text1, text2, question): 153 | 154 | prompt = f"""We are evaluating answers to the question \"{question}\"\n""" 155 | 156 | prompt += "Here are two possible answers:\n" 157 | prompt += f"Possible Answer 1: {text1}\nPossible Answer 2: {text2}\n" 158 | prompt += "Does Possible Answer 1 semantically entail Possible Answer 2? Respond only with entailment, contradiction, or neutral.\n""" 159 | prompt += "Response:""" 160 | 161 | return prompt 162 | 163 | def predict(self, prompt, temperature): 164 | predicted_answer, _, _ = self.model.predict(prompt, temperature) 165 | return predicted_answer 166 | 167 | 168 | def context_entails_response(context, responses, model): 169 | votes = [] 170 | for response in responses: 171 | votes.append(model.check_implication(context, response)) 172 | return 2 - np.mean(votes) 173 | 174 | 175 | def get_semantic_ids(strings_list, model, strict_entailment=False, example=None): 176 | """Group list of predictions into semantic meaning.""" 177 | 178 | def are_equivalent(text1, text2): 179 | 180 | implication_1 = model.check_implication(text1, text2, example=example) 181 | implication_2 = model.check_implication(text2, text1, example=example) # pylint: disable=arguments-out-of-order 182 | assert (implication_1 in [0, 1, 2]) and (implication_2 in [0, 1, 2]) 183 | 184 | if strict_entailment: 185 | semantically_equivalent = (implication_1 == 2) and (implication_2 == 2) 186 | 187 | else: 188 | implications = [implication_1, implication_2] 189 | # Check if none of the implications are 0 (contradiction) and not both of them are neutral. 190 | semantically_equivalent = (0 not in implications) and ([1, 1] != implications) 191 | 192 | return semantically_equivalent 193 | 194 | # Initialise all ids with -1. 195 | semantic_set_ids = [-1] * len(strings_list) 196 | # Keep track of current id. 197 | next_id = 0 198 | for i, string1 in enumerate(strings_list): 199 | # Check if string1 already has an id assigned. 200 | if semantic_set_ids[i] == -1: 201 | # If string1 has not been assigned an id, assign it next_id. 202 | semantic_set_ids[i] = next_id 203 | for j in range(i+1, len(strings_list)): 204 | # Search through all remaining strings. If they are equivalent to string1, assign them the same id. 205 | if are_equivalent(string1, strings_list[j]): 206 | semantic_set_ids[j] = next_id 207 | next_id += 1 208 | 209 | assert -1 not in semantic_set_ids 210 | 211 | return semantic_set_ids 212 | 213 | 214 | def logsumexp_by_id(semantic_ids, log_likelihoods, agg='sum'): 215 | """Sum probabilities with the same semantic id. 216 | 217 | Log-Sum-Exp because input and output probabilities in log space. 218 | """ 219 | unique_ids = sorted(list(set(semantic_ids))) 220 | assert unique_ids == list(range(len(unique_ids))) 221 | log_likelihood_per_semantic_id = [] 222 | 223 | for uid in unique_ids: 224 | id_indices = [pos for pos, x in enumerate(semantic_ids) if x == uid] 225 | id_log_likelihoods = [log_likelihoods[i] for i in id_indices] 226 | if agg == 'sum': 227 | logsumexp_value = np.log(np.sum(np.exp(id_log_likelihoods))) - 5.0 228 | elif agg == 'sum_normalized': 229 | log_lik_norm = id_log_likelihoods - np.log(np.sum(np.exp(log_likelihoods))) 230 | logsumexp_value = np.log(np.sum(np.exp(log_lik_norm))) 231 | elif agg == 'mean': 232 | logsumexp_value = np.log(np.mean(np.exp(id_log_likelihoods))) 233 | else: 234 | raise ValueError 235 | log_likelihood_per_semantic_id.append(logsumexp_value) 236 | 237 | return log_likelihood_per_semantic_id 238 | 239 | 240 | def predictive_entropy(log_probs): 241 | """Compute MC estimate of entropy. 242 | 243 | `E[-log p(x)] ~= -1/N sum_i log p(x_i)` where i are the is the sequence 244 | likelihood, i.e. the average token likelihood. 245 | """ 246 | 247 | entropy = -np.sum(log_probs) / len(log_probs) 248 | 249 | return entropy 250 | 251 | 252 | def predictive_entropy_rao(log_probs): 253 | entropy = -np.sum(np.exp(log_probs) * log_probs) 254 | return entropy 255 | 256 | 257 | def cluster_assignment_entropy(semantic_ids): 258 | """Estimate semantic uncertainty from how often different clusters get assigned. 259 | 260 | We estimate the categorical distribution over cluster assignments from the 261 | semantic ids. The uncertainty is then given by the entropy of that 262 | distribution. This estimate does not use token likelihoods, it relies soley 263 | on the cluster assignments. If probability mass is spread of between many 264 | clusters, entropy is larger. If probability mass is concentrated on a few 265 | clusters, entropy is small. 266 | 267 | Input: 268 | semantic_ids: List of semantic ids, e.g. [0, 1, 2, 1]. 269 | Output: 270 | cluster_entropy: Entropy, e.g. (-p log p).sum() for p = [1/4, 2/4, 1/4]. 271 | """ 272 | 273 | n_generations = len(semantic_ids) 274 | counts = np.bincount(semantic_ids) 275 | probabilities = counts/n_generations 276 | assert np.isclose(probabilities.sum(), 1) 277 | entropy = - (probabilities * np.log(probabilities)).sum() 278 | return entropy 279 | -------------------------------------------------------------------------------- /semantic_uncertainty/generate_answers.py: -------------------------------------------------------------------------------- 1 | """Predict with LLM on task.""" 2 | import gc 3 | import os 4 | import logging 5 | import random 6 | from tqdm import tqdm 7 | 8 | import numpy as np 9 | import torch 10 | import openai 11 | import wandb 12 | 13 | from uncertainty.data.data_utils import load_ds 14 | from uncertainty.utils import utils 15 | from uncertainty.uncertainty_measures import p_true as p_true_utils 16 | from compute_uncertainty_measures import main as main_compute 17 | 18 | 19 | utils.setup_logger() 20 | openai.api_key = os.getenv("OPENAI_API_KEY") # Set up OpenAI API credentials. 21 | 22 | 23 | def main(args): 24 | if args.dataset == 'svamp': 25 | if not args.use_context: 26 | logging.info('Forcing `use_context=True` for svamp dataset.') 27 | args.use_context = True 28 | elif args.dataset == 'squad': 29 | if not args.answerable_only: 30 | logging.info('Forcing `answerable_only=True` for squad dataset.') 31 | args.answerable_only = True 32 | 33 | experiment_details = {'args': args} 34 | random.seed(args.random_seed) 35 | 36 | # Implement 37 | user = os.environ['USER'] 38 | entity = os.environ['WANDB_ENT'] 39 | slurm_jobid = os.getenv('SLURM_JOB_ID', None) 40 | scratch_dir = os.getenv('SCRATCH_DIR', '.') 41 | if not os.path.exists(f"{scratch_dir}/{user}/uncertainty"): 42 | os.makedirs(f"{scratch_dir}/{user}/uncertainty") 43 | 44 | wandb.init( 45 | entity=entity, 46 | project="semantic_uncertainty" if not args.debug else "semantic_uncertainty_debug", 47 | dir=f"{scratch_dir}/{user}/uncertainty", 48 | config=args, 49 | notes=f'slurm_id: {slurm_jobid}, experiment_lot: {args.experiment_lot}', 50 | ) 51 | logging.info('Finished wandb init.') 52 | 53 | metric = utils.get_metric(args.metric) 54 | 55 | train_dataset, validation_dataset = load_ds( 56 | args.dataset, add_options=args.use_mc_options, seed=args.random_seed) 57 | if args.ood_train_dataset is not None: 58 | logging.warning( 59 | 'Using OOD dataset %s to construct few-shot prompts and train p_ik.', 60 | args.ood_train_dataset) 61 | # Get indices of answerable and unanswerable questions and construct prompt. 62 | train_dataset, _ = load_ds(args.ood_train_dataset, add_options=args.use_mc_options) 63 | if not isinstance(train_dataset, list): 64 | logging.info('Train dataset: %s', train_dataset) 65 | 66 | # Get indices of answerable and unanswerable questions and construct prompt. 67 | answerable_indices, unanswerable_indices = utils.split_dataset(train_dataset) 68 | 69 | if args.answerable_only: 70 | unanswerable_indices = [] 71 | val_answerable, val_unanswerable = utils.split_dataset(validation_dataset) 72 | del val_unanswerable 73 | validation_dataset = [validation_dataset[i] for i in val_answerable] 74 | 75 | prompt_indices = random.sample(answerable_indices, args.num_few_shot) 76 | experiment_details['prompt_indices'] = prompt_indices 77 | remaining_answerable = list(set(answerable_indices) - set(prompt_indices)) 78 | 79 | # Create Few-Shot prompt. 80 | make_prompt = utils.get_make_prompt(args) 81 | BRIEF = utils.BRIEF_PROMPTS[args.brief_prompt] 82 | arg = args.brief_always if args.enable_brief else True 83 | prompt = utils.construct_fewshot_prompt_from_indices( 84 | train_dataset, prompt_indices, BRIEF, arg, make_prompt) 85 | experiment_details['prompt'] = prompt 86 | experiment_details['BRIEF'] = BRIEF 87 | logging.info('Prompt is: %s', prompt) 88 | 89 | # Initialize model. 90 | model = utils.init_model(args) 91 | 92 | # Initialize prompt for p_true baseline. 93 | if args.compute_p_true: 94 | logging.info(80*'#') 95 | logging.info('Constructing few-shot prompt for p_true.') 96 | 97 | p_true_indices = random.sample(answerable_indices, args.p_true_num_fewshot) 98 | remaining_answerable = list(set(remaining_answerable) - set(p_true_indices)) 99 | p_true_few_shot_prompt, p_true_responses, len_p_true = p_true_utils.construct_few_shot_prompt( 100 | model=model, dataset=train_dataset, indices=p_true_indices, 101 | prompt=prompt, brief=BRIEF, 102 | brief_always=args.brief_always and args.enable_brief, 103 | make_prompt=make_prompt, num_generations=args.num_generations, 104 | metric=metric) 105 | wandb.config.update( 106 | {'p_true_num_fewshot': len_p_true}, allow_val_change=True) 107 | wandb.log(dict(len_p_true=len_p_true)) 108 | experiment_details['p_true_indices'] = p_true_indices 109 | experiment_details['p_true_responses'] = p_true_responses 110 | experiment_details['p_true_few_shot_prompt'] = p_true_few_shot_prompt 111 | logging.info('Finished constructing few-shot prompt for p_true.') 112 | logging.info(80*'#') 113 | logging.info('p_true_few_shot_prompt: %s', p_true_few_shot_prompt) 114 | logging.info(80*'#') 115 | 116 | # Start answer generation. 117 | logging.info(80 * '=') 118 | logging.info('Generating answers: ') 119 | logging.info(80 * '=') 120 | for dataset_split in ['train', 'validation']: 121 | logging.info(80 * 'x') 122 | logging.info('Starting with dataset_split %s.', dataset_split) 123 | logging.info(80 * 'x') 124 | 125 | # This will store all input data and model predictions. 126 | accuracies, generations, results_dict, p_trues = [], {}, {}, [] 127 | 128 | if dataset_split == 'train': 129 | if not args.get_training_set_generations: 130 | logging.info('Skip training data.') 131 | continue 132 | dataset = train_dataset 133 | possible_indices = list(set(remaining_answerable) | set(unanswerable_indices)) 134 | 135 | else: 136 | dataset = validation_dataset 137 | possible_indices = range(0, len(dataset)) 138 | 139 | # Evaluate over random subset of the datasets. 140 | indices = random.sample(possible_indices, min(args.num_samples, len(dataset))) 141 | experiment_details[dataset_split] = {'indices': indices} 142 | 143 | if args.num_samples > len(dataset): 144 | logging.warning('Not enough samples in dataset. Using all %d samples.', len(dataset)) 145 | 146 | it = 0 147 | for index in tqdm(indices): 148 | if (it + 1 % 10) == 0: 149 | gc.collect() 150 | torch.cuda.empty_cache() 151 | it += 1 152 | 153 | # Grab example at index. 154 | example = dataset[index] 155 | question, context = example["question"], example['context'] 156 | generations[example['id']] = {'question': question, 'context': context} 157 | correct_answer = example['answers']['text'] 158 | 159 | current_input = make_prompt( 160 | context, question, None, BRIEF, args.brief_always and args.enable_brief) 161 | local_prompt = prompt + current_input 162 | 163 | logging.info('Current input: '.ljust(15) + current_input) 164 | 165 | full_responses = [] 166 | 167 | # We sample 1 low temperature answer on which we will compute the 168 | # accuracy and args.num_generation high temperature answers which will 169 | # be used to estimate the entropy. 170 | 171 | if dataset_split == 'train' and args.get_training_set_generations_most_likely_only: 172 | num_generations = 1 173 | else: 174 | num_generations = args.num_generations + 1 175 | 176 | for i in range(num_generations): 177 | 178 | # Temperature for first generation is always `0.1`. 179 | temperature = 0.1 if i == 0 else args.temperature 180 | 181 | predicted_answer, token_log_likelihoods, (embedding, emb_last_before_gen, emb_before_eos) = model.predict(local_prompt, temperature, return_latent=True) 182 | 183 | # Last token embedding 184 | embedding = embedding.cpu() if embedding is not None else None 185 | emb_last_before_gen = emb_last_before_gen.cpu() if emb_last_before_gen is not None else None 186 | emb_before_eos = emb_before_eos.cpu() if emb_before_eos is not None else None 187 | 188 | compute_acc = args.compute_accuracy_at_all_temps or (i == 0) 189 | if correct_answer and compute_acc: 190 | acc = metric(predicted_answer, example, model) 191 | else: 192 | acc = 0.0 # pylint: disable=invalid-name 193 | 194 | if i == 0: 195 | # Logging. 196 | logging.info('Iteration ' + str(it) + ': ' + 80*'#') 197 | if args.use_context: 198 | logging.info('context: '.ljust(15) + str(context)) 199 | logging.info('question: '.ljust(15) + question) 200 | logging.info('low-t prediction: '.ljust(15) + predicted_answer) 201 | logging.info('correct answer: '.ljust(15) + str(correct_answer)) 202 | logging.info('accuracy: '.ljust(15) + str(acc)) 203 | 204 | accuracies.append(acc) 205 | most_likely_answer_dict = { 206 | 'response': predicted_answer, 207 | 'token_log_likelihoods': token_log_likelihoods, 208 | 'embedding': embedding, 209 | 'accuracy': acc, 210 | 'emb_last_tok_before_gen': emb_last_before_gen, 211 | 'emb_tok_before_eos': emb_before_eos, 212 | } 213 | 214 | generations[example['id']].update({ 215 | 'most_likely_answer': most_likely_answer_dict, 216 | 'reference': utils.get_reference(example), 217 | }) 218 | else: 219 | logging.info('high-t prediction '.ljust(15) + str(i) + ' : ' + predicted_answer) 220 | # Aggregate predictions over num_generations. 221 | full_responses.append( 222 | (predicted_answer, token_log_likelihoods, embedding, acc)) 223 | 224 | # Append all predictions for this example to `generations`. 225 | generations[example['id']]['responses'] = full_responses 226 | 227 | if args.compute_p_true and dataset_split == 'validation': 228 | # Already compute p_true here. Avoid cost of generations in compute_uncertainty script. 229 | p_true = p_true_utils.calculate_p_true( 230 | model, question, most_likely_answer_dict['response'], 231 | [r[0] for r in full_responses], p_true_few_shot_prompt, 232 | hint=args.p_true_hint) 233 | p_trues.append(p_true) 234 | logging.info('p_true: %s', p_true) 235 | 236 | # Save generations for that split. 237 | utils.save(generations, f'{dataset_split}_generations.pkl') 238 | 239 | # Log overall accuracy. 240 | accuracy = np.mean(accuracies) 241 | print(f"Overall {dataset_split} split accuracy: {accuracy}") 242 | wandb.log({f"{dataset_split}_accuracy": accuracy}) 243 | 244 | if dataset_split == 'validation': 245 | if args.compute_p_true: 246 | results_dict['uncertainty_measures'] = { 247 | 'p_false': [1 - p for p in p_trues], 248 | 'p_false_fixed': [1 - np.exp(p) for p in p_trues], 249 | } 250 | utils.save(results_dict, 'uncertainty_measures.pkl') 251 | 252 | utils.save(experiment_details, 'experiment_details.pkl') 253 | logging.info('Run complete.') 254 | del model 255 | 256 | 257 | if __name__ == '__main__': 258 | 259 | parser = utils.get_parser() 260 | args, unknown = parser.parse_known_args() 261 | logging.info('Starting new run with args: %s', args) 262 | 263 | if unknown: 264 | raise ValueError(f'Unkown args: {unknown}') 265 | 266 | if args.compute_uncertainties: 267 | args.assign_new_wandb_id = False 268 | 269 | logging.info('STARTING `generate_answers`!') 270 | main(args) 271 | logging.info('FINISHED `generate_answers`!') 272 | 273 | if args.compute_uncertainties: 274 | logging.info(50 * '#X') 275 | logging.info('STARTING `compute_uncertainty_measures`!') 276 | main_compute(args) 277 | logging.info('FINISHED `compute_uncertainty_measures`!') 278 | -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/utils/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | import os 3 | import logging 4 | import argparse 5 | import pickle 6 | 7 | import wandb 8 | 9 | from evaluate import load 10 | 11 | from uncertainty.models.huggingface_models import HuggingfaceModel 12 | from uncertainty.utils import openai as oai 13 | 14 | BRIEF_PROMPTS = { 15 | 'default': "Answer the following question as briefly as possible.\n", 16 | 'chat': 'Answer the following question in a single brief but complete sentence.\n'} 17 | 18 | 19 | def get_parser(stages=['generate', 'compute']): 20 | entity = os.getenv('WANDB_SEM_UNC_ENTITY', None) 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | "--debug", action=argparse.BooleanOptionalAction, default=False, 25 | help="Keep default wandb clean.") 26 | parser.add_argument('--entity', type=str, default=entity) 27 | parser.add_argument('--random_seed', type=int, default=10) 28 | parser.add_argument( 29 | "--metric", type=str, default="squad", 30 | choices=['squad', 'llm', 'llm_gpt-3.5', 'llm_gpt-4'], 31 | help="Metric to assign accuracy to generations.") 32 | parser.add_argument( 33 | "--compute_accuracy_at_all_temps", 34 | action=argparse.BooleanOptionalAction, default=True, 35 | help="Compute accuracy at all temperatures or only t<<1.") 36 | parser.add_argument( 37 | "--experiment_lot", type=str, default='Unnamed Experiment', 38 | help="Keep default wandb clean.") 39 | if 'generate' in stages: 40 | parser.add_argument( 41 | "--model_name", type=str, default="Llama-2-7b-chat", help="Model name", 42 | ) 43 | parser.add_argument( 44 | "--model_max_new_tokens", type=int, default=50, 45 | help="Max number of tokens generated.", 46 | ) 47 | parser.add_argument( 48 | "--dataset", type=str, default="trivia_qa", 49 | choices=['trivia_qa', 'squad', 'bioasq', 'nq', 'svamp'], 50 | help="Dataset to use") 51 | parser.add_argument( 52 | "--ood_train_dataset", type=str, default=None, 53 | choices=['trivia_qa', 'squad', 'bioasq', 'nq', 'svamp'], 54 | help="Dataset to use to assemble few-shot prompt, p_true prompt, and train p_ik.") 55 | parser.add_argument( 56 | "--num_samples", type=int, default=400, 57 | help="Number of samples to use") 58 | parser.add_argument( 59 | "--num_few_shot", type=int, default=5, 60 | help="Number of few shot examples to use") 61 | parser.add_argument( 62 | "--p_true_num_fewshot", type=int, default=20, 63 | help="Number of few shot examples to use") 64 | parser.add_argument( 65 | "--p_true_hint", default=False, 66 | action=argparse.BooleanOptionalAction, 67 | help="Get generations for training set?") 68 | parser.add_argument( 69 | "--num_generations", type=int, default=10, 70 | help="Number of generations to use") 71 | parser.add_argument( 72 | "--temperature", type=float, default=1.0, 73 | help="Temperature") 74 | parser.add_argument( 75 | "--use_mc_options", type=bool, default=True, 76 | help="Include MC options question?") 77 | parser.add_argument( 78 | "--get_training_set_generations", default=True, 79 | action=argparse.BooleanOptionalAction, 80 | help="Get generations for training set?") 81 | parser.add_argument( 82 | "--use_context", default=False, 83 | action=argparse.BooleanOptionalAction, 84 | help="Get generations for training set?") 85 | parser.add_argument( 86 | "--get_training_set_generations_most_likely_only", default=True, 87 | action=argparse.BooleanOptionalAction, 88 | help=( 89 | "Only get embedding of most likely answer for training set. " 90 | "This is all that's needed for p_true.")) 91 | parser.add_argument('--compute_p_true', default=True, 92 | action=argparse.BooleanOptionalAction) 93 | parser.add_argument( 94 | "--brief_always", default=False, action=argparse.BooleanOptionalAction) 95 | parser.add_argument( 96 | "--enable_brief", default=True, action=argparse.BooleanOptionalAction) 97 | parser.add_argument( 98 | "--brief_prompt", default='default', type=str) 99 | parser.add_argument( 100 | "--prompt_type", default='default', type=str) 101 | parser.add_argument( 102 | "--compute_uncertainties", default=True, 103 | action=argparse.BooleanOptionalAction, 104 | help='Trigger compute_uncertainty_measures.py') 105 | parser.add_argument( 106 | "--answerable_only", default=False, 107 | action=argparse.BooleanOptionalAction, 108 | help='Exclude unanswerable questions.') 109 | 110 | if 'compute' in stages: 111 | parser.add_argument('--recompute_accuracy', 112 | default=False, action=argparse.BooleanOptionalAction) 113 | parser.add_argument('--eval_wandb_runid', type=str, 114 | help='wandb run id of the dataset to evaluate on') 115 | parser.add_argument('--train_wandb_runid', type=str, default=None, 116 | help='wandb run id of the dataset from which training embeddings and p_true samples will be taken') 117 | parser.add_argument('--num_eval_samples', type=int, default=int(1e19)) 118 | parser.add_argument('--compute_predictive_entropy', 119 | default=True, action=argparse.BooleanOptionalAction) 120 | parser.add_argument('--compute_p_ik', default=True, 121 | action=argparse.BooleanOptionalAction) 122 | parser.add_argument('--compute_p_ik_answerable', default=False, 123 | action=argparse.BooleanOptionalAction) 124 | parser.add_argument('--compute_context_entails_response', default=False, 125 | action=argparse.BooleanOptionalAction) 126 | parser.add_argument('--analyze_run', default=True, 127 | action=argparse.BooleanOptionalAction) 128 | parser.add_argument('--assign_new_wandb_id', default=True, 129 | action=argparse.BooleanOptionalAction) 130 | parser.add_argument('--restore_entity_eval', type=str, default=entity) 131 | parser.add_argument('--restore_entity_train', type=str, default=entity) 132 | parser.add_argument('--condition_on_question', 133 | default=True, action=argparse.BooleanOptionalAction) 134 | parser.add_argument('--strict_entailment', 135 | default=True, action=argparse.BooleanOptionalAction) 136 | parser.add_argument('--use_all_generations', default=True, action=argparse.BooleanOptionalAction) 137 | parser.add_argument('--use_num_generations', type=int, default=-1) 138 | parser.add_argument("--entailment_model", default='deberta', type=str) 139 | parser.add_argument( 140 | "--entailment_cache_id", default=None, type=str, 141 | help='Restore entailment predictions from previous run for GPT-4/LLaMa-Entailment.') 142 | parser.add_argument('--entailment_cache_only', default=False, action=argparse.BooleanOptionalAction) 143 | parser.add_argument('--compute_p_true_in_compute_stage', 144 | default=False, action=argparse.BooleanOptionalAction) 145 | parser.add_argument('--reuse_entailment_model', 146 | default=False, action=argparse.BooleanOptionalAction, 147 | help='Use entailment model as p_true model.') 148 | return parser 149 | 150 | 151 | def setup_logger(): 152 | """Setup logger to always print time and level.""" 153 | logging.basicConfig( 154 | format='%(asctime)s %(levelname)-8s %(message)s', 155 | level=logging.INFO, 156 | datefmt='%Y-%m-%d %H:%M:%S') 157 | logging.getLogger().setLevel(logging.INFO) # logging.DEBUG 158 | 159 | 160 | def construct_fewshot_prompt_from_indices(dataset, example_indices, brief, brief_always, make_prompt): 161 | """Given a dataset and indices, construct a fewshot prompt.""" 162 | if not brief_always: 163 | prompt = brief 164 | else: 165 | prompt = '' 166 | 167 | for example_index in example_indices: 168 | 169 | example = dataset[example_index] 170 | context = example["context"] 171 | question = example["question"] 172 | answer = example["answers"]["text"][0] 173 | 174 | prompt = prompt + make_prompt(context, question, answer, brief, brief_always) 175 | 176 | return prompt 177 | 178 | 179 | def split_dataset(dataset): 180 | """Get indices of answerable and unanswerable questions.""" 181 | 182 | def clen(ex): 183 | return len(ex["answers"]["text"]) 184 | 185 | answerable_indices = [i for i, ex in enumerate(dataset) if clen(ex) > 0] 186 | unanswerable_indices = [i for i, ex in enumerate(dataset) if clen(ex) == 0] 187 | 188 | # union == full dataset 189 | assert set(answerable_indices) | set( 190 | unanswerable_indices) == set(range(len(dataset))) 191 | # no overlap 192 | assert set(answerable_indices) - \ 193 | set(unanswerable_indices) == set(answerable_indices) 194 | 195 | return answerable_indices, unanswerable_indices 196 | 197 | 198 | def model_based_metric(predicted_answer, example, model): 199 | if 'answers' in example: 200 | correct_answers = example['answers']['text'] 201 | elif 'reference' in example: 202 | correct_answers = example['reference']['answers']['text'] 203 | else: 204 | raise ValueError 205 | 206 | prompt = f'We are assessing the quality of answers to the following question: {example["question"]}\n' 207 | if len(correct_answers) == 1: 208 | prompt += f"The expected answer is: {correct_answers[0]}.\n" 209 | else: 210 | prompt += f"The following are expected answers to this question: {correct_answers}.\n" 211 | 212 | prompt += f"The proposed answer is: {predicted_answer}\n" 213 | 214 | if len(correct_answers) == 1: 215 | prompt += "Within the context of the question, does the proposed answer mean the same as the expected answer?" 216 | else: 217 | prompt += "Within the context of the question, does the proposed answer mean the same as any of the expected answers?" 218 | 219 | prompt += " Respond only with yes or no.\nResponse:" 220 | 221 | if 'gpt' in model.model_name.lower(): 222 | predicted_answer = model.predict(prompt, 0.01) 223 | else: 224 | predicted_answer, _, _ = model.predict(prompt, 0.01) 225 | 226 | if 'yes' in predicted_answer.lower(): 227 | return 1.0 228 | elif 'no' in predicted_answer.lower(): 229 | return 0.0 230 | else: 231 | logging.warning('Redo llm check.') 232 | predicted_answer = model.predict(prompt, 1) 233 | if 'yes' in predicted_answer.lower(): 234 | return 1.0 235 | elif 'no' in predicted_answer.lower(): 236 | return 0.0 237 | 238 | logging.warning('Answer neither no nor yes. Defaulting to no!') 239 | return 0.0 240 | 241 | 242 | def llm_metric(predicted_answer, example, model): 243 | return model_based_metric(predicted_answer, example, model) 244 | 245 | 246 | def get_gpt_metric(metric_name): 247 | 248 | model_name = '_'.join(metric_name.split('_')[1:]) 249 | 250 | class EntailmentGPT(): 251 | def __init__(self, model_name): 252 | self.model_name = model_name 253 | 254 | def predict(self, prompt, temperature): 255 | return oai.predict(prompt, temperature, model=self.model_name) 256 | 257 | gpt_model = EntailmentGPT(model_name) 258 | 259 | def gpt_metric(predicted_answer, example, model): 260 | del model 261 | return model_based_metric(predicted_answer, example, gpt_model) 262 | 263 | return gpt_metric 264 | 265 | 266 | def get_reference(example): 267 | if 'answers' not in example: 268 | example = example['reference'] 269 | answers = example['answers'] 270 | answer_starts = answers.get('answer_start', []) 271 | reference = {'answers': {'answer_start': answer_starts, 'text': answers['text']}, 'id': example['id']} 272 | return reference 273 | 274 | 275 | def init_model(args): 276 | mn = args.model_name 277 | if 'llama' in mn.lower() or 'falcon' in mn.lower() or 'mistral' in mn.lower() or 'phi' in mn.lower(): 278 | model = HuggingfaceModel( 279 | mn, stop_sequences='default', 280 | max_new_tokens=args.model_max_new_tokens) 281 | else: 282 | raise ValueError(f'Unknown model_name `{mn}`.') 283 | return model 284 | 285 | 286 | def get_make_prompt(args): 287 | if args.prompt_type == 'default': 288 | def make_prompt(context, question, answer, brief, brief_always): 289 | prompt = '' 290 | if brief_always: 291 | prompt += brief 292 | if args.use_context and (context is not None): 293 | prompt += f"Context: {context}\n" 294 | prompt += f"Question: {question}\n" 295 | if answer: 296 | prompt += f"Answer: {answer}\n\n" 297 | else: 298 | prompt += 'Answer:' 299 | return prompt 300 | else: 301 | raise ValueError 302 | 303 | return make_prompt 304 | 305 | 306 | def get_metric(metric): 307 | if metric == 'squad': 308 | 309 | squad_metric = load("squad_v2") 310 | 311 | def metric(response, example, *args, **kwargs): 312 | # Compatibility with recomputation. 313 | if 'id' in example: 314 | exid = example['id'] 315 | elif 'id' in example['reference']: 316 | exid = example['reference']['id'] 317 | else: 318 | raise ValueError 319 | 320 | prediction = {'prediction_text': response, 'no_answer_probability': 0.0, 'id': exid} 321 | results = squad_metric.compute( 322 | predictions=[prediction], 323 | references=[get_reference(example)]) 324 | return 1.0 if (results['f1'] >= 50.0) else 0.0 325 | 326 | # Reuses the globally active model for these. 327 | elif metric == 'llm': 328 | metric = llm_metric 329 | elif metric == 'llm_gpt-3.5': 330 | metric = get_gpt_metric(metric) 331 | elif metric == 'llm_gpt-4': 332 | metric = get_gpt_metric(metric) 333 | else: 334 | raise ValueError 335 | 336 | return metric 337 | 338 | 339 | def save(object, file): 340 | with open(f'{wandb.run.dir}/{file}', 'wb') as f: 341 | pickle.dump(object, f) 342 | wandb.save(f'{wandb.run.dir}/{file}') -------------------------------------------------------------------------------- /semantic_uncertainty/uncertainty/models/huggingface_models.py: -------------------------------------------------------------------------------- 1 | """Implement HuggingfaceModel models.""" 2 | import copy 3 | import logging 4 | import os 5 | from collections import Counter 6 | 7 | import accelerate 8 | import torch 9 | from accelerate import Accelerator 10 | 11 | from transformers import AutoTokenizer 12 | from transformers import AutoConfig 13 | from transformers import AutoModelForCausalLM 14 | from transformers import BitsAndBytesConfig 15 | from transformers import StoppingCriteria 16 | from transformers import StoppingCriteriaList 17 | from huggingface_hub import snapshot_download 18 | 19 | 20 | from uncertainty.models.base_model import BaseModel 21 | from uncertainty.models.base_model import STOP_SEQUENCES 22 | 23 | 24 | class StoppingCriteriaSub(StoppingCriteria): 25 | """Stop generations when they match a particular text or token.""" 26 | def __init__(self, stops, tokenizer, match_on='text', initial_length=None): 27 | super().__init__() 28 | self.stops = stops 29 | self.initial_length = initial_length 30 | self.tokenizer = tokenizer 31 | self.match_on = match_on 32 | if self.match_on == 'tokens': 33 | self.stops = [torch.tensor(self.tokenizer.encode(i)).to('cuda') for i in self.stops] 34 | print(self.stops) 35 | 36 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 37 | del scores 38 | for stop in self.stops: 39 | if self.match_on == 'text': 40 | generation = self.tokenizer.decode(input_ids[0][self.initial_length:], skip_special_tokens=False) 41 | match = stop in generation 42 | elif self.match_on == 'tokens': 43 | # Can be dangerous due to tokenizer ambiguities. 44 | match = stop in input_ids[0][-len(stop):] 45 | else: 46 | raise 47 | if match: 48 | return True 49 | return False 50 | 51 | 52 | def remove_split_layer(device_map_in): 53 | """Modify device maps s.t. individual layers are not spread across devices.""" 54 | 55 | device_map = copy.deepcopy(device_map_in) 56 | destinations = list(device_map.keys()) 57 | 58 | counts = Counter(['.'.join(i.split('.')[:2]) for i in destinations]) 59 | 60 | found_split = False 61 | for layer, count in counts.items(): 62 | if count == 1: 63 | continue 64 | 65 | if found_split: 66 | # Only triggers if we find more than one split layer! 67 | raise ValueError( 68 | 'More than one split layer.\n' 69 | f'Currently at layer {layer}.\n' 70 | f'In map: {device_map_in}\n' 71 | f'Out map: {device_map}\n') 72 | 73 | logging.info(f'Split layer is {layer}.') 74 | 75 | # remove split for that layer 76 | for name in list(device_map.keys()): 77 | if name.startswith(layer): 78 | print(f'pop {name}') 79 | device = device_map.pop(name) 80 | 81 | device_map[layer] = device 82 | found_split = True 83 | 84 | return device_map 85 | 86 | 87 | class HuggingfaceModel(BaseModel): 88 | """HuggingfaceModel.""" 89 | 90 | def __init__(self, model_name, stop_sequences=None, max_new_tokens=None): 91 | if max_new_tokens is None: 92 | raise 93 | self.max_new_tokens = max_new_tokens 94 | 95 | if stop_sequences == 'default': 96 | stop_sequences = STOP_SEQUENCES 97 | print(model_name) 98 | if 'llama' in model_name.lower(): 99 | 100 | if model_name.endswith('-8bit'): 101 | kwargs = {'quantization_config': BitsAndBytesConfig( 102 | load_in_8bit=True,)} 103 | model_name = model_name[:-len('-8bit')] 104 | eightbit = True 105 | else: 106 | kwargs = {} 107 | eightbit = False 108 | 109 | if 'Llama-2' in model_name or 'Llama-3' in model_name: 110 | base = 'meta-llama' 111 | model_name = model_name + '-hf' if 'Llama-2' in model_name else model_name 112 | else: 113 | base = 'huggyllama' 114 | 115 | self.tokenizer = AutoTokenizer.from_pretrained( 116 | f"{base}/{model_name}", device_map="auto", 117 | token_type_ids=None) 118 | 119 | llama65b = '65b' in model_name.lower() and base == 'huggyllama' 120 | llama2or3_70b = '70b' in model_name.lower() and base == 'meta-llama' 121 | 122 | if ('7b' in model_name or '13b' in model_name) or eightbit: 123 | self.model = AutoModelForCausalLM.from_pretrained( 124 | f"{base}/{model_name}", device_map="auto", 125 | max_memory={0: '80GIB'}, **kwargs,) 126 | 127 | elif llama2or3_70b or llama65b: 128 | path = snapshot_download( 129 | repo_id=f'{base}/{model_name}', 130 | allow_patterns=['*.json', '*.model', '*.safetensors'], 131 | ignore_patterns=['pytorch_model.bin.index.json'] 132 | ) 133 | config = AutoConfig.from_pretrained(f"{base}/{model_name}") 134 | with accelerate.init_empty_weights(): 135 | self.model = AutoModelForCausalLM.from_config(config) 136 | self.model.tie_weights() 137 | if 'chat' in model_name: 138 | max_mem = 17.5 * 4686198491 139 | else: 140 | max_mem = 15 * 4686198491 141 | 142 | device_map = accelerate.infer_auto_device_map( 143 | self.model.model, 144 | max_memory={0: max_mem, 1: max_mem}, 145 | dtype='float16' 146 | ) 147 | device_map = remove_split_layer(device_map) 148 | full_model_device_map = {f"model.{k}": v for k, v in device_map.items()} 149 | full_model_device_map["lm_head"] = 0 150 | 151 | self.model = accelerate.load_checkpoint_and_dispatch( 152 | self.model, path, device_map=full_model_device_map, 153 | dtype='float16', skip_keys='past_key_values') 154 | 155 | else: 156 | raise ValueError 157 | 158 | elif 'mistral' in model_name.lower(): 159 | 160 | if model_name.endswith('-8bit'): 161 | kwargs = {'quantization_config': BitsAndBytesConfig( 162 | load_in_8bit=True,)} 163 | model_name = model_name[:-len('-8bit')] 164 | if model_name.endswith('-4bit'): 165 | kwargs = {'quantization_config': BitsAndBytesConfig( 166 | load_in_4bit=True,)} 167 | model_name = model_name[:-len('-8bit')] 168 | else: 169 | kwargs = {} 170 | 171 | model_id = f'mistralai/{model_name}' 172 | self.tokenizer = AutoTokenizer.from_pretrained( 173 | model_id, device_map='auto', token_type_ids=None, 174 | clean_up_tokenization_spaces=False) 175 | 176 | self.model = AutoModelForCausalLM.from_pretrained( 177 | model_id, 178 | device_map='auto', 179 | max_memory={0: '80GIB'}, 180 | **kwargs, 181 | ) 182 | 183 | elif 'falcon' in model_name: 184 | model_id = f'tiiuae/{model_name}' 185 | self.tokenizer = AutoTokenizer.from_pretrained( 186 | model_id, device_map='auto', token_type_ids=None, 187 | clean_up_tokenization_spaces=False) 188 | 189 | kwargs = {'quantization_config': BitsAndBytesConfig( 190 | load_in_8bit=True,)} 191 | 192 | self.model = AutoModelForCausalLM.from_pretrained( 193 | model_id, 194 | trust_remote_code=True, 195 | device_map='auto', 196 | **kwargs, 197 | ) 198 | elif 'phi' in model_name.lower(): 199 | model_id = f'microsoft/{model_name}' # e.g. Phi-3-mini-128k-instruct 200 | self.tokenizer = AutoTokenizer.from_pretrained( 201 | model_id, device_map='auto', token_type_ids=None, 202 | clean_up_tokenization_spaces=False) 203 | self.model = AutoModelForCausalLM.from_pretrained( 204 | model_id, 205 | trust_remote_code=True, 206 | device_map='auto', 207 | ) 208 | elif 'gemma' in model_name: 209 | model_id = f'google/{model_name}' # e.g. gemma-7b-it 210 | self.tokenizer = AutoTokenizer.from_pretrained( 211 | model_id, device_map='auto', token_type_ids=None, 212 | clean_up_tokenization_spaces=False) 213 | self.model = AutoModelForCausalLM.from_pretrained( 214 | model_id, 215 | trust_remote_code=True, 216 | device_map='auto', 217 | torch_dtype=torch.bfloat16 218 | ) 219 | else: 220 | raise ValueError 221 | 222 | self.model_name = model_name 223 | self.stop_sequences = stop_sequences + [self.tokenizer.eos_token] 224 | self.token_limit = 4096 if 'Llama-2' in model_name else 2048 225 | 226 | 227 | def predict(self, input_data, temperature, return_full=False, return_latent=False): 228 | 229 | if isinstance(input_data, tuple): 230 | logging.WARNING("INPUT IS A TUPLE.") 231 | input_data = input_data[0] 232 | 233 | inputs = self.tokenizer(input_data, return_tensors="pt").to("cuda") 234 | 235 | if 'llama' in self.model_name.lower() or 'falcon' in self.model_name or 'mistral' in self.model_name.lower(): 236 | if 'token_type_ids' in inputs: # HF models seems has changed. 237 | del inputs['token_type_ids'] 238 | pad_token_id = self.tokenizer.eos_token_id 239 | else: 240 | pad_token_id = None 241 | 242 | if self.stop_sequences is not None: 243 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( 244 | stops=self.stop_sequences, 245 | initial_length=len(inputs['input_ids'][0]), 246 | tokenizer=self.tokenizer)]) 247 | else: 248 | stopping_criteria = None 249 | 250 | logging.debug('temperature: %f', temperature) 251 | with torch.no_grad(): 252 | outputs = self.model.generate( 253 | **inputs, 254 | max_new_tokens=self.max_new_tokens, 255 | return_dict_in_generate=True, 256 | output_scores=True, 257 | output_hidden_states=True, 258 | temperature=temperature, 259 | do_sample=True, 260 | stopping_criteria=stopping_criteria, 261 | pad_token_id=pad_token_id, 262 | ) 263 | 264 | if len(outputs.sequences[0]) > self.token_limit: 265 | raise ValueError( 266 | 'Generation exceeding token limit %d > %d', 267 | len(outputs.sequences[0]), self.token_limit) 268 | 269 | full_answer = self.tokenizer.decode( 270 | outputs.sequences[0], skip_special_tokens=True) 271 | 272 | if return_full: 273 | return full_answer 274 | 275 | # For some models, we need to remove the input_data from the answer. 276 | if full_answer.startswith(input_data): 277 | input_data_offset = len(input_data) 278 | else: 279 | raise ValueError('Have not tested this in a while.') 280 | 281 | # Remove input from answer. 282 | answer = full_answer[input_data_offset:] 283 | 284 | # Remove stop_words from answer. 285 | stop_at = len(answer) 286 | sliced_answer = answer 287 | if self.stop_sequences is not None: 288 | for stop in self.stop_sequences: 289 | if answer.endswith(stop): 290 | stop_at = len(answer) - len(stop) 291 | sliced_answer = answer[:stop_at] 292 | break 293 | if not all([stop not in sliced_answer for stop in self.stop_sequences]): 294 | error_msg = 'Error: Stop words not removed successfully!' 295 | error_msg += f'Answer: >{answer}< ' 296 | error_msg += f'Sliced Answer: >{sliced_answer}<' 297 | logging.error(error_msg) 298 | 299 | # Remove whitespaces from answer (in particular from beginning.) 300 | sliced_answer = sliced_answer.strip() 301 | token_stop_index = self.tokenizer(full_answer[:input_data_offset + stop_at], return_tensors="pt")['input_ids'].shape[1] 302 | n_input_token = len(inputs['input_ids'][0]) 303 | n_generated = token_stop_index - n_input_token 304 | 305 | if n_generated == 0: 306 | logging.warning('Only stop_words were generated. For likelihoods and embeddings, taking stop word instead.') 307 | n_generated = 1 308 | 309 | if 'decoder_hidden_states' in outputs.keys(): 310 | hidden = outputs.decoder_hidden_states 311 | else: 312 | hidden = outputs.hidden_states 313 | 314 | if len(hidden) == 1: 315 | logging.warning( 316 | 'Taking first and only generation for hidden! ' 317 | 'n_generated: %d, n_input_token: %d, token_stop_index %d, ' 318 | 'last_token: %s, generation was: %s', 319 | n_generated, n_input_token, token_stop_index, 320 | self.tokenizer.decode(outputs['sequences'][0][-1]), 321 | full_answer, 322 | ) 323 | last_input = hidden[0] 324 | elif ((n_generated - 1) >= len(hidden)): 325 | # if access idx is larger/equal 326 | logging.error( 327 | 'Taking last state because n_generated is too large' 328 | 'n_generated: %d, n_input_token: %d, token_stop_index %d, ' 329 | 'last_token: %s, generation was: %s, slice_answer: %s', 330 | n_generated, n_input_token, token_stop_index, 331 | self.tokenizer.decode(outputs['sequences'][0][-1]), 332 | full_answer, sliced_answer 333 | ) 334 | last_input = hidden[-1] 335 | else: 336 | last_input = hidden[n_generated - 1] 337 | 338 | # Then access last layer for input 339 | last_layer = last_input[-1] 340 | # Then access last token in input. 341 | last_token_embedding = last_layer[:, -1, :].cpu() 342 | 343 | if return_latent: 344 | # Stack second last token embeddings from all layers 345 | if len(hidden) == 1: # FIX: runtime error for mistral-7b on bioasq 346 | sec_last_input = hidden[0] 347 | elif ((n_generated - 2) >= len(hidden)): 348 | sec_last_input = hidden[-2] 349 | else: 350 | sec_last_input = hidden[n_generated - 2] 351 | sec_last_token_embedding = torch.stack([layer[:, -1, :] for layer in sec_last_input]).cpu() 352 | 353 | # Get the last input token embeddings (before generated tokens) 354 | last_tok_bef_gen_input = hidden[0] 355 | last_tok_bef_gen_embedding = torch.stack([layer[:, -1, :] for layer in last_tok_bef_gen_input]).cpu() 356 | 357 | # Get log_likelihoods. 358 | transition_scores = self.model.compute_transition_scores( 359 | outputs.sequences, outputs.scores, normalize_logits=True) 360 | log_likelihoods = [score.item() for score in transition_scores[0]] 361 | if len(log_likelihoods) == 1: 362 | logging.warning('Taking first and only generation for log likelihood!') 363 | log_likelihoods = log_likelihoods 364 | else: 365 | log_likelihoods = log_likelihoods[:n_generated] 366 | 367 | if len(log_likelihoods) == self.max_new_tokens: 368 | logging.warning('Generation interrupted by max_token limit.') 369 | 370 | if len(log_likelihoods) == 0: 371 | raise ValueError 372 | 373 | hidden_states = (last_token_embedding,) 374 | 375 | if return_latent: 376 | hidden_states += (sec_last_token_embedding, last_tok_bef_gen_embedding) 377 | else: 378 | hidden_states += (None, None) 379 | 380 | return_values = (sliced_answer, log_likelihoods, hidden_states) 381 | 382 | return return_values 383 | 384 | def get_p_true(self, input_data): 385 | """Get the probability of the model anwering A (True) for the given input""" 386 | 387 | input_data += ' A' 388 | tokenized_prompt_true = self.tokenizer(input_data, return_tensors='pt').to('cuda')['input_ids'] 389 | 390 | target_ids_true = tokenized_prompt_true.clone() 391 | # Set all target_ids except the last one to -100. 392 | target_ids_true[0, :-1] = -100 393 | 394 | with torch.no_grad(): 395 | model_output_true = self.model(tokenized_prompt_true, labels=target_ids_true) 396 | 397 | loss_true = model_output_true.loss 398 | 399 | return -loss_true.item() 400 | 401 | def get_perplexity(self, input_data): 402 | """Get the probability of the model anwering A (True) for the given input""" 403 | 404 | tokenized_data = self.tokenizer(input_data, return_tensors='pt').to('cuda')['input_ids'] 405 | 406 | with torch.no_grad(): 407 | model_output_true = self.model(tokenized_data, labels=tokenized_data) 408 | 409 | perplexity = - model_output_true.loss.item() 410 | 411 | 412 | return perplexity 413 | -------------------------------------------------------------------------------- /semantic_uncertainty/compute_uncertainty_measures.py: -------------------------------------------------------------------------------- 1 | """Compute uncertainty measures after generating answers.""" 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | import logging 5 | import os 6 | import pickle 7 | import random 8 | import numpy as np 9 | import wandb 10 | 11 | from analyze_results import analyze_run 12 | from uncertainty.data.data_utils import load_ds 13 | from uncertainty.uncertainty_measures.p_ik import get_p_ik 14 | from uncertainty.uncertainty_measures.semantic_entropy import get_semantic_ids 15 | from uncertainty.uncertainty_measures.semantic_entropy import logsumexp_by_id 16 | from uncertainty.uncertainty_measures.semantic_entropy import predictive_entropy 17 | from uncertainty.uncertainty_measures.semantic_entropy import predictive_entropy_rao 18 | from uncertainty.uncertainty_measures.semantic_entropy import cluster_assignment_entropy 19 | from uncertainty.uncertainty_measures.semantic_entropy import context_entails_response 20 | from uncertainty.uncertainty_measures.semantic_entropy import EntailmentDeberta 21 | from uncertainty.uncertainty_measures.semantic_entropy import EntailmentGPT4 22 | from uncertainty.uncertainty_measures.semantic_entropy import EntailmentGPT35 23 | from uncertainty.uncertainty_measures.semantic_entropy import EntailmentLlama 24 | from uncertainty.uncertainty_measures import p_true as p_true_utils 25 | from uncertainty.utils import utils 26 | 27 | 28 | utils.setup_logger() 29 | 30 | EXP_DETAILS = 'experiment_details.pkl' 31 | 32 | def main(args): 33 | 34 | if args.train_wandb_runid is None: 35 | args.train_wandb_runid = args.eval_wandb_runid 36 | 37 | user = os.environ['USER'] 38 | scratch_dir = os.getenv('SCRATCH_DIR', '.') 39 | wandb_dir = f'{scratch_dir}/{user}/uncertainty' 40 | slurm_jobid = os.getenv('SLURM_JOB_ID', None) 41 | project = "semantic_uncertainty" if not args.debug else "semantic_uncertainty_debug" 42 | if args.assign_new_wandb_id: 43 | logging.info('Assign new wandb_id.') 44 | api = wandb.Api() 45 | old_run = api.run(f'{args.restore_entity_eval}/{project}/{args.eval_wandb_runid}') 46 | wandb.init( 47 | entity=args.entity, 48 | project=project, 49 | dir=wandb_dir, 50 | notes=f'slurm_id: {slurm_jobid}, experiment_lot: {args.experiment_lot}', 51 | config={**old_run.config, **args.__dict__}, 52 | ) 53 | 54 | def restore(filename): 55 | old_run.file(filename).download( 56 | replace=False, exist_ok=True, root=wandb.run.dir) 57 | 58 | class Restored: 59 | name = f'{wandb.run.dir}/{filename}' 60 | 61 | return Restored 62 | else: 63 | logging.info('Reuse active wandb id.') 64 | 65 | def restore(filename): 66 | class Restored: 67 | name = f'{wandb.run.dir}/{filename}' 68 | return Restored 69 | 70 | if args.train_wandb_runid != args.eval_wandb_runid: 71 | logging.info( 72 | "Distribution shift for p_ik. Training on embeddings from run %s but evaluating on run %s", 73 | args.train_wandb_runid, args.eval_wandb_runid) 74 | 75 | is_ood_eval = True # pylint: disable=invalid-name 76 | api = wandb.Api() 77 | old_run_train = api.run(f'{args.restore_entity_train}/semantic_uncertainty/{args.train_wandb_runid}') 78 | filename = 'train_generations.pkl' 79 | old_run_train.file(filename).download( 80 | replace=True, exist_ok=False, root=wandb.run.dir) 81 | with open(f'{wandb.run.dir}/{filename}', "rb") as infile: 82 | train_generations = pickle.load(infile) 83 | wandb.config.update( 84 | {"ood_training_set": old_run_train.config['dataset']}, allow_val_change=True) 85 | else: 86 | is_ood_eval = False # pylint: disable=invalid-name 87 | if args.compute_p_ik or args.compute_p_ik_answerable: 88 | train_generations_pickle = restore('train_generations.pkl') 89 | with open(train_generations_pickle.name, 'rb') as infile: 90 | train_generations = pickle.load(infile) 91 | 92 | wandb.config.update({"is_ood_eval": is_ood_eval}, allow_val_change=True) 93 | 94 | if args.compute_predictive_entropy: 95 | logging.info('Beginning loading for entailment model.') 96 | if args.entailment_model == 'deberta': 97 | entailment_model = EntailmentDeberta() 98 | elif args.entailment_model == 'gpt-4': 99 | entailment_model = EntailmentGPT4(args.entailment_cache_id, args.entailment_cache_only) 100 | elif args.entailment_model == 'gpt-3.5': 101 | entailment_model = EntailmentGPT35(args.entailment_cache_id, args.entailment_cache_only) 102 | elif 'llama' in args.entailment_model.lower(): 103 | entailment_model = EntailmentLlama(args.entailment_cache_id, args.entailment_cache_only, args.entailment_model) 104 | else: 105 | raise ValueError 106 | logging.info('Entailment model loading complete.') 107 | 108 | if args.compute_p_true_in_compute_stage: 109 | old_exp_file = restore(EXP_DETAILS) 110 | with open(old_exp_file.name, "rb") as infile: 111 | old_exp = pickle.load(infile) 112 | 113 | if args.reuse_entailment_model: 114 | pt_model = entailment_model.model 115 | else: 116 | pt_model = utils.init_model(old_exp['args']) 117 | 118 | pt_train_dataset, pt_validation_dataset = load_ds( 119 | old_exp['args'].dataset, add_options=old_exp['args'].use_mc_options, 120 | seed=args.random_seed) 121 | 122 | # Reduce num generations used in p_true if needed! 123 | if not args.use_all_generations: 124 | if args.use_num_generations == -1: 125 | raise ValueError 126 | num_gen = args.use_num_generations 127 | else: 128 | num_gen = 10 129 | 130 | answerable_indices, unanswerable_indices = utils.split_dataset(pt_train_dataset) 131 | p_true_indices = random.sample(answerable_indices, 20) # args.p_true_num_fewshot = 20 132 | 133 | p_true_few_shot_prompt, p_true_responses, len_p_true = p_true_utils.construct_few_shot_prompt( 134 | model=pt_model, 135 | dataset=pt_train_dataset, 136 | indices=p_true_indices, 137 | prompt=old_exp['prompt'], 138 | brief=old_exp['BRIEF'], 139 | brief_always=old_exp['args'].brief_always and old_exp['args'].enable_brief, 140 | make_prompt=utils.get_make_prompt(old_exp['args']), 141 | num_generations=num_gen, 142 | metric=utils.get_metric(old_exp['args'].metric)) 143 | wandb.config.update( 144 | {'p_true_num_fewshot': len_p_true}, allow_val_change=True) 145 | wandb.log(dict(len_p_true=len_p_true)) 146 | del p_true_responses, pt_train_dataset 147 | 148 | logging.info('Generated few-shot prompt for p_true.') 149 | logging.info(80*'#') 150 | logging.info('p_true_few_shot_prompt: %s', p_true_few_shot_prompt) 151 | logging.info(80*'#') 152 | 153 | if args.recompute_accuracy: 154 | logging.warning('Recompute accuracy enabled. This does not apply to precomputed p_true!') 155 | metric = utils.get_metric(args.metric) 156 | 157 | result_dict_pickle = restore('uncertainty_measures.pkl') 158 | with open(result_dict_pickle.name, "rb") as infile: 159 | result_dict = pickle.load(infile) 160 | 161 | if 'semantic_ids' not in result_dict: 162 | result_dict['semantic_ids'] = [] 163 | 164 | validation_generations_pickle = restore('validation_generations.pkl') 165 | with open(validation_generations_pickle.name, 'rb') as infile: 166 | validation_generations = pickle.load(infile) 167 | 168 | entropies, accuracies = defaultdict(list), defaultdict(list) 169 | validation_embeddings, validation_is_true, validation_answerable = [], [], [] 170 | p_trues = [] 171 | count = 0 # pylint: disable=invalid-name 172 | 173 | def is_answerable(generation): 174 | return len(generation['reference']['answers']['text']) > 0 175 | 176 | # Loop over datapoints and compute validation embeddings, accuracies and entropies. 177 | for idx, tid in enumerate(validation_generations): 178 | example = validation_generations[tid] 179 | question = example['question'] 180 | context = example['context'] 181 | full_responses = example["responses"] 182 | most_likely_answer = example['most_likely_answer'] 183 | 184 | if not args.use_all_generations: 185 | if args.use_num_generations == -1: 186 | raise ValueError 187 | responses = [fr[0] for fr in full_responses[:args.use_num_generations]] 188 | else: 189 | responses = [fr[0] for fr in full_responses] 190 | 191 | if args.recompute_accuracy: 192 | logging.info('Recomputing accuracy!') 193 | if is_answerable(example): 194 | acc = metric(most_likely_answer['response'], example, None) 195 | else: 196 | acc = 0.0 # pylint: disable=invalid-name 197 | validation_is_true.append(acc) 198 | logging.info('Recomputed accuracy!') 199 | 200 | else: 201 | validation_is_true.append(most_likely_answer['accuracy']) 202 | 203 | validation_answerable.append(is_answerable(example)) 204 | validation_embeddings.append(most_likely_answer['embedding']) 205 | logging.info('validation_is_true: %f', validation_is_true[-1]) 206 | 207 | if args.compute_predictive_entropy: 208 | # Token log likelihoods. Shape = (n_sample, n_tokens) 209 | if not args.use_all_generations: 210 | log_liks = [r[1] for r in full_responses[:args.use_num_generations]] 211 | else: 212 | log_liks = [r[1] for r in full_responses] 213 | 214 | for i in log_liks: 215 | assert i 216 | 217 | if args.compute_context_entails_response: 218 | # Compute context entails answer baseline. 219 | entropies['context_entails_response'].append(context_entails_response( 220 | context, responses, entailment_model)) 221 | 222 | if args.condition_on_question and args.entailment_model == 'deberta': 223 | responses = [f'{question} {r}' for r in responses] 224 | 225 | # Compute semantic ids. 226 | semantic_ids = get_semantic_ids( 227 | responses, model=entailment_model, 228 | strict_entailment=args.strict_entailment, example=example) 229 | 230 | result_dict['semantic_ids'].append(semantic_ids) 231 | 232 | # Compute entropy from frequencies of cluster assignments. 233 | entropies['cluster_assignment_entropy'].append(cluster_assignment_entropy(semantic_ids)) 234 | 235 | # Compute entropies with and without length normalized token probabilities. 236 | for agg_name, agg_func in zip(['', '_sum'], [np.mean, np.sum]): 237 | log_liks_agg = [agg_func(log_lik) for log_lik in log_liks] 238 | 239 | # Compute standard entropy. 240 | entropies['regular_entropy' + agg_name].append(predictive_entropy(log_liks_agg)) 241 | 242 | # Compute semantic entropies with summing and with averaging probabilities within the cluster. 243 | cluster_agg_names = ['', '_sum-normalized', '_sum-normalized-rao', '_cmean'] 244 | cluster_aggs = ['sum', 'sum_normalized', 'sum_normalized', 'mean'] 245 | for cluster_agg_name, cluster_agg in zip(cluster_agg_names, cluster_aggs): 246 | log_likelihood_per_semantic_id = logsumexp_by_id(semantic_ids, log_liks_agg, agg=cluster_agg) 247 | name = 'semantic_entropy' + agg_name + cluster_agg_name 248 | 249 | if cluster_agg_name != '_sum-normalized-rao': 250 | pe = predictive_entropy(log_likelihood_per_semantic_id) 251 | else: 252 | pe = predictive_entropy_rao(log_likelihood_per_semantic_id) 253 | 254 | entropies[name].append(pe) 255 | 256 | # For the semantic uncertainties, we can also change the prediction, by first selecting the semantic 257 | # cluster with the highest probability, and then selecting the generation with the highest probability 258 | # within that cluster. 259 | # NOTE: nanargmax because we currently have some clusters with empty generations. 260 | max_cluster_id = np.nanargmax(log_likelihood_per_semantic_id) 261 | # Filter log_liks to max cluster. 262 | generations_in_cluster = np.array(log_liks_agg) 263 | generations_in_cluster[np.array(semantic_ids) != max_cluster_id] = -np.inf 264 | # Select generation with new max probability. 265 | max_idx_in_cluster = np.argmax(generations_in_cluster) 266 | # Accuracies for alternative generations saved at last index. 267 | accuracies[name].append(full_responses[max_idx_in_cluster][-1]) 268 | 269 | # pylint: disable=invalid-name 270 | log_str = 'semantic_ids: %s, avg_token_log_likelihoods: %s, entropies: %s' 271 | entropies_fmt = ', '.join([f'{i}:{j[-1]:.2f}' for i, j in entropies.items()]) 272 | # pylint: enable=invalid-name 273 | logging.info(80*'#') 274 | logging.info('NEW ITEM %d at id=`%s`.', idx, tid) 275 | logging.info('Context:') 276 | logging.info(example['context']) 277 | logging.info('Question:') 278 | logging.info(question) 279 | logging.info('True Answers:') 280 | logging.info(example['reference']) 281 | logging.info('Low Temperature Generation:') 282 | logging.info(most_likely_answer['response']) 283 | logging.info('Low Temperature Generation Accuracy:') 284 | logging.info(most_likely_answer['accuracy']) 285 | logging.info('High Temp Generation:') 286 | logging.info([r[0] for r in full_responses]) 287 | logging.info('High Temp Generation:') 288 | logging.info(log_str, semantic_ids, log_liks_agg, entropies_fmt) 289 | 290 | if args.compute_p_true_in_compute_stage: 291 | p_true = p_true_utils.calculate_p_true( 292 | pt_model, question, most_likely_answer['response'], 293 | responses, p_true_few_shot_prompt, 294 | hint=old_exp['args'].p_true_hint) 295 | p_trues.append(p_true) 296 | logging.info('p_true: %s', np.exp(p_true)) 297 | 298 | count += 1 299 | if count >= args.num_eval_samples: 300 | logging.info('Breaking out of main loop.') 301 | break 302 | 303 | logging.info('Accuracy on original task: %f', np.mean(validation_is_true)) 304 | validation_is_false = [1.0 - is_t for is_t in validation_is_true] 305 | result_dict['validation_is_false'] = validation_is_false 306 | 307 | validation_unanswerable = [1.0 - is_a for is_a in validation_answerable] 308 | result_dict['validation_unanswerable'] = validation_unanswerable 309 | logging.info('Unanswerable prop on validation: %f', np.mean(validation_unanswerable)) 310 | 311 | if 'uncertainty_measures' not in result_dict: 312 | result_dict['uncertainty_measures'] = dict() 313 | 314 | if args.compute_predictive_entropy: 315 | result_dict['uncertainty_measures'].update(entropies) 316 | accuracies_mean = {k: np.mean(v) for k, v in accuracies.items()} 317 | logging.info('Accuracy on original task from cluster-based generations: %s', accuracies_mean) 318 | 319 | result_dict['alt_validation_accuracies_mean'] = accuracies_mean 320 | result_dict['alt_validation_is_false'] = {k: [1 - vi for vi in v] for k, v in accuracies.items()} 321 | 322 | if args.compute_p_ik or args.compute_p_ik_answerable: 323 | # Assemble training data for embedding classification. 324 | train_is_true, train_embeddings, train_answerable = [], [], [] 325 | for tid in train_generations: 326 | most_likely_answer = train_generations[tid]['most_likely_answer'] 327 | train_embeddings.append(most_likely_answer['embedding']) 328 | train_is_true.append(most_likely_answer['accuracy']) 329 | train_answerable.append(is_answerable(train_generations[tid])) 330 | train_is_false = [0.0 if is_t else 1.0 for is_t in train_is_true] 331 | train_unanswerable = [0.0 if is_t else 1.0 for is_t in train_answerable] 332 | logging.info('Unanswerable prop on p_ik training: %f', np.mean(train_unanswerable)) 333 | 334 | if args.compute_p_ik: 335 | logging.info('Starting training p_ik on train embeddings.') 336 | # Train classifier of correct/incorrect. 337 | p_ik_predictions = get_p_ik( 338 | train_embeddings=train_embeddings, is_false=train_is_false, 339 | eval_embeddings=validation_embeddings, eval_is_false=validation_is_false) 340 | result_dict['uncertainty_measures']['p_ik'] = p_ik_predictions 341 | logging.info('Finished training p_ik on train embeddings.') 342 | 343 | if args.compute_p_ik_answerable: 344 | # Train classifier of answerable/unanswerable: 345 | p_ik_predictions = get_p_ik( 346 | train_embeddings=train_embeddings, is_false=train_unanswerable, 347 | eval_embeddings=validation_embeddings, eval_is_false=validation_unanswerable) 348 | result_dict['uncertainty_measures']['p_ik_unanswerable'] = p_ik_predictions 349 | 350 | if args.compute_p_true_in_compute_stage: 351 | result_dict['uncertainty_measures']['p_false'] = [1 - p for p in p_trues] 352 | result_dict['uncertainty_measures']['p_false_fixed'] = [1 - np.exp(p) for p in p_trues] 353 | 354 | utils.save(result_dict, 'uncertainty_measures.pkl') 355 | 356 | if args.compute_predictive_entropy: 357 | entailment_model.save_prediction_cache() 358 | 359 | if args.analyze_run: 360 | logging.info(50 * '#X') 361 | logging.info('STARTING `analyze_run`!') 362 | analyze_run(wandb.run.id) 363 | logging.info(50 * '#X') 364 | logging.info('FINISHED `analyze_run`!') 365 | 366 | 367 | if __name__ == '__main__': 368 | parser = utils.get_parser(stages=['compute']) 369 | args, unknown = parser.parse_known_args() # pylint: disable=invalid-name 370 | if unknown: 371 | raise ValueError(f'Unkown args: {unknown}') 372 | 373 | logging.info("Args: %s", args) 374 | 375 | main(args) 376 | --------------------------------------------------------------------------------