├── requirements.txt ├── figures └── teaser_llm.png ├── LICENSE ├── README.md ├── util_clm.py ├── adaptive-retrieval.ipynb └── run_model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | transformers 3 | torch 4 | -------------------------------------------------------------------------------- /figures/teaser_llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexTMallen/adaptive-retrieval/HEAD/figures/teaser_llm.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alex Mallen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Retrieval: Popularity-based LM Augmentation 2 | 3 | This is the official repository for our preprint: Alex Mallen, Akari Asai, Victor Zhong, Rajarshi Das, Hannaneh Hajishirzi, and Daniel Khashabi. [When Not to Trust Language Models: Investigating Effectiveness and Limitations of Parametric and Non-Parametric Memories](https://arxiv.org/abs/2212.10511). 2022. 4 | 5 | In this work, we conduct a large-scale knowledge probing of 10 language models (GPT-Neo series, OPT series and GPT-3 series) and 4 retrieval-augmentation approaches (BM25, Contriever, GenRead and vanilla), using our new open-domain QA dataset, **PopQA**. 6 | 7 |

8 | 9 |

10 | 11 | We further introduce a simple-yet-effective method, **Adaptive Retrieval**, which adaptively retrieves and incorporates non-parametric memories when necessary. Our experimental results show that Adaptive Retrieval is not only more competitive but also more efficient in terms of inference-time latency as well as the GPT-3 API cost. 12 | 13 | ### Contact and Citations 14 | For any questions about the paper or the code, please contact the first authors or leave issues. 15 | If you find our code or paper useful, please cite the paper: 16 | ``` 17 | @article{ mallen2023llm_memorization , 18 | title={When Not to Trust Language Models: Investigating Effectiveness and Limitations of Parametric and Non-Parametric Memories }, 19 | author={ Mallen, Alex and Asai,Akari and Zhong, Victor and Das, Rajarshi and Hajishirzi, Hannaneh and Khashabi, Daniel}, 20 | journal={ arXiv preprint }, 21 | year={ 2022 } 22 | } 23 | ``` 24 | ## Content 25 | 26 | 1. [Installation](#installation) 27 | 2. [PopQA](#popqa) 28 | 3. [Baselines](#baselines) 29 | - [LMs](#lms) 30 | - [Retrieval-augmented LMs ](#retrieval-augmented-lms) 31 | - [Adaptive Retrieval](#adaptive-retrieval) 32 | 33 | ## Installation 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ## PopQA 39 | We construct an entity-centric open-domain QA dataset, consisting of 14k QA pairs with fine-grained Wikidata entity ID, Wikipedia page views, and relationship type information. 40 | 41 | ``` 42 | {'id': 4222362, 'subj': 'George Rankin', 'prop': 'occupation', 'obj': 'politician', 'subj_id': 1850297, 'prop_id': 22, 'obj_id': 2834605, 's_aliases': '["George James Rankin"]', 'o_aliases': '["political leader","political figure","polit.","pol"]', 's_uri': 'http://www.wikidata.org/entity/Q5543720', 'o_uri': 'http://www.wikidata.org/entity/Q82955', 's_wiki_title': 'George Rankin', 'o_wiki_title': 'Politician', 's_pop': 142, 'o_pop': 25692, 'question': "What is George Rankin's occupation?", 'possible_answers': '["politician", "political leader", "political figure", "polit.", "pol"]'} 43 | ``` 44 | 45 | The data is available at [data](data/popQA.tsv). 46 | 47 | PopQA is also available available at huggingface datasets: [akariasai/PopQA](https://huggingface.co/datasets/akariasai/PopQA) 48 | ```python 49 | import datasets 50 | popqa = datasets.load_dataset("akariasai/PopQA")["test"] 51 | ``` 52 | ## Baselines 53 | 54 | ### LMs 55 | You can reproduce our zero-shot prompting experiments by running the command below: 56 | ```bash 57 | python run_model.py \ 58 | --model_name MODEL_NAME \ 59 | --input_file data/popQA.tsv \ 60 | --eval_method vanilla 61 | ``` 62 | We use the [int8bit](https://arxiv.org/abs/2208.07339) quantization to run GPT-Neox-20B and OPT-13B in our environment (a single V100 Volta 32 GB GRAM). 63 | 64 | ```sh 65 | python run_model.py \ 66 | --model_name EleutherAI/gpt-neox-20b \ 67 | --input_file data/popQA.tsv \ 68 | --eval_method vanilla \ 69 | --int8bit 70 | ``` 71 | 72 | ### Retrieval-augmented LMs 73 | To run retrieval-augmented LMs using BM25 or [Contriever](https://github.com/facebookresearch/contriever), please download the retrieval results [here](https://drive.google.com/drive/folders/1ggeoHbSPobbGOljOlwl_d16yssSygYqy?usp=sharing). 74 | 75 | Then, you can run the retrieval-augmented baselines as follows: 76 | ```sh 77 | python run_model.py \ 78 | --model_name MODEL_NAME \ 79 | --input_file data/popQA.tsv \ 80 | --eval_method contriever \ 81 | --ret_file PATH_TO_RETRIEVED_DOCUMENTS.jsonl 82 | ``` 83 | To run GenRead, you don't need to specify the retrieval file path. 84 | ```sh 85 | python run_model.py \ 86 | --model_name MODEL_NAME \ 87 | --input_file data/popQA.tsv \ 88 | --eval_method genread 89 | ``` 90 | 91 | ### Adaptive Retrieval 92 | See the `adaptive-retrieval.ipynb` notebook, where you can point to the results files (obtained from `run_model.py`) for a parametric (vanilla, GenRead) and non-parametric (BM25, Contriever) evaluation, which will be used to compute adaptive results. 93 | -------------------------------------------------------------------------------- /util_clm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def assert_all_approx_close(a, b, rtol, atol, count): 4 | 5 | idx = torch.isclose(a.float(), b.float(), rtol, atol) 6 | sumval = (idx==0).sum().item() 7 | if sumval > count: 8 | print(f'Too many values not close: assert {sumval} < {count}') 9 | try: 10 | torch.testing.assert_allclose(a, b, rtol, atol) 11 | except Exception as e: 12 | print(e) 13 | 14 | 15 | def get_memory_footprint(model, return_buffers=True): 16 | """ 17 | Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. 18 | Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the 19 | PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 20 | Arguments: 21 | return_buffers (`bool`, *optional*, defaults to `True`): 22 | Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers 23 | are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch 24 | norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 25 | """ 26 | mem = sum([param.nelement() * param.element_size() for param in model.parameters()]) 27 | if return_buffers: 28 | mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) 29 | mem = mem + mem_bufs 30 | return mem 31 | 32 | 33 | def ـreplace_linear_with_int8linear(model, modules_to_not_convert="lm_head"): 34 | for name, module in model.named_children(): 35 | ـreplace_linear_with_int8linear(module, modules_to_not_convert) 36 | 37 | if isinstance(module, torch.nn.Linear) and name != modules_to_not_convert: 38 | model._modules[name] = QuantizedLinearInt8(linear_layer=module) 39 | return 40 | 41 | 42 | class QuantizedLinearInt8(torch.nn.Module): 43 | ''' 44 | A simple but effictive implmenetion of Int8 quantization for linear layers. 45 | The weights are quantized and stored as Int8, which saves ~50% of the gpu memory. 46 | During the forwared pass, the weights are de-quantized back to fp16 to do multiplication. 47 | Pros: 48 | - saves ~50% of the gpu memory 49 | - accurate quantization because only the weights are quantized, and the weights don't suffer 50 | from the "outliers" issue mentioned in the LLM.int8 paper; only the activations do. 51 | - high precision results beacuse the multiplication is done in fp16 52 | - much faster than LLM.int8 53 | Cons: 54 | - a bit slower because of the added computation of dequantization in each forward pass. In practice, the slowdown 55 | is not large because in the generation application, gpu utilization is not very high. 56 | ''' 57 | def __init__(self, linear_layer): 58 | super().__init__() 59 | self.bias = linear_layer.bias 60 | 61 | weight_bit_width = 8 62 | weight = linear_layer.weight 63 | 64 | self.weight_scale = torch.nn.Parameter( 65 | (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half(), 66 | ) 67 | # print(self.weight_scale.max().item(), self.weight_scale.min().item(), self.weight_scale.mean().item()) 68 | # if self.weight_scale.max().item() > 0.002: 69 | # print(self.weight_scale.max().item()) 70 | self.weight = torch.nn.Parameter( 71 | torch.round(weight.float() / self.weight_scale[:, None]).char(), 72 | requires_grad=False 73 | ) 74 | 75 | def forward(self, x): 76 | weight = self.weight.half() * self.weight_scale[:, None] 77 | return torch.nn.functional.linear(x, weight, self.bias) 78 | 79 | 80 | def convert_model_to_int8_on_gpu(model, device): 81 | """ 82 | Quantize a model to int8 and move it to GPU using a simple method. 83 | """ 84 | if 'cuda' not in device: 85 | raise ValueError(f"Target device should be a gpu. Device {device} is not supported") 86 | 87 | model.half() 88 | 89 | memory_before_quantization = get_memory_footprint(model) # without lm_head 90 | 91 | ـreplace_linear_with_int8linear(model) # replace `Linear` with `QuantizedLinearInt8` 92 | 93 | model.to(device=device) 94 | memory_after_quantization = get_memory_footprint(model) # without lm_head 95 | 96 | saving = round(100 * memory_after_quantization/memory_before_quantization) 97 | memory_before_quantization = round(memory_before_quantization / 2**30, 2) # rounding for printing 98 | memory_after_quantization = round(memory_after_quantization / 2**30, 2) # rounding for printing 99 | 100 | print(f'Quantization memory - before: {memory_before_quantization} GB, after: {memory_after_quantization} GB ({saving}% of the size before)') 101 | return model 102 | -------------------------------------------------------------------------------- /adaptive-retrieval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Using adaptive retrieval on the synthetic dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import pandas as pd\n", 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import seaborn as sns\n", 21 | "from statsmodels.stats.proportion import proportion_confint\n", 22 | "from scipy.stats import pearsonr\n", 23 | "import os\n", 24 | "\n", 25 | "seed = 633\n", 26 | "np.random.seed(seed)\n", 27 | "import random\n", 28 | "random.seed(seed)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 16, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# take the results of parametric and nonparametric-augmented systems \n", 38 | "# and compute how well adaptive retrieval would perform\n", 39 | "do_plot = False\n", 40 | "n_boot = 100\n", 41 | "parametric_path = \"\" # ADD PATH TO VANILLA RESULTS\n", 42 | "nonparametric_path = \"\" # ADD PATH TO RETRIEVAL-AUGMENTED RESULTS\n", 43 | "def clean(df):\n", 44 | " return df[~df[\"s_pop\"].isna() & (df[\"s_pop\"] >= 0)]\n", 45 | "sample = clean(pd.read_csv(parametric_path))\n", 46 | "sample_ret = clean(pd.read_csv(nonparametric_path))\n", 47 | "sample = sample.sort_values(\"question\").reset_index(drop=True)\n", 48 | "sample_ret = sample_ret.sort_values(\"question\").reset_index(drop=True)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 17, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "prop adaptive: 1.0\n", 61 | "prop retrieval: 0.7981889543033361\n", 62 | "parametric knowledge only: 0.17163442668909443\n", 63 | "retrieval augmented accuracy: 0.2537706756377908\n", 64 | "hybrid accuracy: 0.2706980656013457 pm 0.000232407510793975\n", 65 | "hybrid accuracy on train: 0.2719317757009346\n", 66 | "overall accuracy gain: 0.016927389963554862\n", 67 | "overall accuracy: 0.2706980656013457\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "# suppress warnings\n", 73 | "import warnings\n", 74 | "warnings.filterwarnings(\"ignore\")\n", 75 | "props = sample.prop.unique()\n", 76 | "\n", 77 | "test_ret_accs_all = []\n", 78 | "test_param_accs_all = []\n", 79 | "test_hybrid_accs_all = []\n", 80 | "test_count_rets = []\n", 81 | "test_count_params = []\n", 82 | "train_hybrid_accs_all = []\n", 83 | "\n", 84 | "for boot in range(1 if do_plot else n_boot):\n", 85 | " split_proportion = 0.75\n", 86 | " split_mask = [\"train\"] * round(len(sample) * split_proportion) + [\"test\"] * round(len(sample) * (1 - split_proportion))\n", 87 | " np.random.shuffle(split_mask)\n", 88 | " sample[\"split\"] = split_mask\n", 89 | " sample_ret[\"split\"] = split_mask\n", 90 | "\n", 91 | " test_ret_accs = []\n", 92 | " test_param_accs = []\n", 93 | " test_hybrid_accs = []\n", 94 | " train_hybrid_accs = []\n", 95 | " test_count_rets.append(0)\n", 96 | " test_count_params.append(0)\n", 97 | " test_sizes = []\n", 98 | " train_sizes = []\n", 99 | " train_threshs = dict()\n", 100 | " plot_title = \"\"\n", 101 | " if do_plot:\n", 102 | " plt.figure(dpi=200, figsize=(30, 30))\n", 103 | " for i, prop in enumerate(props):\n", 104 | "\n", 105 | " if do_plot:\n", 106 | " plt.subplot(4, 4, i+1)\n", 107 | " cluster_sample = sample[sample.prop == prop].copy()\n", 108 | " cluster_sample_ret_with_pop = sample_ret[sample_ret.prop == prop].copy()\n", 109 | "\n", 110 | " log_pop = np.log(cluster_sample[\"s_pop\"].values)\n", 111 | " cluster_sample[\"log_pop\"] = log_pop\n", 112 | " cluster_sample_ret_with_pop[\"log_pop\"] = log_pop\n", 113 | " ser = log_pop\n", 114 | " _, bin_edges = np.histogram(ser)\n", 115 | "\n", 116 | " c = cluster_sample.is_correct.values\n", 117 | " counts_c, _ = np.histogram(ser[c], bins=bin_edges)\n", 118 | " counts_inc, _ = np.histogram(ser[~c], bins=bin_edges)\n", 119 | " total = counts_c + counts_inc\n", 120 | "\n", 121 | " c_ret = cluster_sample_ret_with_pop.is_correct.values\n", 122 | " counts_c_ret, _ = np.histogram(ser[c_ret], bins=bin_edges)\n", 123 | " counts_inc_ret, _ = np.histogram(ser[~c_ret], bins=bin_edges)\n", 124 | " total_ret = counts_c_ret + counts_inc_ret\n", 125 | "\n", 126 | " width = 0.4*(bin_edges[1] - bin_edges[0])\n", 127 | " thresh_idx = np.argmax(list((sum(counts_c_ret[:i]) + sum(counts_c[i:])) / sum(total_ret) for i in range(len(total_ret) + 1)))\n", 128 | "\n", 129 | " # plt.bar(bin_edges[:-1] - 0.5 * width, counts_c / total, width=width, alpha=0.9, label=\"parametric knowledge (vanilla)\", align='edge')\n", 130 | " # plt.bar(bin_edges[:-1] + 0.5 * width, counts_c_ret / total_ret, width=width, alpha=0.9, label=\"retrieval augmented (BM25)\", align='edge')\n", 131 | " # lo, hi = proportion_confint(counts_c, total, alpha=0.05, method='wilson')\n", 132 | " # plt.errorbar(bin_edges[:-1], counts_c / total, yerr=[counts_c / total - lo, hi - counts_c / total], fmt='none', ecolor='black', elinewidth=1, capsize=2)\n", 133 | " # lo, hi = proportion_confint(counts_c_ret, total_ret, alpha=0.05, method='wilson')\n", 134 | " # plt.errorbar(bin_edges[:-1] + width, counts_c_ret / total_ret, yerr=[counts_c_ret / total_ret - lo, hi - counts_c_ret / total_ret], fmt='none', ecolor='black', elinewidth=1, capsize=2)\n", 135 | " # # plt.errorbar(bin_edges[:-1], counts_c / total, , fmt='none', ecolor='black', capsize=2)\n", 136 | " # # plt.errorbar(bin_edges[:-1], counts_c_ret / total_ret, yerr=wilson(counts_c_ret / total_ret, total_ret), fmt='none', ecolor='black', capsize=2)\n", 137 | " # plt.axvline(x=bin_edges[thresh_idx] + 0.7 * (bin_edges[1] - bin_edges[0]), color='red', linestyle='--', label=\"threshold\")\n", 138 | "\n", 139 | " param_acc = sum(counts_c) / sum(total)\n", 140 | " ret_acc = sum(counts_c_ret) / sum(total_ret)\n", 141 | " hybrid_acc = (sum(counts_c_ret[:thresh_idx]) + sum(counts_c[thresh_idx:])) / (sum(total_ret[:thresh_idx]) + sum(total[thresh_idx:]))\n", 142 | " ret_acc_gain = hybrid_acc - ret_acc\n", 143 | " param_acc_gain = hybrid_acc - param_acc\n", 144 | " \n", 145 | " # let the optimal threshold be the one that maximizes the hybrid accuracy\n", 146 | " train_idxs = cluster_sample.split.values == \"train\"\n", 147 | " test_idxs = cluster_sample.split.values == \"test\"\n", 148 | " train_ser = ser[train_idxs]\n", 149 | " test_ser = ser[test_idxs]\n", 150 | " train_c = c[train_idxs]\n", 151 | " test_c = c[test_idxs]\n", 152 | " train_c_ret = c_ret[train_idxs]\n", 153 | " test_c_ret = c_ret[test_idxs]\n", 154 | " train_counts_c, _ = np.histogram(train_ser[train_c], bins=bin_edges)\n", 155 | " train_counts_inc, _ = np.histogram(train_ser[~train_c], bins=bin_edges)\n", 156 | " train_total = train_counts_c + train_counts_inc\n", 157 | " train_counts_c_ret, _ = np.histogram(train_ser[train_c_ret], bins=bin_edges)\n", 158 | " train_counts_inc_ret, _ = np.histogram(train_ser[~train_c_ret], bins=bin_edges)\n", 159 | " train_total_ret = train_counts_c_ret + train_counts_inc_ret\n", 160 | " test_counts_c, _ = np.histogram(test_ser[test_c], bins=bin_edges)\n", 161 | " test_counts_inc, _ = np.histogram(test_ser[~test_c], bins=bin_edges)\n", 162 | " test_total = test_counts_c + test_counts_inc\n", 163 | " test_counts_c_ret, _ = np.histogram(test_ser[test_c_ret], bins=bin_edges)\n", 164 | " test_counts_inc_ret, _ = np.histogram(test_ser[~test_c_ret], bins=bin_edges)\n", 165 | " test_total_ret = test_counts_c_ret + test_counts_inc_ret\n", 166 | " \n", 167 | " # find the optimal threshold\n", 168 | " train_thresh_idx = np.argmax(list((sum(train_counts_c_ret[:i]) + sum(train_counts_c[i:])) / (sum(train_total_ret[:i]) + sum(train_total[i:])) for i in range(len(train_total) + 1)))\n", 169 | " train_thresh = bin_edges[train_thresh_idx] - 0.5 * (bin_edges[1] - bin_edges[0])\n", 170 | " train_threshs[prop] = train_thresh\n", 171 | "\n", 172 | " # calculate the accuracy on the test set\n", 173 | " test_param_acc = sum(test_counts_c) / sum(test_total)\n", 174 | " test_ret_acc = sum(test_counts_c_ret) / sum(test_total_ret)\n", 175 | " test_hybrid_acc = (sum(test_counts_c_ret[:train_thresh_idx]) + sum(test_counts_c[train_thresh_idx:])) / (sum(test_total_ret[:train_thresh_idx]) + sum(test_total[train_thresh_idx:]))\n", 176 | " test_sizes.append(sum(test_total))\n", 177 | " test_count_rets[-1] += sum(test_total_ret[:train_thresh_idx])\n", 178 | " test_count_params[-1] += sum(test_total[train_thresh_idx:])\n", 179 | " test_ret_accs.append(test_ret_acc)\n", 180 | " test_param_accs.append(test_param_acc)\n", 181 | " test_hybrid_accs.append(test_hybrid_acc)\n", 182 | " train_hybrid_accs.append((sum(train_counts_c_ret[:train_thresh_idx]) + sum(train_counts_c[train_thresh_idx:])) / (sum(train_total_ret[:train_thresh_idx]) + sum(train_total[train_thresh_idx:])))\n", 183 | " train_sizes.append(sum(train_total))\n", 184 | "\n", 185 | " if do_plot:\n", 186 | " plt.bar(bin_edges[:-1] - 0.5 * width, test_counts_c / test_total, width=width, alpha=0.9, label=\"parametric knowledge (vanilla)\", align='edge', hatch='//')\n", 187 | " plt.bar(bin_edges[:-1] + 0.5 * width, test_counts_c_ret / test_total_ret, width=width, alpha=0.9, label=\"retrieval augmented (BM25)\", align='edge', hatch='//')\n", 188 | " lo, hi = proportion_confint(test_counts_c, test_total, alpha=0.05, method='wilson')\n", 189 | " plt.errorbar(bin_edges[:-1], test_counts_c / test_total, yerr=[test_counts_c / test_total - lo, hi - test_counts_c / test_total], fmt='none', ecolor='black', elinewidth=1, capsize=2)\n", 190 | " lo, hi = proportion_confint(test_counts_c_ret, test_total_ret, alpha=0.05, method='wilson')\n", 191 | " plt.errorbar(bin_edges[:-1] + width, test_counts_c_ret / test_total_ret, yerr=[test_counts_c_ret / test_total_ret - lo, hi - test_counts_c_ret / test_total_ret], fmt='none', ecolor='black', elinewidth=1, capsize=2)\n", 192 | " plt.axvline(x=bin_edges[train_thresh_idx] - 0.8 * width, color='red', linestyle='--', label=\"threshold\")\n", 193 | "\n", 194 | " print(f\"Threshold for {prop}:\", train_thresh)\n", 195 | " print(\"Parametric knowledge only:\", param_acc)\n", 196 | " print(\"Retrieval augmented accuracy:\", ret_acc)\n", 197 | " print(f\"New accuracy with thresh={thresh_idx}:\", hybrid_acc)\n", 198 | " print()\n", 199 | " plt.title(f\"{prop}\")\n", 200 | " plt.ylim([0,1.01])\n", 201 | " if do_plot:\n", 202 | " plt.xlabel(\"log(s_pop)\")\n", 203 | " plt.ylabel(\"proportion correct\")\n", 204 | " plt.legend()\n", 205 | " plt.tight_layout()\n", 206 | " plt.show()\n", 207 | "\n", 208 | " # take the weighted mean by test_size\n", 209 | " test_ret_accs_all.append(np.average(test_ret_accs, weights=test_sizes))\n", 210 | " test_param_accs_all.append(np.average(test_param_accs, weights=test_sizes))\n", 211 | " test_hybrid_accs_all.append(np.average(test_hybrid_accs, weights=test_sizes))\n", 212 | " train_hybrid_accs_all.append(np.average(train_hybrid_accs, weights=train_sizes))\n", 213 | "\n", 214 | "test_size = split_mask.count(\"test\")\n", 215 | "prop_adaptive = (np.mean(test_count_rets) + np.mean(test_count_params)) / test_size\n", 216 | "param_acc = np.mean(test_param_accs_all)\n", 217 | "ret_acc = np.mean(test_ret_accs_all)\n", 218 | "hybrid_acc = np.mean(test_hybrid_accs_all)\n", 219 | "train_hybrid_acc = np.mean(train_hybrid_accs_all)\n", 220 | "overall_hybrid_acc = np.mean(test_hybrid_accs_all) * prop_adaptive + max(np.mean(test_param_accs_all), np.mean(test_ret_accs_all)) * (1 - prop_adaptive)\n", 221 | "sem_test_hybrid_acc = 2 * np.std(test_hybrid_accs_all) / np.sqrt(test_size)\n", 222 | "print(\"prop adaptive:\", prop_adaptive)\n", 223 | "print(\"prop retrieval:\", np.mean(test_count_rets) / test_size)\n", 224 | "print(\"parametric knowledge only:\", param_acc)\n", 225 | "print(\"retrieval augmented accuracy:\", ret_acc)\n", 226 | "print(\"hybrid accuracy:\", hybrid_acc, \"pm\", sem_test_hybrid_acc)\n", 227 | "print(\"hybrid accuracy on train:\", train_hybrid_acc)\n", 228 | "print(\"overall accuracy gain:\", overall_hybrid_acc - max(param_acc, ret_acc))\n", 229 | "print(\"overall accuracy:\", overall_hybrid_acc)" 230 | ] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.10.3" 250 | }, 251 | "orig_nbformat": 4, 252 | "vscode": { 253 | "interpreter": { 254 | "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1" 255 | } 256 | } 257 | }, 258 | "nbformat": 4, 259 | "nbformat_minor": 2 260 | } 261 | -------------------------------------------------------------------------------- /run_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import random 5 | import torch 6 | import os 7 | import numpy as np 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | import openai 12 | from tqdm import tqdm 13 | import json 14 | import argparse 15 | sns.set_theme() 16 | 17 | 18 | seed = 633 19 | 20 | torch.backends.cudnn.deterministic = True 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | print('Cuda:', torch.cuda.is_available()) 26 | print('pwd', os.getcwd()) 27 | 28 | from transformers import AutoTokenizer, AutoModelForCausalLM 29 | from util_clm import convert_model_to_int8_on_gpu 30 | 31 | 32 | 33 | import jsonlines 34 | 35 | def load_jsonlines(file): 36 | with jsonlines.open(file, 'r') as jsonl_f: 37 | lst = [obj for obj in jsonl_f] 38 | return lst 39 | 40 | q_templates = { 41 | 22: "What is {}'s occupation?", 42 | 218: "In what city was {} born?", 43 | 91: "What genre is {}?", 44 | 257: "Who is the father of {}?", 45 | 182: "In what country is {}?", 46 | 164: "Who was the producer of {}?", 47 | 526: "Who was the director of {}?", 48 | 97: "What is {} the capital of?", 49 | 533: "Who was the screenwriter for {}?", 50 | 639: "Who was the composer of {}?", 51 | 472: "What color is {}?", 52 | 106: "What is the religion of {}?", 53 | 560: "What sport does {} play?", 54 | 484: "Who is the author of {}?", 55 | 292: "Who is the mother of {}?", 56 | 422: "What is the capital of {}?" 57 | } 58 | completion_template = "Q: {} A:" # "{}" # "Query: {}\nResult:" # "Q: {} A:" # "{} The answer is" 59 | genread_template = "Generate a background document from Wikipedia to answer the given question. {}" # This prompt comes from the GenRead paper 60 | 61 | def call_request(prompt, model, tokenizer, max_new_tokens=15): 62 | max_inpt_tokens = tokenizer.model_max_length 63 | if len(prompt) > tokenizer.model_max_length: # conservative lower bound, since each token is at least 1 character 64 | inpts = tokenizer(prompt, return_tensors="pt") 65 | new_prompt = tokenizer.decode(inpts.input_ids[0, -(max_inpt_tokens - max_new_tokens):]) 66 | else: 67 | new_prompt = prompt 68 | 69 | # try to get a response from the model multiple times if theres a timeout 70 | for i in range(5): 71 | try: 72 | if i > 0: 73 | print("Retrying request") 74 | response = openai.Completion.create(model=model, prompt=new_prompt, temperature=0.0, max_tokens=max_new_tokens, logprobs=5, top_p=1,frequency_penalty=0.0,presence_penalty=0.0) 75 | break 76 | except Exception as e: 77 | print(e) 78 | print("Timeout, trying again") 79 | 80 | pred = response["choices"][0]["text"] 81 | if pred.startswith("\n\n"): 82 | pred = pred[2:] 83 | pred = pred.split("\n")[0] 84 | return pred, response.to_dict_recursive() 85 | 86 | def call_model(prompt, model, tokenizer, device, max_new_tokens=15, model_max_length=None): 87 | max_inpt_tokens = tokenizer.model_max_length if model_max_length is None else model_max_length 88 | inpts = tokenizer(prompt, return_tensors="pt").to(device) 89 | gen = model.generate(input_ids=inpts.input_ids[:, -(max_inpt_tokens - max_new_tokens):], attention_mask=inpts.attention_mask[:, -(max_inpt_tokens - max_new_tokens):], pad_token_id=tokenizer.eos_token_id, max_new_tokens=max_new_tokens, num_beams=1, do_sample=False) 90 | text = tokenizer.decode(gen[0]) 91 | actual_prompt = tokenizer.decode(inpts.input_ids[0, -(max_inpt_tokens - max_new_tokens):]) 92 | pred = text[len(actual_prompt):] 93 | if pred.startswith("\n\n"): 94 | pred = pred[2:] 95 | pred = pred.split("\n")[0] 96 | return pred, text 97 | 98 | def clip_paragraph(text, eval_method): 99 | if eval_method in ["BM25", "genread"]: 100 | return text 101 | split = text.split(". ") 102 | return ". ".join(split[:-1]) + "." 103 | 104 | def get_few_shot_text_with_retrieval(row, retrieval_dict, eval_method): 105 | if eval_method == "vanilla": 106 | return completion_template.format(row.question) + " " + row.obj 107 | # retrieval_dict[row.id]["ctxs"][0] 108 | if row.question.replace("?", "").lower() not in retrieval_dict: 109 | print("missing retrieval") 110 | return completion_template.format(row.question) + " " + row.obj 111 | else: 112 | retrieval = retrieval_dict[row.question.replace("?", "").lower()]["ctxs"][0] 113 | retrieved_text = clip_paragraph(retrieval["text"], eval_method) 114 | return retrieved_text + "\n\n" + completion_template.format(row.question) + " " + row.obj 115 | 116 | def get_few_shot_text(row, eval_method): 117 | return completion_template.format(row.question) + " " + row.obj 118 | 119 | def get_genread_passage(question, genread_template, generate_function, max_new_tokens=150): 120 | prompt = genread_template.format(question) 121 | return generate_function(prompt, max_new_tokens=max_new_tokens)[0] 122 | 123 | def get_few_shot_examples_genread(knowledge, generate_function, n_examples, genread_template, is_templatedQA, max_new_tokens=150): 124 | if is_templatedQA: 125 | few_shot_examples = dict() 126 | all_pids = list(q_templates.keys()) 127 | examples_per_template = n_examples // (len(q_templates) - 1) 128 | for pid in all_pids: 129 | for row2 in knowledge[knowledge.prop_id == pid].sample(n=examples_per_template).iloc: 130 | if pid not in few_shot_examples: 131 | few_shot_examples[pid] = [] 132 | generation = get_genread_passage(row2.question, genread_template, generate_function, max_new_tokens=max_new_tokens) 133 | few_shot_examples[pid].append(get_few_shot_text_with_retrieval(row2, {row2.question: {"ctxs": [{"id": -1, "text": generation}]}}, "genread")) 134 | else: 135 | few_shot_examples = [] 136 | for row2 in knowledge.sample(n=n_examples + 1).iloc: 137 | generation = get_genread_passage(row2.question, genread_template, generate_function, max_new_tokens=max_new_tokens) 138 | few_shot_examples.append(get_few_shot_text_with_retrieval(row2, {row2.question: {"ctxs": [{"id": -1, "text": generation}]}}, "genread")) 139 | 140 | return few_shot_examples 141 | 142 | 143 | def main(): 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('--model_name', type=str) 146 | parser.add_argument('--input_file', type=str) 147 | parser.add_argument('--alias', type=str) 148 | parser.add_argument('--n_examples', type=int, default=15) 149 | parser.add_argument('--eval_method', type=str, default="vanilla", choices=["vanilla", "BM25", "contriever", "genread"]) 150 | parser.add_argument('--ret_path', type=str, default=None, required=False, help="path to retrieved documents jsonl") 151 | parser.add_argument('--device', type=str, default="cuda") 152 | parser.add_argument('--max_new_tokens', type=int, default=15) 153 | parser.add_argument('--sample', type=int, default=0, help="if 0, use all examples") 154 | parser.add_argument('--continue_from', type=str, help="path to previous results file") 155 | parser.add_argument('--int8bit', action="store_true") 156 | parser.add_argument('--parallel', type=str, help="string of format 'i.n_workers' where i is the index of the worker") 157 | 158 | args = parser.parse_args() 159 | 160 | use_gpt3 = args.model_name in {"text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001"} 161 | if use_gpt3: 162 | with open("../../openAIkey.txt") as f: 163 | openai.api_key = f.read()[:-1] 164 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 165 | generate = lambda prompt, max_new_tokens: call_request(prompt, args.model_name, tokenizer, max_new_tokens=max_new_tokens) 166 | else: 167 | gpt = args.model_name 168 | device = args.device 169 | tokenizer = AutoTokenizer.from_pretrained(gpt) 170 | tokenizer.pad_token = tokenizer.eos_token 171 | tokenizer.pad_token_id = tokenizer.eos_token_id 172 | if args.int8bit: 173 | model = convert_model_to_int8_on_gpu(AutoModelForCausalLM.from_pretrained(gpt), device) 174 | else: 175 | model = AutoModelForCausalLM.from_pretrained(gpt).eval().to(device) 176 | if "opt" in args.model_name or args.model_name == "EleutherAI/gpt-neox-20b": 177 | generate = lambda prompt, max_new_tokens: call_model(prompt, model=model, tokenizer=tokenizer, device=device, max_new_tokens=max_new_tokens, model_max_length=2048) 178 | else: 179 | generate = lambda prompt, max_new_tokens: call_model(prompt, model=model, tokenizer=tokenizer, device=device, max_new_tokens=max_new_tokens) 180 | input_path = args.input_file 181 | knowledge = pd.read_csv(input_path, sep="\t") 182 | 183 | if args.continue_from is not None: 184 | results = pd.read_csv(args.continue_from, sep="\t") 185 | knowledge = knowledge[~knowledge.id.isin(results.id)] 186 | n = len(knowledge) if args.sample == 0 else args.sample 187 | sample = knowledge.sample(n=n, replace=False) 188 | if args.parallel is not None: 189 | worker_num, n_workers = map(int, args.parallel.split(".")) 190 | sample = sample.iloc[worker_num::n_workers] 191 | 192 | n_examples = args.n_examples 193 | is_templatedQA = True 194 | examples_per_template = n_examples // (len(q_templates) - 1) 195 | 196 | preds = [] 197 | prompts =[] 198 | accuracy = [] 199 | responses = [] 200 | if args.eval_method in ["BM25", "contriever"]: 201 | has_answer = [] 202 | retrieval_ids = [] 203 | with open(args.ret_path) as f: 204 | retrieval_dict = {json.loads(s)["question"]: json.loads(s) for s in f.readlines()} 205 | # print(retrieval_dict) 206 | if args.eval_method == "genread": 207 | genread_few_shot_examples = get_few_shot_examples_genread(knowledge, generate, n_examples, genread_template, is_templatedQA, max_new_tokens=150) 208 | has_answer = [] 209 | gen_passages = [] 210 | 211 | # main loop 212 | for row in tqdm(sample.iloc, total=n): 213 | 214 | # get few shot examples text 215 | if n_examples == 0: 216 | few_shot_examples_text = "" 217 | else: 218 | few_shot_examples = [] 219 | if args.eval_method == "genread": 220 | if is_templatedQA: 221 | other_pids = list(q_templates.keys()) 222 | other_pids.remove(row.prop_id) 223 | few_shot_examples = [] 224 | for pid in other_pids: 225 | few_shot_examples.extend(random.sample(genread_few_shot_examples[pid], examples_per_template)) 226 | else: 227 | few_shot_examples = random.sample([ex for ex in genread_few_shot_examples if row.question not in ex], n_examples) 228 | else: 229 | if is_templatedQA: 230 | other_pids = list(q_templates.keys()) 231 | other_pids.remove(row.prop_id) 232 | for pid in other_pids: 233 | for row2 in knowledge[knowledge.prop_id == pid].sample(n=examples_per_template).iloc: 234 | few_shot_examples.append(get_few_shot_text_with_retrieval(row2, retrieval_dict, args.eval_method) if args.eval_method in ["BM25", "contriever"] else get_few_shot_text(row2, args.eval_method)) 235 | else: 236 | for row2 in knowledge[knowledge.question != row.question].sample(n=n_examples).iloc: 237 | few_shot_examples.append(get_few_shot_text_with_retrieval(row2, retrieval_dict, args.eval_method) if args.eval_method in ["BM25", "contriever"] else get_few_shot_text(row2, args.eval_method)) 238 | 239 | 240 | np.random.shuffle(few_shot_examples) 241 | few_shot_examples_text = "\n\n".join(few_shot_examples) + "\n\n" 242 | 243 | # get prompt 244 | if args.eval_method == "vanilla": 245 | prompt = few_shot_examples_text + completion_template.format(row.question) 246 | elif args.eval_method in ["BM25", "contriever"]: 247 | query = row.question 248 | try: 249 | retrieval = retrieval_dict[query]["ctxs"][0] # retrieval_dict[row.id]["ctxs"][0] 250 | except: 251 | 252 | print("No retrieval for", query, " Example query:", list(retrieval_dict.keys())[0]) 253 | retrieval = {"text": "", "id": np.nan, "hasanswer": False} 254 | retrieved_text = clip_paragraph(retrieval["text"], eval_method=args.eval_method) 255 | retrieval_id = retrieval["id"] 256 | prompt = few_shot_examples_text + retrieved_text + "\n\n" + completion_template.format(row.question) 257 | has_answer.append(retrieval["hasanswer"]) 258 | retrieval_ids.append(retrieval_id) 259 | elif args.eval_method == "genread": 260 | generation = get_genread_passage(row.question, genread_template, generate, max_new_tokens=150) 261 | prompt = few_shot_examples_text + generation + "\n\n" + completion_template.format(row.question) 262 | gen_passages.append(generation) 263 | 264 | # generate response 265 | pred, response = generate(prompt, max_new_tokens=args.max_new_tokens) 266 | prompts.append(prompt) 267 | preds.append(pred) 268 | responses.append(response) 269 | 270 | # compute accuracy 271 | possible_answers = json.loads(row.possible_answers) 272 | is_correct = False 273 | genread_has_answer = False 274 | for pa in possible_answers: 275 | if pa in pred or pa.lower() in pred or pa.capitalize() in pred: 276 | is_correct = True 277 | if args.eval_method == "genread" and pa in response or pa.lower() in response or pa.capitalize() in response: 278 | genread_has_answer = True 279 | accuracy.append(is_correct) 280 | if args.eval_method == "genread": 281 | has_answer.append(genread_has_answer) 282 | 283 | # save results intermittently 284 | if len(preds) % 100 == 0: 285 | temp_sample = sample.iloc[:len(preds)].copy() 286 | temp_sample["pred"] = preds 287 | temp_sample["prompt"] = prompts 288 | temp_sample["generation"] = responses 289 | temp_sample["is_correct"] = accuracy 290 | if args.eval_method in ["BM25", "contriever"]: 291 | temp_sample["has_answer"] = has_answer 292 | temp_sample["retrieval_id"] = retrieval_ids 293 | if args.eval_method == "genread": 294 | temp_sample["has_answer"] = has_answer 295 | temp_sample["gen_passage"] = gen_passages 296 | model_name_alias = args.model_name.replace("/","_") 297 | if not os.path.exists(f"results/temp/"): 298 | os.makedirs(f"results/temp/") 299 | worker_str = "" if args.parallel is None else f"-worker={args.parallel}" 300 | output_path = f"results/temp/model={model_name_alias}-input={args.alias}-method={args.eval_method}-shots={n_examples}-n={len(temp_sample)}{'_int8bit' if args.int8bit is True else ''}{worker_str}.csv" 301 | temp_sample.to_csv(output_path, index=False) 302 | 303 | sample["is_correct"] = accuracy 304 | sample["prompt"] = prompts 305 | sample["pred"] = preds 306 | sample["generation"] = responses 307 | if args.eval_method in ["BM25", "contriever"]: 308 | sample["has_answer"] = has_answer 309 | sample["retrieval_id"] = retrieval_ids 310 | if args.eval_method == "genread": 311 | sample["has_answer"] = has_answer 312 | sample["gen_passage"] = gen_passages 313 | 314 | print(sample.is_correct.mean()) 315 | model_name_alias = args.model_name.replace("/","_") 316 | worker_str = "" if args.parallel is None else f"-worker={args.parallel}" 317 | sample.to_csv(f"results/model={model_name_alias}-input={args.alias}-method={args.eval_method}-shots={n_examples}-n={len(sample)}{'_int8bit' if args.int8bit is True else ''}{worker_str}.csv") 318 | 319 | 320 | if __name__ == "__main__": 321 | main() 322 | --------------------------------------------------------------------------------