├── .gitmodules
├── LICENSE
├── README.md
├── ablations
├── README.md
└── T0_variants
│ ├── CB_variants.py
│ ├── RTE_variants.py
│ ├── WIC_variants.py
│ ├── WSC_variants.py
│ ├── decomposition.py
│ ├── run.sh
│ └── utils.py
├── boosting
├── binary_deps.py
├── make_pgm.py
├── methods.py
├── pgm.py
├── run_ws.py
└── utils.py
├── diagnostics
├── README.md
├── data.zip
└── run_diagnostics.py
├── download_p3.py
├── imgs
├── crfm.png
├── decomp.png
├── numbers_station.png
└── snorkel.png
├── requirements.txt
└── tasks
├── AGNews_final.py
├── ANLIR1_final.py
├── ANLIR2_final.py
├── ANLIR3_final.py
├── Amazon_final.py
├── BoolQ_final.py
├── CB_final.py
├── COPA_final.py
├── DBPedia_final.py
├── MultiRC_final.py
├── NQ_final.py
├── RTE_final.py
├── ReCoRD_final.py
├── RealtimeQA_final.py
├── SST2_final.py
├── StoryCloze_final.py
├── WIC_final.py
├── WSC_final.py
├── decomposition.py
├── drop_final.py
├── utils.py
└── webq_final.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "metal-ama"]
2 | path = metal-ama
3 | url = https://github.com/mayeechen/metal-ama.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ask Me Anything: A simple strategy for prompting language models
2 |
3 | 
4 | [](https://together.xyz/)
5 |
6 | This repository contains code for the Ask Me Anything (AMA) prompt-aggregation strategy. The end-to-end AMA approach includes (1) recursively using the language model to transform the task format and prompt and (2) aggregating the predictions of multiple prompts using weak supervision. We include code for both components and pointers to the publicly downloadable datasets. See our [paper](https://arxiv.org/abs/2210.02441) for more details.
7 |
8 |

9 |
10 |
11 |
12 | ## Table of Contents
13 | - [Setup](#setup)
14 | - [Data](#getting-the-data)
15 | - [Running models](#models)
16 | - [Running experiments](#experiments)
17 | - [Repository Structure](#overall-repository-structure)
18 | - [Citation](#citation)
19 |
20 |
21 | ## Setup
22 |
23 | ### Installation
24 | Here we will setup the AMA code (prompting models for tasks), weak supervision code (aggregating predictions), and [Manifest](https://github.com/HazyResearch/manifest/) code (tooling for easily loading and running the models).
25 |
26 | We encourage the use of conda environments:
27 | ```
28 | conda create --name ama python=3.8
29 | conda activate ama
30 | ```
31 |
32 | Clone as follows:
33 | ```bash
34 | # Ask Me Anything code
35 | git clone git@github.com:HazyResearch/ama_prompting.git
36 | cd ama_prompting
37 | pip install -r requirements.txt
38 |
39 | # Weak supervision code
40 | cd metal-ama
41 | git submodule init
42 | git submodule update
43 | pip install -e .
44 |
45 | # Manifest
46 | git clone git@github.com:HazyResearch/manifest.git
47 | cd manifest
48 | pip install -e .
49 | ```
50 |
51 |
52 | ### Getting the data
53 | We assume all data lives in the ```AMA_DATA``` environment variable. By default, this is set to ```/home/data```. To change this, run
54 | ```bash
55 | export AMA_DATA=
56 | ```
57 |
58 | Please follow the instructions below to download all necessary data for experiments.
59 |
60 | 1. Download the PromptSource (P3) dataset from Hugging Face at https://huggingface.co/datasets/bigscience/P3.
61 | ```bash
62 | cd $AMA_DATA
63 | git lfs install
64 | git clone https://huggingface.co/datasets/bigscience/P3
65 | ```
66 | Then run [ama_prompting/download_p3.py](./download_p3.py). We use the GPT3-Style prompts in the few-shot baseline for each benchmark.
67 |
68 | 2. We downloaded the remaining tasks from the following sources:
69 | * [AGNews, DBPedia, and SST2](https://github.com/tonyzhaozh/few-shot-learning)
70 | * [Amazon Products](https://github.com/allenai/flex/blob/75d6d1cea66df2c8a7e3d429c6af5008ccf1544b/fewshot/hf_datasets_scripts/amazon/amazon.py)
71 | * [Natural Questions and WebQs](https://github.com/facebookresearch/FiD)
72 | * [RealTimeQA](https://github.com/realtimeqa/realtimeqa_public/tree/main/past/2022) (GCS files from June 17th - July 22, 2022)
73 | * [ReCoRD](https://sheng-z.github.io/ReCoRD-explorer/)
74 | * [StoryCloze](http://goo.gl/forms/aQz39sdDrO)
75 |
76 | ### Running models
77 | We run inference on models using a tool called [Manifest](https://github.com/HazyResearch/manifest). This tool is useful because it caches your inference results and does not require reloading the model for each new run you launch. To load the EleutherAI GPT-j-6B model, in a Tmux session, run:
78 | ```bash
79 | python3 manifest/manifest/api/app.py \
80 | --model_type huggingface \
81 | --model_name_or_path EleutherAI/gpt-j-6B \
82 | --device 0
83 | ```
84 | It will take a few minutes for large models to load! To use a different model, replace ```EleutherAI/gpt-j-6B``` with the model name. See the Manifest repo for more information on loading other models.
85 |
86 |
87 | ## Experiments
88 |
89 | ### Collecting the prompting predictions
90 |
91 | To run a single task such as the Recognizing Textual Entailment (RTE) SuperGLUE benchmark, you can use the following steps.
92 |
93 | 1. Load a Manifest model using the above command
94 |
95 | 2. Run the following command. This will run the zero-shot baseline (```run_zeroshot = 1```), few-shot baseline (```run_fewshot = 1```) with $k$ in-context demonstrations (```k_shot = 3```), and the AMA baseline (```run_decomp = 1```). In AMA, we aggregate the predictions of multiple prompts-per-input. The number of prompts over which to aggregate is specified by ```num_boost```.
96 |
97 | ```bash
98 | python3 tasks/RTE_final.py \
99 | --run_zeroshot 1 \
100 | --run_fewshot 1 \
101 | --run_decomp 1 \
102 | --num_boost 5 \
103 | --k_shot 3 \
104 | --output_metrics_file ../ama_logs/metrics.json \
105 | --cache_connection ../ama_logs/manifest_cache.sqlite \
106 | --save_dir ../ama_logs/ama_final_runs
107 | ```
108 |
109 | Please see the argparse in ```tasks/decomposition.py``` for other run options; for instance, to control Manifest's caching behavior.
110 |
111 | 3. The results of all baselines will be saved in `ama_final_runs/` (e.g., `` is `super_glue_rte` as seen in the `RTE_final.py` main function) and output all performance metrics to `metrics.json`. The output appears as follows:
112 |
113 | ```
114 | Saving to ../ama_logs/ama_final_runs/super_glue_rte/EleutherAI_gpt-j-6B_decomposed_10052022.json
115 | Saving to ../ama_logs/ama_final_runs/super_glue_rte/EleutherAI_gpt-j-6B_decomposed_10052022_train.json
116 | Accuracy Few Shot 0.5884476534296029
117 | Accuracy by Boost Set Decomposed [0.592057761732852, 0.6209386281588448, 0.5848375451263538, 0.6678700361010831, 0.6173285198555957]
118 | Accuracy by Boost Set Decomposed Average 0.6166064981949458
119 | Accuracy Boost Decomposed 0.6642599277978339
120 | Saved metrics to ../ama_logs/metrics.json
121 | Saved final data to ../ama_logs/ama_final_runs/super_glue_rte
122 | ```
123 |
124 | For the AMA baseline, which consists of ```num_boost``` prompt-chains, the metrics include the individual prompt-chain accuracies over the dataset ("Accuracy by Boost Set Decomposed"), average score ("Accuracy by Boost Set Decomposed Average"), and majority vote result ("Accuracy Boost Decomposed").
125 |
126 | ### Running weak supervision
127 |
128 | 4. Next we aggregate over the predictions with weak supervision (WS). In order to run the WS algorithm on the predictions which were saved down in `ama_final_runs/super_glue_rte`, use the following command. By default, we assume the date of the log file is today. You can change it with the `--override_date` command.
129 |
130 | ```bash
131 | python3 boosting/run_ws.py \
132 | --task_name super_glue_rte \
133 | --data_dir ../ama_logs/ama_final_runs \
134 | --model_prefix EleutherAI_gpt-j-6B \
135 | --override_date 10052022
136 | ```
137 |
138 | The output will include the following results:
139 |
140 | ```
141 | # The code will first output results without modelling dependencies.
142 |
143 | Trained Label Model Metrics (No deps):
144 | Accuracy: 0.650
145 | Precision: 0.724
146 | Recall: 0.420
147 | F1: 0.531
148 |
149 | # For this task, the WS algorithm identifies a dependency between prompts 0 and 2. Next the code outputs results after modelling dependencies, if dependencies are recovered above.
150 |
151 | Trained Label Model Metrics (with deps):
152 | Accuracy: 0.751
153 | Precision: 0.758
154 | Recall: 0.695
155 | F1: 0.725
156 |
157 |
158 | # Conditional entropy metric discussed in the paper
159 |
160 | H(Y | WS output): 0.5602824867598865
161 | ```
162 |
163 | For this task, [Brown et al., 2020](https://arxiv.org/pdf/2005.14165.pdf) reports accuracy metrics.
164 |
165 |
166 | ## Overall repository structure
167 | ```
168 | tasks/ code for running inference on tasks
169 | diagnostics/ contains the diagnostic tasks
170 | boosting/ code for running weak supervision
171 | metal-ama/ weak supervision algorithm
172 | manifest/ code for loading and using models
173 | /home/data/ default location for benchmarks
174 | ```
175 |
176 |
177 | ## Citation
178 | If you use this codebase, or otherwise found our work valuable, please cite:
179 | ```
180 | @article{arora2022ama,
181 | title={Ask Me Anything: A simple strategy for prompting language models},
182 | author={Arora, Simran and Narayan, Avanika and Chen, Mayee F. and Orr, Laurel and Guha, Neel and Bhatia, Kush and Chami, Ines and Sala, Frederic and R\'e, Christopher},
183 | journal={arXiv:2210.02441},
184 | year={2022}
185 | }
186 | ```
187 |
188 | As well as [Snorkel MeTaL](https://github.com/HazyResearch/metal), [bigscience P3](https://huggingface.co/datasets/bigscience/P3), and the benchmark authors.
189 |
190 | ## Acknowledgements
191 |
192 | We are very grateful to the following organizations for the resources that made this work possible: [Together Computer](https://together.xyz/), [Numbers Station](https://numbersstation.ai/), [Snorkel](https://snorkel.ai/), [Stanford Center for Research on Foundation Models](https://crfm.stanford.edu/) and [Stanford HAI](https://hai.stanford.edu/).
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------
/ablations/README.md:
--------------------------------------------------------------------------------
1 |
2 | This directory includes code for additional paper experiments beyond the main benchmark results.
3 |
4 | In ```T0_variants```, we include code for running aggregation over PromptSource P3 prompts. To run this, load the T0 model using manifest and you can launch via the script:
5 |
6 | ```bash
7 | cd T0_variants
8 | bash run.sh
9 | ```
--------------------------------------------------------------------------------
/ablations/T0_variants/CB_variants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import os
4 | from tqdm.auto import tqdm
5 | import pandas as pd
6 |
7 | from sklearn.metrics import classification_report
8 | from decomposition import Decomposition, get_args, DATA_DIR
9 | from utils import get_response, InputOutputPrompt
10 | from collections import defaultdict, Counter
11 |
12 | class CBDecomp(Decomposition):
13 | def __init__(self, task_name, data_dir, val_split="validation"):
14 | super().__init__(task_name, data_dir, val_split)
15 |
16 | def get_boost_decomp_examples(self, data_train, boost_id):
17 | lst = [
18 | 'super_glue_cb_claim_true_false_inconclusive',
19 | 'super_glue_cb_does_this_imply',
20 | 'super_glue_cb_always_sometimes_never',
21 | 'super_glue_cb_does_it_follow_that',
22 | 'super_glue_cb_guaranteed_true',
23 | 'super_glue_cb_take_the_following_as_truth',
24 | 'super_glue_cb_justified_in_saying',
25 | 'super_glue_cb_should_assume',
26 | 'super_glue_cb_GPT_3_style',
27 | 'super_glue_cb_can_we_infer',
28 | 'super_glue_cb_consider_always_sometimes_never',
29 | 'super_glue_cb_guaranteed_possible_impossible',
30 | 'super_glue_cb_MNLI_crowdsource',
31 | 'super_glue_cb_based_on_the_previous_passage',
32 | 'super_glue_cb_must_be_true'
33 | ]
34 | file_path = lst[boost_id]
35 | print(f"FILE PATH: {file_path}")
36 | train_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/train.feather")
37 | val_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/validation.feather")
38 | return [
39 | train_data,
40 | val_data
41 | ]
42 |
43 | def zero_few_baseline(
44 | self,
45 | test_data,
46 | few_shot_df,
47 | manifest,
48 | overwrite_manifest,
49 | do_few_shot=True,
50 | ):
51 | expt_log = {}
52 | preds = []
53 | labels = []
54 |
55 | labels_names = set(test_data["targets_pretokenized"])
56 | labels_names = [l.lower().strip() for l in labels_names]
57 |
58 | for i, (ind, row) in tqdm(
59 | enumerate(test_data.iterrows()), total=len(test_data)
60 | ):
61 | if ind in expt_log:
62 | pred = entry["pred"]
63 | gold = entry["gold"]
64 | else:
65 | text = row["inputs_pretokenized"]
66 | gold = row["targets_pretokenized"]
67 |
68 | icl_str = ""
69 | if do_few_shot:
70 | for s_ind, s_row in few_shot_df.iterrows():
71 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n"
72 |
73 | text = row["inputs_pretokenized"]
74 | text = text.replace("True, False, or Neither?", "")
75 | text = text + ". True, False, or Neither?"
76 | gold = row["targets_pretokenized"]
77 | prompt = f"{icl_str}{{text:}}"
78 | pmp = prompt.format(text=text)
79 | if i == 0:
80 | print(pmp)
81 |
82 | raw_answer = get_response(
83 | pmp,
84 | manifest,
85 | overwrite=bool(overwrite_manifest),
86 | max_toks=30,
87 | )
88 | answer = raw_answer.strip().lower()
89 | answer = answer.split("\n")
90 | answer = [a for a in answer if a]
91 | answer = [
92 | a
93 | for a in answer
94 | if any(l.lower() in a.lower() for l in labels_names)
95 | ]
96 | if answer:
97 | answer = answer[0]
98 | else:
99 | answer = ""
100 | answer = "".join(
101 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']]
102 | )
103 | is_yes = "true" in answer.split()
104 | is_no = "false" in answer.split()
105 | is_maybe = "neither" in answer.split() or "unknown" in answer.split() or "maybe" in answer.split()
106 | pred = "Neither"
107 | if is_yes and (not is_maybe and not is_no):
108 | pred = "True"
109 | if is_no and (not is_maybe and not is_yes):
110 | pred = "False"
111 | if is_maybe and (not is_no and not is_yes):
112 | pred = "Neither"
113 | entry = {
114 | "ind": ind,
115 | "example": text,
116 | "base_prompt": pmp,
117 | "raw_answer": raw_answer,
118 | "pred": pred,
119 | "gold": gold,
120 | }
121 | expt_log[ind] = entry
122 |
123 | preds.append(pred)
124 | labels.append(gold)
125 |
126 | report = classification_report(labels, preds, output_dict=True)
127 | return expt_log, report["accuracy"]
128 |
129 | def run_decomposed_prompt(
130 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
131 | ):
132 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, do_train=0)
133 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, do_train=1)
134 | # Do WS
135 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
136 | # Get accuracies across all boost sets
137 | individual_accuracies = []
138 | for i in range(len(all_boost_preds[0])):
139 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"])
140 | report = classification_report(labels, preds, output_dict=True)
141 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies
142 |
143 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, do_train=-1):
144 | expt_log = {}
145 | all_boost_preds = []
146 | labels = []
147 |
148 | label_name_to_maping = {
149 | "always": "true",
150 | "never": "false",
151 | "sometimes": 'neither',
152 | "true": "true",
153 | "false": "false",
154 | "neither": 'neither',
155 | "no": "false",
156 | "yes": "true",
157 | "maybe": "neither",
158 | "unknown": "neither",
159 | "inconclusive": "neither",
160 | "impossible": "false",
161 | "possible": "neither",
162 | "guaranteed": "true",
163 | }
164 |
165 | prompts_across_boost = defaultdict(list)
166 | preds_across_boost = defaultdict(list)
167 | for boost_num, boost_examples in enumerate(boost_dfs):
168 | if do_train:
169 | data = boost_examples[0].iloc[:1]
170 | elif not do_train:
171 | data = boost_examples[1]
172 | else:
173 | raise ValueError("Unsupported value for do train.")
174 |
175 | for i, (ind, row) in tqdm(enumerate(data.iterrows()), total=len(data)):
176 | input = row["inputs_pretokenized"]
177 | gold = row["targets_pretokenized"]
178 | all_prompts = []
179 | raw_answer = get_response(
180 | input,
181 | manifest,
182 | overwrite=bool(overwrite_manifest),
183 | max_toks=30,
184 | )
185 | all_prompts.append(input)
186 | answer = raw_answer.lower()
187 | if answer not in label_name_to_maping:
188 | pred = 'neither'
189 | print("BAD ANSWER", answer)
190 | else:
191 | pred = label_name_to_maping[answer]
192 | prompts_across_boost[i].append(all_prompts)
193 | preds_across_boost[i].append(pred)
194 | if gold.lower() not in label_name_to_maping:
195 | import pdb;
196 | pdb.set_trace()
197 |
198 | for i, (ind, row) in enumerate(data.iterrows()):
199 | label = row["targets_pretokenized"].lower()
200 | entry = {
201 | "ind": ind,
202 | "prompts": prompts_across_boost[i],
203 | "preds_boost": preds_across_boost[i],
204 | "example": row['inputs_pretokenized'],
205 | "gold": label,
206 | }
207 | expt_log[ind] = entry
208 | all_boost_preds.append(preds_across_boost[ind])
209 | labels.append(label_name_to_maping[label])
210 | return expt_log, all_boost_preds, labels
211 |
212 |
213 | def main():
214 | args = get_args()
215 | args.num_boost = 10
216 | task_name = "cb_t0_variants"
217 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_cb_GPT_3_style/"
218 | decomp = CBDecomp(task_name, data_dir)
219 | decomp.run(args)
220 |
221 |
222 | if __name__ == "__main__":
223 | main()
224 |
--------------------------------------------------------------------------------
/ablations/T0_variants/RTE_variants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import os
4 | from tqdm.auto import tqdm
5 | import pandas as pd
6 |
7 | from sklearn.metrics import classification_report
8 | from decomposition import Decomposition, get_args, DATA_DIR
9 | from utils import get_response, InputOutputPrompt
10 | from collections import defaultdict, Counter
11 |
12 | class RTEDecomp(Decomposition):
13 | def __init__(self, task_name, data_dir, val_split="validation"):
14 | super().__init__(task_name, data_dir, val_split)
15 |
16 | def get_boost_decomp_examples(self, data_train, boost_id):
17 | lst = [
18 | 'super_glue_rte_GPT_3_style',
19 | 'super_glue_rte_does_it_follow_that',
20 | 'super_glue_rte_MNLI_crowdsource',
21 | 'super_glue_rte_can_we_infer',
22 | 'super_glue_rte_justified_in_saying',
23 | 'super_glue_rte_does_this_imply',
24 | 'super_glue_rte_guaranteed_true',
25 | 'super_glue_rte_based_on_the_previous_passage',
26 | 'super_glue_rte_should_assume',
27 | 'super_glue_rte_must_be_true'
28 | ]
29 | file_path = lst[boost_id]
30 | print(f"FILE PATH: {file_path}")
31 | train_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/train.feather")
32 | val_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/validation.feather")
33 | return [
34 | train_data,
35 | val_data
36 | ]
37 |
38 | def zero_few_baseline(
39 | self,
40 | test_data,
41 | few_shot_df,
42 | manifest,
43 | overwrite_manifest,
44 | do_few_shot=True,
45 | ):
46 | expt_log = {}
47 | preds = []
48 | labels = []
49 |
50 | labels_names = set(test_data["targets_pretokenized"])
51 | labels_names = [l.lower().strip() for l in labels_names]
52 |
53 | for i, (ind, row) in tqdm(
54 | enumerate(test_data.iterrows()), total=len(test_data)
55 | ):
56 | if ind in expt_log:
57 | pred = entry["pred"]
58 | gold = entry["gold"]
59 | else:
60 | text = row["inputs_pretokenized"]
61 | gold = row["targets_pretokenized"]
62 |
63 | icl_str = ""
64 | if do_few_shot:
65 | for s_ind, s_row in few_shot_df.iterrows():
66 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n"
67 |
68 | text = row["inputs_pretokenized"]
69 | text = text.replace("True, False, or Neither?", "")
70 | text = text + ". True, False, or Neither?"
71 | gold = row["targets_pretokenized"]
72 | prompt = f"{icl_str}{{text:}}"
73 | pmp = prompt.format(text=text)
74 | if i == 0:
75 | print(pmp)
76 |
77 | raw_answer = get_response(
78 | pmp,
79 | manifest,
80 | overwrite=bool(overwrite_manifest),
81 | max_toks=30,
82 | )
83 | answer = raw_answer.strip().lower()
84 | answer = answer.split("\n")
85 | answer = [a for a in answer if a]
86 | answer = [
87 | a
88 | for a in answer
89 | if any(l.lower() in a.lower() for l in labels_names)
90 | ]
91 | if answer:
92 | answer = answer[0]
93 | else:
94 | answer = ""
95 | answer = "".join(
96 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']]
97 | )
98 | is_yes = "true" in answer.split()
99 | is_no = "false" in answer.split()
100 | is_maybe = "neither" in answer.split() or "unknown" in answer.split() or "maybe" in answer.split()
101 | pred = "Neither"
102 | if is_yes and (not is_maybe and not is_no):
103 | pred = "True"
104 | if is_no and (not is_maybe and not is_yes):
105 | pred = "False"
106 | if is_maybe and (not is_no and not is_yes):
107 | pred = "Neither"
108 | entry = {
109 | "ind": ind,
110 | "example": text,
111 | "base_prompt": pmp,
112 | "raw_answer": raw_answer,
113 | "pred": pred,
114 | "gold": gold,
115 | }
116 | expt_log[ind] = entry
117 |
118 | preds.append(pred)
119 | labels.append(gold)
120 |
121 | report = classification_report(labels, preds, output_dict=True)
122 | return expt_log, report["accuracy"]
123 |
124 | def run_decomposed_prompt(
125 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
126 | ):
127 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, do_train=0)
128 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, do_train=1)
129 | # Do WS
130 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
131 | # Get accuracies across all boost sets
132 | individual_accuracies = []
133 | for i in range(len(all_boost_preds[0])):
134 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"])
135 | report = classification_report(labels, preds, output_dict=True)
136 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies
137 |
138 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, do_train=-1):
139 | expt_log = {}
140 | all_boost_preds = []
141 | labels = []
142 |
143 | label_name_to_maping = {
144 | "always": "true",
145 | "never": "false",
146 | "true": "true",
147 | "false": "false",
148 | "no": "false",
149 | "yes": "true",
150 | "impossible": "false",
151 | "guaranteed": "true",
152 | }
153 |
154 | prompts_across_boost = defaultdict(list)
155 | preds_across_boost = defaultdict(list)
156 | for boost_num, boost_examples in enumerate(boost_dfs):
157 | if do_train:
158 | data = boost_examples[0].iloc[:1000]
159 | elif not do_train:
160 | data = boost_examples[1]
161 | else:
162 | raise ValueError("Unsupported value for do train.")
163 | for i, (ind, row) in tqdm(enumerate(data.iterrows()), total=len(data)):
164 | input = row["inputs_pretokenized"]
165 | gold = row["targets_pretokenized"]
166 | all_prompts = []
167 | raw_answer = get_response(
168 | input,
169 | manifest,
170 | overwrite=bool(overwrite_manifest),
171 | max_toks=30,
172 | )
173 | all_prompts.append(input)
174 | answer = raw_answer.lower()
175 | if answer not in label_name_to_maping:
176 | pred = 'neither'
177 | print("BAD ANSWER", answer)
178 | else:
179 | pred = label_name_to_maping[answer]
180 | prompts_across_boost[i].append(all_prompts)
181 | preds_across_boost[i].append(pred)
182 | if gold.lower() not in label_name_to_maping:
183 | import pdb;
184 | pdb.set_trace()
185 |
186 | for i, (ind, row) in enumerate(data.iterrows()):
187 | label = row["targets_pretokenized"].lower()
188 | entry = {
189 | "ind": ind,
190 | "prompts": prompts_across_boost[i],
191 | "preds_boost": preds_across_boost[i],
192 | "example": row['inputs_pretokenized'],
193 | "gold": label,
194 | }
195 | expt_log[ind] = entry
196 | all_boost_preds.append(preds_across_boost[ind])
197 | labels.append(label_name_to_maping[label])
198 | return expt_log, all_boost_preds, labels
199 |
200 |
201 | def main():
202 | args = get_args()
203 | args.num_boost = 10
204 | task_name = "rte_t0_variants"
205 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_rte_GPT_3_style/"
206 | decomp = RTEDecomp(task_name, data_dir)
207 | decomp.run(args)
208 |
209 |
210 | if __name__ == "__main__":
211 | main()
212 |
--------------------------------------------------------------------------------
/ablations/T0_variants/WIC_variants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import os
4 | from tqdm.auto import tqdm
5 | import pandas as pd
6 |
7 | from sklearn.metrics import classification_report
8 | from decomposition import Decomposition, get_args, DATA_DIR
9 | from utils import get_response, InputOutputPrompt
10 | from collections import defaultdict, Counter
11 |
12 | class WICDecomp(Decomposition):
13 | def __init__(self, task_name, data_dir, val_split="validation"):
14 | super().__init__(task_name, data_dir, val_split)
15 |
16 | def get_boost_decomp_examples(self, data_train, boost_id):
17 | lst = [
18 | "super_glue_wic_affirmation_true_or_false",
19 | "super_glue_wic_GPT_3_prompt",
20 | "super_glue_wic_GPT_3_prompt_with_label",
21 | "super_glue_wic_grammar_homework",
22 | "super_glue_wic_polysemous",
23 | "super_glue_wic_question_context",
24 | "super_glue_wic_question_context_meaning",
25 | "super_glue_wic_question_context_meaning_with_label",
26 | "super_glue_wic_same_sense",
27 | "super_glue_wic_similar_sense",
28 | ]
29 | file_path = lst[boost_id]
30 | print(f"FILE PATH: {file_path}")
31 | train_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/train.feather")
32 | val_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/validation.feather")
33 | return [
34 | train_data,
35 | val_data
36 | ]
37 |
38 | def zero_few_baseline(
39 | self,
40 | test_data,
41 | few_shot_df,
42 | manifest,
43 | overwrite_manifest,
44 | do_few_shot=True,
45 | ):
46 | expt_log = {}
47 | preds = []
48 | labels = []
49 |
50 | labels_names = set(test_data["targets_pretokenized"])
51 | labels_names = [l.lower().strip() for l in labels_names]
52 |
53 | for i, (ind, row) in tqdm(
54 | enumerate(test_data.iterrows()), total=len(test_data)
55 | ):
56 | if ind in expt_log:
57 | pred = entry["pred"]
58 | gold = entry["gold"]
59 | else:
60 | text = row["inputs_pretokenized"]
61 | gold = row["targets_pretokenized"]
62 |
63 | icl_str = ""
64 | if do_few_shot:
65 | for s_ind, s_row in few_shot_df.iterrows():
66 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n"
67 |
68 | text = row["inputs_pretokenized"]
69 | gold = row["targets_pretokenized"]
70 | prompt = f"{icl_str}{{text:}}"
71 | pmp = prompt.format(text=text)
72 | if i == 0:
73 | print(pmp)
74 |
75 | raw_answer = get_response(
76 | pmp,
77 | manifest,
78 | overwrite=bool(overwrite_manifest),
79 | max_toks=30,
80 | )
81 | answer = raw_answer.strip().lower()
82 | answer = answer.split("\n")
83 | answer = [a for a in answer if a]
84 | answer = [
85 | a
86 | for a in answer
87 | if any(l.lower() in a.lower() for l in labels_names)
88 | ]
89 | if answer:
90 | answer = answer[0]
91 | else:
92 | answer = ""
93 | answer = "".join(
94 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']]
95 | )
96 | is_yes = "yes" in answer.split()
97 | is_no = "no" in answer.split()
98 | pred = ""
99 | if is_yes and (not is_no):
100 | pred = "yes"
101 | if is_no and (not is_yes):
102 | pred = "no"
103 | entry = {
104 | "ind": ind,
105 | "example": text,
106 | "base_prompt": pmp,
107 | "raw_answer": raw_answer,
108 | "pred": pred,
109 | "gold": gold.strip().lower(),
110 | }
111 | expt_log[ind] = entry
112 |
113 | preds.append(pred)
114 | labels.append(gold)
115 |
116 | report = classification_report(labels, preds, output_dict=True)
117 | return expt_log, report["accuracy"]
118 |
119 | def run_decomposed_prompt(
120 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
121 | ):
122 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, do_train=0)
123 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, do_train=1)
124 | # Do WS
125 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
126 | # Get accuracies across all boost sets
127 | individual_accuracies = []
128 | for i in range(len(all_boost_preds[0])):
129 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"])
130 | report = classification_report(labels, preds, output_dict=True)
131 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies
132 |
133 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, do_train=-1):
134 | expt_log = {}
135 | all_boost_preds = []
136 | labels = []
137 |
138 | label_name_to_maping = {
139 | "always": "true",
140 | "never": "false",
141 | "true": "true",
142 | "false": "false",
143 | "no": "false",
144 | "yes": "true",
145 | "impossible": "false",
146 | "guaranteed": "true",
147 | }
148 |
149 | prompts_across_boost = defaultdict(list)
150 | preds_across_boost = defaultdict(list)
151 | for boost_num, boost_examples in enumerate(boost_dfs):
152 | if do_train:
153 | data = boost_examples[0].iloc[:1]
154 | elif not do_train:
155 | data = boost_examples[1]
156 | else:
157 | raise ValueError("Unsupported value for do train.")
158 |
159 | for i, (ind, row) in tqdm(enumerate(data.iterrows()), total=len(data)):
160 | input = row["inputs_pretokenized"]
161 | gold = row["targets_pretokenized"]
162 | all_prompts = []
163 | raw_answer = get_response(
164 | input,
165 | manifest,
166 | overwrite=bool(overwrite_manifest),
167 | max_toks=30,
168 | )
169 | all_prompts.append(input)
170 | answer = raw_answer.lower()
171 | if answer not in label_name_to_maping:
172 | pred = ''
173 | print("BAD ANSWER", answer)
174 | else:
175 | pred = label_name_to_maping[answer]
176 | prompts_across_boost[i].append(all_prompts)
177 | preds_across_boost[i].append(pred)
178 | if gold.lower() not in label_name_to_maping:
179 | import pdb;
180 | pdb.set_trace()
181 |
182 | for i, (ind, row) in enumerate(data.iterrows()):
183 | label = row["targets_pretokenized"].lower()
184 | entry = {
185 | "ind": ind,
186 | "prompts": prompts_across_boost[i],
187 | "preds_boost": preds_across_boost[i],
188 | "example": row['inputs_pretokenized'],
189 | "gold": label,
190 | }
191 | expt_log[ind] = entry
192 | all_boost_preds.append(preds_across_boost[ind])
193 | labels.append(label_name_to_maping[label])
194 | return expt_log, all_boost_preds, labels
195 |
196 |
197 | def main():
198 | args = get_args()
199 | args.num_boost = 9
200 | task_name = "wic_t0_variants"
201 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_wic_GPT_3_prompt/"
202 | decomp = WICDecomp(task_name, data_dir)
203 | decomp.run(args)
204 |
205 |
206 | if __name__ == "__main__":
207 | main()
208 |
--------------------------------------------------------------------------------
/ablations/T0_variants/WSC_variants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import os
4 | from tqdm.auto import tqdm
5 | import pandas as pd
6 |
7 | from sklearn.metrics import classification_report
8 | from decomposition import Decomposition, get_args, DATA_DIR
9 | from utils import get_response, InputOutputPrompt
10 | from collections import defaultdict, Counter
11 |
12 | class RTEDecomp(Decomposition):
13 | def __init__(self, task_name, data_dir, val_split="validation"):
14 | super().__init__(task_name, data_dir, val_split)
15 |
16 | def get_boost_decomp_examples(self, data_train, boost_id):
17 | lst = [
18 | 'super_glue_wsc.fixed_GPT_3_Style',
19 | 'super_glue_wsc.fixed_Who_or_what_is_are',
20 | 'super_glue_wsc.fixed_p_is_are_r',
21 | 'super_glue_wsc.fixed_by_p_they_mean',
22 | 'super_glue_wsc.fixed_replaced_with',
23 | 'super_glue_wsc.fixed_in_other_words',
24 | 'super_glue_wsc.fixed_I_think_they_mean',
25 | 'super_glue_wsc.fixed_does_p_stand_for',
26 | 'super_glue_wsc.fixed_does_the_pronoun_refer_to',
27 | 'super_glue_wsc.fixed_the_pronoun_refers_to'
28 | ]
29 | file_path = lst[boost_id]
30 | print(f"FILE PATH: {file_path}")
31 | train_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/train.feather")
32 | val_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/validation.feather")
33 | return [
34 | train_data,
35 | val_data
36 | ]
37 |
38 | def zero_few_baseline(
39 | self,
40 | test_data,
41 | few_shot_df,
42 | manifest,
43 | overwrite_manifest,
44 | do_few_shot=True,
45 | ):
46 | expt_log = {}
47 | preds = []
48 | labels = []
49 |
50 | labels_names = set(test_data["targets_pretokenized"])
51 | labels_names = [l.lower().strip() for l in labels_names]
52 |
53 | for i, (ind, row) in tqdm(
54 | enumerate(test_data.iterrows()), total=len(test_data)
55 | ):
56 | if ind in expt_log:
57 | pred = entry["pred"]
58 | gold = entry["gold"]
59 | else:
60 | text = row["inputs_pretokenized"]
61 | gold = row["targets_pretokenized"]
62 |
63 | icl_str = ""
64 | if do_few_shot:
65 | for s_ind, s_row in few_shot_df.iterrows():
66 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n"
67 |
68 | text = row["inputs_pretokenized"]
69 | text = text.replace("True, False, or Neither?", "")
70 | text = text + ". True, False, or Neither?"
71 | gold = row["targets_pretokenized"]
72 | prompt = f"{icl_str}{{text:}}"
73 | pmp = prompt.format(text=text)
74 | if i == 0:
75 | print(pmp)
76 |
77 | raw_answer = get_response(
78 | pmp,
79 | manifest,
80 | overwrite=bool(overwrite_manifest),
81 | max_toks=30,
82 | )
83 | answer = raw_answer.strip().lower()
84 | answer = answer.split("\n")
85 | answer = [a for a in answer if a]
86 | answer = [
87 | a
88 | for a in answer
89 | if any(l.lower() in a.lower() for l in labels_names)
90 | ]
91 | if answer:
92 | answer = answer[0]
93 | else:
94 | answer = ""
95 | answer = "".join(
96 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']]
97 | )
98 | is_yes = "true" in answer.split()
99 | is_no = "false" in answer.split()
100 | is_maybe = "neither" in answer.split() or "unknown" in answer.split() or "maybe" in answer.split()
101 | pred = "Neither"
102 | if is_yes and (not is_maybe and not is_no):
103 | pred = "True"
104 | if is_no and (not is_maybe and not is_yes):
105 | pred = "False"
106 | if is_maybe and (not is_no and not is_yes):
107 | pred = "Neither"
108 | entry = {
109 | "ind": ind,
110 | "example": text,
111 | "base_prompt": pmp,
112 | "raw_answer": raw_answer,
113 | "pred": pred,
114 | "gold": gold,
115 | }
116 | expt_log[ind] = entry
117 |
118 | preds.append(pred)
119 | labels.append(gold)
120 |
121 | report = classification_report(labels, preds, output_dict=True)
122 | return expt_log, report["accuracy"]
123 |
124 | def run_decomposed_prompt(
125 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
126 | ):
127 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, do_train=0)
128 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, do_train=1)
129 | # Do WS
130 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
131 | # Get accuracies across all boost sets
132 | individual_accuracies = []
133 | for i in range(len(all_boost_preds[0])):
134 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"])
135 | report = classification_report(labels, preds, output_dict=True)
136 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies
137 |
138 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, do_train=-1):
139 | expt_log = {}
140 | all_boost_preds = []
141 | labels = []
142 |
143 | label_name_to_maping = {
144 | "always": "true",
145 | "never": "false",
146 | "sometimes": 'neither',
147 | "true": "true",
148 | "false": "false",
149 | "neither": 'neither',
150 | "no": "false",
151 | "yes": "true",
152 | "maybe": "neither",
153 | "unknown": "neither",
154 | "inconclusive": "neither",
155 | "impossible": "false",
156 | "possible": "neither",
157 | "guaranteed": "true",
158 | }
159 |
160 | prompts_across_boost = defaultdict(list)
161 | preds_across_boost = defaultdict(list)
162 | for boost_num, boost_examples in enumerate(boost_dfs):
163 | if do_train:
164 | data = boost_examples[0].iloc[:1000]
165 | elif not do_train:
166 | data = boost_examples[1]
167 | else:
168 | raise ValueError("Unsupported value for do train.")
169 |
170 | for i, (ind, row) in tqdm(enumerate(data.iterrows()), total=len(data)):
171 | input = row["inputs_pretokenized"]
172 | gold = row["targets_pretokenized"]
173 | all_prompts = []
174 | raw_answer = get_response(
175 | input,
176 | manifest,
177 | overwrite=bool(overwrite_manifest),
178 | max_toks=30,
179 | )
180 | all_prompts.append(input)
181 | answer = raw_answer.lower()
182 | if answer not in label_name_to_maping:
183 | pred = 'neither'
184 | print(answer)
185 | else:
186 | pred = label_name_to_maping[answer]
187 | prompts_across_boost[i].append(all_prompts)
188 | preds_across_boost[i].append(pred)
189 | if gold.lower() not in label_name_to_maping:
190 | import pdb;
191 | pdb.set_trace()
192 |
193 | for i, (ind, row) in enumerate(data.iterrows()):
194 | label = row["targets_pretokenized"].lower()
195 | entry = {
196 | "ind": ind,
197 | "prompts": prompts_across_boost[i],
198 | "preds_boost": preds_across_boost[i],
199 | "example": row['inputs_pretokenized'],
200 | "gold": label,
201 | }
202 | expt_log[ind] = entry
203 | all_boost_preds.append(preds_across_boost[ind])
204 | labels.append(label_name_to_maping[label])
205 | return expt_log, all_boost_preds, labels
206 |
207 |
208 | def main():
209 | args = get_args()
210 | args.num_boost = 10
211 | task_name = "wsc_t0_variants"
212 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_wsc.fixed_GPT_3_Style/"
213 | decomp = RTEDecomp(task_name, data_dir)
214 | decomp.run(args)
215 |
216 |
217 | if __name__ == "__main__":
218 | main()
219 |
--------------------------------------------------------------------------------
/ablations/T0_variants/run.sh:
--------------------------------------------------------------------------------
1 |
2 | for task in WSC_variants; do
3 | python3 ${task}.py \
4 | --run_zeroshot 0 \
5 | --run_fewshot 0 \
6 | --run_decomp 1 \
7 | --num_boost 10 \
8 | --k_shot 3 \
9 | --output_metrics_file ../ama_logs/metrics.json \
10 | --save_dir ../ama_logs/ama_final_runs \
11 | --overwrite_data 1 \
12 | --overwrite_boost_exs 1 \
13 | --num_run -1 \
14 | --client_connection http://127.0.0.1:5000
15 | done;
--------------------------------------------------------------------------------
/ablations/T0_variants/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from collections import Counter
3 | import json
4 | from datasets import load_dataset
5 | import re
6 | import pandas as pd
7 | from typing import Callable, List
8 | from manifest import Manifest
9 |
10 |
11 | class InputOutputPrompt:
12 | def __init__(self,
13 | input_formatter: Callable,
14 | output_formatter: Callable,
15 | required_keys: List,
16 | input_output_sep: str = "\n",
17 | example_sep: str = "\n\n",
18 | instruction: str = ""
19 | ):
20 | self.input_formatter = input_formatter
21 | self.output_formatter = output_formatter
22 | self.required_keys = required_keys
23 | self.input_output_sep = input_output_sep
24 | self.example_sep = example_sep
25 | self.instruction = instruction
26 |
27 | def __call__(self, input_output_pairs: pd.DataFrame):
28 | examples = []
29 | for _, example in input_output_pairs.iterrows():
30 | examples.append(f"{self.input_formatter(example)}{self.input_output_sep}{self.output_formatter(example)}")
31 | if examples:
32 | input_str = self.example_sep.join(examples)
33 | res = f"{self.instruction}{input_str}"
34 | else:
35 | res = f"{self.instruction}".rstrip()
36 | return res
37 |
38 | def __repr__(self):
39 | dummy_ex = pd.DataFrame([{k: f"<{k.upper()}>" for k in self.required_keys}])
40 | st = self(dummy_ex)
41 | return st
42 |
43 |
44 | def prefix_formatter(ex_keys: List[str], prefix: str, error_on_empty: bool = True) -> str:
45 | def full_prefix_formatter(ex: pd.Series):
46 | for k in ex_keys:
47 | if k in ex:
48 | return f"{prefix} {getattr(ex, k)}"
49 | if error_on_empty:
50 | raise ValueError(f"Example {ex} has no value for any of the keys {ex_keys}")
51 | else:
52 | return f"{prefix}"
53 | return full_prefix_formatter
54 |
55 |
56 | def get_manifest_session(
57 | client_name="huggingface",
58 | client_engine=None,
59 | client_connection="http://127.0.0.1:5000",
60 | cache_connection=None,
61 | temperature=0,
62 | top_p=1.0,
63 | ):
64 | if client_name == "huggingface" and temperature == 0:
65 | params = {
66 | "temperature": 0.001,
67 | "do_sample": False,
68 | "top_p": top_p,
69 | }
70 | elif client_name in {"openai", "ai21"}:
71 | params = {
72 | "temperature": temperature,
73 | "top_p": top_p,
74 | "engine": client_engine,
75 | }
76 | else:
77 | raise ValueError(f"{client_name} is not a valid client name")
78 | manifest = Manifest(
79 | client_name=client_name,
80 | client_connection=client_connection,
81 | cache_name="sqlite",
82 | cache_connection=cache_connection,
83 | session_id=None,
84 | **params,
85 | )
86 | params = manifest.client.get_model_params()
87 | model_name = params["model_name"]
88 | if "engine" in params:
89 | model_name += f"_{params['engine']}"
90 | return manifest, model_name
91 |
92 |
93 | def get_response(
94 | prompt,
95 | manifest,
96 | overwrite=False,
97 | max_toks=10,
98 | stop_token=None,
99 | gold_choices=[],
100 | verbose=False,
101 | ):
102 | prompt = prompt.strip()
103 | if gold_choices:
104 | gold_choices = [" " + g.strip() for g in gold_choices]
105 | response_obj = manifest.run(
106 | prompt, gold_choices=gold_choices, overwrite_cache=overwrite, return_response=True
107 | )
108 | response_obj = response_obj.get_json_response()["choices"][0]
109 | log_prob = response_obj["text_logprob"]
110 | response = response_obj["text"]
111 | else:
112 | response = manifest.run(
113 | prompt,
114 | max_tokens=max_toks,
115 | stop_token=stop_token,
116 | overwrite_cache=overwrite,
117 | )
118 | log_prob = None
119 | if verbose:
120 | print("\n***Prompt***\n", prompt)
121 | print("\n***Response***\n", response)
122 | if log_prob:
123 | return response, log_prob
124 | return response
125 |
126 | def load_hf_data(save_dir, task_name, val_split, hf_name, overwrite_data):
127 | save_data = Path(f"{save_dir}/{task_name}/data.feather")
128 | if not save_data.exists() or overwrite_data:
129 | dataset = load_dataset(hf_name)
130 | test_data = dataset[val_split].to_pandas()
131 | test_data.to_feather(f"{save_data}")
132 | else:
133 | print(f"Reading test data from {save_data}")
134 | test_data = pd.read_feather(f"{save_data}")
135 |
136 | save_data_train = Path(f"{save_dir}/{task_name}/train_data.feather")
137 | if not save_data_train.exists() or overwrite_data:
138 | dataset = load_dataset(hf_name)
139 | train_data = dataset["train"].to_pandas()
140 | train_data.to_feather(f"{save_data_train}")
141 | else:
142 | print(f"Reading train data from {save_data_train}")
143 | train_data = pd.read_feather(f"{save_data_train}")
144 |
145 | print(f"Test Data Size: {len(test_data)}")
146 | print(f"Train Data Size: {len(train_data)}")
147 | return test_data, train_data
148 |
149 | def save_log(task_name, expt_name, log, final_run_dir):
150 | final_run_dir = Path(final_run_dir)
151 | output_fpath = final_run_dir / task_name
152 | output_fpath.mkdir(parents=True, exist_ok=True)
153 |
154 | print("Saving to", output_fpath / f"{expt_name}.json")
155 | assert all(a in list(log.values())[0].keys() for a in ["ind","example","pred","gold"])
156 | with open(output_fpath / f"{expt_name}.json", "w") as f:
157 | json.dump(log, f)
158 |
159 | def text_f1(preds, golds):
160 | """Compute average F1 of text spans.
161 | Taken from Squad without prob threshold for no answer.
162 | """
163 | total_f1 = 0
164 | for pred, gold in zip(preds, golds):
165 | pred_toks = pred.split()
166 | gold_toks = gold.split()
167 | common = Counter(pred_toks) & Counter(gold_toks)
168 | num_same = sum(common.values())
169 | if len(gold_toks) == 0 or len(pred_toks) == 0:
170 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
171 | total_f1 += int(gold_toks == pred_toks)
172 | elif num_same == 0:
173 | total_f1 += 0
174 | else:
175 | precision = 1.0 * num_same / len(pred_toks)
176 | recall = 1.0 * num_same / len(gold_toks)
177 | f1 = (2 * precision * recall) / (precision + recall)
178 | total_f1 += f1
179 | f1_avg = total_f1 / len(golds)
180 | return f1_avg
181 |
182 | def accuracy_span_overlap(preds, golds):
183 | correct = 0
184 | for pred, gold in zip(preds, golds):
185 | found = False
186 | for p in pred:
187 | for g in gold:
188 | if len(p) < len(g):
189 | if p.lower() in g.lower():
190 | found = True
191 | break
192 | else:
193 | if g.lower() in p.lower():
194 | found = True
195 | break
196 | if found: correct += 1
197 | return correct / len(preds)
198 |
199 |
200 |
--------------------------------------------------------------------------------
/boosting/pgm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 | import matplotlib.pyplot as plt
4 | import scipy.stats
5 |
6 |
7 |
8 | class Ising():
9 |
10 |
11 | def __init__(self, m, potentials, thetas = None, vals = [-1, 1], ) -> None:
12 | self.m = m
13 | self.v = m + 1 # total number of vertices
14 | self.potentials = potentials
15 |
16 | self.vals = vals
17 | #TODO support values in 0, 1
18 |
19 | if thetas is not None:
20 | assert len(thetas) >= len(potentials), f"Need to specify at least {len(potentials)} theta parameters."
21 | self.thetas = thetas
22 | else:
23 | self.thetas = np.random.rand(len(potentials))
24 |
25 | self.support = np.array(list(map(list, itertools.product(vals, repeat=self.v))))
26 |
27 | self._make_pdf()
28 | self._make_cdf()
29 |
30 | self._get_means()
31 | self._get_balance()
32 | self._get_accs()
33 |
34 | def _exponential_family(self, labels):
35 | x = 0.0
36 | for i in range(len(self.potentials)):
37 | x += self.thetas[i] * labels[self.potentials[i]].prod()
38 |
39 | return np.exp(x)
40 |
41 | def _make_pdf(self):
42 | p = np.zeros(len(self.support))
43 | for i, labels in enumerate(self.support):
44 | p[i] = self._exponential_family(labels)
45 |
46 | self.z = sum(p)
47 |
48 | self.pdf = p/self.z
49 |
50 | def _make_cdf(self):
51 | self.cdf = np.cumsum(self.pdf)
52 |
53 |
54 |
55 |
56 | def joint_p(self, C, values):
57 | p = 0.0
58 | for k, labels in enumerate(self.support):
59 | flag = True
60 | for i in range(len(C)):
61 | prod = labels[C[i]].prod()
62 | if prod != values[i]:
63 | flag = False
64 |
65 | if flag == True:
66 | p += self.pdf[k]
67 |
68 | return p
69 |
70 | def expectation(self, C):
71 | return self.vals[0] * self.joint_p(C, self.vals[0] * np.ones(len(C))) + self.vals[1] * self.joint_p(C, self.vals[1] * np.ones(len(C)))
72 |
73 | def _get_means(self):
74 | self.means = np.zeros(self.m)
75 | for k in range(self.m):
76 | self.means[k] = self.expectation([[k]])
77 |
78 |
79 | def _get_balance(self):
80 | self.balance = self.joint_p([[self.m]], [1])
81 |
82 | # def _get_covariance(self):
83 |
84 | def _get_accs(self):
85 | """
86 | self.accs[k, i, j] = Pr(lf_k = j | y = i) (i, j scaled to -1, 1 if needed)
87 | """
88 | self.accs = np.zeros((self.m, 2, 2))
89 | for k in range(self.m):
90 | self.accs[k, 1, 1] = self.joint_p([[k], [self.m]], [self.vals[1], self.vals[1]]) / self.balance
91 | self.accs[k, 0, 0] = self.joint_p([[k], [self.m]], [self.vals[0], self.vals[0]]) / (1 - self.balance)
92 | self.accs[k, 1, 0] = 1 - self.accs[k, 1, 1]
93 | self.accs[k, 0, 1] = 1 - self.accs[k, 0, 0]
94 |
95 |
96 | def sample(self):
97 | r = np.random.random_sample()
98 | smaller = np.where(self.cdf < r)[0]
99 | if len(smaller) == 0:
100 | i = 0
101 | else:
102 | i = smaller.max() + 1
103 |
104 | return self.support[i]
105 |
106 | def make_data(self, n, has_label = True):
107 | L = np.zeros((n, self.m))
108 | gold = np.zeros(n)
109 | for i in range(n):
110 | l = self.sample()
111 | L[i, :] = l[:self.m]
112 |
113 | if has_label:
114 | gold[i] = l[self.m]
115 |
116 | return L.astype(int), gold.astype(int)
117 |
118 |
119 |
120 | def est_accs(m, vote, gold):
121 | # compute pr(lf | y) accuracies. Each prompt has 4 values (2x2)
122 | # we need to do this on the train/dev set
123 | classes = [0, 1]
124 | gold_idxs = [np.where(gold == -1)[0], np.where(gold == 1)[0]]
125 |
126 | accs = np.zeros((m, 2, 2)) # [i, j, k] = Pr(prompt_i = j| y = k)
127 | for p in range(m):
128 | for i in classes:
129 | for j in classes:
130 | accs[p, i, j] = len(np.where(vote[gold_idxs[i], p] == 2*j-1)[0]) / len(gold_idxs[i])
131 |
132 | return accs
133 |
134 | def est_balance(gold, n):
135 | return len(np.where(gold == 1)[0]) / n
136 |
137 | # Pr(lf votes, y)
138 | def get_cond_probs(m, votes, y, accs, balance):
139 | pr_y = balance if y == 1 else 1 - balance
140 | prod = pr_y
141 | for i in range(m):
142 | prod *= accs[i, y, int(0.5*(votes[i] + 1))] # this assumes everything is independent
143 | return prod
144 |
145 | # Pr(y = 1 | lf votes)
146 | def get_probs(m, votes, accs, balance):
147 | pos = get_cond_probs(m, votes, 1, accs, balance)
148 | neg = get_cond_probs(m, votes, 0, accs, balance)
149 |
150 | if pos == 0:
151 | return 0
152 | else:
153 | return pos / (pos + neg)
154 |
155 |
156 | def pick_best_prompt(m, vote, gold, n):
157 | # overall accuracies Pr(lf_p = y) on test (we don't know these)
158 | overall_train_acc = np.zeros(m)
159 | for i in range(m):
160 | overall_train_acc[i] = len(np.where((vote[:, i] == gold) == True)[0])/n
161 |
162 | return overall_train_acc.argmax()
163 |
164 |
165 | def main():
166 |
167 | # number of weak labels
168 | m = 5
169 |
170 | # total number of vertices
171 | v = m + 1
172 |
173 | # randomly parametrize exponential family to determine accuracies and correlations
174 | #theta = np.random.rand()
175 | #theta_cliques = (np.random.randint(0, 2, 5)*2 - 1)*theta
176 | #theta = np.random.rand()
177 | #theta_cliques = [1, 1, 1, 1, 1, 1, 1]
178 | thetas = np.random.rand(30)
179 |
180 | # all conditionally independent
181 | potentials = [[5], [0], [1], [4], [0, 5], [1, 5], [2, 5], [3, 5], [4, 5]]
182 |
183 | pgm = Ising(m, potentials, thetas)
184 |
185 | n_train = 10000
186 | vote_train, gold_train = pgm.make_data(n_train)
187 |
188 | n_test = 1000
189 | vote_test, gold_test = pgm.make_data(n_test)
190 |
191 | accs = est_accs(m, vote_train, gold_train)
192 | balance = est_balance(gold_train, n_train)
193 |
194 | nb_output = np.zeros(n_test) # naive bayes
195 | mv_output = np.zeros(n_test)
196 |
197 | nb_err = 0
198 | mv_err = 0
199 |
200 | for i in range(n_test):
201 | nb_output[i] = 2*np.round(get_probs(m, vote_test[i], accs, balance))-1
202 | if nb_output[i] != gold_test[i]:
203 | nb_err += 1
204 |
205 |
206 | # note: play around with MV tie breaking strategy
207 | if len(np.where(vote_test[i] == 1)[0]) >= m / 2:
208 | mv_output[i] = 1
209 | elif len(np.where(vote_test[i] == 1)[0]) < m / 2:
210 | mv_output[i] = -1
211 | else:
212 | mv_output[i] = 2*np.random.randint(0, 2)-1
213 |
214 | if mv_output[i] != gold_test[i]:
215 | mv_err += 1
216 |
217 | nb_acc = 1 - (nb_err / n_test)
218 | mv_acc = 1 - (mv_err / n_test)
219 | #fs_acc = 1 - (fs_err / n_test)
220 |
221 | best_prompt = pick_best_prompt(m, vote_train, gold_train, n_train)
222 |
223 | best_prompt_acc = len(np.where((vote_test[:, best_prompt] == gold_test) == True)[0]) / n_test
224 |
225 | print(f"Naive bayes: {nb_acc}")
226 | print(f"Best prompt: {best_prompt_acc}")
227 | print(f"Majority vote: {mv_acc}")
228 |
229 |
230 | if __name__ == "__main__":
231 | main()
--------------------------------------------------------------------------------
/boosting/run_ws.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import json
4 | import os
5 | import cvxpy as cp
6 | import scipy as sp
7 | import datetime
8 | from methods import Aggregator
9 | from metal.label_model import LabelModel
10 |
11 |
12 | def get_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--task_name", type=str, required=True)
15 | parser.add_argument("--data_dir", type=str, default="../ama_logs/ama_final_runs")
16 | parser.add_argument("--symmetric", action="store_true")
17 | parser.add_argument("--model_prefix", type=str, default="EleutherAI_gpt-j-6B")
18 | parser.add_argument("--override_date", type=str, default=None)
19 |
20 | args = parser.parse_args()
21 | return args
22 |
23 | def get_data(task_name, data_dir, model_prefix, override_date=None):
24 | """
25 | Load in dataset from task_name depending on where files are saved.
26 | """
27 | task_dir = os.path.join(data_dir, task_name)
28 | date = datetime.datetime.today().strftime("%m%d%Y") if not override_date else override_date
29 | print(f"Loading runs from {date}")
30 | dpath = os.path.join(task_dir, f"{model_prefix}_decomposed_{date}.json")
31 | train_dpath = os.path.join(task_dir, f"{model_prefix}_decomposed_{date}_train.json")
32 |
33 | print(dpath)
34 | print(train_dpath)
35 |
36 | if task_name in ["super_glue_rte"]:
37 | label_name_to_int = {"True": 1, "False": 0}
38 | elif task_name == "amazon":
39 | label_name_to_int = {'amazon instant video':0,
40 | 'books':1,
41 | 'clothing shoes and jewelry':2,
42 | 'electronics':3,
43 | 'kindle store':4,
44 | 'movies and tv':5,
45 | 'musical instruments':6,
46 | 'office products':7,
47 | 'tools and home improvement':8}
48 |
49 | elif task_name == 'wic':
50 | label_name_to_int = {"Yes": 1, "No": 0}
51 | elif task_name == 'super_glue_wsc':
52 | label_name_to_int = {"True": 1, "False": 0}
53 | elif task_name == "sst2":
54 | label_name_to_int = {"positive": 1, "negative": 0, "neutral": 0}
55 | elif task_name == "super_glue_boolq":
56 | label_name_to_int = {"true": 1, "false": 0}
57 | elif task_name in ["story_cloze", "story_cloze_v2", "story_cloze_v3"]:
58 | label_name_to_int = {"2": 1, "1": 0}
59 | elif task_name in ["anli_r1", "anli_r2", "anli_r3"]:
60 | label_name_to_int = {"true": 1, "false": 0, "neither": 2}
61 | elif task_name == "MR" or task_name == "mr":
62 | label_name_to_int = {"positive": 1, "negative": 0}
63 | elif task_name == "multirc":
64 | label_name_to_int = {"yes": 1, "no": 0}
65 | elif task_name == "super_glue_cb":
66 | label_name_to_int = {"true": 0, "false": 2, "neither": 1}
67 | elif task_name == "super_glue_copa":
68 | label_name_to_int = {"1": 1, "2": 0}
69 | elif task_name == "drop":
70 | label_name_to_int = {"true": 1, "false": 0}
71 | elif task_name == "super_glue_record":
72 | label_name_to_int = {"true": 1, "false": 0}
73 | elif task_name == "ag_news":
74 | label_name_to_int = {"world news": 0, "sports": 1, "business": 2, "technology and science": 3}
75 | elif task_name == "dbpedia":
76 | label_name_to_int = {"company": 0, "educational institution": 1, "artist": 2, "athlete": 3,
77 | "office holder": 4, "mean of transportation": 5, "building": 6, "natural place": 7,
78 | "village": 8, "animal": 9, "plant": 10, "album": 11, "film": 12, "written work": 13}
79 | else:
80 | raise ValueError("Unsupported task!")
81 |
82 | test_data = json.load(open(dpath))
83 | train_data = json.load(open(train_dpath))
84 |
85 | print(train_data['0']['example'])
86 | print(train_data['0']['gold'])
87 |
88 | n_test = len(test_data)
89 | n_train = len(train_data)
90 |
91 | m = len(test_data['0']['preds_boost'])
92 |
93 | test_votes = np.zeros((n_test, m))
94 | test_gold = np.zeros(n_test)
95 | for i in range(n_test):
96 |
97 | test_votes[i] = np.array([label_name_to_int[ans] if ans in label_name_to_int else -1 for ans in test_data[str(i)]['preds_boost']])
98 | test_gold[i] = label_name_to_int[str(test_data[str(i)]['gold'])]
99 |
100 | test_votes = test_votes.astype(int)
101 | test_gold = test_gold.astype(int)
102 |
103 | train_votes = np.zeros((n_train, m))
104 | train_gold = np.zeros(n_train)
105 | for i in range(n_train):
106 | train_votes[i] = np.array([label_name_to_int[ans] if ans in label_name_to_int else -1 for ans in train_data[str(i)]['preds_boost']])
107 | train_gold[i] = label_name_to_int[str(train_data[str(i)]['gold'])]
108 |
109 | train_votes = train_votes.astype(int)
110 | train_gold = train_gold.astype(int)
111 |
112 | return train_votes, train_gold, test_votes, test_gold
113 |
114 |
115 |
116 | def get_top_deps_from_inverse_sig(J, k):
117 | m = J.shape[0]
118 | deps = []
119 | sorted_idxs = np.argsort(np.abs(J), axis=None)
120 | n = m*m
121 | idxs = sorted_idxs[-k:]
122 | for idx in idxs:
123 | i = int(np.floor(idx / m))
124 | j = idx % m
125 | if (j, i) in deps:
126 | continue
127 | deps.append((i, j))
128 | return deps
129 |
130 | def learn_structure(L):
131 | m = L.shape[1]
132 | n = float(np.shape(L)[0])
133 | sigma_O = (np.dot(L.T,L))/(n-1) - np.outer(np.mean(L,axis=0), np.mean(L,axis=0))
134 |
135 | #bad code
136 | O = 1/2*(sigma_O+sigma_O.T)
137 | O_root = np.real(sp.linalg.sqrtm(O))
138 |
139 | # low-rank matrix
140 | L_cvx = cp.Variable([m,m], PSD=True)
141 |
142 | # sparse matrix
143 | S = cp.Variable([m,m], PSD=True)
144 |
145 | # S-L matrix
146 | R = cp.Variable([m,m], PSD=True)
147 |
148 | #reg params
149 | lam = 1/np.sqrt(m)
150 | gamma = 1e-8
151 |
152 | objective = cp.Minimize(0.5*(cp.norm(R @ O_root, 'fro')**2) - cp.trace(R) + lam*(gamma*cp.pnorm(S,1) + cp.norm(L_cvx, "nuc")))
153 | constraints = [R == S - L_cvx, L_cvx>>0]
154 |
155 | prob = cp.Problem(objective, constraints)
156 | result = prob.solve(verbose=False, solver=cp.SCS)
157 | opt_error = prob.value
158 |
159 | #extract dependencies
160 | J_hat = S.value
161 |
162 | if J_hat is None:
163 | raise ValueError("CVXPY failed to solve the structured learning problem, use result without dependencies.")
164 |
165 | for i in range(m):
166 | J_hat[i, i] = 0
167 | return J_hat
168 |
169 |
170 | def learn_structure_multiclass(L, k):
171 | m = L.shape[1]
172 | J_hats = np.zeros((k, m, m))
173 | for c in range(k):
174 |
175 | all_votes_c = np.where(L == c, 1, 0)
176 | J_hats[c] = learn_structure(all_votes_c)
177 |
178 | return J_hats
179 |
180 | def get_min_off_diagonal(J_hat):
181 | J_hat_copy = J_hat.copy()
182 | for i in range(len(J_hat_copy)):
183 | J_hat_copy[i, i] = np.inf
184 | return np.abs(J_hat_copy).min()
185 |
186 | def main():
187 | args = get_args()
188 | task_name = args.task_name
189 | data_dir = args.data_dir
190 | symmetric = args.symmetric
191 |
192 | train_votes, train_gold, test_votes, test_gold = get_data(task_name, data_dir, args.model_prefix, args.override_date)
193 | classes = np.sort(np.unique(test_gold))
194 | vote_classes = np.sort(np.unique(test_votes))
195 | n_train, m = train_votes.shape
196 | n_test = len(test_votes)
197 | k = len(classes)
198 | abstains = len(vote_classes) == len(classes) + 1
199 | print(f"Abstains: {abstains}")
200 |
201 | m = test_votes.shape[1]
202 |
203 | all_votes= np.concatenate((train_votes, test_votes))
204 |
205 | label_model = LabelModel(k=k, seed=123)
206 |
207 | # scale to 0, 1, 2 (0 is abstain)
208 | test_votes_scaled = (test_votes + np.ones((n_test, m))).astype(int)
209 | test_gold_scaled = (test_gold + np.ones(n_test)).astype(int)
210 |
211 | train_votes_scaled = (train_votes + np.ones((n_train, m))).astype(int)
212 | train_gold_scaled = (train_gold + np.ones(n_train)).astype(int)
213 |
214 | all_votes_scaled = np.concatenate((train_votes_scaled, test_votes_scaled))
215 |
216 | label_model.train_model(all_votes_scaled, Y_dev=train_gold_scaled, abstains=abstains, symmetric=False, n_epochs=10000, log_train_every=50, lr=0.00001)
217 |
218 |
219 | print('Trained Label Model Metrics (No deps):')
220 | scores = label_model.score((test_votes_scaled, test_gold_scaled), metric=['accuracy','precision', 'recall', 'f1'])
221 | print(scores)
222 |
223 | all_votes_no_abstains = np.where(all_votes == -1, 0, all_votes)
224 |
225 | if len(classes) == 2:
226 | J_hat = learn_structure(all_votes_no_abstains)
227 | else:
228 | J_hats = learn_structure_multiclass(all_votes_no_abstains, len(classes))
229 | J_hat = J_hats.mean(axis=0)
230 |
231 | # if values in J are all too large, then everything is connected / structure learning isn't learning the right thing. Don't model deps then
232 | min_entry = get_min_off_diagonal(J_hat)
233 | if min_entry < 1:
234 | deps = get_top_deps_from_inverse_sig(J_hat, 1)
235 | print("Recovered dependencies: ", deps)
236 |
237 | label_model.train_model(all_votes_scaled, Y_dev=train_gold_scaled, abstains=abstains, symmetric=symmetric, n_epochs=80000, log_train_every=50, lr=0.000001, deps=deps)
238 | print('Trained Label Model Metrics (with deps):')
239 | scores = label_model.score((test_votes_scaled, test_gold_scaled), metric=['accuracy', 'precision', 'recall', 'f1'])
240 | print(scores)
241 |
242 | try:
243 | lm_probs = label_model.predict_proba(test_votes_scaled)
244 | agg = Aggregator(test_votes, test_gold, test_votes, test_gold, abstains, classes=[0, 1]) #
245 | print("H(Y | WS output):")
246 | print(agg.conditional_entropy_singleton(lm_probs, test_gold))
247 | except:
248 | print(f"Failed to produce conditional entropy value: H(Y | WS output).")
249 |
250 |
251 |
252 | if __name__ == "__main__":
253 | main()
254 |
--------------------------------------------------------------------------------
/boosting/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 |
4 | def get_probabilties(num_lfs, num_examples, predictions, label_name_to_int):
5 |
6 | lf_array = np.zeros((num_lfs, num_examples))
7 | golds = []
8 |
9 | # Collect golds and preds
10 | for i, (k, item) in enumerate(predictions.items()):
11 | preds = item['chosen_answers_lst']
12 | preds_mapped = []
13 | for p in preds:
14 | if p in label_name_to_int:
15 | preds_mapped.append(label_name_to_int[p])
16 | else:
17 | preds_mapped.append(0)
18 | preds = preds_mapped.copy()
19 | for lf_num, p in zip(range(num_lfs), preds):
20 | lf_array[lf_num][i] = p
21 | gold = label_name_to_int[item['gold']]
22 | golds.append(gold)
23 | golds = np.array(golds)
24 | neg_indices, pos_indices = [np.where(golds == -1)[0], np.where(golds == 1)[0]]
25 | indices = {
26 | -1: neg_indices,
27 | 1: pos_indices
28 | }
29 |
30 | # [i, j, k] = Pr(prompt_i = j| y = k)
31 | # Accuracies
32 | lf_accuracies = []
33 | for i in range(num_lfs):
34 | lf_accuracies.append(np.sum(golds == np.array(lf_array[i]))/num_examples)
35 | print(f"LF Accs: {lf_accuracies}")
36 |
37 | # [i, j, k] = Pr(prompt_i = j| y = k)
38 | classes = label_name_to_int.values()
39 | accs = np.zeros((num_lfs, len(classes), len(classes)))
40 | for p in range(num_lfs):
41 | for i in classes:
42 | for j in classes:
43 | j_idx = j
44 | if j == -1:
45 | j_idx = 0
46 | i_idx = i
47 | if i == -1:
48 | i_idx = 0
49 | accs[p, i_idx, j_idx] = len(np.where(lf_array[p, indices[i]] == j)[0]) / len(indices[i])
50 |
51 | # Compute probabilities
52 | pos_probs = []
53 | for i in range(num_lfs):
54 | sub_preds = lf_array[i][pos_indices]
55 | sub_golds = golds[pos_indices]
56 | pos_probs.append(np.sum(sub_golds == np.array(sub_preds))/len(pos_indices))
57 | print(f"Pos Probs: {pos_probs}")
58 |
59 | neg_probs = []
60 | for i in range(num_lfs):
61 | sub_preds = lf_array[i][neg_indices]
62 | sub_golds = golds[neg_indices]
63 | neg_probs.append(np.sum(sub_golds == np.array(sub_preds))/len(neg_indices))
64 | print(f"Neg Probs: {neg_probs}\n\n")
65 |
66 | return lf_accuracies, accs, pos_probs, neg_probs, golds, indices
67 |
68 |
69 | """ Independence Assumption: take the product of probabilities as p(L1, L2, ..., LK | y) """
70 |
71 | # Pr(y = 1 | lf votes)
72 | def get_cond_probs(votes, y, indices_train, golds_train, accs_train, num_lfs_test):
73 | prop_pos = len(indices_train[1])/len(golds_train)
74 | pr_y = prop_pos if y == 1 else 1 - prop_pos
75 | prod = pr_y
76 | for i in range(num_lfs_test):
77 | if y == -1:
78 | y = 0
79 | prod *= accs_train[i, y, votes[i]]
80 | return prod
81 |
82 | # Pr(y = 1 | lf votes)
83 | def get_probs(votes, indices_train, golds_train, acc_train, num_lfs_test):
84 | votes = [max(v, 0) for v in votes]
85 | numerator = get_cond_probs(votes, 1, indices_train, golds_train, acc_train, num_lfs_test)
86 | denominator = numerator + get_cond_probs(votes, -1, indices_train, golds_train, acc_train, num_lfs_test)
87 | return numerator / denominator
88 |
89 |
90 | def get_nb_accuracy(num_examples_test, num_lfs_test, predictions_test, label_name_to_int, golds_test, indices_train, golds_train, accs_train):
91 | output = np.zeros(num_examples_test)
92 | errors = 0
93 | for i, (k, item) in enumerate(predictions_test.items()):
94 | votes = item['chosen_answers_lst']
95 | votes_mapped = []
96 | for v in votes:
97 | if v in label_name_to_int:
98 | votes_mapped.append(label_name_to_int[v])
99 | else:
100 | votes_mapped.append(0)
101 | votes = votes_mapped.copy()
102 | probs = np.round(get_probs(votes, indices_train, golds_train, accs_train, num_lfs_test))
103 | output[i] = probs
104 |
105 | # Mean squared error
106 | g = golds_test[i]
107 | if golds_test[i] == -1:
108 | g = 0
109 | error = np.abs(output[i] - g)**2
110 | errors += error
111 | accuracy = 1 - (errors / num_examples_test)
112 | return accuracy, output
113 |
114 |
115 | def estimate_matrix(m, n, L):
116 | E_prod = np.zeros((m, m))
117 | l_avg = np.zeros(m)
118 | for i in range(n):
119 | l = L[i, :]
120 | l_avg += l
121 | E_prod += np.outer(l, l)
122 |
123 | l_avg = l_avg/n
124 | E_prod = E_prod/n
125 |
126 | cov = E_prod - np.outer(l_avg, l_avg)
127 |
128 | return (E_prod, cov, l_avg)
129 |
130 |
131 | def get_vote_vectors(num_samples, num_lfs, predictions, label_name_to_int):
132 | vectors = np.zeros((num_samples, num_lfs+1), float)
133 | vectors_no_y = np.zeros((num_samples, num_lfs), float)
134 | labels_vector = np.zeros((num_samples, 1), float)
135 | for i, p in enumerate(predictions.values()):
136 | votes = p['chosen_answers_lst']
137 | votes_mapped = []
138 | for v in votes:
139 | if v in label_name_to_int:
140 | votes_mapped.append(label_name_to_int[v])
141 | else:
142 | votes_mapped.append(0)
143 | votes = votes_mapped.copy()
144 | # votes = [max(v, 0) for v in votes]
145 | gold = p['gold']
146 | gold = label_name_to_int[gold]
147 | vectors_no_y[i] = np.array(votes)
148 | vectors[i] = np.array(votes + [gold]) #- lf_accuracies_train
149 | labels_vector[i] = np.array([gold])
150 | print(f"Shape: {vectors.shape}")
151 | print(f"Sample: {vectors[0]}")
152 |
153 | return vectors, vectors_no_y, labels_vector
154 |
155 | def get_feature_vector(vote_vectors, include_pairwise=False, include_singletons=True):
156 | feature_vectors = []
157 | for votes in vote_vectors:
158 | if include_singletons:
159 | feature_vector = list(votes[:])
160 | else:
161 | feature_vector = []
162 | if include_pairwise:
163 | for subset in itertools.combinations(votes[:], 2):
164 | feature_vector.append(subset[0] * subset[1])
165 | feature_vectors.append(feature_vector)
166 | X = np.matrix(feature_vectors)
167 | return X
--------------------------------------------------------------------------------
/diagnostics/README.md:
--------------------------------------------------------------------------------
1 | This folder contains the diagnostic tasks for Ask Me Antying (AMA) as introduced in the main paper. The diagnostic tasks are aligned with the types of reasoning steps involved in the end-to-end AMA strategy, including question generation, text generation, extraction (from context), and selection (from listed answer choices). Given a new language model, this diagnostic may help anticipate whether AMA will provide lift.
2 |
3 |
4 | With a running manifest session, you can run the diagnostic as follows:
5 |
6 | ```
7 | unzip data.zip
8 | python run_diagnostics.py
9 | ```
--------------------------------------------------------------------------------
/diagnostics/data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/diagnostics/data.zip
--------------------------------------------------------------------------------
/diagnostics/run_diagnostics.py:
--------------------------------------------------------------------------------
1 | """ Running and Scoring the AMA Diagnostics """
2 |
3 | import os
4 | import json
5 | from collections import Counter
6 | from tqdm import tqdm
7 | import openai
8 | from manifest import Manifest
9 | openai.api_key = "" # Find this on the OpenAI Website
10 |
11 | from datasets import load_metric
12 | rouge = load_metric("rouge")
13 |
14 | ######################### HELPER FUNCTIONS #########################
15 |
16 | def rougeL(preds, labels):
17 | return rouge.compute(predictions=preds, references=labels)['rougeL'].mid.fmeasure
18 |
19 | from collections import Counter
20 | def text_f1(preds=[], labels=[]):
21 | """Compute average F1 of text spans.
22 |
23 | Taken from Squad without prob threshold for no answer.
24 | """
25 | total_f1 = 0
26 | for pred, gold in zip(preds, labels):
27 | pred_toks = pred.split()
28 | gold_toks = gold.split()
29 | common = Counter(pred_toks) & Counter(gold_toks)
30 | num_same = sum(common.values())
31 | if len(gold_toks) == 0 or len(pred_toks) == 0:
32 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
33 | total_f1 += int(gold_toks == pred_toks)
34 | elif num_same == 0:
35 | total_f1 += 0
36 | else:
37 | precision = 1.0 * num_same / len(pred_toks)
38 | recall = 1.0 * num_same / len(gold_toks)
39 | f1 = (2 * precision * recall) / (precision + recall)
40 | total_f1 += f1
41 | f1_avg = total_f1 / len(labels)
42 | return f1_avg
43 |
44 | def get_response(manifest, prompt, max_toks=10, temperature = 0, gold_choices=[], model_name="manifest", engine="text-davinci-002", logit_bias={}):
45 | prompt = prompt.strip()
46 | if model_name == 'openai':
47 | if logit_bias:
48 | completion = openai.Completion.create(engine=engine, prompt=prompt, temperature=temperature, top_p=1, max_tokens=max_toks, logprobs=5, logit_bias=logit_bias)
49 | else:
50 | completion = openai.Completion.create(engine=engine, prompt=prompt, temperature=temperature, top_p=1, max_tokens=max_toks, logprobs=5)
51 | response = completion.choices[0].text
52 |
53 | if model_name == "manifest":
54 | if gold_choices:
55 | max_len = max([len(g.split()) for g in gold_choices])
56 | else:
57 | max_len = 0
58 | max_token_args = ({"max_tokens": min(max_toks,8 * len(max_len),)}
59 | if gold_choices is None
60 | else {}
61 | )
62 | if gold_choices:
63 | response = manifest.run(prompt, gold_choices=gold_choices,overwrite_cache=False,**max_token_args,)
64 | else:
65 | response = manifest.run(prompt, max_tokens=max_toks, overwrite_cache=False)
66 |
67 | return response
68 |
69 | def get_manifest_session(
70 | client_name="huggingface",
71 | client_connection="http://127.0.0.1:5000",
72 | cache_connection=None,
73 | temperature=0,
74 | top_p=1.0,
75 | ):
76 | if client_name == "huggingface" and temperature == 0:
77 | params = {
78 | "temperature": 0.001,
79 | "do_sample": False,
80 | "top_p": top_p,
81 | }
82 | elif client_name in {"openai", "ai21"}:
83 | params = {
84 | "temperature": temperature,
85 | "top_p": top_p,
86 | }
87 | else:
88 | raise ValueError(f"{client_name} is not a valid client name")
89 | manifest = Manifest(
90 | client_name=client_name,
91 | client_connection=client_connection,
92 | cache_name="sqlite",
93 | cache_connection=cache_connection,
94 | **params,
95 | )
96 | model_name = manifest.client.get_model_params()["model_name"]
97 | return manifest, model_name
98 |
99 | ######################### SCORE FUNCTIONS #########################
100 |
101 |
102 | """ A "blah" maps to category: blah """
103 | def selection_easy(dataset, manifest):
104 | preds = []
105 | for i, (ind, row) in tqdm(enumerate(dataset.items())):
106 | prefix = row['input']
107 | pred = get_response(manifest, prefix, max_toks=50)
108 | pred = [p for p in pred.split("\n") if p][0].strip()
109 | preds.append(pred)
110 |
111 | if i == 0:
112 | print(prefix)
113 | print(f"PRED: {pred}")
114 |
115 | labels = [l.strip() for l in row['labels']]
116 | num_valid = [p for p in preds if p in labels]
117 | return len(num_valid)/len(dataset)
118 |
119 |
120 | """ Does the model pick one of the given choices? """
121 | def selection_hard(dataset, manifest):
122 | preds = []
123 | for i, (ind, row) in tqdm(enumerate(dataset.items())):
124 | prefix = row['input']
125 | pred = get_response(manifest, prefix, max_toks=50)
126 | pred = [p for p in pred.split("\n")][0]
127 | preds.append(pred)
128 |
129 | if i == 0:
130 | print(prefix)
131 | print(f"PRED: {pred}")
132 |
133 | valid = 0
134 | for (ind, row), pred in zip(dataset.items(), preds):
135 | choices = row['output']
136 | if pred.lower().strip(".") in [c.lower().strip(".") for c in choices]:
137 | valid += 1
138 | return valid/len(dataset)
139 |
140 |
141 | """ Does the model generate three choices? """
142 | def text_generation(dataset, manifest):
143 | preds = []
144 | for i, (ind, row) in tqdm(enumerate(dataset.items())):
145 | prefix = row['input']
146 | pred = get_response(manifest, prefix, max_toks=50)
147 | pred = pred.split("\n\n")[0]
148 | pred = pred.split("\n")
149 | pred = list(set([a.replace("- ", "").strip() for a in pred]))
150 | preds.append(pred)
151 |
152 | if i == 0:
153 | print(prefix)
154 | print(f"PRED: {pred}")
155 |
156 | valid = 0
157 | for pred in preds:
158 | if len(pred) == 2:
159 | valid += 1
160 | return valid/len(dataset)
161 |
162 |
163 | """ Does the model faithfully transform the statement to a question? """
164 | def question_generation(dataset, manifest):
165 | preds = []
166 | for i, (ind, row) in tqdm(enumerate(dataset.items())):
167 | prefix = row['input']
168 | pred = get_response(manifest, prefix, max_toks=50)
169 | pred = [p for p in pred.split("\n")][0]
170 | preds.append(pred)
171 |
172 | if i == 0:
173 | print(prefix)
174 | print(f"PRED: {pred}")
175 |
176 | outputs = [row['output'] for ind, row in dataset.items()]
177 | score = rougeL(preds=preds, labels = outputs)
178 | return score
179 |
180 |
181 | """ Does the model faithfully choose the sentence with the entity name? """
182 | """ Does the model faithfully answer given a keyword for extraction? """
183 | def extraction(dataset, manifest):
184 | preds = []
185 | for i, (ind, row) in tqdm(enumerate(dataset.items())):
186 | prefix = row['input']
187 | pred = get_response(manifest, prefix, max_toks=50)
188 | pred = [p for p in pred.split("\n")][0]
189 | preds.append(pred)
190 |
191 | if i == 0:
192 | print(prefix)
193 | print(f"PRED: {pred}")
194 |
195 | outputs = [row['output'] for ind, row in dataset.items()]
196 | score = text_f1(preds=preds, labels = outputs)
197 | return score
198 |
199 |
200 | def main():
201 | data_dir = "data/"
202 | synthetics = {
203 | 'selection_easy':selection_easy,
204 | 'selection_hard':selection_hard,
205 | 'extraction':extraction,
206 | "text_generation": text_generation,
207 | "question_generation": question_generation
208 | }
209 |
210 | manifest, model_name = get_manifest_session()
211 |
212 | synthetic_scores = {}
213 | for synthetic, function in synthetics.items():
214 | print(f"RUNNING {synthetic}")
215 | with open(f"{data_dir}/{synthetic}.json") as f:
216 | dataset = json.load(f)
217 | score = function(dataset, manifest)
218 | synthetic_scores[synthetic] = score
219 | print(f"SCORE: {score}")
220 |
221 | print(synthetic_scores)
222 | model_name = model_name.replace("/", "_")
223 |
224 | if not os.path.exists(f"results/"):
225 | os.makedirs(f"results/")
226 |
227 | with open(f"results/{model_name}_results.json", "w") as f:
228 | json.dump(synthetic_scores, f)
229 |
230 | print(f"Saved to: results/{model_name}_results.json")
231 |
232 | if __name__ == "__main__":
233 | main()
234 |
--------------------------------------------------------------------------------
/download_p3.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import tensorflow as tf
3 | import pandas as pd
4 | from pathlib import Path
5 | import json
6 | from tqdm.auto import tqdm
7 | import os
8 |
9 | DATA_DIR = os.environ.get("AMA_DATA", "/home/data")
10 |
11 | # Download P3 github data from HF website
12 | '''
13 | git lfs install
14 | git clone https://huggingface.co/datasets/bigscience/P3
15 | '''
16 |
17 | import sys
18 | sys.path.append(f"{DATA_DIR}/P3")
19 | # From P3 github
20 | from tasks_splits_and_features import DATA_SPLITS_SIZES
21 |
22 | SPLITS = ["train", "test", "validation"]
23 | _FEAT_MAPPING_FUNCTIONS = {
24 | "answer_choices": lambda x: [choice.decode("utf-8") for choice in x],
25 | "inputs": lambda x: x.tolist(),
26 | "inputs_pretokenized": lambda x: x.decode("utf-8"),
27 | "targets": lambda x: x.tolist(),
28 | "targets_pretokenized": lambda x: x.decode("utf-8"),
29 | "idx": lambda x: x.tolist(),
30 | "weight": lambda x: float(x),
31 | "is_correct": lambda x: x,
32 | }
33 | def _feature_config(shape, dtype):
34 | if dtype in ("int32", "bool"):
35 | # int32 and bool are stored as int64 in the tf.train.Example protobuf.
36 | dtype = "int64"
37 | if shape and shape[0] is None:
38 | return tf.io.FixedLenSequenceFeature(
39 | shape[1:], dtype, allow_missing=True
40 | )
41 | return tf.io.FixedLenFeature(shape, dtype)
42 |
43 | @tf.autograph.experimental.do_not_convert
44 | def extract_dataset(data_path, subfolder, sizes, processes=10, splits=SPLITS):
45 | datasets = {}
46 | for split in splits:
47 | if not (data_path / subfolder / f"info.{split}.json").exists():
48 | continue
49 | features_dict = json.load(open(data_path / subfolder / f"info.{split}.json"))
50 | if "features" not in features_dict:
51 | features_dict = json.load(open(data_path / subfolder / f"info.train.json"))
52 | if "features" not in features_dict:
53 | continue
54 | features_dict = features_dict["features"]
55 | tfrecord = str(data_path / subfolder / f"{split}.tfrecord-00000-of-00001")
56 | feature_description = {
57 | feat: _feature_config(**desc) for feat, desc in features_dict.items()
58 | }
59 | ds = tf.data.TFRecordDataset(tf.io.gfile.glob([tfrecord])) # TODO -> handle multiple shards
60 | ds = ds.map(
61 | lambda pb: tf.io.parse_single_example(pb, feature_description),
62 | num_parallel_calls=processes
63 | )
64 | # Cast features back to the types from the info JSON since some features
65 | # must be cast for storage (e.g., int32 is stored as int64).
66 | ds = ds.map(
67 | lambda x: {k: tf.cast(v, features_dict[k]["dtype"]) for k, v in x.items()},
68 | num_parallel_calls=processes
69 | )
70 | res = []
71 | for ex in tqdm(ds.as_numpy_iterator(), total=sizes.get(split, 10000)):
72 | ex_dict = {}
73 | for feat_name, feat_value in ex.items():
74 | ex_dict[feat_name] = _FEAT_MAPPING_FUNCTIONS[feat_name](feat_value)
75 | res.append(ex_dict)
76 | if len(res) > 0:
77 | df = pd.DataFrame.from_records(res)
78 | datasets[split] = df
79 | return datasets
80 |
81 | data_path = Path(f"{DATA_DIR}/P3/data")
82 | output_data_path = Path(f"{DATA_DIR}/P3/data_feather")
83 | for subfolder in data_path.iterdir():
84 | print(subfolder)
85 | subfolder = subfolder.name
86 | out_path = output_data_path / subfolder
87 | out_path.mkdir(parents=True, exist_ok=True)
88 | splits_to_pass = []
89 | for split in SPLITS:
90 | if not (out_path / f"{split}.feather").exists():
91 | splits_to_pass.append(split)
92 | datasets = extract_dataset(data_path, subfolder, sizes=DATA_SPLITS_SIZES[str(subfolder)], splits=splits_to_pass)
93 | for split, df in datasets.items():
94 | df.to_feather(out_path / f"{split}.feather")
--------------------------------------------------------------------------------
/imgs/crfm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/imgs/crfm.png
--------------------------------------------------------------------------------
/imgs/decomp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/imgs/decomp.png
--------------------------------------------------------------------------------
/imgs/numbers_station.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/imgs/numbers_station.png
--------------------------------------------------------------------------------
/imgs/snorkel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/imgs/snorkel.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cvxpy==1.2.1
2 | datasets==1.4.0
3 | dill==0.3.5.1
4 | GPUtil==1.4.0
5 | ipdb==0.13.9
6 | ipython==8.5.0
7 | matplotlib==3.5.2
8 | more_itertools==8.14.0
9 | networkx==2.8.5
10 | nltk==3.7
11 | numpy==1.23.3
12 | pandas==1.4.2
13 | scikit_learn==1.1.2
14 | scipy==1.8.1
15 | setuptools==59.5.0
16 | snorkel==0.9.9
17 | tensorboardX==2.5.1
18 | tensorflow==2.9.1
19 | tools==0.1.9
20 | torch==1.12.1
21 | torchtext==0.13.1
22 | tqdm
23 |
--------------------------------------------------------------------------------
/tasks/CB_final.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | from tqdm.auto import tqdm
4 | import pandas as pd
5 |
6 | from sklearn.metrics import classification_report
7 | from decomposition import Decomposition, get_args, DATA_DIR
8 | from utils import get_response, InputOutputPrompt
9 |
10 | questioner = InputOutputPrompt(
11 | input_formatter=lambda x: f"Statement: {x['statement']}",
12 | output_formatter=lambda x: f"Question: {x['question']}",
13 | required_keys=["statement", "question"],
14 | input_output_sep="\n",
15 | example_sep="\n\n",
16 | instruction="Rewrite the statement as a yes/no question.\n\n"
17 | )
18 |
19 | questioner_examples = [
20 | pd.DataFrame([
21 | {
22 | "statement": "most of the light comes from the sun",
23 | "question": "Does most of the light come from the sun?"
24 | },
25 | {
26 | "statement": "the test was not",
27 | "question": "Was the test hard?"
28 | },
29 | {
30 | "statement": "it is a good idea to buy your parents gifts",
31 | "question": "Is it a good idea to buy your parents gifts?"
32 | },
33 | {
34 | "statement": "the balloon popped",
35 | "question": "Did the balloon pop?"
36 | },
37 | {
38 | "statement": "The father and son went camping to California.",
39 | "question": "Did the father and son go camping?"
40 | }
41 | ]),
42 | pd.DataFrame([
43 | {
44 | "statement": "most of the light comes from the sun",
45 | "question": "Does most of the light come from the sun?"
46 | },
47 | {
48 | "statement": "the test was not",
49 | "question": "Was the test hard?"
50 | },
51 | {
52 | "statement": "it is a good idea to buy your parents gifts",
53 | "question": "Is it a good idea to buy your parents gifts?"
54 | },
55 | {
56 | "statement": "the balloon popped",
57 | "question": "Did the balloon pop?"
58 | },
59 | {
60 | "statement": "The father and son went camping to California.",
61 | "question": "Did the father and son go camping?"
62 | }
63 | ]),
64 | pd.DataFrame([
65 | {
66 | "statement": "she prefers kittens over puppies.",
67 | "question": "Does she prefer kittens over puppies?",
68 | },
69 | {
70 | "statement": "Max and his wife went on a trip to Europe",
71 | "question": "Did Max and his wife go on a trip to Europe?",
72 | },
73 | {
74 | "statement": "jared was born during the war in 1942.",
75 | "question": "Was Jared born during a war in 1942?",
76 | },
77 | {
78 | "statement": "it took jenna 7 attempts to solve the problem",
79 | "question": "Did it take Jenna 7 attempts to solve the problem?",
80 | },
81 | ]),
82 | ]
83 |
84 | openended_qa = InputOutputPrompt(
85 | input_formatter=lambda x: f"Passage: {x['passage']}\nQuestion: {x['question']}",
86 | output_formatter=lambda x: f"Answer: {x['answer']}",
87 | required_keys=["passage", "question", "answer"],
88 | input_output_sep="\n",
89 | example_sep="\n\n",
90 | instruction="Provide the answer to the question from the passage.\n\n"
91 | )
92 |
93 | openended_qa_examples = [
94 | pd.DataFrame([
95 | {
96 | "passage": "When Judy and Jack went to school, they got in trouble with their teacher for being late. I didn't think it was very fair.",
97 | "question": "Did she think it was fair?",
98 | "answer": "No"
99 | },
100 | {
101 | "passage": "If inflation is occurring, leading to higher prices for basic necessities such as gas by 2 dollars. Do you think that inflation is good for society?",
102 | "question": "Is inflation good for society?",
103 | "answer": "Maybe"
104 | },
105 | {
106 | "passage": "Put yourself out there. The more time you spend dating and socializing, the more likely you will find a boyfriend you like.",
107 | "question": "Does socializing help you find a boyfriend?",
108 | "answer": "Yes"
109 | },
110 | ]),
111 | pd.DataFrame([
112 | {
113 | "passage": "Jack recommends his least favorite books of the year to his followers. The least favorite book this year was Harry Potter and the 7 Rings.",
114 | "question": "What book does Jack dislike?",
115 | "answer": "Jack does not like Harry Potter and the 7 Rings."
116 | },
117 | {
118 | "passage": "When Judy and Jack went to school, they got in trouble with their teacher for being late. I didn't think it was very fair.",
119 | "question": "Did she think it was fair?",
120 | "answer": "No, she didn't think it was very fair."
121 | },
122 | {
123 | "passage": "If inflation is occurring, leading to higher prices for basic necessities such as gas by 2 dollars. Do you think that inflation is good for society?",
124 | "question": "Is inflation good for society?",
125 | "answer": "Hmm. Do you think so?"
126 | },
127 | {
128 | "passage": "Put yourself out there. The more time you spend dating and socializing, the more likely you will find a boyfriend you like.",
129 | "question": "Does socializing help you find a boyfriend?",
130 | "answer": "Yes, it helps you find a boyfriend."
131 | }
132 | ]),
133 | pd.DataFrame([
134 | {
135 | "passage": "Anna's mother always told her to be confident even if she feels nervous on the inside",
136 | "question": "Does Anna always feel nervous on the inside?",
137 | "answer": "Unknown"
138 | },
139 | {
140 | "passage": "Max and Jeff were extremely competitive at soccer, but Max was a lot better.",
141 | "question": "Was Jeff better than Max at soccer?",
142 | "answer": "No, Max was a lot better"
143 | },
144 | {
145 | "passage": "When Judy and Jack went to school, they got in trouble with their teacher for being late. I didn't think it was very fair.",
146 | "question": "Did she think it was fair?",
147 | "answer": "No, she didn't think it was very fair."
148 | },
149 | {
150 | "passage": "The FSP conference took place last week in Spain and representatives from 21 countries attended.",
151 | "question": "Did representatives from more than 20 countries attend FSP?",
152 | "answer": "Yes"
153 | },
154 | ]),
155 | ]
156 |
157 | class CBDecomp(Decomposition):
158 | def __init__(self, task_name, data_dir, val_split="validation"):
159 | super().__init__(task_name, data_dir, val_split)
160 |
161 | def get_boost_decomp_examples(self, data_train, boost_id):
162 | return [
163 | questioner_examples[boost_id],
164 | openended_qa_examples[boost_id],
165 | ]
166 |
167 | def zero_few_baseline(
168 | self,
169 | test_data,
170 | few_shot_df,
171 | manifest,
172 | overwrite_manifest,
173 | do_few_shot=True,
174 | ):
175 | expt_log = {}
176 | preds = []
177 | labels = []
178 |
179 | labels_names = set(test_data["targets_pretokenized"])
180 | labels_names = [l.lower().strip() for l in labels_names]
181 |
182 | for i, (ind, row) in tqdm(
183 | enumerate(test_data.iterrows()), total=len(test_data)
184 | ):
185 | if ind in expt_log:
186 | pred = entry["pred"]
187 | gold = entry["gold"]
188 | else:
189 | text = row["inputs_pretokenized"]
190 | gold = row["targets_pretokenized"]
191 |
192 | icl_str = ""
193 | if do_few_shot:
194 | for s_ind, s_row in few_shot_df.iterrows():
195 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n"
196 |
197 | text = row["inputs_pretokenized"]
198 | text = text.replace("True, False, or Neither?", "")
199 | text = text + ". True, False, or Neither?"
200 | gold = row["targets_pretokenized"]
201 | prompt = f"{icl_str}{{text:}}"
202 | pmp = prompt.format(text=text)
203 | if i == 0:
204 | print(pmp)
205 |
206 | raw_answer = get_response(
207 | pmp,
208 | manifest,
209 | overwrite=bool(overwrite_manifest),
210 | max_toks=30,
211 | )
212 | answer = raw_answer.strip().lower()
213 | answer = answer.split("\n")
214 | answer = [a for a in answer if a]
215 | answer = [
216 | a
217 | for a in answer
218 | if any(l.lower() in a.lower() for l in labels_names)
219 | ]
220 | if answer:
221 | answer = answer[0]
222 | else:
223 | answer = ""
224 | answer = "".join(
225 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']]
226 | )
227 | is_yes = "true" in answer.split()
228 | is_no = "false" in answer.split()
229 | is_maybe = "neither" in answer.split() or "unknown" in answer.split() or "maybe" in answer.split()
230 | pred = "Neither"
231 | if is_yes and (not is_maybe and not is_no):
232 | pred = "True"
233 | if is_no and (not is_maybe and not is_yes):
234 | pred = "False"
235 | if is_maybe and (not is_no and not is_yes):
236 | pred = "Neither"
237 | entry = {
238 | "ind": ind,
239 | "example": text,
240 | "base_prompt": pmp,
241 | "raw_answer": raw_answer,
242 | "pred": pred,
243 | "gold": gold,
244 | }
245 | expt_log[ind] = entry
246 |
247 | preds.append(pred)
248 | labels.append(gold)
249 |
250 | report = classification_report(labels, preds, output_dict=True)
251 | return expt_log, report["accuracy"]
252 |
253 | def get_question(self, statement, prompt, boost_ex, manifest, overwrite_manifest):
254 | prompt_suffix = prompt(boost_ex)
255 | quesiton_prompt = f"\n{prompt_suffix}\n\nStatement: {{statement:}}\nQuestion:"
256 | question_pmp = quesiton_prompt.format(statement=statement)
257 | answer = get_response(
258 | question_pmp,
259 | manifest,
260 | overwrite=bool(overwrite_manifest),
261 | max_toks=50,
262 | )
263 | answer = answer.split("\n")[0]
264 | return answer, question_pmp
265 |
266 | def open_qa(self, question, passage, prompt, boost_ex, manifest, overwrite_manifest):
267 | prompt_suffix = prompt(boost_ex)
268 | qa_prompt = f"\n{prompt_suffix}\n\nPassage: {{passage:}}\nQuestion: {question}\nAnswer:"
269 | qa_pmp = qa_prompt.format(passage=passage)
270 | answer = get_response(
271 | qa_pmp,
272 | manifest,
273 | overwrite=bool(overwrite_manifest),
274 | max_toks=50,
275 | )
276 | answer = answer.split("\n")[0]
277 | answer = (
278 | answer.replace("A: ", "")
279 | .replace("B: ", "")
280 | .replace("Answer: ", "")
281 | .replace(", ", " ")
282 | )
283 | return answer, qa_pmp
284 |
285 | def run_decomposed_prompt(
286 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
287 | ):
288 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest)
289 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=-1)
290 | # Do WS
291 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
292 | # Get accuracies across all boost sets
293 | individual_accuracies = []
294 | for i in range(len(all_boost_preds[0])):
295 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"])
296 | report = classification_report(labels, preds, output_dict=True)
297 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies
298 |
299 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1):
300 | expt_log = {}
301 | all_boost_preds = []
302 | labels = []
303 |
304 | for i, (ind, row) in tqdm(
305 | enumerate(test_data.iterrows()), total=len(test_data)
306 | ):
307 | suffix = "True, False, or Neither?"
308 | input = row["inputs_pretokenized"]
309 | passage = input.split("\nQuestion: ")[0]
310 | statement = (
311 | input.split("\nQuestion: ")[-1]
312 | .replace(suffix, "")
313 | .replace('"', "")
314 | .strip()
315 | )
316 |
317 | if i == run_limit:
318 | break
319 |
320 | gold = row["targets_pretokenized"]
321 | gold = gold.lower()
322 | prompts_across_boost = []
323 | preds_across_boost = []
324 | for boost_examples in boost_dfs:
325 | all_prompts = []
326 | question, question_final_prompt = self.get_question(
327 | statement, questioner, boost_examples[0], manifest, overwrite_manifest
328 | )
329 | open_answer, answer_final_prompt = self.open_qa(
330 | question, passage, openended_qa, boost_examples[1], manifest, overwrite_manifest
331 | )
332 | all_prompts.append(question_final_prompt)
333 | all_prompts.append(answer_final_prompt)
334 | if i == 0:
335 | print("\n".join(all_prompts))
336 | if "Yes" in open_answer.split():
337 | answer = "True"
338 | elif "No" in open_answer.split():
339 | answer = "False"
340 | else:
341 | answer = "Neither"
342 |
343 | answer = answer.lower()
344 |
345 | is_yes = "yes" in answer.split() or "true" in answer.split()
346 | is_no = "no" in answer.split() or "false" in answer.split()
347 | is_maybe = "neither" in answer.split() or "maybe" in answer.split() or "unknown" in answer.split()
348 | pred = "neither"
349 | if is_yes and (not is_maybe and not is_no):
350 | pred = "true"
351 | if is_no and (not is_maybe and not is_yes):
352 | pred = "false"
353 | if is_maybe and (not is_no and not is_yes):
354 | pred = "neither"
355 | prompts_across_boost.append(all_prompts)
356 | preds_across_boost.append(pred)
357 |
358 | entry = {
359 | "ind": ind,
360 | "prompts": prompts_across_boost,
361 | "preds_boost": preds_across_boost,
362 | "example": input,
363 | "gold": gold,
364 | }
365 | expt_log[ind] = entry
366 | all_boost_preds.append(preds_across_boost)
367 | labels.append(gold)
368 | return expt_log, all_boost_preds, labels
369 |
370 |
371 | def main():
372 | args = get_args()
373 | args.num_boost = 3
374 | task_name = "super_glue_cb"
375 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_cb_GPT_3_style/"
376 | decomp = CBDecomp(task_name, data_dir)
377 | decomp.run(args)
378 |
379 |
380 | if __name__ == "__main__":
381 | main()
382 |
--------------------------------------------------------------------------------
/tasks/DBPedia_final.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import pandas as pd
4 | import numpy as np
5 |
6 | from tqdm.auto import tqdm
7 |
8 | from sklearn.metrics import classification_report
9 | from decomposition import Decomposition, get_args, DATA_DIR
10 | from utils import get_response, InputOutputPrompt
11 |
12 | summarize = InputOutputPrompt(
13 | input_formatter=lambda x: f"Passage: {x['passage']}",
14 | output_formatter=lambda x: f"Summarize: the passage \"Passage\": {x['summary']}",
15 | required_keys=["passage", "summary"],
16 | input_output_sep="\n",
17 | example_sep="\n\n",
18 | instruction="Summarize the passage.\n\n\"Categories\":\n- company\n- educational institution\n- artist\n- athlete\n- office holder\n- mean of transportation\n- building\n- natural place\n- village\n- animal\n- plant\n- album\n- film \n- written work\n\n"
19 | )
20 |
21 | summarize_examples = [
22 | pd.DataFrame([
23 | {
24 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.",
25 | "summary": "The passage is about a journal."
26 | },
27 | {
28 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.",
29 | "summary": "The passage is about a lifeboat."
30 | },
31 | {
32 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.",
33 | "summary": "The passage is about a album."
34 | }
35 | ]),
36 | pd.DataFrame([
37 | {
38 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.",
39 | "summary": "The passage is about a journal."
40 | },
41 | {
42 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.",
43 | "summary": "The passage is about a album."
44 | },
45 | {
46 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.",
47 | "summary": "The passage is about a lifeboat."
48 | },
49 | ]),
50 | pd.DataFrame([
51 | {
52 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.",
53 | "summary": "The passage is about a album."
54 | },
55 | {
56 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.",
57 | "summary": "The passage is about a journal."
58 | },
59 | {
60 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.",
61 | "summary": "The passage is about a lifeboat."
62 | },
63 | ])
64 | ]
65 |
66 | categorize = InputOutputPrompt(
67 | input_formatter=lambda x: f"Passage: {x['passage']}\nSummary: {x['summary']}",
68 | output_formatter=lambda x: f"The summary \"Summary\" fits \"Category\": {x['category']}",
69 | required_keys=["passage", "summary", "category"],
70 | input_output_sep="\n",
71 | example_sep="\n\n",
72 | instruction="Pick one category for the following text.\n\n\"Categories\":\n- company\n- educational institution\n- artist\n- athlete\n- office holder\n- mean of transportation\n- building\n- natural place\n- village\n- animal\n- plant\n- album\n- film\n- written work\n\n"
73 | )
74 | categorize_examples = [
75 | pd.DataFrame([
76 | {
77 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.",
78 | "summary": "The passage is about a journal.",
79 | "category": "written work"
80 | },
81 | {
82 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.",
83 | "summary": "The passage is about a lifeboat.",
84 | "category": "mean of transportation"
85 | },
86 | {
87 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.",
88 | "summary": "The passage is about a album.",
89 | "category": "album"
90 | }
91 | ]),
92 | pd.DataFrame([
93 | {
94 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.",
95 | "summary": "The passage is about a journal.",
96 | "category": "written work"
97 | },
98 | {
99 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.",
100 | "summary": "The passage is about a album.",
101 | "category": "album"
102 | },
103 | {
104 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.",
105 | "summary": "The passage is about a lifeboat.",
106 | "category": "mean of transportation"
107 | },
108 | ]),
109 | pd.DataFrame([
110 | {
111 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.",
112 | "summary": "The passage is about a album.",
113 | "category": "album"
114 | },
115 | {
116 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.",
117 | "summary": "The passage is about a journal.",
118 | "category": "written work"
119 | },
120 | {
121 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.",
122 | "summary": "The passage is about a lifeboat.",
123 | "category": "mean of transportation"
124 | },
125 | ])
126 | ]
127 | description_zeroshot="""
128 | Pick the correct category for the passage.
129 | Categories:
130 | - company
131 | - educational institution
132 | - artist
133 | - athlete
134 | - office holder
135 | - mean of transportation
136 | - building
137 | - natural place
138 | - village
139 | - animal
140 | - plant
141 | - album
142 | - film
143 | - written work"""
144 |
145 | class DBPediaDecomp(Decomposition):
146 | def __init__(self, task_name, data_dir, val_split="validation"):
147 | super().__init__(task_name, data_dir, val_split)
148 |
149 | def get_boost_decomp_examples(self, train_data, boost_id):
150 | return [
151 | summarize_examples[boost_id],
152 | categorize_examples[boost_id],
153 | ]
154 |
155 | def get_few_shot_examples(self, train_data, k_shot):
156 | """Get few shot examples"""
157 | labels = sorted(set(train_data["targets_pretokenized"]))
158 | num_per_class = int(np.ceil(k_shot / len(labels)))
159 | print(f"Selecting {num_per_class} examples per class.")
160 |
161 | dfs = []
162 | total_in_context = 0
163 | for label in labels:
164 | while num_per_class + total_in_context > k_shot:
165 | num_per_class -= 1
166 | sub_df = train_data[train_data["targets_pretokenized"] == label].sample(
167 | num_per_class
168 | )
169 | dfs.append(sub_df)
170 | total_in_context += num_per_class
171 | if total_in_context == k_shot:
172 | break
173 | mini_df = pd.concat(dfs)
174 | return mini_df
175 |
176 | def zero_few_baseline(
177 | self,
178 | test_data,
179 | few_shot_df,
180 | manifest,
181 | overwrite_manifest,
182 | do_few_shot=True,
183 | ):
184 | expt_log = {}
185 | preds = []
186 | labels = []
187 |
188 | for i, (ind, row) in tqdm(
189 | enumerate(test_data.iterrows()), total=len(test_data)
190 | ):
191 | input = row["inputs_pretokenized"]
192 | body = input.split("written work. ")[-1]
193 | gold = row["targets_pretokenized"].strip().lower()
194 | icl_str = ""
195 | title = description_zeroshot
196 | if do_few_shot:
197 | for s_ind, s_row in few_shot_df.iterrows():
198 | s_input = s_row.inputs_pretokenized
199 | s_body = s_input.split("written work. ")[-1]
200 | s_title = s_input.split("written work. ")[0] + "written work."
201 | s_output = s_row.targets_pretokenized.strip()
202 | icl_str += f"Passage: {s_body}\nCategory: {s_output}\n\n"
203 |
204 | icl_str = f"{title}\n\n{icl_str}"
205 | prompt = f"{icl_str}Passage: {{body:}}\nCategory:"
206 | pmp = prompt.format(body=body)
207 | if i == 0:
208 | print(pmp)
209 |
210 | output = get_response(
211 | pmp,
212 | manifest,
213 | overwrite=bool(overwrite_manifest),
214 | max_toks=10,
215 | stop_token="\n\n",
216 | )
217 | pred = output.strip().lower()
218 | entry = {
219 | "ind": ind,
220 | "example": input,
221 | "base_prompt": pmp,
222 | "pred": pred,
223 | "gold": gold,
224 | }
225 | expt_log[ind] = entry
226 |
227 | preds.append(pred)
228 | labels.append(gold)
229 |
230 | report = classification_report(labels, preds, output_dict=True)
231 | return expt_log, report["accuracy"]
232 |
233 | def run_decomposed_prompt(
234 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
235 | ):
236 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest)
237 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1)
238 | # Do WS
239 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
240 | # Get accuracies across all boost sets
241 | individual_accuracies = []
242 | for i in range(len(all_boost_preds[0])):
243 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"])
244 | report = classification_report(labels, preds, output_dict=True)
245 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies
246 |
247 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1):
248 | expt_log = {}
249 | all_boost_preds = []
250 | labels = []
251 |
252 | for i, (ind, row) in tqdm(
253 | enumerate(test_data.iterrows()), total=len(test_data)
254 | ):
255 | if i == run_limit:
256 | break
257 |
258 | text = (
259 | row.inputs_pretokenized.strip("\n").split("written work.")[-1].strip()
260 | )
261 | gold = row.targets_pretokenized.strip().lower()
262 |
263 | prompts_across_boost = []
264 | preds_across_boost = []
265 | for boost_examples in boost_dfs:
266 | all_prompts = []
267 | prompt_suffix = summarize(boost_examples[0])
268 | summarize_prompt = f"{prompt_suffix}\n\nPassage: {{text:}}\nSummarize: the passage \"Passage\":"
269 | summarize_pmp = summarize_prompt.format(text=text)
270 | output = get_response(
271 | summarize_pmp,
272 | manifest,
273 | overwrite=bool(overwrite_manifest),
274 | max_toks=50,
275 | )
276 | summary = output.split("\n")[0].split(":")[-1].strip("\n")
277 | prompt_suffix = categorize(boost_examples[1])
278 | category_prompt = f"{prompt_suffix}\n\nPassage: {{text:}}\nSummary: {{summary:}}\nThe summary \"Summary\" fits \"Category\":"
279 | category_pmp = category_prompt.format(text=text, summary=summary)
280 | output = get_response(
281 | category_pmp,
282 | manifest,
283 | overwrite=bool(overwrite_manifest),
284 | max_toks=15,
285 | )
286 | pred = output.split("\n")[0].strip().lower()
287 |
288 | all_prompts.append(summarize_pmp)
289 | all_prompts.append(category_pmp)
290 | if i == 0:
291 | print(summarize_pmp)
292 | print(category_pmp)
293 | prompts_across_boost.append(all_prompts)
294 | preds_across_boost.append(pred)
295 |
296 | entry = {
297 | "ind": ind,
298 | "example": text,
299 | "prompts": prompts_across_boost,
300 | "preds_boost": preds_across_boost,
301 | "gold": gold,
302 | }
303 | expt_log[ind] = entry
304 | all_boost_preds.append(preds_across_boost)
305 | labels.append(gold)
306 | return expt_log, all_boost_preds, labels
307 |
308 |
309 | def main():
310 | args = get_args()
311 | task_name = "dbpedia"
312 | data_dir = (
313 | f"{DATA_DIR}/P3/data_feather/dbpedia_14_pick_one_category_for_the_following_text"
314 | )
315 | decomp = DBPediaDecomp(task_name, data_dir, val_split="test")
316 | decomp.run(args)
317 |
318 |
319 | if __name__ == "__main__":
320 | main()
321 |
--------------------------------------------------------------------------------
/tasks/NQ_final.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import pandas as pd
4 | import numpy as np
5 | import ast
6 |
7 | from tqdm.auto import tqdm
8 | from pathlib import Path
9 | from nltk.corpus import stopwords
10 |
11 | stops = set(stopwords.words("english"))
12 | from decomposition import Decomposition, get_args, DATA_DIR
13 | from utils import get_response, InputOutputPrompt, accuracy_span_overlap
14 |
15 | extract = InputOutputPrompt(
16 | input_formatter=lambda x: f"Question: {x['question']}",
17 | output_formatter=lambda x: f"Answer: {x['answer']}",
18 | required_keys=["question", "answer"],
19 | input_output_sep="\n",
20 | example_sep="\n\n",
21 | instruction="Produce distinct questions.\n\n"
22 | )
23 |
24 | more_info_examples = [
25 | pd.DataFrame([
26 | {
27 | "question": "who plays Carrie Bradshaw in sex and the city?",
28 | "answer": "Caroline \"Carrie\" Bradshaw is a fictional character from the HBO franchise Sex and the City, portrayed by Sarah Jessica Parker."
29 | },
30 | {
31 | "question": "what are the elements in air?",
32 | "answer": "By mole fraction (i.e., by number of molecules), dry air contains 78.08% nitrogen, 20.95% oxygen, 0.93% argon, 0.04% carbon dioxide, and small amounts of other gases"
33 | },
34 | {
35 | "question": "what is HP company?",
36 | "answer": "HP Inc. is an American multinational information technology company headquartered in Palo Alto, California, that develops personal computers (PCs)"
37 | },
38 | {
39 | "question": "when was the last season of FRIENDS released?",
40 | "answer": "The series finale aired on May 6, 2004, and was watched by around 52.5 million American viewers, making it the fifth-most-watched series finale in television history"
41 | }
42 | ]),
43 |
44 | ]
45 |
46 | answer = InputOutputPrompt(
47 | input_formatter=lambda x: f"Context: {x['context']}\nQuestion: {x['question']}",
48 | output_formatter=lambda x: f"Answer: {x['answer']}",
49 | required_keys=["context", "question", "answer"],
50 | input_output_sep="\n",
51 | example_sep="\n\n",
52 | instruction="Answer the question.\n\n"
53 | )
54 |
55 | answer_question = [
56 | pd.DataFrame([
57 | {
58 | 'context': 'The nearest airport to Palm Springs is Indio/Palm Springs (PSP) Airport which is 2.1 miles away. ',
59 | 'question': 'what airport is closest to palm springs?',
60 | 'answer': 'Palm Springs International Airport'
61 | },
62 | {
63 | 'context': 'Martin Luther King earned his Bachelor of Divinity degree from Crozer Theological Seminary, followed by a doctorate in Systematic Theology from Boston University.',
64 | 'question': 'what degree did martin luther king get?',
65 | 'answer': 'Bachelor of Divinity'
66 | },
67 | {
68 | 'context': 'The Niger river runs in a crescent through Libya, Mali, Niger, on the border with Benin and then through Nigeria.',
69 | 'question': 'what countries does the niger river flow through?',
70 | 'answer': 'Libya'
71 | },
72 | {
73 | 'context': 'Puerto Rico is a territory of the United States and uses the U.S. dollar. ',
74 | 'question': 'what type of currency is used in puerto rico?',
75 | 'answer': 'United States dollar'
76 | },
77 | {
78 | 'context': 'kitt was voice most often by William daniels.',
79 | 'question': 'who played kitt in knight rider?',
80 | 'answer': 'William Daniels'
81 | }
82 | ]),
83 | pd.DataFrame([
84 | {
85 | 'context': 'leonardo da vinci invented the parachute, the helicopter, double hull, an armored fighting vehicle,',
86 | 'question': 'what inventions did leonardo da vinci made?',
87 | 'answer': 'Double hull'
88 | },
89 | {
90 | 'context': "The French franc (F) was the national currency of France prior to France's adoption of the euro (EUR) in January 2002.",
91 | 'question': 'what currency is used in france before euro?',
92 | 'answer': 'French franc'
93 | },
94 | {
95 | 'context': 'The Isthmus of Panama, contains the country of Panama and the panama canal.',
96 | 'question': 'where is isthmus of panama located?',
97 | 'answer': 'Costa Rica'
98 | },
99 | {
100 | 'context': 'Hurricane Irene was a large and destructive tropical cyclone which affected much of the Caribbean and East Coast',
101 | 'question': 'where did hurricane irene?',
102 | 'answer': 'Eastern United States'
103 | },
104 | {
105 | 'context': 'Rihanna acted in This is the End and Battleship.',
106 | 'question': 'what movie did rihanna play in?',
107 | 'answer': 'This Is the End'
108 | }
109 | ]),
110 | pd.DataFrame([
111 | {
112 | 'context': 'st vincent de paul is buried in the 6th arrondisment of Paris.',
113 | 'question': 'where is st vincent de paul buried?',
114 | 'answer': 'Paris'
115 | },
116 | {
117 | 'context': 'Thomas Luther "Luke" Bryan (born July 17, 1976) is an American country singer and songwriter from Leesburg.',
118 | 'question': 'where is luke bryan from?',
119 | 'answer': 'Leesburg'
120 | },
121 | {
122 | 'context': "Klum and Seal got married on 10 May 2005 on a beach in Mexico near Seal's home on Costa Careyes. ",
123 | 'question': 'where did heidi klum and seal get married?',
124 | 'answer': 'Mexico'},
125 | {
126 | 'context': 'Tarantino starred in pulp fiction, grindhouse and others.',
127 | 'question': 'what movies did quentin tarantino star in?',
128 | 'answer': 'Grindhouse'
129 | },
130 | {
131 | 'context': 'Countries that are sometimes considered to be entirely or partially part of the Balkans are Croatia, Serbia, Lake Prespa.',
132 | 'question': 'what country is located in the balkan peninsula?',
133 | 'answer': 'Lake Prespa'
134 | }
135 | ])
136 |
137 | ]
138 |
139 | prefix_select_zeroshot = """Answer the question.\n\n"""
140 |
141 |
142 | class NQDecomp(Decomposition):
143 | def __init__(self, task_name, data_dir, val_split="validation"):
144 | super().__init__(task_name, data_dir, val_split)
145 |
146 | def get_boost_decomp_examples(self, train_data, boost_id):
147 | return [
148 | more_info_examples[0],
149 | answer_question[boost_id],
150 | ]
151 |
152 | def read_data(self, save_dir, overwrite_data):
153 | save_data = Path(f"{save_dir}/{self.task_name}/data.feather")
154 | if not save_data.exists() or overwrite_data:
155 | test_data = pd.read_csv(f"{self.data_dir}/{self.task_name}/nq-test.qa.csv", sep="\t", header=None)
156 | test_data.to_feather(f"{save_data}")
157 | else:
158 | print(f"Reading test data from {save_data}")
159 | test_data = pd.read_feather(save_data)
160 |
161 | save_data = Path(f"{self.data_dir}/{self.task_name}/train_data.feather")
162 | if not save_data.exists() or overwrite_data:
163 | train_data = pd.read_json(f"{self.data_dir}/{self.task_name}/biencoder-nq-dev.json")
164 | train_data.to_feather(f"{save_data}")
165 | else:
166 | print(f"Reading train data from {save_data}")
167 | train_data = pd.read_feather(save_data)
168 | print(f"Test Data Size: {len(test_data)}")
169 | print(f"Train Data Size: {len(train_data)}")
170 | return test_data, train_data
171 |
172 | def get_few_shot_examples(self, train_data, k_shot):
173 | """Get few shot examples"""
174 | labels = train_data.question.tolist()
175 |
176 | num_per_class = int(np.ceil(k_shot / len(labels)))
177 | print(f"Selecting {num_per_class} examples per class.")
178 |
179 | dfs = []
180 | total_in_context = 0
181 | for label in labels:
182 | while num_per_class + total_in_context > k_shot:
183 | num_per_class -= 1
184 | sub_df = train_data[train_data["question"] == label].sample(num_per_class)
185 | dfs.append(sub_df)
186 | total_in_context += num_per_class
187 | if total_in_context == k_shot:
188 | break
189 | mini_df = pd.concat(dfs)
190 | return mini_df
191 |
192 | def zero_few_baseline(
193 | self,
194 | test_data,
195 | few_shot_df,
196 | manifest,
197 | overwrite_manifest,
198 | prompt_suffix="",
199 | do_few_shot=True,
200 | ):
201 | expt_log = {}
202 | preds = []
203 | labels = []
204 |
205 | for i, (ind, row) in tqdm(
206 | enumerate(test_data.iterrows()), total=len(test_data)
207 | ):
208 | question = row.question
209 | if isinstance(row.answers, str):
210 | label = ast.literal_eval(row.answers)
211 | else:
212 | label = row.answers.tolist()
213 | gold = label
214 |
215 | icl_str = ""
216 | if do_few_shot:
217 | for s_ind, s_row in few_shot_df.iterrows():
218 | input = s_row.question
219 | s_question = s_row.question
220 | if isinstance(s_row.answers, str):
221 | label = ast.literal_eval(s_row.answers)
222 | else:
223 | label = s_row.answers.tolist()
224 | icl_str += f"Question: {s_question}\nAnswer: {label[0]}\n\n"
225 |
226 | prompt = (
227 | icl_str
228 | + "Question: {question:}"
229 | + prompt_suffix
230 | + "\nAnswer:"
231 | )
232 |
233 | if i == 0:
234 | print(prompt.format(question=question))
235 | prompt = prompt.format(question=question)
236 | raw_answer = get_response(
237 | prompt, #prompt.format(question=question),
238 | manifest,
239 | overwrite=bool(overwrite_manifest),
240 | max_toks=20,
241 | stop_token="\n\n",
242 | )
243 | pred = raw_answer.strip("\n").strip().lower()
244 | entry = {
245 | "ind": ind,
246 | "example": question,
247 | "base_prompt": prompt,
248 | "raw_answer": raw_answer,
249 | "pred": pred,
250 | "gold": gold,
251 | }
252 | expt_log[ind] = entry
253 |
254 | preds.append([pred])
255 | labels.append(gold)
256 | metric = accuracy_span_overlap(preds=preds, golds=labels)
257 | return expt_log, metric
258 |
259 | def run_decomposed_prompt(
260 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
261 | ):
262 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest)
263 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000)
264 |
265 | # Do WS
266 | boost_test, boost_train = [], []
267 | for p in all_boost_preds:
268 | samples = [lf[1] for lf in p]
269 | boost_test.append(samples)
270 | for p in all_boost_train_preds:
271 | samples = [lf[1] for lf in p]
272 | boost_train.append(samples)
273 |
274 | preds = self.merge_boosted_preds(boost_test, boost_train, train_labels, expt_log, expt_log_train)
275 | preds = [(x,y) for x,y in zip([p[0][0] for p in all_boost_preds], preds)]
276 |
277 | # Get accuracies across all boost sets
278 | individual_accuracies = []
279 | for i in range(len(all_boost_preds[0])):
280 | individual_accuracies.append(accuracy_span_overlap(preds=[p[i] for p in all_boost_preds], golds=labels))
281 |
282 | metric = accuracy_span_overlap(preds=preds, golds=labels)
283 | return expt_log, expt_log_train, metric, individual_accuracies
284 |
285 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1):
286 | expt_log = {}
287 | all_boost_preds = []
288 | labels = []
289 |
290 |
291 | for i, (ind, row) in tqdm(
292 | enumerate(test_data.iterrows()), total=len(test_data)
293 | ):
294 | if i == run_limit:
295 | break
296 |
297 | question = row.question
298 | if isinstance(row.answers, str):
299 | label = ast.literal_eval(row.answers)
300 | else:
301 | label = row.answers.tolist()
302 |
303 | gold = label
304 | prompts_across_boost = []
305 | preds_across_boost = []
306 |
307 | # extract context
308 | prompt_suffix = extract(boost_dfs[0][0])
309 | prompt = (
310 | prompt_suffix + "\n\Question: {question:}\nAnswer:"
311 | )
312 | more_info_answer = get_response(
313 | prompt.format(question=question),
314 | manifest,
315 | overwrite=bool(overwrite_manifest),
316 | max_toks=20,
317 | stop_token="\n\n",
318 | )
319 |
320 | for boost_examples in boost_dfs:
321 | all_prompts = []
322 | prompt_suffix = answer(boost_examples[1])
323 | prompt = (
324 | prompt_suffix + "\n\nContext: {text:}\nQuestion: {question:}\nAnswer:"
325 | )
326 | all_prompts.append(prompt)
327 | if i == 0:
328 | print(prompt.format(text=more_info_answer, question=question))
329 |
330 | raw_answer = get_response(
331 | prompt.format(text=more_info_answer, question=question),
332 | manifest,
333 | overwrite=bool(overwrite_manifest),
334 | max_toks=20,
335 | stop_token="\n\n",
336 | )
337 | pred = raw_answer.split("\n")[0].strip().lower()
338 | prompts_across_boost.append(all_prompts)
339 | preds_across_boost.append((more_info_answer, pred))
340 |
341 | entry = {
342 | "ind": ind,
343 | "example": question,
344 | "prompts": prompts_across_boost,
345 | "preds_boost": preds_across_boost,
346 | "gold": gold,
347 | }
348 | expt_log[ind] = entry
349 | all_boost_preds.append(preds_across_boost)
350 | labels.append(gold)
351 | return expt_log, all_boost_preds, labels
352 |
353 |
354 | def main():
355 | args = get_args()
356 | task_name = "NQ"
357 | data_dir = f"{DATA_DIR}/NQ"
358 | webq = NQDecomp(task_name, data_dir)
359 | webq.run(args)
360 |
361 |
362 | if __name__ == "__main__":
363 | main()
364 |
--------------------------------------------------------------------------------
/tasks/SST2_final.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | from tqdm.auto import tqdm
4 | from pathlib import Path
5 | import pandas as pd
6 | import numpy as np
7 | import random
8 |
9 | from sklearn.metrics import classification_report
10 | from decomposition import Decomposition, get_args, DATA_DIR
11 | from utils import get_response
12 |
13 | def format_data(lines):
14 | """from lines in dataset to two lists of sentences and labels respectively"""
15 | def process_raw_data_sst(lines):
16 | labels = []
17 | sentences = []
18 | for line in lines:
19 | labels.append(int(line[0]))
20 | sentences.append(line[2:].strip())
21 | return sentences, labels
22 |
23 | train_sentences, train_labels = process_raw_data_sst(lines)
24 | return train_sentences, train_labels
25 |
26 |
27 | class SST2Decomp(Decomposition):
28 | def __init__(self, task_name, data_dir, val_split="validation"):
29 | super().__init__(task_name, data_dir, val_split)
30 |
31 | def get_few_shot_examples(self, train_data, k_shot):
32 | """Get few shot examples"""
33 | labels = [0, 1]
34 | num_per_class = int(np.ceil(k_shot / len(labels)))
35 | print(f"Selecting {num_per_class} examples per class.")
36 |
37 | dfs = []
38 | total_in_context = 0
39 | for label in labels:
40 | while num_per_class + total_in_context > k_shot:
41 | num_per_class -= 1
42 | sub_df = train_data[train_data["label"] == label].sample(num_per_class)
43 | dfs.append(sub_df)
44 | total_in_context += num_per_class
45 | if total_in_context == k_shot:
46 | break
47 | mini_df = pd.concat(dfs)
48 | return mini_df
49 |
50 | def read_data(self, save_dir, overwrite_data):
51 | save_data = Path(f"{save_dir}/{self.task_name}/data.feather")
52 | if not save_data.exists() or overwrite_data:
53 | with open(f"{self.data_dir}/stsa.binary.test", "r") as f:
54 | test_lines = f.readlines()
55 | test_sentences, test_labels = format_data(test_lines)
56 | test_data = pd.DataFrame({
57 | 'sentence': test_sentences,
58 | 'label': test_labels,
59 | })
60 | test_data.to_feather(f"{save_data}")
61 | else:
62 | print(f"Reading test data from {save_data}")
63 | test_data = pd.read_feather(save_data)
64 |
65 | save_data = Path(f"{save_dir}/{self.task_name}/train_data.feather")
66 | if not save_data.exists() or overwrite_data:
67 | with open(f"{self.data_dir}/stsa.binary.train", "r") as f:
68 | train_lines = f.readlines()
69 | train_sentences, train_labels = format_data(train_lines)
70 | train_data = pd.DataFrame({
71 | 'sentence': train_sentences,
72 | 'label': train_labels,
73 | })
74 | train_data.to_feather(f"{save_data}")
75 | else:
76 | print(f"Reading train data from {save_data}")
77 | train_data = pd.read_feather(save_data)
78 | print(f"Test Data Size: {len(test_data)}")
79 | print(f"Train Data Size: {len(train_data)}")
80 | return test_data, train_data
81 |
82 |
83 | def get_boost_decomp_examples(self, train_data, boost_id):
84 | seed = [1, 2, 3][boost_id]
85 | k_shot = 16
86 | random.seed(seed)
87 | np.random.seed(seed)
88 |
89 | data_train = pd.DataFrame(train_data)
90 | labels = set(data_train["label"])
91 | num_per_class = int(np.ceil(k_shot / len(labels)))
92 |
93 | dfs = []
94 | total_in_context = 0
95 | for label in labels:
96 | while num_per_class + total_in_context > k_shot:
97 | num_per_class -= 1
98 |
99 | if seed % 2 == 1:
100 | sub_df = data_train[data_train["label"] == label].sample(num_per_class, random_state=seed)
101 | elif seed % 2 == 0:
102 | sub_df = data_train[data_train["label"] != label].sample(num_per_class, random_state=seed)
103 | dfs.append(sub_df)
104 | total_in_context += num_per_class
105 | if total_in_context == k_shot:
106 | break
107 |
108 | booster_df = pd.concat(dfs).sample(frac=1, random_state=0)
109 | print(f"Selected: {len(booster_df)} in context examples.")
110 | return [
111 | booster_df
112 | ]
113 |
114 | def zero_few_baseline(
115 | self,
116 | test_data,
117 | few_shot_df,
118 | manifest,
119 | overwrite_manifest,
120 | do_few_shot=True,
121 | ):
122 | expt_log = {}
123 | preds = []
124 | labels = []
125 |
126 | for i, (ind, row) in tqdm(
127 | enumerate(test_data.iterrows()), total=len(test_data)
128 | ):
129 | if ind in expt_log:
130 | pred = entry["pred"]
131 | gold_str = entry["gold"]
132 | else:
133 | sentence = row["sentence"]
134 | label = row["label"]
135 |
136 | icl_str = ""
137 | if do_few_shot:
138 | for s_ind, s_row in few_shot_df.iterrows():
139 | if s_row["label"] == 0:
140 | demo_label = "negative"
141 | else:
142 | demo_label = "positive"
143 | icl = f"Text: {s_row['sentence']}\nSentiment: {demo_label}"
144 | icl_str += f"{icl}\n\n"
145 |
146 | description = "For each snippet of text, label the sentiment of the text as positive or negative."
147 | prompt = f"{description}\n\n{icl_str}Text: {{sentence:}}\nSentiment:"
148 | pmp = prompt.format(sentence=sentence)
149 | if i == 0:
150 | print(pmp)
151 |
152 | pred = get_response(
153 | pmp,
154 | manifest,
155 | overwrite=bool(overwrite_manifest),
156 | max_toks=50,
157 | )
158 | pred = (
159 | pred.replace(".", "")
160 | .replace(",", "")
161 | .replace("Label: ", "")
162 | .replace("Sentiment: ", "")
163 | )
164 | pred = [p for p in pred.split("\n") if p]
165 | is_pos = "positive" in pred
166 | is_neg = "negative" in pred
167 | if is_pos and not is_neg:
168 | pred = "positive"
169 | elif is_neg and not is_pos:
170 | pred = "negative"
171 | else:
172 | pred = ""
173 |
174 | if label == 1:
175 | gold_str = "positive"
176 | else:
177 | gold_str = "negative"
178 |
179 | entry = {
180 | "gold": gold_str,
181 | "pred": pred,
182 | "base_prompt": pmp,
183 | "ind": ind,
184 | "example": sentence,
185 | }
186 | expt_log[ind] = entry
187 |
188 | preds.append(pred)
189 | labels.append(gold_str)
190 |
191 | report = classification_report(labels, preds, output_dict=True)
192 | return expt_log, report["accuracy"]
193 |
194 | def run_decomposed_prompt(
195 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
196 | ):
197 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1)
198 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000)
199 | # Do WS
200 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
201 | # Get accuracies across all boost sets
202 | individual_accuracies = []
203 | for i in range(len(all_boost_preds[0])):
204 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"])
205 | report = classification_report(labels, preds, output_dict=True)
206 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies
207 |
208 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1):
209 | expt_log = {}
210 | all_boost_preds = []
211 | labels = []
212 |
213 | for i, (ind, row) in tqdm(
214 | enumerate(test_data.iterrows()), total=len(test_data)
215 | ):
216 | sentence = row["sentence"]
217 | label = row["label"]
218 |
219 | if i == run_limit:
220 | break
221 |
222 | prompts_across_boost = []
223 | preds_across_boost = []
224 | for boost_examples in boost_dfs:
225 | all_prompts = []
226 | icl_str = ""
227 | for s_ind, s_row in boost_examples[0].iterrows():
228 | if s_row["label"] == 0:
229 | demo_label = "negative"
230 | else:
231 | demo_label = "positive"
232 | icl = f"Text: {s_row['sentence']}\nSentiment: {demo_label}"
233 | icl_str += f"{icl}\n\n"
234 |
235 | description = "For each snippet of text, label the sentiment of the text as positive or negative."
236 | prompt = f"{description}\n\n{icl_str}Text: {{sentence:}}\nSentiment:"
237 | pmp = prompt.format(sentence=sentence)
238 | all_prompts.append(pmp)
239 | if i == 0:
240 | print(pmp)
241 | pred = get_response(
242 | pmp,
243 | manifest,
244 | overwrite=bool(overwrite_manifest),
245 | max_toks=5,
246 | )
247 | pred = pred.replace(".", "").replace(",", "").replace("Label: ", "")
248 | pred = [p for p in pred.split("\n") if p]
249 | if pred:
250 | pred = pred[0]
251 | else:
252 | pred = ""
253 | prompts_across_boost.append(all_prompts)
254 | preds_across_boost.append(pred)
255 |
256 | if label == 1:
257 | gold_str = "positive"
258 | else:
259 | gold_str = "negative"
260 |
261 | entry = {
262 | "gold": gold_str,
263 | "prompts": prompts_across_boost,
264 | "preds_boost": preds_across_boost,
265 | "example": sentence,
266 | "ind": i,
267 | }
268 | expt_log[i] = entry
269 | all_boost_preds.append(preds_across_boost)
270 | labels.append(gold_str)
271 |
272 | return expt_log, all_boost_preds, labels
273 |
274 |
275 | def main():
276 | args = get_args()
277 | args.num_boost = 3
278 | task_name = "sst2"
279 | data_dir = f"{DATA_DIR}/sst2/"
280 | if not Path(data_dir).exists():
281 | raise ValueError(
282 | f"Data directory {data_dir} does not exist. Download AGNews from https://github.com/tonyzhaozh/few-shot-learning.")
283 | decomp = SST2Decomp(task_name, data_dir)
284 | decomp.run(args)
285 |
286 |
287 | if __name__ == "__main__":
288 | main()
289 |
--------------------------------------------------------------------------------
/tasks/WIC_final.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | from tqdm.auto import tqdm
4 | import pandas as pd
5 |
6 | from sklearn.metrics import classification_report
7 | from decomposition import Decomposition, get_args, DATA_DIR
8 | from utils import get_response, InputOutputPrompt
9 |
10 | synonym = InputOutputPrompt(
11 | input_formatter=lambda x: f"{x['passage']}",
12 | output_formatter=lambda x: f"- {x['answer']}",
13 | required_keys=["passage", "answer"],
14 | input_output_sep="\n",
15 | example_sep="\n\n",
16 | instruction="Give synonyms of the word in the sentence.\n\n"
17 | )
18 | synonym_examples = [
19 | pd.DataFrame([
20 | {
21 | "passage": "In \"She heard the sound of voices in the hall.\", synonyms for the word \"sound\" are:",
22 | "answer": "noise",
23 | },
24 | {
25 | "passage": "In \"Enter the secret code.\", synonyms for the word \"code\" are:",
26 | "answer": "password",
27 | },
28 | {
29 | "passage": "In \"She acted in a play on Broadway\", synonyms for the word \"play\" are:",
30 | "answer": "show",
31 | },
32 | ]),
33 | pd.DataFrame([
34 | {
35 | "passage": "In \"She rode around the park on her cycle.\", synonyms for the word \"cycle\" are:",
36 | "answer": "bicycle",
37 | },
38 | {
39 | "passage": "In \"Don't keep everything bottled up.\", synonyms for the word \"bottled\" are:",
40 | "answer": "trapped inside",
41 | },
42 | {
43 | "passage": "In \"The present is like no other time.\", synonyms for the word \"present\" are:",
44 | "answer": "current moment",
45 | },
46 | ]),
47 | pd.DataFrame([
48 | {
49 | "passage": "In \"The movie was awful.\", synonyms for the word \"aweful\" are:",
50 | "answer": "bad and terrible",
51 | },
52 | {
53 | "passage": "In \"She is so beautiful.\", synonyms for the word \"beautiful\" are:",
54 | "answer": "pretty and gorgeous",
55 | },
56 | {
57 | "passage": "In \"It was quite cool out so she wore a jacket\", synonyms for the word \"cool\" are:",
58 | "answer": "cold and chilly",
59 | },
60 | ]),
61 | pd.DataFrame([
62 | {
63 | "passage": "In \"There are so many flies near the food.\", synonyms for the word \"flies\" are:",
64 | "answer": "bugs",
65 | },
66 | {
67 | "passage": "In \"Eat your noodles with a fork.\", synonyms for the word \"fork\" are:",
68 | "answer": "utensils",
69 | },
70 | {
71 | "passage": "In \"She and her husband went on a trip.\", synonyms for the word \"trip\" are:",
72 | "answer": "vacation",
73 | },
74 | ]),
75 | pd.DataFrame([
76 | {
77 | "passage": "In \"It was definitely a cry for help.\", synonyms for the word \"cry\" are:",
78 | "answer": "call",
79 | },
80 | {
81 | "passage": "In \"I watch all my students as they take their exams.\", synonyms for the word \"watch\" are:",
82 | "answer": "look at",
83 | },
84 | {
85 | "passage": "In \"The beginning of the book was fine, but the end was terrible.\", synonyms for the word \"beginning\" are:",
86 | "answer": "start",
87 | },
88 | ])
89 | ]
90 |
91 | description = InputOutputPrompt(
92 | input_formatter=lambda x: f"Choices\n{x['choices']}\nFill the [MASK] with the correct \"Choice\": {x['sentence']}",
93 | output_formatter=lambda x: f"[MASK] is \"Choice\": {x['answer']}\n",
94 | required_keys=["choices", "sentence", "answer"],
95 | input_output_sep="\n",
96 | example_sep="\n\n",
97 | instruction="Select the correct choice for the blank.\n\n"
98 | )
99 | description_examples = [
100 | pd.DataFrame([
101 | {
102 | "choices": "1: noise\n2. in good condition",
103 | "sentence": "She heard the [MASK] of voices in the hall.",
104 | "answer": "noise",
105 | },
106 | {
107 | "choices": "1. not heavy\n2. sun rays",
108 | "sentence": "The [MASK] shined through the window.",
109 | "answer": "sun rays",
110 | },
111 | {
112 | "choices": "1. widespread\n2. commander of an army",
113 | "sentence": "The book is of [MASK] interest.",
114 | "answer": "widespread",
115 | },
116 | ])
117 | ]
118 |
119 | class WICDecomp(Decomposition):
120 | def __init__(self, task_name, data_dir, val_split="validation"):
121 | super().__init__(task_name, data_dir, val_split)
122 |
123 | def get_boost_decomp_examples(self, train_data, boost_id):
124 | return [
125 | synonym_examples[boost_id],
126 | ]
127 |
128 | def zero_few_baseline(
129 | self,
130 | test_data,
131 | few_shot_df,
132 | manifest,
133 | overwrite_manifest,
134 | do_few_shot=True,
135 | ):
136 | expt_log = {}
137 | preds = []
138 | labels = []
139 |
140 | labels_names = set(test_data["targets_pretokenized"])
141 | labels_names = [l.lower().strip() for l in labels_names]
142 |
143 | for i, (ind, row) in tqdm(
144 | enumerate(test_data.iterrows()), total=len(test_data)
145 | ):
146 | if ind in expt_log:
147 | pred = entry["pred"]
148 | gold = entry["gold"]
149 | else:
150 | text = row["inputs_pretokenized"]
151 | gold = row["targets_pretokenized"]
152 |
153 | icl_str = ""
154 | if do_few_shot:
155 | for s_ind, s_row in few_shot_df.iterrows():
156 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n"
157 |
158 | prompt = f"{icl_str}{{text:}}"
159 | pmp = prompt.format(text=text)
160 | if i == 0:
161 | print(pmp)
162 |
163 | raw_answer = get_response(
164 | pmp,
165 | manifest,
166 | overwrite=bool(overwrite_manifest),
167 | max_toks=30,
168 | )
169 | answer = raw_answer.strip().lower()
170 | answer = answer.split("\n")
171 | answer = [a for a in answer if a]
172 | answer = [
173 | a
174 | for a in answer
175 | if any(l.lower() in a.lower() for l in labels_names)
176 | ]
177 | if answer:
178 | answer = answer[0]
179 | answer = "".join(
180 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']]
181 | )
182 |
183 | is_yes = "yes" in answer.split()
184 | is_no = "no" in answer.split()
185 | pred = ""
186 | if is_yes and (not is_no):
187 | pred = "yes"
188 | if is_no and (not is_yes):
189 | pred = "no"
190 |
191 | gold = gold.strip().lower()
192 | pred = pred.strip().lower()
193 | entry = {
194 | "ind": ind,
195 | "example": text,
196 | "base_prompt": pmp,
197 | "raw_answer": raw_answer,
198 | "pred": pred,
199 | "gold": gold,
200 | }
201 | expt_log[ind] = entry
202 |
203 | preds.append(pred)
204 | labels.append(gold)
205 |
206 | report = classification_report(labels, preds, output_dict=True)
207 | return expt_log, report["accuracy"]
208 |
209 | def get_parts(self, text):
210 | parts = text.split("\n")
211 | sent1 = parts[0]
212 | sent2 = parts[1]
213 | word = parts[2].split("Question: Is the word ")[-1].split()[0]
214 | return word, sent1, sent2
215 |
216 | def clean_sentence(self, sentence):
217 | sentence = sentence.replace("2. ", "")
218 | sentence = sentence.replace("3. ", "")
219 | sentence = sentence.replace("\n\n", "")
220 | sentence = sentence.replace("A:", "")
221 | sentence = sentence.strip()
222 | sentence = sentence.split(".")[0]
223 | sentence = sentence.split("\n")[0]
224 | return sentence
225 |
226 | def get_sentences(
227 | self, all_constructors, all_boost_exs, sent, word, manifest, overwrite_manifest
228 | ):
229 | synonym = all_constructors[0]
230 |
231 | all_prompts = []
232 | # synonyms
233 | prompt_suffix = synonym(all_boost_exs[0])
234 | prompt_combined = f'{prompt_suffix}\n\nIn "{{sent:}}", synonyms for the word \"{{word:}}\" are: '
235 | all_prompts.append(prompt_combined.format(sent=sent, word=word))
236 | synonyms = get_response(
237 | prompt_combined.format(sent=sent, word=word),
238 | manifest,
239 | overwrite=bool(overwrite_manifest),
240 | max_toks=10,
241 | )
242 | synonyms = synonyms.replace("- ", "").split("\n")
243 | synonyms = ", ".join([a for a in synonyms if a][0:1])
244 |
245 | # generate sentences
246 | quoted_sent = sent.split()
247 | sent = []
248 | for tok in quoted_sent:
249 | if tok.lower() == word.strip('"').lower():
250 | sent.append(f'"{tok}"')
251 | else:
252 | sent.append(tok)
253 | if sent:
254 | sent = " ".join(sent)
255 | else:
256 | sent = " ".join(quoted_sent)
257 |
258 | combined_definition = f"{synonyms}"
259 | sentences = []
260 | return combined_definition, sentences, all_prompts
261 |
262 | def pairwise_comparisons(
263 | self,
264 | description_constructor,
265 | boost_exs,
266 | def1,
267 | sentences_lst1,
268 | def2,
269 | sentences_lst2,
270 | word,
271 | manifest,
272 | overwrite_manifest,
273 | ):
274 | all_prompts = []
275 | # reconcile the result
276 | answer = ""
277 | if def1.strip() != def2.strip():
278 | answer = "No"
279 | else:
280 | answer = "Yes"
281 | return answer, all_prompts
282 |
283 | def run_decomposed_prompt(
284 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
285 | ):
286 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest)
287 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000)
288 | # Do WS
289 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
290 | # Get accuracies across all boost sets
291 | individual_accuracies = []
292 | for i in range(len(all_boost_preds[0])):
293 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"])
294 | report = classification_report(labels, preds, output_dict=True)
295 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies
296 |
297 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1):
298 | expt_log = {}
299 | all_boost_preds = []
300 | labels = []
301 |
302 | for i, (ind, row) in tqdm(
303 | enumerate(test_data.iterrows()), total=len(test_data)
304 | ):
305 | text = row["inputs_pretokenized"]
306 | gold = row["targets_pretokenized"]
307 |
308 | word, sent1, sent2 = self.get_parts(text)
309 |
310 | if i == run_limit:
311 | break
312 |
313 | prompts_across_boost = []
314 | preds_across_boost = []
315 | for boost_examples in boost_dfs:
316 | all_prompts = []
317 |
318 | def1, sentences_lst1, lst1_prompts = self.get_sentences(
319 | [synonym], [boost_examples[0]], sent1, word, manifest, overwrite_manifest
320 | )
321 | def2, sentences_lst2, lst2_prompts = self.get_sentences(
322 | [synonym], [boost_examples[0]], sent2, word, manifest, overwrite_manifest
323 | )
324 |
325 | pred, pred_prompts = self.pairwise_comparisons(
326 | description,
327 | boost_examples[-1],
328 | def1,
329 | sentences_lst1,
330 | def2,
331 | sentences_lst2,
332 | word,
333 | manifest,
334 | overwrite_manifest,
335 | )
336 | all_prompts = lst1_prompts + lst2_prompts + pred_prompts
337 | if i == 0:
338 | print("\n".join(all_prompts))
339 | prompts_across_boost.append(all_prompts)
340 | preds_across_boost.append(pred)
341 | entry = {
342 | "ind": ind,
343 | "example": text,
344 | "word": word,
345 | "prompts": prompts_across_boost,
346 | "preds_boost": preds_across_boost,
347 | "sent1": sent1,
348 | "sent2": sent2,
349 | "def1": def1,
350 | "def2": def2,
351 | "gold": gold,
352 | "sentences_lst1": sentences_lst1,
353 | "sentences_lst2": sentences_lst2,
354 | }
355 | expt_log[ind] = entry
356 | all_boost_preds.append(preds_across_boost)
357 | labels.append(gold)
358 | return expt_log, all_boost_preds, labels
359 |
360 |
361 | def main():
362 | args = get_args()
363 | args.num_boost = 5
364 | task_name = "super_glue_wic"
365 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_wic_GPT_3_prompt/"
366 | decomp = WICDecomp(task_name, data_dir)
367 | decomp.run(args)
368 |
369 |
370 | if __name__ == "__main__":
371 | main()
372 |
--------------------------------------------------------------------------------
/tasks/decomposition.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | from pathlib import Path
4 |
5 | import argparse
6 | from typing import Counter
7 | import pandas as pd
8 | import json
9 | import numpy as np
10 | import datetime
11 | import os
12 | import random
13 |
14 | from utils import save_log, get_manifest_session
15 |
16 | DATA_DIR = os.environ.get("AMA_DATA", "/home/data")
17 |
18 | def get_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--seed", type=int, default=3)
21 | parser.add_argument(
22 | "--num_run", type=int, default=-1, help="Number of rows of test data to run"
23 | )
24 | parser.add_argument(
25 | "--k_shot", type=int, default=3, help="Number of few shot"
26 | )
27 | parser.add_argument(
28 | "--num_boost", type=int, default=3, help="Number of few shot sets to boost over")
29 | parser.add_argument(
30 | "--boost_train_examples", type=int, default=1000, help="Number of training examples to run through for boosting"
31 | )
32 | parser.add_argument(
33 | "--output_metrics_file", type=str, default="decomposition_metrics.json", help="Output file for all metrics."
34 | )
35 | parser.add_argument(
36 | "--save_dir", type=str, default="/home/final_runs/", help="Data directory"
37 | )
38 | parser.add_argument(
39 | "--run_decomp",
40 | type=int,
41 | default=1,
42 | help="Run decomp",
43 | choices=[0, 1],
44 | )
45 | parser.add_argument(
46 | "--run_zeroshot",
47 | type=int,
48 | default=1,
49 | help="Run zeroshot",
50 | choices=[0, 1],
51 | )
52 | parser.add_argument(
53 | "--run_fewshot",
54 | type=int,
55 | default=1,
56 | help="Run fewshot",
57 | choices=[0, 1],
58 | )
59 | parser.add_argument(
60 | "--run_zeroshot_decomp",
61 | type=int,
62 | default=0,
63 | help="Run zero shot decomp",
64 | choices=[0, 1],
65 | )
66 | parser.add_argument(
67 | "--overwrite_boost_exs",
68 | type=int,
69 | default=0,
70 | help="Overwrite boost examples",
71 | choices=[0, 1],
72 | )
73 | parser.add_argument(
74 | "--overwrite_data",
75 | type=int,
76 | default=0,
77 | help="Overwrite saved data examples",
78 | choices=[0, 1],
79 | )
80 | # Manifest
81 | parser.add_argument(
82 | "--client_name",
83 | type=str,
84 | default="huggingface",
85 | help="Client name manifest",
86 | choices=["huggingface", "openai", "ai21"],
87 | )
88 | parser.add_argument(
89 | "--client_engine",
90 | type=str,
91 | default=None,
92 | help="Client engine manifest. Only used for openai/ai21",
93 | choices=["davinci"],
94 | )
95 | parser.add_argument(
96 | "--client_connection",
97 | type=str,
98 | default="http://127.0.0.1:5000",
99 | help="Client connection str",
100 | )
101 | parser.add_argument(
102 | "--cache_connection",
103 | type=str,
104 | default="/home/manifest/final_runs.sqlite",
105 | help="Cache connection str",
106 | )
107 | parser.add_argument(
108 | "--overwrite_manifest",
109 | type=int,
110 | default=0,
111 | help="Overwrite manifest",
112 | choices=[0, 1],
113 | )
114 | return parser.parse_args()
115 |
116 | class Decomposition:
117 | def __init__(self, task_name, data_dir, val_split="validation"):
118 | self.task_name = task_name
119 | self.data_dir = data_dir
120 | self.val_split = val_split
121 |
122 | def read_data(self, save_dir, overwrite_data):
123 | save_data = Path(f"{save_dir}/{self.task_name}/data.feather")
124 | if not save_data.exists() or overwrite_data:
125 | test_data = pd.read_feather(f"{self.data_dir}/{self.val_split}.feather")
126 | test_data.to_feather(f"{save_data}")
127 | else:
128 | print(f"Reading test data from {save_data}")
129 | test_data = pd.read_feather(save_data)
130 |
131 | save_data = Path(f"{save_dir}/{self.task_name}/train_data.feather")
132 | if not save_data.exists() or overwrite_data:
133 | train_data = pd.read_feather(f"{self.data_dir}/train.feather")
134 | else:
135 | print(f"Reading train data from {save_data}")
136 | train_data = pd.read_feather(save_data)
137 | print(f"Test Data Size: {len(test_data)}")
138 | print(f"Train Data Size: {len(train_data)}")
139 | return test_data, train_data
140 |
141 | def get_few_shot_examples(self, train_data, k_shot):
142 | """Get few shot examples"""
143 | return train_data.sample(k_shot)
144 |
145 | def get_boost_decomp_examples(self, train_data, boost_i=0):
146 | """Get boost examples"""
147 | raise NotImplementedError()
148 |
149 | def zero_few_baseline(
150 | self, test_data, few_shot_df, manifest, overwrite_manifest, do_few_shot=True
151 | ):
152 | """Zero and few shot baseline"""
153 | raise NotImplementedError()
154 |
155 | def run_decomposed_prompt(
156 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
157 | ):
158 | """Decomposition run"""
159 | raise NotImplementedError()
160 |
161 | def merge_boosted_preds(self, boosted_preds, all_boost_train_preds, train_labels, exp_log, expt_log_train, indecisive_ans=None):
162 | """Merge boosted preds"""
163 | if isinstance(boosted_preds, list):
164 | boosted_preds = np.array(boosted_preds)
165 | if isinstance(all_boost_train_preds, list):
166 | all_boost_train_preds = np.array(all_boost_train_preds)
167 | if isinstance(train_labels, list):
168 | train_labels = np.array(train_labels)
169 |
170 | uniq = np.unique(boosted_preds)
171 | pred_map = {}
172 | if "yes" in uniq:
173 | pred_map = {"yes": 1, "no": -1, "neither": 0}
174 | elif "true" in uniq:
175 | pred_map = {"true": 1, "false": -1, "neither": 0}
176 | elif "positive" in uniq:
177 | pred_map = {"positive": 1, "negative": -1, "neutral": 0}
178 | pred_map_inv = {v:k for k,v in pred_map.items()}
179 | use_pred_map = False
180 | if all(p.lower() in pred_map for p in uniq):
181 | use_pred_map = True
182 | if use_pred_map:
183 | # Cast to integers
184 | boosted_preds = np.array([[pred_map[p.lower()] for p in preds] for preds in boosted_preds])
185 | all_boost_train_preds = np.array(
186 | [[pred_map[p.lower()] for p in preds] for preds in all_boost_train_preds]
187 | )
188 | train_labels = np.array([pred_map[p.lower()] for p in train_labels])
189 | if indecisive_ans:
190 | indecisive_ans = pred_map[indecisive_ans.lower()]
191 |
192 | # Take majority vote
193 | preds_test = []
194 | for i, voter_preds in enumerate(boosted_preds):
195 | most_common = Counter(voter_preds).most_common(1)[0]
196 | if indecisive_ans and len(voter_preds) > 1 and most_common[1] == 1:
197 | majority_vote_pred = indecisive_ans
198 | else:
199 | majority_vote_pred = most_common[0]
200 | if use_pred_map:
201 | majority_vote_pred = pred_map_inv[majority_vote_pred]
202 | preds_test.append(majority_vote_pred)
203 | exp_log[i]["pred"] = majority_vote_pred
204 |
205 | # Take majority vote
206 | preds_train = []
207 | for i, voter_preds in enumerate(all_boost_train_preds):
208 | most_common = Counter(voter_preds).most_common(1)[0]
209 | if indecisive_ans and len(voter_preds) > 1 and most_common[1] == 1:
210 | majority_vote_pred = indecisive_ans
211 | else:
212 | majority_vote_pred = most_common[0]
213 | if use_pred_map:
214 | majority_vote_pred = pred_map_inv[majority_vote_pred]
215 | preds_train.append(majority_vote_pred)
216 | expt_log_train[i]["pred"] = majority_vote_pred
217 | return preds_test
218 |
219 | def run(self, args):
220 | print(json.dumps(vars(args), indent=4))
221 |
222 | random.seed(args.seed)
223 | np.random.seed(args.seed)
224 | save_path = Path(f"{args.save_dir}/{self.task_name}")
225 | save_path.mkdir(parents=True, exist_ok=True)
226 | data_test, data_train = self.read_data(args.save_dir, bool(args.overwrite_data))
227 | # Subsample train for boost exps
228 | if args.boost_train_examples >= 0:
229 | boost_data_train = data_train.head(min(len(data_train), args.boost_train_examples))
230 | else:
231 | boost_data_train = data_train
232 | # Reset indexes for enumerations
233 | boost_data_train = boost_data_train.reset_index(drop=True)
234 | data_test = data_test.reset_index(drop=True)
235 | data_train = data_train.reset_index(drop=True)
236 | num_run = (
237 | min(args.num_run, len(data_test)) if args.num_run > 0 else len(data_test)
238 | )
239 | save_results = True
240 | if num_run != len(data_test):
241 | print("Using {} rows".format(num_run))
242 | data_test = data_test.iloc[:num_run]
243 | save_results = False
244 |
245 | runner, model_name = get_manifest_session(
246 | client_name=args.client_name,
247 | client_engine=args.client_engine,
248 | client_connection=args.client_connection,
249 | cache_connection=args.cache_connection,
250 | )
251 |
252 | model_name = model_name.replace("/", "_")
253 | print("Model name:", model_name)
254 |
255 | # Read in few shot examples
256 | few_shot_path = save_path /f"{args.k_shot}_shot_examples.feather"
257 | if bool(args.overwrite_data) or not few_shot_path.exists():
258 | mini_df = self.get_few_shot_examples(data_train, args.k_shot)
259 | mini_df.reset_index().to_feather(few_shot_path)
260 | else:
261 | print(f"Reading few show examples from {few_shot_path}")
262 | mini_df = pd.read_feather(few_shot_path)
263 |
264 | # Read in few shot decomp examples - one data frame per decomp step
265 | boost_examples = []
266 | for i in range(args.num_boost):
267 | boost_examples_per_step = []
268 | # Get all steps
269 | boost_examples_paths = list(save_path.glob(f"boost_examples_{i}_step*.feather"))
270 | if bool(args.overwrite_boost_exs) or not boost_examples_paths or not all(p.exists() for p in boost_examples_paths):
271 | boost_df_steps = self.get_boost_decomp_examples(data_train, boost_id=i)
272 | if not isinstance(boost_df_steps, list) or not isinstance(
273 | boost_df_steps[0], pd.DataFrame
274 | ):
275 | raise ValueError("Must return list of dataframes, one per step")
276 | for step, boost_df in enumerate(boost_df_steps):
277 | boost_df.reset_index().to_feather(save_path / f"boost_examples_{i}_step{step}.feather")
278 | print(f"Saving boost examples to", save_path / f"boost_examples_{i}_step{step}.feather")
279 | boost_examples_per_step.append(boost_df)
280 | else:
281 | for boost_examples_p in sorted(boost_examples_paths):
282 | print(f"Reading boost examples from {boost_examples_p}")
283 | boost_examples_per_step.append(pd.read_feather(boost_examples_p))
284 | boost_examples.append(boost_examples_per_step)
285 |
286 | today = datetime.datetime.today().strftime("%m%d%Y")
287 |
288 | # Default metrics
289 | metric_zero = -1.0
290 | metric_few = -1.0
291 | metric_decomposed = -1.0
292 | metric_decomposed_by_boost = []
293 | metric_zeroshot_decomposed = -1.0
294 |
295 | if bool(args.run_zeroshot):
296 | # Zero Shot
297 | run_name = f"{model_name}_0shot"
298 | exp_zero, metric_zero = self.zero_few_baseline(
299 | test_data=data_test,
300 | few_shot_df=mini_df,
301 | manifest=runner,
302 | overwrite_manifest=args.overwrite_manifest,
303 | do_few_shot=False,
304 | )
305 | if save_results:
306 | save_log(self.task_name, run_name, exp_zero, args.save_dir)
307 |
308 | if bool(args.run_fewshot):
309 | # Few Shot
310 | run_name = f"{model_name}_{args.k_shot}shot"
311 | exp_few, metric_few = self.zero_few_baseline(
312 | test_data=data_test,
313 | few_shot_df=mini_df,
314 | manifest=runner,
315 | overwrite_manifest=args.overwrite_manifest,
316 | do_few_shot=True,
317 | )
318 | if save_results:
319 | save_log(self.task_name, run_name, exp_few, args.save_dir)
320 |
321 | if bool(args.run_decomp):
322 | # Decomp
323 | run_name = f"{model_name}_decomposed_{today}"
324 | exp_decomposed, exp_decomposed_train, metric_decomposed, metric_decomposed_by_boost = self.run_decomposed_prompt(
325 | test_data=data_test, boost_data_train=boost_data_train, boost_dfs=boost_examples, manifest=runner, overwrite_manifest=args.overwrite_manifest
326 | )
327 | if save_results:
328 | save_log(
329 | self.task_name,
330 | run_name,
331 | exp_decomposed,
332 | args.save_dir
333 | )
334 | if exp_decomposed_train:
335 | save_log(
336 | self.task_name,
337 | f"{run_name}_train",
338 | exp_decomposed_train,
339 | args.save_dir
340 | )
341 |
342 | # Zero shot decomp
343 | exp_zeroshot_decomposed = []
344 | if bool(args.run_zeroshot_decomp):
345 | run_name = f"{model_name}_decomposed_0shot_{today}"
346 | (
347 | exp_zeroshot_decomposed,
348 | exp_zeroshot_decomposed_train,
349 | metric_zeroshot_decomposed,
350 | _,
351 | ) = self.run_decomposed_prompt(
352 | test_data=data_test, boost_data_train=boost_data_train, boost_dfs=[[pd.DataFrame() for _ in range(len(boost_examples[0]))]], manifest=runner, overwrite_manifest=args.overwrite_manifest,
353 | )
354 | if save_results and len(exp_zeroshot_decomposed) > 0:
355 | save_log(
356 | self.task_name,
357 | run_name,
358 | exp_zeroshot_decomposed,
359 | args.save_dir,
360 | )
361 | if exp_zeroshot_decomposed_train:
362 | save_log(
363 | self.task_name,
364 | f"{run_name}_train",
365 | exp_zeroshot_decomposed_train,
366 | args.save_dir,
367 | )
368 | print("Accuracy Zero Shot", metric_zero)
369 | print("Accuracy Few Shot", metric_few)
370 | if len(metric_decomposed_by_boost) > 0:
371 | print("Accuracy by Boost Set Decomposed", metric_decomposed_by_boost)
372 | print("Accuracy by Boost Set Decomposed Average", np.mean(metric_decomposed_by_boost))
373 | print("Accuracy Boost Decomposed", metric_decomposed)
374 | if len(exp_zeroshot_decomposed) > 0:
375 | print("Accuracy Zero Shot Decomposed", metric_zeroshot_decomposed)
376 |
377 | metrics = {
378 | "model_name": model_name,
379 | "task_name": self.task_name,
380 | "today": today,
381 | "zero_shot": metric_zero,
382 | "few_shot": metric_few,
383 | "decomposed": metric_decomposed,
384 | "decomposed_by_boost": metric_decomposed_by_boost,
385 | "decomposed_by_boost_avg": np.mean(metric_decomposed_by_boost),
386 | "zero_shot_decomposed": metric_zeroshot_decomposed,
387 | }
388 | output_metrics = Path(args.output_metrics_file)
389 | output_metrics.parent.mkdir(parents=True, exist_ok=True)
390 | with open(output_metrics, "a") as f:
391 | f.write(json.dumps(metrics) + "\n")
392 | print(f"Saved metrics to {output_metrics}")
393 | print(f"Saved final data to", Path(args.save_dir) / self.task_name)
394 |
--------------------------------------------------------------------------------
/tasks/drop_final.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | from pathlib import Path
4 | import pandas as pd
5 | import numpy as np
6 | from tqdm.auto import tqdm
7 | from datasets import load_dataset
8 |
9 | from decomposition import Decomposition, get_args
10 | from utils import get_response, text_f1, InputOutputPrompt, load_hf_data
11 |
12 | extract = InputOutputPrompt(
13 | input_formatter=lambda x: f"Context: {x['context']}\nQuestion: {x['question']}",
14 | output_formatter=lambda x: f"Answer: {x['answer']}",
15 | required_keys=["context", "question", "answer"],
16 | input_output_sep="\n",
17 | example_sep="\n\n",
18 | instruction="Answer the question. If there is no evidence in the context, return \"Unknown\".\n\n"
19 | )
20 |
21 | extract_examples = [
22 | pd.DataFrame([
23 | {
24 | "context": "According to Biraben, the plague was present somewhere in Europe in every year between 1346 and 1671",
25 | "question": "Where was the plague present?",
26 | "answer": "somewhere in Europe"
27 | },
28 | {
29 | "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.",
30 | "question": "What's one factor in increasing self-esteem?",
31 | "answer": "Unknown"
32 | },
33 | {
34 | "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.",
35 | "question": "What is another name for anti-matter?",
36 | "answer": "Unknown"
37 | }
38 | ]),
39 | pd.DataFrame([
40 | {
41 | "context": "According to Biraben, the plague was present somewhere in Europe in every year between 1346 and 1671",
42 | "question": "Where was the plague present?",
43 | "answer": "somewhere in Europe"
44 | },
45 | {
46 | "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.",
47 | "question": "What is another name for anti-matter?",
48 | "answer": "Unknown"
49 | },
50 | {
51 | "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.",
52 | "question": "What's one factor in increasing self-esteem?",
53 | "answer": "Unknown"
54 | },
55 | ]),
56 | pd.DataFrame([
57 | {
58 | "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.",
59 | "question": "What's one factor in increasing self-esteem?",
60 | "answer": "Unknown"
61 | },
62 | {
63 | "context": "According to Biraben, the plague was present somewhere in Europe in every year between 1346 and 1671",
64 | "question": "Where was the plague present?",
65 | "answer": "somewhere in Europe"
66 | },
67 | {
68 | "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.",
69 | "question": "What is another name for anti-matter?",
70 | "answer": "Unknown"
71 | }
72 | ]),
73 |
74 | ]
75 |
76 | prefix_select_zeroshot = """Answer the question. If there is no evidence in the context, return "Unknown".\n\n"""
77 |
78 |
79 | class DropDecomp(Decomposition):
80 | def __init__(self, task_name, data_dir, val_split="validation"):
81 | super().__init__(task_name, data_dir, val_split)
82 |
83 | def read_data(self, save_dir, overwrite_data):
84 | return load_hf_data(save_dir, self.task_name, self.val_split, "drop", overwrite_data)
85 |
86 | def get_boost_decomp_examples(self, train_data, boost_id):
87 | return [
88 | extract_examples[boost_id],
89 | ]
90 |
91 | def get_few_shot_examples(self, train_data, k_shot):
92 | """Get few shot examples"""
93 | labels = []
94 | for x in train_data.answers_spans:
95 | if len(x['spans']) == 0:
96 | labels.append("unknown")
97 | else:
98 | labels.append(x['spans'][0])
99 | train_data['expanded_labels'] = labels
100 | labels = ["unknown"] + list(sorted(set(labels)))
101 |
102 | num_per_class = int(np.ceil(k_shot / len(labels)))
103 | print(f"Selecting {num_per_class} examples per class.")
104 |
105 | dfs = []
106 | total_in_context = 0
107 | for label in labels:
108 | while num_per_class + total_in_context > k_shot:
109 | num_per_class -= 1
110 | sub_df = train_data[train_data["expanded_labels"] == label].sample(num_per_class)
111 | dfs.append(sub_df)
112 | total_in_context += num_per_class
113 | if total_in_context == k_shot:
114 | break
115 | mini_df = pd.concat(dfs)
116 | return mini_df
117 |
118 | def zero_few_baseline(
119 | self,
120 | test_data,
121 | few_shot_df,
122 | manifest,
123 | overwrite_manifest,
124 | do_few_shot=True,
125 | ):
126 | expt_log = {}
127 | preds = []
128 | labels = []
129 |
130 | for i, (ind, row) in tqdm(
131 | enumerate(test_data.iterrows()), total=len(test_data)
132 | ):
133 | if ind in expt_log:
134 | entry = expt_log[ind]
135 | pred = entry["pred"]
136 | gold = entry["gold"]
137 |
138 | else:
139 | text = row.passage
140 | question = row.question
141 | if len(row.answers_spans["spans"]) == 0:
142 | label = "unknown"
143 | else:
144 | label = row.answers_spans["spans"][0]
145 | gold = label.lower()
146 |
147 | icl_str = ""
148 | if do_few_shot:
149 | for s_ind, s_row in few_shot_df.iterrows():
150 | input = s_row.passage
151 | s_question = s_row.question
152 | if len(s_row.answers_spans["spans"]) == 0:
153 | label = "unknown"
154 | else:
155 | label = s_row.answers_spans["spans"][0]
156 | icl_str += f"Passage: {input}\nQuestion: {s_question}\nAnswer: {label}\n\n"
157 |
158 | prompt = (
159 | icl_str
160 | + "Passage: {text:}\nQuestion: {question:}"
161 | + "\nAnswer:"
162 | )
163 | pmp = prompt.format(text=text, question=question)
164 | if i == 0:
165 | print(prompt.format(text=text, question=question))
166 |
167 | raw_answer = get_response(
168 | pmp,
169 | manifest,
170 | overwrite=bool(overwrite_manifest),
171 | max_toks=10,
172 | stop_token="\n\n",
173 | )
174 | pred = raw_answer.strip("\n").strip().lower()
175 | entry = {
176 | "ind": ind,
177 | "example": text,
178 | "base_prompt": pmp,
179 | "raw_answer": raw_answer,
180 | "pred": pred,
181 | "gold": gold,
182 | }
183 | expt_log[ind] = entry
184 |
185 | preds.append(pred)
186 | labels.append(gold)
187 | metric = text_f1(preds=preds, golds=labels)
188 | return expt_log, metric
189 |
190 | def run_decomposed_prompt(
191 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
192 | ):
193 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest)
194 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1)
195 | # Do WS
196 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train)
197 | # Get accuracies across all boost sets
198 | individual_accuracies = []
199 | for i in range(len(all_boost_preds[0])):
200 | individual_accuracies.append(text_f1(preds=[p[i] for p in all_boost_preds], golds=labels))
201 | metric = text_f1(preds=preds, golds=labels)
202 | return expt_log, expt_log_train, metric, individual_accuracies
203 |
204 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1):
205 | expt_log = {}
206 | all_boost_preds = []
207 | labels = []
208 |
209 | for i, (ind, row) in tqdm(
210 | enumerate(test_data.iterrows()), total=len(test_data)
211 | ):
212 | if i == run_limit:
213 | break
214 |
215 | text = row.passage
216 | question = row.question
217 | if len(row.answers_spans["spans"]) == 0:
218 | label = "unknown"
219 | else:
220 | label = row.answers_spans["spans"][0]
221 | gold = label.lower()
222 | prompts_across_boost = []
223 | preds_across_boost = []
224 | for boost_examples in boost_dfs:
225 | prompt_suffix = extract(boost_examples[0])
226 | prompt = (
227 | prompt_suffix + "\n\nContext: {text:}\nQuestion: {question:}\nAnswer:"
228 | )
229 | pmp = prompt.format(text=text, question=question)
230 | if i == 0:
231 | print(pmp)
232 |
233 | raw_answer = get_response(
234 | pmp,
235 | manifest,
236 | overwrite=bool(overwrite_manifest),
237 | max_toks=50,
238 | stop_token="\n\n",
239 | )
240 | pred = raw_answer.split("\n")[0].replace('"', "").strip().lower()
241 | # Single list pmp for one step decomp
242 | prompts_across_boost.append([pmp])
243 | preds_across_boost.append(pred)
244 |
245 | entry = {
246 | "ind": ind,
247 | "example": text,
248 | "prompts": prompts_across_boost,
249 | "preds_boost": preds_across_boost,
250 | "gold": gold,
251 | }
252 | expt_log[ind] = entry
253 | all_boost_preds.append(preds_across_boost)
254 | labels.append(gold)
255 | return expt_log, all_boost_preds, labels
256 |
257 |
258 | def main():
259 | args = get_args()
260 | task_name = "drop"
261 | data_dir = "drop"
262 | wic = DropDecomp(task_name, data_dir)
263 | wic.run(args)
264 |
265 |
266 | if __name__ == "__main__":
267 | main()
268 |
--------------------------------------------------------------------------------
/tasks/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from collections import Counter
3 | import json
4 | from datasets import load_dataset
5 | import re
6 | import pandas as pd
7 | from typing import Callable, List
8 | from manifest import Manifest
9 |
10 |
11 | class InputOutputPrompt:
12 | def __init__(self,
13 | input_formatter: Callable,
14 | output_formatter: Callable,
15 | required_keys: List,
16 | input_output_sep: str = "\n",
17 | example_sep: str = "\n\n",
18 | instruction: str = ""
19 | ):
20 | self.input_formatter = input_formatter
21 | self.output_formatter = output_formatter
22 | self.required_keys = required_keys
23 | self.input_output_sep = input_output_sep
24 | self.example_sep = example_sep
25 | self.instruction = instruction
26 |
27 | def __call__(self, input_output_pairs: pd.DataFrame):
28 | examples = []
29 | for _, example in input_output_pairs.iterrows():
30 | examples.append(f"{self.input_formatter(example)}{self.input_output_sep}{self.output_formatter(example)}")
31 | if examples:
32 | input_str = self.example_sep.join(examples)
33 | res = f"{self.instruction}{input_str}"
34 | else:
35 | res = f"{self.instruction}".rstrip()
36 | return res
37 |
38 | def __repr__(self):
39 | dummy_ex = pd.DataFrame([{k: f"<{k.upper()}>" for k in self.required_keys}])
40 | st = self(dummy_ex)
41 | return st
42 |
43 |
44 | def prefix_formatter(ex_keys: List[str], prefix: str, error_on_empty: bool = True) -> str:
45 | def full_prefix_formatter(ex: pd.Series):
46 | for k in ex_keys:
47 | if k in ex:
48 | return f"{prefix} {getattr(ex, k)}"
49 | if error_on_empty:
50 | raise ValueError(f"Example {ex} has no value for any of the keys {ex_keys}")
51 | else:
52 | return f"{prefix}"
53 | return full_prefix_formatter
54 |
55 |
56 | def get_manifest_session(
57 | client_name="huggingface",
58 | client_engine=None,
59 | client_connection="http://127.0.0.1:5000",
60 | cache_connection=None,
61 | temperature=0,
62 | top_p=1.0,
63 | ):
64 | if client_name == "huggingface" and temperature == 0:
65 | params = {
66 | "temperature": 0.001,
67 | "do_sample": False,
68 | "top_p": top_p,
69 | }
70 | elif client_name in {"openai", "ai21"}:
71 | params = {
72 | "temperature": temperature,
73 | "top_p": top_p,
74 | "engine": client_engine,
75 | }
76 | else:
77 | raise ValueError(f"{client_name} is not a valid client name")
78 | manifest = Manifest(
79 | client_name=client_name,
80 | client_connection=client_connection,
81 | cache_name="sqlite",
82 | cache_connection=cache_connection,
83 | session_id=None,
84 | **params,
85 | )
86 | params = manifest.client.get_model_params()
87 | model_name = params["model_name"]
88 | if "engine" in params:
89 | model_name += f"_{params['engine']}"
90 | return manifest, model_name
91 |
92 |
93 | def get_response(
94 | prompt,
95 | manifest,
96 | overwrite=False,
97 | max_toks=10,
98 | stop_token=None,
99 | gold_choices=[],
100 | verbose=False,
101 | ):
102 | prompt = prompt.strip()
103 | if gold_choices:
104 | gold_choices = [" " + g.strip() for g in gold_choices]
105 | response_obj = manifest.run(
106 | prompt, gold_choices=gold_choices, overwrite_cache=overwrite, return_response=True
107 | )
108 | response_obj = response_obj.get_json_response()["choices"][0]
109 | log_prob = response_obj["text_logprob"]
110 | response = response_obj["text"]
111 | else:
112 | response = manifest.run(
113 | prompt,
114 | max_tokens=max_toks,
115 | stop_token=stop_token,
116 | overwrite_cache=overwrite,
117 | )
118 | log_prob = None
119 | if verbose:
120 | print("\n***Prompt***\n", prompt)
121 | print("\n***Response***\n", response)
122 | if log_prob:
123 | return response, log_prob
124 | return response
125 |
126 | def load_hf_data(save_dir, task_name, val_split, hf_name, overwrite_data):
127 | save_data = Path(f"{save_dir}/{task_name}/data.feather")
128 | if not save_data.exists() or overwrite_data:
129 | dataset = load_dataset(hf_name)
130 | test_data = dataset[val_split].to_pandas()
131 | test_data.to_feather(f"{save_data}")
132 | else:
133 | print(f"Reading test data from {save_data}")
134 | test_data = pd.read_feather(f"{save_data}")
135 |
136 | save_data_train = Path(f"{save_dir}/{task_name}/train_data.feather")
137 | if not save_data_train.exists() or overwrite_data:
138 | dataset = load_dataset(hf_name)
139 | train_data = dataset["train"].to_pandas()
140 | train_data.to_feather(f"{save_data_train}")
141 | else:
142 | print(f"Reading train data from {save_data_train}")
143 | train_data = pd.read_feather(f"{save_data_train}")
144 |
145 | print(f"Test Data Size: {len(test_data)}")
146 | print(f"Train Data Size: {len(train_data)}")
147 | return test_data, train_data
148 |
149 | def save_log(task_name, expt_name, log, final_run_dir):
150 | final_run_dir = Path(final_run_dir)
151 | output_fpath = final_run_dir / task_name
152 | output_fpath.mkdir(parents=True, exist_ok=True)
153 |
154 | print("Saving to", output_fpath / f"{expt_name}.json")
155 | assert all(a in list(log.values())[0].keys() for a in ["ind","example","pred","gold"])
156 | with open(output_fpath / f"{expt_name}.json", "w") as f:
157 | json.dump(log, f)
158 |
159 | def text_f1(preds, golds):
160 | """Compute average F1 of text spans.
161 | Taken from Squad without prob threshold for no answer.
162 | """
163 | total_f1 = 0
164 | for pred, gold in zip(preds, golds):
165 | pred_toks = pred.split()
166 | gold_toks = gold.split()
167 | common = Counter(pred_toks) & Counter(gold_toks)
168 | num_same = sum(common.values())
169 | if len(gold_toks) == 0 or len(pred_toks) == 0:
170 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
171 | total_f1 += int(gold_toks == pred_toks)
172 | elif num_same == 0:
173 | total_f1 += 0
174 | else:
175 | precision = 1.0 * num_same / len(pred_toks)
176 | recall = 1.0 * num_same / len(gold_toks)
177 | f1 = (2 * precision * recall) / (precision + recall)
178 | total_f1 += f1
179 | f1_avg = total_f1 / len(golds)
180 | return f1_avg
181 |
182 | def accuracy_span_overlap(preds, golds):
183 | correct = 0
184 | for pred, gold in zip(preds, golds):
185 | found = False
186 | for p in pred:
187 | for g in gold:
188 | if len(p) < len(g):
189 | if p.lower() in g.lower():
190 | found = True
191 | break
192 | else:
193 | if g.lower() in p.lower():
194 | found = True
195 | break
196 | if found: correct += 1
197 | return correct / len(preds)
198 |
199 |
200 |
--------------------------------------------------------------------------------
/tasks/webq_final.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import pandas as pd
4 | import numpy as np
5 | import ast
6 | from tqdm.auto import tqdm
7 | from decomposition import Decomposition, get_args
8 | from utils import get_response, InputOutputPrompt, accuracy_span_overlap, load_hf_data
9 |
10 | extract = InputOutputPrompt(
11 | input_formatter=lambda x: f"Question: {x['question']}",
12 | output_formatter=lambda x: f"Answer: {x['answer']}",
13 | required_keys=["question", "answer"],
14 | input_output_sep="\n",
15 | example_sep="\n\n",
16 | instruction="Produce distinct questions.\n\n"
17 | )
18 |
19 | more_info_examples = [
20 | pd.DataFrame([
21 | {
22 | "question": "who plays Carrie Bradshaw in sex and the city?",
23 | "answer": "Caroline \"Carrie\" Bradshaw is a fictional character from the HBO franchise Sex and the City, portrayed by Sarah Jessica Parker."
24 | },
25 | {
26 | "question": "what are the elements in air?",
27 | "answer": "By mole fraction (i.e., by number of molecules), dry air contains 78.08% nitrogen, 20.95% oxygen, 0.93% argon, 0.04% carbon dioxide, and small amounts of other gases"
28 | },
29 | {
30 | "question": "what is HP company?",
31 | "answer": "HP Inc. is an American multinational information technology company headquartered in Palo Alto, California, that develops personal computers (PCs)"
32 | },
33 | {
34 | "question": "when was the last season of FRIENDS released?",
35 | "answer": "The series finale aired on May 6, 2004, and was watched by around 52.5 million American viewers, making it the fifth-most-watched series finale in television history"
36 | }
37 | ]),
38 |
39 | ]
40 |
41 | answer = InputOutputPrompt(
42 | input_formatter=lambda x: f"Context: {x['context']}\nQuestion: {x['question']}",
43 | output_formatter=lambda x: f"Answer: {x['answer']}",
44 | required_keys=["context", "question", "answer"],
45 | input_output_sep="\n",
46 | example_sep="\n\n",
47 | instruction="Answer the question.\n\n"
48 | )
49 |
50 | answer_question = [
51 | pd.DataFrame([
52 | {
53 | 'context': 'The nearest airport to Palm Springs is Indio/Palm Springs (PSP) Airport which is 2.1 miles away. ',
54 | 'question': 'what airport is closest to palm springs?',
55 | 'answer': 'Palm Springs International Airport'
56 | },
57 | {
58 | 'context': 'Martin Luther King earned his Bachelor of Divinity degree from Crozer Theological Seminary, followed by a doctorate in Systematic Theology from Boston University.',
59 | 'question': 'what degree did martin luther king get?',
60 | 'answer': 'Bachelor of Divinity'
61 | },
62 | {
63 | 'context': 'The Niger river runs in a crescent through Libya, Mali, Niger, on the border with Benin and then through Nigeria.',
64 | 'question': 'what countries does the niger river flow through?',
65 | 'answer': 'Libya'
66 | },
67 | {
68 | 'context': 'Puerto Rico is a territory of the United States and uses the U.S. dollar. ',
69 | 'question': 'what type of currency is used in puerto rico?',
70 | 'answer': 'United States dollar'
71 | },
72 | {
73 | 'context': 'kitt was voice most often by William daniels.',
74 | 'question': 'who played kitt in knight rider?',
75 | 'answer': 'William Daniels'
76 | }
77 | ]),
78 | pd.DataFrame([
79 | {
80 | 'context': 'leonardo da vinci invented the parachute, the helicopter, double hull, an armored fighting vehicle,',
81 | 'question': 'what inventions did leonardo da vinci made?',
82 | 'answer': 'Double hull'
83 | },
84 | {
85 | 'context': "The French franc (F) was the national currency of France prior to France's adoption of the euro (EUR) in January 2002.",
86 | 'question': 'what currency is used in france before euro?',
87 | 'answer': 'French franc'
88 | },
89 | {
90 | 'context': 'The Isthmus of Panama, contains the country of Panama and the panama canal.',
91 | 'question': 'where is isthmus of panama located?',
92 | 'answer': 'Costa Rica'
93 | },
94 | {
95 | 'context': 'Hurricane Irene was a large and destructive tropical cyclone which affected much of the Caribbean and East Coast',
96 | 'question': 'where did hurricane irene?',
97 | 'answer': 'Eastern United States'
98 | },
99 | {
100 | 'context': 'Rihanna acted in This is the End and Battleship.',
101 | 'question': 'what movie did rihanna play in?',
102 | 'answer': 'This Is the End'
103 | }
104 | ]),
105 | pd.DataFrame([
106 | {
107 | 'context': 'st vincent de paul is buried in the 6th arrondisment of Paris.',
108 | 'question': 'where is st vincent de paul buried?',
109 | 'answer': 'Paris'
110 | },
111 | {
112 | 'context': 'Thomas Luther "Luke" Bryan (born July 17, 1976) is an American country singer and songwriter from Leesburg.',
113 | 'question': 'where is luke bryan from?',
114 | 'answer': 'Leesburg'
115 | },
116 | {
117 | 'context': "Klum and Seal got married on 10 May 2005 on a beach in Mexico near Seal's home on Costa Careyes. ",
118 | 'question': 'where did heidi klum and seal get married?',
119 | 'answer': 'Mexico'},
120 | {
121 | 'context': 'Tarantino starred in pulp fiction, grindhouse and others.',
122 | 'question': 'what movies did quentin tarantino star in?',
123 | 'answer': 'Grindhouse'
124 | },
125 | {
126 | 'context': 'Countries that are sometimes considered to be entirely or partially part of the Balkans are Croatia, Serbia, Lake Prespa.',
127 | 'question': 'what country is located in the balkan peninsula?',
128 | 'answer': 'Lake Prespa'
129 | }
130 | ])
131 |
132 | ]
133 |
134 | prefix_select_zeroshot = """Answer the question.\n\n"""
135 |
136 |
137 | class WebQDecomp(Decomposition):
138 | def __init__(self, task_name, data_dir, val_split="validation"):
139 | super().__init__(task_name, data_dir, val_split)
140 |
141 | def read_data(self, save_dir, overwrite_data):
142 | return load_hf_data(save_dir, self.task_name, self.val_split, "web_questions", overwrite_data)
143 |
144 | def get_boost_decomp_examples(self, train_data, boost_id):
145 | return [
146 | more_info_examples[0],
147 | answer_question[boost_id],
148 | ]
149 |
150 | def zero_few_baseline(
151 | self,
152 | test_data,
153 | few_shot_df,
154 | manifest,
155 | overwrite_manifest,
156 | prompt_suffix="",
157 | do_few_shot=True,
158 | ):
159 | expt_log = {}
160 | preds = []
161 | labels = []
162 |
163 | for i, (ind, row) in tqdm(
164 | enumerate(test_data.iterrows()), total=len(test_data)
165 | ):
166 | if ind in expt_log:
167 | entry = expt_log[ind]
168 | pred = entry["pred"]
169 | gold = entry["gold"]
170 |
171 | else:
172 | question = row.question
173 | if isinstance(row.answers, str):
174 | label = ast.literal_eval(row.answers)
175 | else:
176 | label = row.answers.tolist()
177 | gold = label
178 |
179 | icl_str = ""
180 | if do_few_shot:
181 | for s_ind, s_row in few_shot_df.iterrows():
182 | s_question = s_row.question
183 | if isinstance(s_row.answers, str):
184 | label = ast.literal_eval(s_row.answers)
185 | else:
186 | label = s_row.answers.tolist()
187 | icl_str += f"Question: {s_question}\nAnswer: {label[0]}\n\n"
188 |
189 | prompt = (
190 | icl_str
191 | + "Question: {question:}"
192 | + prompt_suffix
193 | + "\nAnswer:"
194 | )
195 |
196 | if i == 0:
197 | print(prompt.format(question=question))
198 | prompt = prompt.format(question=question)
199 | raw_answer = get_response(
200 | prompt, #prompt.format(question=question),
201 | manifest,
202 | overwrite=bool(overwrite_manifest),
203 | max_toks=20,
204 | stop_token="\n\n",
205 | )
206 | pred = raw_answer.strip("\n").strip().lower()
207 | entry = {
208 | "ind": ind,
209 | "example": question,
210 | "base_prompt": prompt,
211 | "raw_answer": raw_answer,
212 | "pred": pred,
213 | "gold": gold,
214 | }
215 | expt_log[ind] = entry
216 |
217 | preds.append([pred])
218 | labels.append(gold)
219 | metric = accuracy_span_overlap(preds=preds, golds=labels)
220 | return expt_log, metric
221 |
222 | def run_decomposed_prompt(
223 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
224 | ):
225 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest)
226 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000)
227 |
228 | # Do WS
229 | boost_test, boost_train = [], []
230 | for p in all_boost_preds:
231 | samples = [lf[1] for lf in p]
232 | boost_test.append(samples)
233 | for p in all_boost_train_preds:
234 | samples = [lf[1] for lf in p]
235 | boost_train.append(samples)
236 |
237 | preds = self.merge_boosted_preds(boost_test, boost_train, train_labels, expt_log, expt_log_train)
238 | preds = [(x,y) for x,y in zip([p[0][0] for p in all_boost_preds], preds)]
239 |
240 | # Get accuracies across all boost sets
241 | individual_accuracies = []
242 | for i in range(len(all_boost_preds[0])):
243 | individual_accuracies.append(accuracy_span_overlap(preds=[p[i] for p in all_boost_preds], golds=labels))
244 |
245 | metric = accuracy_span_overlap(preds=preds, golds=labels)
246 | return expt_log, expt_log_train, metric, individual_accuracies
247 |
248 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1):
249 | expt_log = {}
250 | all_boost_preds = []
251 | labels = []
252 |
253 | for i, (ind, row) in tqdm(
254 | enumerate(test_data.iterrows()), total=len(test_data)
255 | ):
256 | if i == run_limit:
257 | break
258 |
259 | question = row.question
260 | if isinstance(row.answers, str):
261 | label = ast.literal_eval(row.answers)
262 | else:
263 | label = row.answers.tolist()
264 |
265 | gold = label
266 | prompts_across_boost = []
267 | preds_across_boost = []
268 |
269 | # extract context
270 | prompt_suffix = extract(boost_dfs[0][0])
271 | prompt = (
272 | prompt_suffix + "\n\Question: {question:}\nAnswer:"
273 | )
274 | more_info_answer = get_response(
275 | prompt.format(question=question),
276 | manifest,
277 | overwrite=bool(overwrite_manifest),
278 | max_toks=20,
279 | stop_token="\n\n",
280 | )
281 |
282 |
283 | for boost_examples in boost_dfs:
284 | all_prompts = []
285 | prompt_suffix = answer(boost_examples[1])
286 | prompt = (
287 | prompt_suffix + "\n\nContext: {text:}\nQuestion: {question:}\nAnswer:"
288 | )
289 | if i == 0:
290 | print(prompt.format(text=more_info_answer, question=question))
291 |
292 | all_prompts.append(prompt.format(text=more_info_answer, question=question))
293 | raw_answer = get_response(
294 | prompt.format(text=more_info_answer, question=question),
295 | manifest,
296 | overwrite=bool(overwrite_manifest),
297 | max_toks=20,
298 | stop_token="\n\n",
299 | )
300 | pred = raw_answer.split("\n")[0].strip().lower()
301 | prompts_across_boost.append(all_prompts)
302 | preds_across_boost.append((more_info_answer, pred))
303 |
304 | entry = {
305 | "ind": ind,
306 | "example": question,
307 | "prompts": prompts_across_boost,
308 | "preds_boost": preds_across_boost,
309 | "gold": gold,
310 | }
311 | expt_log[ind] = entry
312 | all_boost_preds.append(preds_across_boost)
313 | labels.append(gold)
314 | return expt_log, all_boost_preds, labels
315 |
316 |
317 | def main():
318 | args = get_args()
319 | task_name = "webq"
320 | data_dir = "webq"
321 | webq = WebQDecomp(task_name, data_dir, val_split="test")
322 | webq.run(args)
323 |
324 |
325 | if __name__ == "__main__":
326 | main()
327 |
--------------------------------------------------------------------------------