├── requirements.txt
├── figures
└── teaser_llm.png
├── LICENSE
├── README.md
├── util_clm.py
├── adaptive-retrieval.ipynb
└── run_model.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai
2 | transformers
3 | torch
4 |
--------------------------------------------------------------------------------
/figures/teaser_llm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexTMallen/adaptive-retrieval/HEAD/figures/teaser_llm.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Alex Mallen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adaptive Retrieval: Popularity-based LM Augmentation
2 |
3 | This is the official repository for our preprint: Alex Mallen, Akari Asai, Victor Zhong, Rajarshi Das, Hannaneh Hajishirzi, and Daniel Khashabi. [When Not to Trust Language Models: Investigating Effectiveness and Limitations of Parametric and Non-Parametric Memories](https://arxiv.org/abs/2212.10511). 2022.
4 |
5 | In this work, we conduct a large-scale knowledge probing of 10 language models (GPT-Neo series, OPT series and GPT-3 series) and 4 retrieval-augmentation approaches (BM25, Contriever, GenRead and vanilla), using our new open-domain QA dataset, **PopQA**.
6 |
7 |
8 |
9 |
10 |
11 | We further introduce a simple-yet-effective method, **Adaptive Retrieval**, which adaptively retrieves and incorporates non-parametric memories when necessary. Our experimental results show that Adaptive Retrieval is not only more competitive but also more efficient in terms of inference-time latency as well as the GPT-3 API cost.
12 |
13 | ### Contact and Citations
14 | For any questions about the paper or the code, please contact the first authors or leave issues.
15 | If you find our code or paper useful, please cite the paper:
16 | ```
17 | @article{ mallen2023llm_memorization ,
18 | title={When Not to Trust Language Models: Investigating Effectiveness and Limitations of Parametric and Non-Parametric Memories },
19 | author={ Mallen, Alex and Asai,Akari and Zhong, Victor and Das, Rajarshi and Hajishirzi, Hannaneh and Khashabi, Daniel},
20 | journal={ arXiv preprint },
21 | year={ 2022 }
22 | }
23 | ```
24 | ## Content
25 |
26 | 1. [Installation](#installation)
27 | 2. [PopQA](#popqa)
28 | 3. [Baselines](#baselines)
29 | - [LMs](#lms)
30 | - [Retrieval-augmented LMs ](#retrieval-augmented-lms)
31 | - [Adaptive Retrieval](#adaptive-retrieval)
32 |
33 | ## Installation
34 | ```
35 | pip install -r requirements.txt
36 | ```
37 |
38 | ## PopQA
39 | We construct an entity-centric open-domain QA dataset, consisting of 14k QA pairs with fine-grained Wikidata entity ID, Wikipedia page views, and relationship type information.
40 |
41 | ```
42 | {'id': 4222362, 'subj': 'George Rankin', 'prop': 'occupation', 'obj': 'politician', 'subj_id': 1850297, 'prop_id': 22, 'obj_id': 2834605, 's_aliases': '["George James Rankin"]', 'o_aliases': '["political leader","political figure","polit.","pol"]', 's_uri': 'http://www.wikidata.org/entity/Q5543720', 'o_uri': 'http://www.wikidata.org/entity/Q82955', 's_wiki_title': 'George Rankin', 'o_wiki_title': 'Politician', 's_pop': 142, 'o_pop': 25692, 'question': "What is George Rankin's occupation?", 'possible_answers': '["politician", "political leader", "political figure", "polit.", "pol"]'}
43 | ```
44 |
45 | The data is available at [data](data/popQA.tsv).
46 |
47 | PopQA is also available available at huggingface datasets: [akariasai/PopQA](https://huggingface.co/datasets/akariasai/PopQA)
48 | ```python
49 | import datasets
50 | popqa = datasets.load_dataset("akariasai/PopQA")["test"]
51 | ```
52 | ## Baselines
53 |
54 | ### LMs
55 | You can reproduce our zero-shot prompting experiments by running the command below:
56 | ```bash
57 | python run_model.py \
58 | --model_name MODEL_NAME \
59 | --input_file data/popQA.tsv \
60 | --eval_method vanilla
61 | ```
62 | We use the [int8bit](https://arxiv.org/abs/2208.07339) quantization to run GPT-Neox-20B and OPT-13B in our environment (a single V100 Volta 32 GB GRAM).
63 |
64 | ```sh
65 | python run_model.py \
66 | --model_name EleutherAI/gpt-neox-20b \
67 | --input_file data/popQA.tsv \
68 | --eval_method vanilla \
69 | --int8bit
70 | ```
71 |
72 | ### Retrieval-augmented LMs
73 | To run retrieval-augmented LMs using BM25 or [Contriever](https://github.com/facebookresearch/contriever), please download the retrieval results [here](https://drive.google.com/drive/folders/1ggeoHbSPobbGOljOlwl_d16yssSygYqy?usp=sharing).
74 |
75 | Then, you can run the retrieval-augmented baselines as follows:
76 | ```sh
77 | python run_model.py \
78 | --model_name MODEL_NAME \
79 | --input_file data/popQA.tsv \
80 | --eval_method contriever \
81 | --ret_file PATH_TO_RETRIEVED_DOCUMENTS.jsonl
82 | ```
83 | To run GenRead, you don't need to specify the retrieval file path.
84 | ```sh
85 | python run_model.py \
86 | --model_name MODEL_NAME \
87 | --input_file data/popQA.tsv \
88 | --eval_method genread
89 | ```
90 |
91 | ### Adaptive Retrieval
92 | See the `adaptive-retrieval.ipynb` notebook, where you can point to the results files (obtained from `run_model.py`) for a parametric (vanilla, GenRead) and non-parametric (BM25, Contriever) evaluation, which will be used to compute adaptive results.
93 |
--------------------------------------------------------------------------------
/util_clm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def assert_all_approx_close(a, b, rtol, atol, count):
4 |
5 | idx = torch.isclose(a.float(), b.float(), rtol, atol)
6 | sumval = (idx==0).sum().item()
7 | if sumval > count:
8 | print(f'Too many values not close: assert {sumval} < {count}')
9 | try:
10 | torch.testing.assert_allclose(a, b, rtol, atol)
11 | except Exception as e:
12 | print(e)
13 |
14 |
15 | def get_memory_footprint(model, return_buffers=True):
16 | """
17 | Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
18 | Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
19 | PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
20 | Arguments:
21 | return_buffers (`bool`, *optional*, defaults to `True`):
22 | Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
23 | are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
24 | norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
25 | """
26 | mem = sum([param.nelement() * param.element_size() for param in model.parameters()])
27 | if return_buffers:
28 | mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
29 | mem = mem + mem_bufs
30 | return mem
31 |
32 |
33 | def ـreplace_linear_with_int8linear(model, modules_to_not_convert="lm_head"):
34 | for name, module in model.named_children():
35 | ـreplace_linear_with_int8linear(module, modules_to_not_convert)
36 |
37 | if isinstance(module, torch.nn.Linear) and name != modules_to_not_convert:
38 | model._modules[name] = QuantizedLinearInt8(linear_layer=module)
39 | return
40 |
41 |
42 | class QuantizedLinearInt8(torch.nn.Module):
43 | '''
44 | A simple but effictive implmenetion of Int8 quantization for linear layers.
45 | The weights are quantized and stored as Int8, which saves ~50% of the gpu memory.
46 | During the forwared pass, the weights are de-quantized back to fp16 to do multiplication.
47 | Pros:
48 | - saves ~50% of the gpu memory
49 | - accurate quantization because only the weights are quantized, and the weights don't suffer
50 | from the "outliers" issue mentioned in the LLM.int8 paper; only the activations do.
51 | - high precision results beacuse the multiplication is done in fp16
52 | - much faster than LLM.int8
53 | Cons:
54 | - a bit slower because of the added computation of dequantization in each forward pass. In practice, the slowdown
55 | is not large because in the generation application, gpu utilization is not very high.
56 | '''
57 | def __init__(self, linear_layer):
58 | super().__init__()
59 | self.bias = linear_layer.bias
60 |
61 | weight_bit_width = 8
62 | weight = linear_layer.weight
63 |
64 | self.weight_scale = torch.nn.Parameter(
65 | (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half(),
66 | )
67 | # print(self.weight_scale.max().item(), self.weight_scale.min().item(), self.weight_scale.mean().item())
68 | # if self.weight_scale.max().item() > 0.002:
69 | # print(self.weight_scale.max().item())
70 | self.weight = torch.nn.Parameter(
71 | torch.round(weight.float() / self.weight_scale[:, None]).char(),
72 | requires_grad=False
73 | )
74 |
75 | def forward(self, x):
76 | weight = self.weight.half() * self.weight_scale[:, None]
77 | return torch.nn.functional.linear(x, weight, self.bias)
78 |
79 |
80 | def convert_model_to_int8_on_gpu(model, device):
81 | """
82 | Quantize a model to int8 and move it to GPU using a simple method.
83 | """
84 | if 'cuda' not in device:
85 | raise ValueError(f"Target device should be a gpu. Device {device} is not supported")
86 |
87 | model.half()
88 |
89 | memory_before_quantization = get_memory_footprint(model) # without lm_head
90 |
91 | ـreplace_linear_with_int8linear(model) # replace `Linear` with `QuantizedLinearInt8`
92 |
93 | model.to(device=device)
94 | memory_after_quantization = get_memory_footprint(model) # without lm_head
95 |
96 | saving = round(100 * memory_after_quantization/memory_before_quantization)
97 | memory_before_quantization = round(memory_before_quantization / 2**30, 2) # rounding for printing
98 | memory_after_quantization = round(memory_after_quantization / 2**30, 2) # rounding for printing
99 |
100 | print(f'Quantization memory - before: {memory_before_quantization} GB, after: {memory_after_quantization} GB ({saving}% of the size before)')
101 | return model
102 |
--------------------------------------------------------------------------------
/adaptive-retrieval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Using adaptive retrieval on the synthetic dataset"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "import pandas as pd\n",
18 | "import numpy as np\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "import seaborn as sns\n",
21 | "from statsmodels.stats.proportion import proportion_confint\n",
22 | "from scipy.stats import pearsonr\n",
23 | "import os\n",
24 | "\n",
25 | "seed = 633\n",
26 | "np.random.seed(seed)\n",
27 | "import random\n",
28 | "random.seed(seed)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 16,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "# take the results of parametric and nonparametric-augmented systems \n",
38 | "# and compute how well adaptive retrieval would perform\n",
39 | "do_plot = False\n",
40 | "n_boot = 100\n",
41 | "parametric_path = \"\" # ADD PATH TO VANILLA RESULTS\n",
42 | "nonparametric_path = \"\" # ADD PATH TO RETRIEVAL-AUGMENTED RESULTS\n",
43 | "def clean(df):\n",
44 | " return df[~df[\"s_pop\"].isna() & (df[\"s_pop\"] >= 0)]\n",
45 | "sample = clean(pd.read_csv(parametric_path))\n",
46 | "sample_ret = clean(pd.read_csv(nonparametric_path))\n",
47 | "sample = sample.sort_values(\"question\").reset_index(drop=True)\n",
48 | "sample_ret = sample_ret.sort_values(\"question\").reset_index(drop=True)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 17,
54 | "metadata": {},
55 | "outputs": [
56 | {
57 | "name": "stdout",
58 | "output_type": "stream",
59 | "text": [
60 | "prop adaptive: 1.0\n",
61 | "prop retrieval: 0.7981889543033361\n",
62 | "parametric knowledge only: 0.17163442668909443\n",
63 | "retrieval augmented accuracy: 0.2537706756377908\n",
64 | "hybrid accuracy: 0.2706980656013457 pm 0.000232407510793975\n",
65 | "hybrid accuracy on train: 0.2719317757009346\n",
66 | "overall accuracy gain: 0.016927389963554862\n",
67 | "overall accuracy: 0.2706980656013457\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "# suppress warnings\n",
73 | "import warnings\n",
74 | "warnings.filterwarnings(\"ignore\")\n",
75 | "props = sample.prop.unique()\n",
76 | "\n",
77 | "test_ret_accs_all = []\n",
78 | "test_param_accs_all = []\n",
79 | "test_hybrid_accs_all = []\n",
80 | "test_count_rets = []\n",
81 | "test_count_params = []\n",
82 | "train_hybrid_accs_all = []\n",
83 | "\n",
84 | "for boot in range(1 if do_plot else n_boot):\n",
85 | " split_proportion = 0.75\n",
86 | " split_mask = [\"train\"] * round(len(sample) * split_proportion) + [\"test\"] * round(len(sample) * (1 - split_proportion))\n",
87 | " np.random.shuffle(split_mask)\n",
88 | " sample[\"split\"] = split_mask\n",
89 | " sample_ret[\"split\"] = split_mask\n",
90 | "\n",
91 | " test_ret_accs = []\n",
92 | " test_param_accs = []\n",
93 | " test_hybrid_accs = []\n",
94 | " train_hybrid_accs = []\n",
95 | " test_count_rets.append(0)\n",
96 | " test_count_params.append(0)\n",
97 | " test_sizes = []\n",
98 | " train_sizes = []\n",
99 | " train_threshs = dict()\n",
100 | " plot_title = \"\"\n",
101 | " if do_plot:\n",
102 | " plt.figure(dpi=200, figsize=(30, 30))\n",
103 | " for i, prop in enumerate(props):\n",
104 | "\n",
105 | " if do_plot:\n",
106 | " plt.subplot(4, 4, i+1)\n",
107 | " cluster_sample = sample[sample.prop == prop].copy()\n",
108 | " cluster_sample_ret_with_pop = sample_ret[sample_ret.prop == prop].copy()\n",
109 | "\n",
110 | " log_pop = np.log(cluster_sample[\"s_pop\"].values)\n",
111 | " cluster_sample[\"log_pop\"] = log_pop\n",
112 | " cluster_sample_ret_with_pop[\"log_pop\"] = log_pop\n",
113 | " ser = log_pop\n",
114 | " _, bin_edges = np.histogram(ser)\n",
115 | "\n",
116 | " c = cluster_sample.is_correct.values\n",
117 | " counts_c, _ = np.histogram(ser[c], bins=bin_edges)\n",
118 | " counts_inc, _ = np.histogram(ser[~c], bins=bin_edges)\n",
119 | " total = counts_c + counts_inc\n",
120 | "\n",
121 | " c_ret = cluster_sample_ret_with_pop.is_correct.values\n",
122 | " counts_c_ret, _ = np.histogram(ser[c_ret], bins=bin_edges)\n",
123 | " counts_inc_ret, _ = np.histogram(ser[~c_ret], bins=bin_edges)\n",
124 | " total_ret = counts_c_ret + counts_inc_ret\n",
125 | "\n",
126 | " width = 0.4*(bin_edges[1] - bin_edges[0])\n",
127 | " thresh_idx = np.argmax(list((sum(counts_c_ret[:i]) + sum(counts_c[i:])) / sum(total_ret) for i in range(len(total_ret) + 1)))\n",
128 | "\n",
129 | " # plt.bar(bin_edges[:-1] - 0.5 * width, counts_c / total, width=width, alpha=0.9, label=\"parametric knowledge (vanilla)\", align='edge')\n",
130 | " # plt.bar(bin_edges[:-1] + 0.5 * width, counts_c_ret / total_ret, width=width, alpha=0.9, label=\"retrieval augmented (BM25)\", align='edge')\n",
131 | " # lo, hi = proportion_confint(counts_c, total, alpha=0.05, method='wilson')\n",
132 | " # plt.errorbar(bin_edges[:-1], counts_c / total, yerr=[counts_c / total - lo, hi - counts_c / total], fmt='none', ecolor='black', elinewidth=1, capsize=2)\n",
133 | " # lo, hi = proportion_confint(counts_c_ret, total_ret, alpha=0.05, method='wilson')\n",
134 | " # plt.errorbar(bin_edges[:-1] + width, counts_c_ret / total_ret, yerr=[counts_c_ret / total_ret - lo, hi - counts_c_ret / total_ret], fmt='none', ecolor='black', elinewidth=1, capsize=2)\n",
135 | " # # plt.errorbar(bin_edges[:-1], counts_c / total, , fmt='none', ecolor='black', capsize=2)\n",
136 | " # # plt.errorbar(bin_edges[:-1], counts_c_ret / total_ret, yerr=wilson(counts_c_ret / total_ret, total_ret), fmt='none', ecolor='black', capsize=2)\n",
137 | " # plt.axvline(x=bin_edges[thresh_idx] + 0.7 * (bin_edges[1] - bin_edges[0]), color='red', linestyle='--', label=\"threshold\")\n",
138 | "\n",
139 | " param_acc = sum(counts_c) / sum(total)\n",
140 | " ret_acc = sum(counts_c_ret) / sum(total_ret)\n",
141 | " hybrid_acc = (sum(counts_c_ret[:thresh_idx]) + sum(counts_c[thresh_idx:])) / (sum(total_ret[:thresh_idx]) + sum(total[thresh_idx:]))\n",
142 | " ret_acc_gain = hybrid_acc - ret_acc\n",
143 | " param_acc_gain = hybrid_acc - param_acc\n",
144 | " \n",
145 | " # let the optimal threshold be the one that maximizes the hybrid accuracy\n",
146 | " train_idxs = cluster_sample.split.values == \"train\"\n",
147 | " test_idxs = cluster_sample.split.values == \"test\"\n",
148 | " train_ser = ser[train_idxs]\n",
149 | " test_ser = ser[test_idxs]\n",
150 | " train_c = c[train_idxs]\n",
151 | " test_c = c[test_idxs]\n",
152 | " train_c_ret = c_ret[train_idxs]\n",
153 | " test_c_ret = c_ret[test_idxs]\n",
154 | " train_counts_c, _ = np.histogram(train_ser[train_c], bins=bin_edges)\n",
155 | " train_counts_inc, _ = np.histogram(train_ser[~train_c], bins=bin_edges)\n",
156 | " train_total = train_counts_c + train_counts_inc\n",
157 | " train_counts_c_ret, _ = np.histogram(train_ser[train_c_ret], bins=bin_edges)\n",
158 | " train_counts_inc_ret, _ = np.histogram(train_ser[~train_c_ret], bins=bin_edges)\n",
159 | " train_total_ret = train_counts_c_ret + train_counts_inc_ret\n",
160 | " test_counts_c, _ = np.histogram(test_ser[test_c], bins=bin_edges)\n",
161 | " test_counts_inc, _ = np.histogram(test_ser[~test_c], bins=bin_edges)\n",
162 | " test_total = test_counts_c + test_counts_inc\n",
163 | " test_counts_c_ret, _ = np.histogram(test_ser[test_c_ret], bins=bin_edges)\n",
164 | " test_counts_inc_ret, _ = np.histogram(test_ser[~test_c_ret], bins=bin_edges)\n",
165 | " test_total_ret = test_counts_c_ret + test_counts_inc_ret\n",
166 | " \n",
167 | " # find the optimal threshold\n",
168 | " train_thresh_idx = np.argmax(list((sum(train_counts_c_ret[:i]) + sum(train_counts_c[i:])) / (sum(train_total_ret[:i]) + sum(train_total[i:])) for i in range(len(train_total) + 1)))\n",
169 | " train_thresh = bin_edges[train_thresh_idx] - 0.5 * (bin_edges[1] - bin_edges[0])\n",
170 | " train_threshs[prop] = train_thresh\n",
171 | "\n",
172 | " # calculate the accuracy on the test set\n",
173 | " test_param_acc = sum(test_counts_c) / sum(test_total)\n",
174 | " test_ret_acc = sum(test_counts_c_ret) / sum(test_total_ret)\n",
175 | " test_hybrid_acc = (sum(test_counts_c_ret[:train_thresh_idx]) + sum(test_counts_c[train_thresh_idx:])) / (sum(test_total_ret[:train_thresh_idx]) + sum(test_total[train_thresh_idx:]))\n",
176 | " test_sizes.append(sum(test_total))\n",
177 | " test_count_rets[-1] += sum(test_total_ret[:train_thresh_idx])\n",
178 | " test_count_params[-1] += sum(test_total[train_thresh_idx:])\n",
179 | " test_ret_accs.append(test_ret_acc)\n",
180 | " test_param_accs.append(test_param_acc)\n",
181 | " test_hybrid_accs.append(test_hybrid_acc)\n",
182 | " train_hybrid_accs.append((sum(train_counts_c_ret[:train_thresh_idx]) + sum(train_counts_c[train_thresh_idx:])) / (sum(train_total_ret[:train_thresh_idx]) + sum(train_total[train_thresh_idx:])))\n",
183 | " train_sizes.append(sum(train_total))\n",
184 | "\n",
185 | " if do_plot:\n",
186 | " plt.bar(bin_edges[:-1] - 0.5 * width, test_counts_c / test_total, width=width, alpha=0.9, label=\"parametric knowledge (vanilla)\", align='edge', hatch='//')\n",
187 | " plt.bar(bin_edges[:-1] + 0.5 * width, test_counts_c_ret / test_total_ret, width=width, alpha=0.9, label=\"retrieval augmented (BM25)\", align='edge', hatch='//')\n",
188 | " lo, hi = proportion_confint(test_counts_c, test_total, alpha=0.05, method='wilson')\n",
189 | " plt.errorbar(bin_edges[:-1], test_counts_c / test_total, yerr=[test_counts_c / test_total - lo, hi - test_counts_c / test_total], fmt='none', ecolor='black', elinewidth=1, capsize=2)\n",
190 | " lo, hi = proportion_confint(test_counts_c_ret, test_total_ret, alpha=0.05, method='wilson')\n",
191 | " plt.errorbar(bin_edges[:-1] + width, test_counts_c_ret / test_total_ret, yerr=[test_counts_c_ret / test_total_ret - lo, hi - test_counts_c_ret / test_total_ret], fmt='none', ecolor='black', elinewidth=1, capsize=2)\n",
192 | " plt.axvline(x=bin_edges[train_thresh_idx] - 0.8 * width, color='red', linestyle='--', label=\"threshold\")\n",
193 | "\n",
194 | " print(f\"Threshold for {prop}:\", train_thresh)\n",
195 | " print(\"Parametric knowledge only:\", param_acc)\n",
196 | " print(\"Retrieval augmented accuracy:\", ret_acc)\n",
197 | " print(f\"New accuracy with thresh={thresh_idx}:\", hybrid_acc)\n",
198 | " print()\n",
199 | " plt.title(f\"{prop}\")\n",
200 | " plt.ylim([0,1.01])\n",
201 | " if do_plot:\n",
202 | " plt.xlabel(\"log(s_pop)\")\n",
203 | " plt.ylabel(\"proportion correct\")\n",
204 | " plt.legend()\n",
205 | " plt.tight_layout()\n",
206 | " plt.show()\n",
207 | "\n",
208 | " # take the weighted mean by test_size\n",
209 | " test_ret_accs_all.append(np.average(test_ret_accs, weights=test_sizes))\n",
210 | " test_param_accs_all.append(np.average(test_param_accs, weights=test_sizes))\n",
211 | " test_hybrid_accs_all.append(np.average(test_hybrid_accs, weights=test_sizes))\n",
212 | " train_hybrid_accs_all.append(np.average(train_hybrid_accs, weights=train_sizes))\n",
213 | "\n",
214 | "test_size = split_mask.count(\"test\")\n",
215 | "prop_adaptive = (np.mean(test_count_rets) + np.mean(test_count_params)) / test_size\n",
216 | "param_acc = np.mean(test_param_accs_all)\n",
217 | "ret_acc = np.mean(test_ret_accs_all)\n",
218 | "hybrid_acc = np.mean(test_hybrid_accs_all)\n",
219 | "train_hybrid_acc = np.mean(train_hybrid_accs_all)\n",
220 | "overall_hybrid_acc = np.mean(test_hybrid_accs_all) * prop_adaptive + max(np.mean(test_param_accs_all), np.mean(test_ret_accs_all)) * (1 - prop_adaptive)\n",
221 | "sem_test_hybrid_acc = 2 * np.std(test_hybrid_accs_all) / np.sqrt(test_size)\n",
222 | "print(\"prop adaptive:\", prop_adaptive)\n",
223 | "print(\"prop retrieval:\", np.mean(test_count_rets) / test_size)\n",
224 | "print(\"parametric knowledge only:\", param_acc)\n",
225 | "print(\"retrieval augmented accuracy:\", ret_acc)\n",
226 | "print(\"hybrid accuracy:\", hybrid_acc, \"pm\", sem_test_hybrid_acc)\n",
227 | "print(\"hybrid accuracy on train:\", train_hybrid_acc)\n",
228 | "print(\"overall accuracy gain:\", overall_hybrid_acc - max(param_acc, ret_acc))\n",
229 | "print(\"overall accuracy:\", overall_hybrid_acc)"
230 | ]
231 | }
232 | ],
233 | "metadata": {
234 | "kernelspec": {
235 | "display_name": "Python 3",
236 | "language": "python",
237 | "name": "python3"
238 | },
239 | "language_info": {
240 | "codemirror_mode": {
241 | "name": "ipython",
242 | "version": 3
243 | },
244 | "file_extension": ".py",
245 | "mimetype": "text/x-python",
246 | "name": "python",
247 | "nbconvert_exporter": "python",
248 | "pygments_lexer": "ipython3",
249 | "version": "3.10.3"
250 | },
251 | "orig_nbformat": 4,
252 | "vscode": {
253 | "interpreter": {
254 | "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
255 | }
256 | }
257 | },
258 | "nbformat": 4,
259 | "nbformat_minor": 2
260 | }
261 |
--------------------------------------------------------------------------------
/run_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: UTF-8 -*-
3 |
4 | import random
5 | import torch
6 | import os
7 | import numpy as np
8 | import pandas as pd
9 | import matplotlib.pyplot as plt
10 | import seaborn as sns
11 | import openai
12 | from tqdm import tqdm
13 | import json
14 | import argparse
15 | sns.set_theme()
16 |
17 |
18 | seed = 633
19 |
20 | torch.backends.cudnn.deterministic = True
21 | random.seed(seed)
22 | np.random.seed(seed)
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 | print('Cuda:', torch.cuda.is_available())
26 | print('pwd', os.getcwd())
27 |
28 | from transformers import AutoTokenizer, AutoModelForCausalLM
29 | from util_clm import convert_model_to_int8_on_gpu
30 |
31 |
32 |
33 | import jsonlines
34 |
35 | def load_jsonlines(file):
36 | with jsonlines.open(file, 'r') as jsonl_f:
37 | lst = [obj for obj in jsonl_f]
38 | return lst
39 |
40 | q_templates = {
41 | 22: "What is {}'s occupation?",
42 | 218: "In what city was {} born?",
43 | 91: "What genre is {}?",
44 | 257: "Who is the father of {}?",
45 | 182: "In what country is {}?",
46 | 164: "Who was the producer of {}?",
47 | 526: "Who was the director of {}?",
48 | 97: "What is {} the capital of?",
49 | 533: "Who was the screenwriter for {}?",
50 | 639: "Who was the composer of {}?",
51 | 472: "What color is {}?",
52 | 106: "What is the religion of {}?",
53 | 560: "What sport does {} play?",
54 | 484: "Who is the author of {}?",
55 | 292: "Who is the mother of {}?",
56 | 422: "What is the capital of {}?"
57 | }
58 | completion_template = "Q: {} A:" # "{}" # "Query: {}\nResult:" # "Q: {} A:" # "{} The answer is"
59 | genread_template = "Generate a background document from Wikipedia to answer the given question. {}" # This prompt comes from the GenRead paper
60 |
61 | def call_request(prompt, model, tokenizer, max_new_tokens=15):
62 | max_inpt_tokens = tokenizer.model_max_length
63 | if len(prompt) > tokenizer.model_max_length: # conservative lower bound, since each token is at least 1 character
64 | inpts = tokenizer(prompt, return_tensors="pt")
65 | new_prompt = tokenizer.decode(inpts.input_ids[0, -(max_inpt_tokens - max_new_tokens):])
66 | else:
67 | new_prompt = prompt
68 |
69 | # try to get a response from the model multiple times if theres a timeout
70 | for i in range(5):
71 | try:
72 | if i > 0:
73 | print("Retrying request")
74 | response = openai.Completion.create(model=model, prompt=new_prompt, temperature=0.0, max_tokens=max_new_tokens, logprobs=5, top_p=1,frequency_penalty=0.0,presence_penalty=0.0)
75 | break
76 | except Exception as e:
77 | print(e)
78 | print("Timeout, trying again")
79 |
80 | pred = response["choices"][0]["text"]
81 | if pred.startswith("\n\n"):
82 | pred = pred[2:]
83 | pred = pred.split("\n")[0]
84 | return pred, response.to_dict_recursive()
85 |
86 | def call_model(prompt, model, tokenizer, device, max_new_tokens=15, model_max_length=None):
87 | max_inpt_tokens = tokenizer.model_max_length if model_max_length is None else model_max_length
88 | inpts = tokenizer(prompt, return_tensors="pt").to(device)
89 | gen = model.generate(input_ids=inpts.input_ids[:, -(max_inpt_tokens - max_new_tokens):], attention_mask=inpts.attention_mask[:, -(max_inpt_tokens - max_new_tokens):], pad_token_id=tokenizer.eos_token_id, max_new_tokens=max_new_tokens, num_beams=1, do_sample=False)
90 | text = tokenizer.decode(gen[0])
91 | actual_prompt = tokenizer.decode(inpts.input_ids[0, -(max_inpt_tokens - max_new_tokens):])
92 | pred = text[len(actual_prompt):]
93 | if pred.startswith("\n\n"):
94 | pred = pred[2:]
95 | pred = pred.split("\n")[0]
96 | return pred, text
97 |
98 | def clip_paragraph(text, eval_method):
99 | if eval_method in ["BM25", "genread"]:
100 | return text
101 | split = text.split(". ")
102 | return ". ".join(split[:-1]) + "."
103 |
104 | def get_few_shot_text_with_retrieval(row, retrieval_dict, eval_method):
105 | if eval_method == "vanilla":
106 | return completion_template.format(row.question) + " " + row.obj
107 | # retrieval_dict[row.id]["ctxs"][0]
108 | if row.question.replace("?", "").lower() not in retrieval_dict:
109 | print("missing retrieval")
110 | return completion_template.format(row.question) + " " + row.obj
111 | else:
112 | retrieval = retrieval_dict[row.question.replace("?", "").lower()]["ctxs"][0]
113 | retrieved_text = clip_paragraph(retrieval["text"], eval_method)
114 | return retrieved_text + "\n\n" + completion_template.format(row.question) + " " + row.obj
115 |
116 | def get_few_shot_text(row, eval_method):
117 | return completion_template.format(row.question) + " " + row.obj
118 |
119 | def get_genread_passage(question, genread_template, generate_function, max_new_tokens=150):
120 | prompt = genread_template.format(question)
121 | return generate_function(prompt, max_new_tokens=max_new_tokens)[0]
122 |
123 | def get_few_shot_examples_genread(knowledge, generate_function, n_examples, genread_template, is_templatedQA, max_new_tokens=150):
124 | if is_templatedQA:
125 | few_shot_examples = dict()
126 | all_pids = list(q_templates.keys())
127 | examples_per_template = n_examples // (len(q_templates) - 1)
128 | for pid in all_pids:
129 | for row2 in knowledge[knowledge.prop_id == pid].sample(n=examples_per_template).iloc:
130 | if pid not in few_shot_examples:
131 | few_shot_examples[pid] = []
132 | generation = get_genread_passage(row2.question, genread_template, generate_function, max_new_tokens=max_new_tokens)
133 | few_shot_examples[pid].append(get_few_shot_text_with_retrieval(row2, {row2.question: {"ctxs": [{"id": -1, "text": generation}]}}, "genread"))
134 | else:
135 | few_shot_examples = []
136 | for row2 in knowledge.sample(n=n_examples + 1).iloc:
137 | generation = get_genread_passage(row2.question, genread_template, generate_function, max_new_tokens=max_new_tokens)
138 | few_shot_examples.append(get_few_shot_text_with_retrieval(row2, {row2.question: {"ctxs": [{"id": -1, "text": generation}]}}, "genread"))
139 |
140 | return few_shot_examples
141 |
142 |
143 | def main():
144 | parser = argparse.ArgumentParser()
145 | parser.add_argument('--model_name', type=str)
146 | parser.add_argument('--input_file', type=str)
147 | parser.add_argument('--alias', type=str)
148 | parser.add_argument('--n_examples', type=int, default=15)
149 | parser.add_argument('--eval_method', type=str, default="vanilla", choices=["vanilla", "BM25", "contriever", "genread"])
150 | parser.add_argument('--ret_path', type=str, default=None, required=False, help="path to retrieved documents jsonl")
151 | parser.add_argument('--device', type=str, default="cuda")
152 | parser.add_argument('--max_new_tokens', type=int, default=15)
153 | parser.add_argument('--sample', type=int, default=0, help="if 0, use all examples")
154 | parser.add_argument('--continue_from', type=str, help="path to previous results file")
155 | parser.add_argument('--int8bit', action="store_true")
156 | parser.add_argument('--parallel', type=str, help="string of format 'i.n_workers' where i is the index of the worker")
157 |
158 | args = parser.parse_args()
159 |
160 | use_gpt3 = args.model_name in {"text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001"}
161 | if use_gpt3:
162 | with open("../../openAIkey.txt") as f:
163 | openai.api_key = f.read()[:-1]
164 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
165 | generate = lambda prompt, max_new_tokens: call_request(prompt, args.model_name, tokenizer, max_new_tokens=max_new_tokens)
166 | else:
167 | gpt = args.model_name
168 | device = args.device
169 | tokenizer = AutoTokenizer.from_pretrained(gpt)
170 | tokenizer.pad_token = tokenizer.eos_token
171 | tokenizer.pad_token_id = tokenizer.eos_token_id
172 | if args.int8bit:
173 | model = convert_model_to_int8_on_gpu(AutoModelForCausalLM.from_pretrained(gpt), device)
174 | else:
175 | model = AutoModelForCausalLM.from_pretrained(gpt).eval().to(device)
176 | if "opt" in args.model_name or args.model_name == "EleutherAI/gpt-neox-20b":
177 | generate = lambda prompt, max_new_tokens: call_model(prompt, model=model, tokenizer=tokenizer, device=device, max_new_tokens=max_new_tokens, model_max_length=2048)
178 | else:
179 | generate = lambda prompt, max_new_tokens: call_model(prompt, model=model, tokenizer=tokenizer, device=device, max_new_tokens=max_new_tokens)
180 | input_path = args.input_file
181 | knowledge = pd.read_csv(input_path, sep="\t")
182 |
183 | if args.continue_from is not None:
184 | results = pd.read_csv(args.continue_from, sep="\t")
185 | knowledge = knowledge[~knowledge.id.isin(results.id)]
186 | n = len(knowledge) if args.sample == 0 else args.sample
187 | sample = knowledge.sample(n=n, replace=False)
188 | if args.parallel is not None:
189 | worker_num, n_workers = map(int, args.parallel.split("."))
190 | sample = sample.iloc[worker_num::n_workers]
191 |
192 | n_examples = args.n_examples
193 | is_templatedQA = True
194 | examples_per_template = n_examples // (len(q_templates) - 1)
195 |
196 | preds = []
197 | prompts =[]
198 | accuracy = []
199 | responses = []
200 | if args.eval_method in ["BM25", "contriever"]:
201 | has_answer = []
202 | retrieval_ids = []
203 | with open(args.ret_path) as f:
204 | retrieval_dict = {json.loads(s)["question"]: json.loads(s) for s in f.readlines()}
205 | # print(retrieval_dict)
206 | if args.eval_method == "genread":
207 | genread_few_shot_examples = get_few_shot_examples_genread(knowledge, generate, n_examples, genread_template, is_templatedQA, max_new_tokens=150)
208 | has_answer = []
209 | gen_passages = []
210 |
211 | # main loop
212 | for row in tqdm(sample.iloc, total=n):
213 |
214 | # get few shot examples text
215 | if n_examples == 0:
216 | few_shot_examples_text = ""
217 | else:
218 | few_shot_examples = []
219 | if args.eval_method == "genread":
220 | if is_templatedQA:
221 | other_pids = list(q_templates.keys())
222 | other_pids.remove(row.prop_id)
223 | few_shot_examples = []
224 | for pid in other_pids:
225 | few_shot_examples.extend(random.sample(genread_few_shot_examples[pid], examples_per_template))
226 | else:
227 | few_shot_examples = random.sample([ex for ex in genread_few_shot_examples if row.question not in ex], n_examples)
228 | else:
229 | if is_templatedQA:
230 | other_pids = list(q_templates.keys())
231 | other_pids.remove(row.prop_id)
232 | for pid in other_pids:
233 | for row2 in knowledge[knowledge.prop_id == pid].sample(n=examples_per_template).iloc:
234 | few_shot_examples.append(get_few_shot_text_with_retrieval(row2, retrieval_dict, args.eval_method) if args.eval_method in ["BM25", "contriever"] else get_few_shot_text(row2, args.eval_method))
235 | else:
236 | for row2 in knowledge[knowledge.question != row.question].sample(n=n_examples).iloc:
237 | few_shot_examples.append(get_few_shot_text_with_retrieval(row2, retrieval_dict, args.eval_method) if args.eval_method in ["BM25", "contriever"] else get_few_shot_text(row2, args.eval_method))
238 |
239 |
240 | np.random.shuffle(few_shot_examples)
241 | few_shot_examples_text = "\n\n".join(few_shot_examples) + "\n\n"
242 |
243 | # get prompt
244 | if args.eval_method == "vanilla":
245 | prompt = few_shot_examples_text + completion_template.format(row.question)
246 | elif args.eval_method in ["BM25", "contriever"]:
247 | query = row.question
248 | try:
249 | retrieval = retrieval_dict[query]["ctxs"][0] # retrieval_dict[row.id]["ctxs"][0]
250 | except:
251 |
252 | print("No retrieval for", query, " Example query:", list(retrieval_dict.keys())[0])
253 | retrieval = {"text": "", "id": np.nan, "hasanswer": False}
254 | retrieved_text = clip_paragraph(retrieval["text"], eval_method=args.eval_method)
255 | retrieval_id = retrieval["id"]
256 | prompt = few_shot_examples_text + retrieved_text + "\n\n" + completion_template.format(row.question)
257 | has_answer.append(retrieval["hasanswer"])
258 | retrieval_ids.append(retrieval_id)
259 | elif args.eval_method == "genread":
260 | generation = get_genread_passage(row.question, genread_template, generate, max_new_tokens=150)
261 | prompt = few_shot_examples_text + generation + "\n\n" + completion_template.format(row.question)
262 | gen_passages.append(generation)
263 |
264 | # generate response
265 | pred, response = generate(prompt, max_new_tokens=args.max_new_tokens)
266 | prompts.append(prompt)
267 | preds.append(pred)
268 | responses.append(response)
269 |
270 | # compute accuracy
271 | possible_answers = json.loads(row.possible_answers)
272 | is_correct = False
273 | genread_has_answer = False
274 | for pa in possible_answers:
275 | if pa in pred or pa.lower() in pred or pa.capitalize() in pred:
276 | is_correct = True
277 | if args.eval_method == "genread" and pa in response or pa.lower() in response or pa.capitalize() in response:
278 | genread_has_answer = True
279 | accuracy.append(is_correct)
280 | if args.eval_method == "genread":
281 | has_answer.append(genread_has_answer)
282 |
283 | # save results intermittently
284 | if len(preds) % 100 == 0:
285 | temp_sample = sample.iloc[:len(preds)].copy()
286 | temp_sample["pred"] = preds
287 | temp_sample["prompt"] = prompts
288 | temp_sample["generation"] = responses
289 | temp_sample["is_correct"] = accuracy
290 | if args.eval_method in ["BM25", "contriever"]:
291 | temp_sample["has_answer"] = has_answer
292 | temp_sample["retrieval_id"] = retrieval_ids
293 | if args.eval_method == "genread":
294 | temp_sample["has_answer"] = has_answer
295 | temp_sample["gen_passage"] = gen_passages
296 | model_name_alias = args.model_name.replace("/","_")
297 | if not os.path.exists(f"results/temp/"):
298 | os.makedirs(f"results/temp/")
299 | worker_str = "" if args.parallel is None else f"-worker={args.parallel}"
300 | output_path = f"results/temp/model={model_name_alias}-input={args.alias}-method={args.eval_method}-shots={n_examples}-n={len(temp_sample)}{'_int8bit' if args.int8bit is True else ''}{worker_str}.csv"
301 | temp_sample.to_csv(output_path, index=False)
302 |
303 | sample["is_correct"] = accuracy
304 | sample["prompt"] = prompts
305 | sample["pred"] = preds
306 | sample["generation"] = responses
307 | if args.eval_method in ["BM25", "contriever"]:
308 | sample["has_answer"] = has_answer
309 | sample["retrieval_id"] = retrieval_ids
310 | if args.eval_method == "genread":
311 | sample["has_answer"] = has_answer
312 | sample["gen_passage"] = gen_passages
313 |
314 | print(sample.is_correct.mean())
315 | model_name_alias = args.model_name.replace("/","_")
316 | worker_str = "" if args.parallel is None else f"-worker={args.parallel}"
317 | sample.to_csv(f"results/model={model_name_alias}-input={args.alias}-method={args.eval_method}-shots={n_examples}-n={len(sample)}{'_int8bit' if args.int8bit is True else ''}{worker_str}.csv")
318 |
319 |
320 | if __name__ == "__main__":
321 | main()
322 |
--------------------------------------------------------------------------------