├── .gitmodules ├── LICENSE ├── README.md ├── ablations ├── README.md └── T0_variants │ ├── CB_variants.py │ ├── RTE_variants.py │ ├── WIC_variants.py │ ├── WSC_variants.py │ ├── decomposition.py │ ├── run.sh │ └── utils.py ├── boosting ├── binary_deps.py ├── make_pgm.py ├── methods.py ├── pgm.py ├── run_ws.py └── utils.py ├── diagnostics ├── README.md ├── data.zip └── run_diagnostics.py ├── download_p3.py ├── imgs ├── crfm.png ├── decomp.png ├── numbers_station.png └── snorkel.png ├── requirements.txt └── tasks ├── AGNews_final.py ├── ANLIR1_final.py ├── ANLIR2_final.py ├── ANLIR3_final.py ├── Amazon_final.py ├── BoolQ_final.py ├── CB_final.py ├── COPA_final.py ├── DBPedia_final.py ├── MultiRC_final.py ├── NQ_final.py ├── RTE_final.py ├── ReCoRD_final.py ├── RealtimeQA_final.py ├── SST2_final.py ├── StoryCloze_final.py ├── WIC_final.py ├── WSC_final.py ├── decomposition.py ├── drop_final.py ├── utils.py └── webq_final.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "metal-ama"] 2 | path = metal-ama 3 | url = https://github.com/mayeechen/metal-ama.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ask Me Anything: A simple strategy for prompting language models 2 | 3 | ![GitHub](https://img.shields.io/github/license/HazyResearch/ama_prompting) 4 | [![Together AI](https://together.xyz/assets/images/ai_platform.svg)](https://together.xyz/) 5 | 6 | This repository contains code for the Ask Me Anything (AMA) prompt-aggregation strategy. The end-to-end AMA approach includes (1) recursively using the language model to transform the task format and prompt and (2) aggregating the predictions of multiple prompts using weak supervision. We include code for both components and pointers to the publicly downloadable datasets. See our [paper](https://arxiv.org/abs/2210.02441) for more details. 7 | 8 |

9 | 10 | 11 | 12 | ## Table of Contents 13 | - [Setup](#setup) 14 | - [Data](#getting-the-data) 15 | - [Running models](#models) 16 | - [Running experiments](#experiments) 17 | - [Repository Structure](#overall-repository-structure) 18 | - [Citation](#citation) 19 | 20 | 21 | ## Setup 22 | 23 | ### Installation 24 | Here we will setup the AMA code (prompting models for tasks), weak supervision code (aggregating predictions), and [Manifest](https://github.com/HazyResearch/manifest/) code (tooling for easily loading and running the models). 25 | 26 | We encourage the use of conda environments: 27 | ``` 28 | conda create --name ama python=3.8 29 | conda activate ama 30 | ``` 31 | 32 | Clone as follows: 33 | ```bash 34 | # Ask Me Anything code 35 | git clone git@github.com:HazyResearch/ama_prompting.git 36 | cd ama_prompting 37 | pip install -r requirements.txt 38 | 39 | # Weak supervision code 40 | cd metal-ama 41 | git submodule init 42 | git submodule update 43 | pip install -e . 44 | 45 | # Manifest 46 | git clone git@github.com:HazyResearch/manifest.git 47 | cd manifest 48 | pip install -e . 49 | ``` 50 | 51 | 52 | ### Getting the data 53 | We assume all data lives in the ```AMA_DATA``` environment variable. By default, this is set to ```/home/data```. To change this, run 54 | ```bash 55 | export AMA_DATA= 56 | ``` 57 | 58 | Please follow the instructions below to download all necessary data for experiments. 59 | 60 | 1. Download the PromptSource (P3) dataset from Hugging Face at https://huggingface.co/datasets/bigscience/P3. 61 | ```bash 62 | cd $AMA_DATA 63 | git lfs install 64 | git clone https://huggingface.co/datasets/bigscience/P3 65 | ``` 66 | Then run [ama_prompting/download_p3.py](./download_p3.py). We use the GPT3-Style prompts in the few-shot baseline for each benchmark. 67 | 68 | 2. We downloaded the remaining tasks from the following sources: 69 | * [AGNews, DBPedia, and SST2](https://github.com/tonyzhaozh/few-shot-learning) 70 | * [Amazon Products](https://github.com/allenai/flex/blob/75d6d1cea66df2c8a7e3d429c6af5008ccf1544b/fewshot/hf_datasets_scripts/amazon/amazon.py) 71 | * [Natural Questions and WebQs](https://github.com/facebookresearch/FiD) 72 | * [RealTimeQA](https://github.com/realtimeqa/realtimeqa_public/tree/main/past/2022) (GCS files from June 17th - July 22, 2022) 73 | * [ReCoRD](https://sheng-z.github.io/ReCoRD-explorer/) 74 | * [StoryCloze](http://goo.gl/forms/aQz39sdDrO) 75 | 76 | ### Running models 77 | We run inference on models using a tool called [Manifest](https://github.com/HazyResearch/manifest). This tool is useful because it caches your inference results and does not require reloading the model for each new run you launch. To load the EleutherAI GPT-j-6B model, in a Tmux session, run: 78 | ```bash 79 | python3 manifest/manifest/api/app.py \ 80 | --model_type huggingface \ 81 | --model_name_or_path EleutherAI/gpt-j-6B \ 82 | --device 0 83 | ``` 84 | It will take a few minutes for large models to load! To use a different model, replace ```EleutherAI/gpt-j-6B``` with the model name. See the Manifest repo for more information on loading other models. 85 | 86 | 87 | ## Experiments 88 | 89 | ### Collecting the prompting predictions 90 | 91 | To run a single task such as the Recognizing Textual Entailment (RTE) SuperGLUE benchmark, you can use the following steps. 92 | 93 | 1. Load a Manifest model using the above command 94 | 95 | 2. Run the following command. This will run the zero-shot baseline (```run_zeroshot = 1```), few-shot baseline (```run_fewshot = 1```) with $k$ in-context demonstrations (```k_shot = 3```), and the AMA baseline (```run_decomp = 1```). In AMA, we aggregate the predictions of multiple prompts-per-input. The number of prompts over which to aggregate is specified by ```num_boost```. 96 | 97 | ```bash 98 | python3 tasks/RTE_final.py \ 99 | --run_zeroshot 1 \ 100 | --run_fewshot 1 \ 101 | --run_decomp 1 \ 102 | --num_boost 5 \ 103 | --k_shot 3 \ 104 | --output_metrics_file ../ama_logs/metrics.json \ 105 | --cache_connection ../ama_logs/manifest_cache.sqlite \ 106 | --save_dir ../ama_logs/ama_final_runs 107 | ``` 108 | 109 | Please see the argparse in ```tasks/decomposition.py``` for other run options; for instance, to control Manifest's caching behavior. 110 | 111 | 3. The results of all baselines will be saved in `ama_final_runs/` (e.g., `` is `super_glue_rte` as seen in the `RTE_final.py` main function) and output all performance metrics to `metrics.json`. The output appears as follows: 112 | 113 | ``` 114 | Saving to ../ama_logs/ama_final_runs/super_glue_rte/EleutherAI_gpt-j-6B_decomposed_10052022.json 115 | Saving to ../ama_logs/ama_final_runs/super_glue_rte/EleutherAI_gpt-j-6B_decomposed_10052022_train.json 116 | Accuracy Few Shot 0.5884476534296029 117 | Accuracy by Boost Set Decomposed [0.592057761732852, 0.6209386281588448, 0.5848375451263538, 0.6678700361010831, 0.6173285198555957] 118 | Accuracy by Boost Set Decomposed Average 0.6166064981949458 119 | Accuracy Boost Decomposed 0.6642599277978339 120 | Saved metrics to ../ama_logs/metrics.json 121 | Saved final data to ../ama_logs/ama_final_runs/super_glue_rte 122 | ``` 123 | 124 | For the AMA baseline, which consists of ```num_boost``` prompt-chains, the metrics include the individual prompt-chain accuracies over the dataset ("Accuracy by Boost Set Decomposed"), average score ("Accuracy by Boost Set Decomposed Average"), and majority vote result ("Accuracy Boost Decomposed"). 125 | 126 | ### Running weak supervision 127 | 128 | 4. Next we aggregate over the predictions with weak supervision (WS). In order to run the WS algorithm on the predictions which were saved down in `ama_final_runs/super_glue_rte`, use the following command. By default, we assume the date of the log file is today. You can change it with the `--override_date` command. 129 | 130 | ```bash 131 | python3 boosting/run_ws.py \ 132 | --task_name super_glue_rte \ 133 | --data_dir ../ama_logs/ama_final_runs \ 134 | --model_prefix EleutherAI_gpt-j-6B \ 135 | --override_date 10052022 136 | ``` 137 | 138 | The output will include the following results: 139 | 140 | ``` 141 | # The code will first output results without modelling dependencies. 142 | 143 | Trained Label Model Metrics (No deps): 144 | Accuracy: 0.650 145 | Precision: 0.724 146 | Recall: 0.420 147 | F1: 0.531 148 | 149 | # For this task, the WS algorithm identifies a dependency between prompts 0 and 2. Next the code outputs results after modelling dependencies, if dependencies are recovered above. 150 | 151 | Trained Label Model Metrics (with deps): 152 | Accuracy: 0.751 153 | Precision: 0.758 154 | Recall: 0.695 155 | F1: 0.725 156 | 157 | 158 | # Conditional entropy metric discussed in the paper 159 | 160 | H(Y | WS output): 0.5602824867598865 161 | ``` 162 | 163 | For this task, [Brown et al., 2020](https://arxiv.org/pdf/2005.14165.pdf) reports accuracy metrics. 164 | 165 | 166 | ## Overall repository structure 167 | ``` 168 | tasks/ code for running inference on tasks 169 | diagnostics/ contains the diagnostic tasks 170 | boosting/ code for running weak supervision 171 | metal-ama/ weak supervision algorithm 172 | manifest/ code for loading and using models 173 | /home/data/ default location for benchmarks 174 | ``` 175 | 176 | 177 | ## Citation 178 | If you use this codebase, or otherwise found our work valuable, please cite: 179 | ``` 180 | @article{arora2022ama, 181 | title={Ask Me Anything: A simple strategy for prompting language models}, 182 | author={Arora, Simran and Narayan, Avanika and Chen, Mayee F. and Orr, Laurel and Guha, Neel and Bhatia, Kush and Chami, Ines and Sala, Frederic and R\'e, Christopher}, 183 | journal={arXiv:2210.02441}, 184 | year={2022} 185 | } 186 | ``` 187 | 188 | As well as [Snorkel MeTaL](https://github.com/HazyResearch/metal), [bigscience P3](https://huggingface.co/datasets/bigscience/P3), and the benchmark authors. 189 | 190 | ## Acknowledgements 191 | 192 | We are very grateful to the following organizations for the resources that made this work possible: [Together Computer](https://together.xyz/), [Numbers Station](https://numbersstation.ai/), [Snorkel](https://snorkel.ai/), [Stanford Center for Research on Foundation Models](https://crfm.stanford.edu/) and [Stanford HAI](https://hai.stanford.edu/). 193 | 194 |

195 | 196 | 197 | 198 |

199 | 200 | -------------------------------------------------------------------------------- /ablations/README.md: -------------------------------------------------------------------------------- 1 | 2 | This directory includes code for additional paper experiments beyond the main benchmark results. 3 | 4 | In ```T0_variants```, we include code for running aggregation over PromptSource P3 prompts. To run this, load the T0 model using manifest and you can launch via the script: 5 | 6 | ```bash 7 | cd T0_variants 8 | bash run.sh 9 | ``` -------------------------------------------------------------------------------- /ablations/T0_variants/CB_variants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import os 4 | from tqdm.auto import tqdm 5 | import pandas as pd 6 | 7 | from sklearn.metrics import classification_report 8 | from decomposition import Decomposition, get_args, DATA_DIR 9 | from utils import get_response, InputOutputPrompt 10 | from collections import defaultdict, Counter 11 | 12 | class CBDecomp(Decomposition): 13 | def __init__(self, task_name, data_dir, val_split="validation"): 14 | super().__init__(task_name, data_dir, val_split) 15 | 16 | def get_boost_decomp_examples(self, data_train, boost_id): 17 | lst = [ 18 | 'super_glue_cb_claim_true_false_inconclusive', 19 | 'super_glue_cb_does_this_imply', 20 | 'super_glue_cb_always_sometimes_never', 21 | 'super_glue_cb_does_it_follow_that', 22 | 'super_glue_cb_guaranteed_true', 23 | 'super_glue_cb_take_the_following_as_truth', 24 | 'super_glue_cb_justified_in_saying', 25 | 'super_glue_cb_should_assume', 26 | 'super_glue_cb_GPT_3_style', 27 | 'super_glue_cb_can_we_infer', 28 | 'super_glue_cb_consider_always_sometimes_never', 29 | 'super_glue_cb_guaranteed_possible_impossible', 30 | 'super_glue_cb_MNLI_crowdsource', 31 | 'super_glue_cb_based_on_the_previous_passage', 32 | 'super_glue_cb_must_be_true' 33 | ] 34 | file_path = lst[boost_id] 35 | print(f"FILE PATH: {file_path}") 36 | train_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/train.feather") 37 | val_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/validation.feather") 38 | return [ 39 | train_data, 40 | val_data 41 | ] 42 | 43 | def zero_few_baseline( 44 | self, 45 | test_data, 46 | few_shot_df, 47 | manifest, 48 | overwrite_manifest, 49 | do_few_shot=True, 50 | ): 51 | expt_log = {} 52 | preds = [] 53 | labels = [] 54 | 55 | labels_names = set(test_data["targets_pretokenized"]) 56 | labels_names = [l.lower().strip() for l in labels_names] 57 | 58 | for i, (ind, row) in tqdm( 59 | enumerate(test_data.iterrows()), total=len(test_data) 60 | ): 61 | if ind in expt_log: 62 | pred = entry["pred"] 63 | gold = entry["gold"] 64 | else: 65 | text = row["inputs_pretokenized"] 66 | gold = row["targets_pretokenized"] 67 | 68 | icl_str = "" 69 | if do_few_shot: 70 | for s_ind, s_row in few_shot_df.iterrows(): 71 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n" 72 | 73 | text = row["inputs_pretokenized"] 74 | text = text.replace("True, False, or Neither?", "") 75 | text = text + ". True, False, or Neither?" 76 | gold = row["targets_pretokenized"] 77 | prompt = f"{icl_str}{{text:}}" 78 | pmp = prompt.format(text=text) 79 | if i == 0: 80 | print(pmp) 81 | 82 | raw_answer = get_response( 83 | pmp, 84 | manifest, 85 | overwrite=bool(overwrite_manifest), 86 | max_toks=30, 87 | ) 88 | answer = raw_answer.strip().lower() 89 | answer = answer.split("\n") 90 | answer = [a for a in answer if a] 91 | answer = [ 92 | a 93 | for a in answer 94 | if any(l.lower() in a.lower() for l in labels_names) 95 | ] 96 | if answer: 97 | answer = answer[0] 98 | else: 99 | answer = "" 100 | answer = "".join( 101 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']] 102 | ) 103 | is_yes = "true" in answer.split() 104 | is_no = "false" in answer.split() 105 | is_maybe = "neither" in answer.split() or "unknown" in answer.split() or "maybe" in answer.split() 106 | pred = "Neither" 107 | if is_yes and (not is_maybe and not is_no): 108 | pred = "True" 109 | if is_no and (not is_maybe and not is_yes): 110 | pred = "False" 111 | if is_maybe and (not is_no and not is_yes): 112 | pred = "Neither" 113 | entry = { 114 | "ind": ind, 115 | "example": text, 116 | "base_prompt": pmp, 117 | "raw_answer": raw_answer, 118 | "pred": pred, 119 | "gold": gold, 120 | } 121 | expt_log[ind] = entry 122 | 123 | preds.append(pred) 124 | labels.append(gold) 125 | 126 | report = classification_report(labels, preds, output_dict=True) 127 | return expt_log, report["accuracy"] 128 | 129 | def run_decomposed_prompt( 130 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 131 | ): 132 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, do_train=0) 133 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, do_train=1) 134 | # Do WS 135 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 136 | # Get accuracies across all boost sets 137 | individual_accuracies = [] 138 | for i in range(len(all_boost_preds[0])): 139 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"]) 140 | report = classification_report(labels, preds, output_dict=True) 141 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies 142 | 143 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, do_train=-1): 144 | expt_log = {} 145 | all_boost_preds = [] 146 | labels = [] 147 | 148 | label_name_to_maping = { 149 | "always": "true", 150 | "never": "false", 151 | "sometimes": 'neither', 152 | "true": "true", 153 | "false": "false", 154 | "neither": 'neither', 155 | "no": "false", 156 | "yes": "true", 157 | "maybe": "neither", 158 | "unknown": "neither", 159 | "inconclusive": "neither", 160 | "impossible": "false", 161 | "possible": "neither", 162 | "guaranteed": "true", 163 | } 164 | 165 | prompts_across_boost = defaultdict(list) 166 | preds_across_boost = defaultdict(list) 167 | for boost_num, boost_examples in enumerate(boost_dfs): 168 | if do_train: 169 | data = boost_examples[0].iloc[:1] 170 | elif not do_train: 171 | data = boost_examples[1] 172 | else: 173 | raise ValueError("Unsupported value for do train.") 174 | 175 | for i, (ind, row) in tqdm(enumerate(data.iterrows()), total=len(data)): 176 | input = row["inputs_pretokenized"] 177 | gold = row["targets_pretokenized"] 178 | all_prompts = [] 179 | raw_answer = get_response( 180 | input, 181 | manifest, 182 | overwrite=bool(overwrite_manifest), 183 | max_toks=30, 184 | ) 185 | all_prompts.append(input) 186 | answer = raw_answer.lower() 187 | if answer not in label_name_to_maping: 188 | pred = 'neither' 189 | print("BAD ANSWER", answer) 190 | else: 191 | pred = label_name_to_maping[answer] 192 | prompts_across_boost[i].append(all_prompts) 193 | preds_across_boost[i].append(pred) 194 | if gold.lower() not in label_name_to_maping: 195 | import pdb; 196 | pdb.set_trace() 197 | 198 | for i, (ind, row) in enumerate(data.iterrows()): 199 | label = row["targets_pretokenized"].lower() 200 | entry = { 201 | "ind": ind, 202 | "prompts": prompts_across_boost[i], 203 | "preds_boost": preds_across_boost[i], 204 | "example": row['inputs_pretokenized'], 205 | "gold": label, 206 | } 207 | expt_log[ind] = entry 208 | all_boost_preds.append(preds_across_boost[ind]) 209 | labels.append(label_name_to_maping[label]) 210 | return expt_log, all_boost_preds, labels 211 | 212 | 213 | def main(): 214 | args = get_args() 215 | args.num_boost = 10 216 | task_name = "cb_t0_variants" 217 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_cb_GPT_3_style/" 218 | decomp = CBDecomp(task_name, data_dir) 219 | decomp.run(args) 220 | 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /ablations/T0_variants/RTE_variants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import os 4 | from tqdm.auto import tqdm 5 | import pandas as pd 6 | 7 | from sklearn.metrics import classification_report 8 | from decomposition import Decomposition, get_args, DATA_DIR 9 | from utils import get_response, InputOutputPrompt 10 | from collections import defaultdict, Counter 11 | 12 | class RTEDecomp(Decomposition): 13 | def __init__(self, task_name, data_dir, val_split="validation"): 14 | super().__init__(task_name, data_dir, val_split) 15 | 16 | def get_boost_decomp_examples(self, data_train, boost_id): 17 | lst = [ 18 | 'super_glue_rte_GPT_3_style', 19 | 'super_glue_rte_does_it_follow_that', 20 | 'super_glue_rte_MNLI_crowdsource', 21 | 'super_glue_rte_can_we_infer', 22 | 'super_glue_rte_justified_in_saying', 23 | 'super_glue_rte_does_this_imply', 24 | 'super_glue_rte_guaranteed_true', 25 | 'super_glue_rte_based_on_the_previous_passage', 26 | 'super_glue_rte_should_assume', 27 | 'super_glue_rte_must_be_true' 28 | ] 29 | file_path = lst[boost_id] 30 | print(f"FILE PATH: {file_path}") 31 | train_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/train.feather") 32 | val_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/validation.feather") 33 | return [ 34 | train_data, 35 | val_data 36 | ] 37 | 38 | def zero_few_baseline( 39 | self, 40 | test_data, 41 | few_shot_df, 42 | manifest, 43 | overwrite_manifest, 44 | do_few_shot=True, 45 | ): 46 | expt_log = {} 47 | preds = [] 48 | labels = [] 49 | 50 | labels_names = set(test_data["targets_pretokenized"]) 51 | labels_names = [l.lower().strip() for l in labels_names] 52 | 53 | for i, (ind, row) in tqdm( 54 | enumerate(test_data.iterrows()), total=len(test_data) 55 | ): 56 | if ind in expt_log: 57 | pred = entry["pred"] 58 | gold = entry["gold"] 59 | else: 60 | text = row["inputs_pretokenized"] 61 | gold = row["targets_pretokenized"] 62 | 63 | icl_str = "" 64 | if do_few_shot: 65 | for s_ind, s_row in few_shot_df.iterrows(): 66 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n" 67 | 68 | text = row["inputs_pretokenized"] 69 | text = text.replace("True, False, or Neither?", "") 70 | text = text + ". True, False, or Neither?" 71 | gold = row["targets_pretokenized"] 72 | prompt = f"{icl_str}{{text:}}" 73 | pmp = prompt.format(text=text) 74 | if i == 0: 75 | print(pmp) 76 | 77 | raw_answer = get_response( 78 | pmp, 79 | manifest, 80 | overwrite=bool(overwrite_manifest), 81 | max_toks=30, 82 | ) 83 | answer = raw_answer.strip().lower() 84 | answer = answer.split("\n") 85 | answer = [a for a in answer if a] 86 | answer = [ 87 | a 88 | for a in answer 89 | if any(l.lower() in a.lower() for l in labels_names) 90 | ] 91 | if answer: 92 | answer = answer[0] 93 | else: 94 | answer = "" 95 | answer = "".join( 96 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']] 97 | ) 98 | is_yes = "true" in answer.split() 99 | is_no = "false" in answer.split() 100 | is_maybe = "neither" in answer.split() or "unknown" in answer.split() or "maybe" in answer.split() 101 | pred = "Neither" 102 | if is_yes and (not is_maybe and not is_no): 103 | pred = "True" 104 | if is_no and (not is_maybe and not is_yes): 105 | pred = "False" 106 | if is_maybe and (not is_no and not is_yes): 107 | pred = "Neither" 108 | entry = { 109 | "ind": ind, 110 | "example": text, 111 | "base_prompt": pmp, 112 | "raw_answer": raw_answer, 113 | "pred": pred, 114 | "gold": gold, 115 | } 116 | expt_log[ind] = entry 117 | 118 | preds.append(pred) 119 | labels.append(gold) 120 | 121 | report = classification_report(labels, preds, output_dict=True) 122 | return expt_log, report["accuracy"] 123 | 124 | def run_decomposed_prompt( 125 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 126 | ): 127 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, do_train=0) 128 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, do_train=1) 129 | # Do WS 130 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 131 | # Get accuracies across all boost sets 132 | individual_accuracies = [] 133 | for i in range(len(all_boost_preds[0])): 134 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"]) 135 | report = classification_report(labels, preds, output_dict=True) 136 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies 137 | 138 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, do_train=-1): 139 | expt_log = {} 140 | all_boost_preds = [] 141 | labels = [] 142 | 143 | label_name_to_maping = { 144 | "always": "true", 145 | "never": "false", 146 | "true": "true", 147 | "false": "false", 148 | "no": "false", 149 | "yes": "true", 150 | "impossible": "false", 151 | "guaranteed": "true", 152 | } 153 | 154 | prompts_across_boost = defaultdict(list) 155 | preds_across_boost = defaultdict(list) 156 | for boost_num, boost_examples in enumerate(boost_dfs): 157 | if do_train: 158 | data = boost_examples[0].iloc[:1000] 159 | elif not do_train: 160 | data = boost_examples[1] 161 | else: 162 | raise ValueError("Unsupported value for do train.") 163 | for i, (ind, row) in tqdm(enumerate(data.iterrows()), total=len(data)): 164 | input = row["inputs_pretokenized"] 165 | gold = row["targets_pretokenized"] 166 | all_prompts = [] 167 | raw_answer = get_response( 168 | input, 169 | manifest, 170 | overwrite=bool(overwrite_manifest), 171 | max_toks=30, 172 | ) 173 | all_prompts.append(input) 174 | answer = raw_answer.lower() 175 | if answer not in label_name_to_maping: 176 | pred = 'neither' 177 | print("BAD ANSWER", answer) 178 | else: 179 | pred = label_name_to_maping[answer] 180 | prompts_across_boost[i].append(all_prompts) 181 | preds_across_boost[i].append(pred) 182 | if gold.lower() not in label_name_to_maping: 183 | import pdb; 184 | pdb.set_trace() 185 | 186 | for i, (ind, row) in enumerate(data.iterrows()): 187 | label = row["targets_pretokenized"].lower() 188 | entry = { 189 | "ind": ind, 190 | "prompts": prompts_across_boost[i], 191 | "preds_boost": preds_across_boost[i], 192 | "example": row['inputs_pretokenized'], 193 | "gold": label, 194 | } 195 | expt_log[ind] = entry 196 | all_boost_preds.append(preds_across_boost[ind]) 197 | labels.append(label_name_to_maping[label]) 198 | return expt_log, all_boost_preds, labels 199 | 200 | 201 | def main(): 202 | args = get_args() 203 | args.num_boost = 10 204 | task_name = "rte_t0_variants" 205 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_rte_GPT_3_style/" 206 | decomp = RTEDecomp(task_name, data_dir) 207 | decomp.run(args) 208 | 209 | 210 | if __name__ == "__main__": 211 | main() 212 | -------------------------------------------------------------------------------- /ablations/T0_variants/WIC_variants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import os 4 | from tqdm.auto import tqdm 5 | import pandas as pd 6 | 7 | from sklearn.metrics import classification_report 8 | from decomposition import Decomposition, get_args, DATA_DIR 9 | from utils import get_response, InputOutputPrompt 10 | from collections import defaultdict, Counter 11 | 12 | class WICDecomp(Decomposition): 13 | def __init__(self, task_name, data_dir, val_split="validation"): 14 | super().__init__(task_name, data_dir, val_split) 15 | 16 | def get_boost_decomp_examples(self, data_train, boost_id): 17 | lst = [ 18 | "super_glue_wic_affirmation_true_or_false", 19 | "super_glue_wic_GPT_3_prompt", 20 | "super_glue_wic_GPT_3_prompt_with_label", 21 | "super_glue_wic_grammar_homework", 22 | "super_glue_wic_polysemous", 23 | "super_glue_wic_question_context", 24 | "super_glue_wic_question_context_meaning", 25 | "super_glue_wic_question_context_meaning_with_label", 26 | "super_glue_wic_same_sense", 27 | "super_glue_wic_similar_sense", 28 | ] 29 | file_path = lst[boost_id] 30 | print(f"FILE PATH: {file_path}") 31 | train_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/train.feather") 32 | val_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/validation.feather") 33 | return [ 34 | train_data, 35 | val_data 36 | ] 37 | 38 | def zero_few_baseline( 39 | self, 40 | test_data, 41 | few_shot_df, 42 | manifest, 43 | overwrite_manifest, 44 | do_few_shot=True, 45 | ): 46 | expt_log = {} 47 | preds = [] 48 | labels = [] 49 | 50 | labels_names = set(test_data["targets_pretokenized"]) 51 | labels_names = [l.lower().strip() for l in labels_names] 52 | 53 | for i, (ind, row) in tqdm( 54 | enumerate(test_data.iterrows()), total=len(test_data) 55 | ): 56 | if ind in expt_log: 57 | pred = entry["pred"] 58 | gold = entry["gold"] 59 | else: 60 | text = row["inputs_pretokenized"] 61 | gold = row["targets_pretokenized"] 62 | 63 | icl_str = "" 64 | if do_few_shot: 65 | for s_ind, s_row in few_shot_df.iterrows(): 66 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n" 67 | 68 | text = row["inputs_pretokenized"] 69 | gold = row["targets_pretokenized"] 70 | prompt = f"{icl_str}{{text:}}" 71 | pmp = prompt.format(text=text) 72 | if i == 0: 73 | print(pmp) 74 | 75 | raw_answer = get_response( 76 | pmp, 77 | manifest, 78 | overwrite=bool(overwrite_manifest), 79 | max_toks=30, 80 | ) 81 | answer = raw_answer.strip().lower() 82 | answer = answer.split("\n") 83 | answer = [a for a in answer if a] 84 | answer = [ 85 | a 86 | for a in answer 87 | if any(l.lower() in a.lower() for l in labels_names) 88 | ] 89 | if answer: 90 | answer = answer[0] 91 | else: 92 | answer = "" 93 | answer = "".join( 94 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']] 95 | ) 96 | is_yes = "yes" in answer.split() 97 | is_no = "no" in answer.split() 98 | pred = "" 99 | if is_yes and (not is_no): 100 | pred = "yes" 101 | if is_no and (not is_yes): 102 | pred = "no" 103 | entry = { 104 | "ind": ind, 105 | "example": text, 106 | "base_prompt": pmp, 107 | "raw_answer": raw_answer, 108 | "pred": pred, 109 | "gold": gold.strip().lower(), 110 | } 111 | expt_log[ind] = entry 112 | 113 | preds.append(pred) 114 | labels.append(gold) 115 | 116 | report = classification_report(labels, preds, output_dict=True) 117 | return expt_log, report["accuracy"] 118 | 119 | def run_decomposed_prompt( 120 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 121 | ): 122 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, do_train=0) 123 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, do_train=1) 124 | # Do WS 125 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 126 | # Get accuracies across all boost sets 127 | individual_accuracies = [] 128 | for i in range(len(all_boost_preds[0])): 129 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"]) 130 | report = classification_report(labels, preds, output_dict=True) 131 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies 132 | 133 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, do_train=-1): 134 | expt_log = {} 135 | all_boost_preds = [] 136 | labels = [] 137 | 138 | label_name_to_maping = { 139 | "always": "true", 140 | "never": "false", 141 | "true": "true", 142 | "false": "false", 143 | "no": "false", 144 | "yes": "true", 145 | "impossible": "false", 146 | "guaranteed": "true", 147 | } 148 | 149 | prompts_across_boost = defaultdict(list) 150 | preds_across_boost = defaultdict(list) 151 | for boost_num, boost_examples in enumerate(boost_dfs): 152 | if do_train: 153 | data = boost_examples[0].iloc[:1] 154 | elif not do_train: 155 | data = boost_examples[1] 156 | else: 157 | raise ValueError("Unsupported value for do train.") 158 | 159 | for i, (ind, row) in tqdm(enumerate(data.iterrows()), total=len(data)): 160 | input = row["inputs_pretokenized"] 161 | gold = row["targets_pretokenized"] 162 | all_prompts = [] 163 | raw_answer = get_response( 164 | input, 165 | manifest, 166 | overwrite=bool(overwrite_manifest), 167 | max_toks=30, 168 | ) 169 | all_prompts.append(input) 170 | answer = raw_answer.lower() 171 | if answer not in label_name_to_maping: 172 | pred = '' 173 | print("BAD ANSWER", answer) 174 | else: 175 | pred = label_name_to_maping[answer] 176 | prompts_across_boost[i].append(all_prompts) 177 | preds_across_boost[i].append(pred) 178 | if gold.lower() not in label_name_to_maping: 179 | import pdb; 180 | pdb.set_trace() 181 | 182 | for i, (ind, row) in enumerate(data.iterrows()): 183 | label = row["targets_pretokenized"].lower() 184 | entry = { 185 | "ind": ind, 186 | "prompts": prompts_across_boost[i], 187 | "preds_boost": preds_across_boost[i], 188 | "example": row['inputs_pretokenized'], 189 | "gold": label, 190 | } 191 | expt_log[ind] = entry 192 | all_boost_preds.append(preds_across_boost[ind]) 193 | labels.append(label_name_to_maping[label]) 194 | return expt_log, all_boost_preds, labels 195 | 196 | 197 | def main(): 198 | args = get_args() 199 | args.num_boost = 9 200 | task_name = "wic_t0_variants" 201 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_wic_GPT_3_prompt/" 202 | decomp = WICDecomp(task_name, data_dir) 203 | decomp.run(args) 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /ablations/T0_variants/WSC_variants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import os 4 | from tqdm.auto import tqdm 5 | import pandas as pd 6 | 7 | from sklearn.metrics import classification_report 8 | from decomposition import Decomposition, get_args, DATA_DIR 9 | from utils import get_response, InputOutputPrompt 10 | from collections import defaultdict, Counter 11 | 12 | class RTEDecomp(Decomposition): 13 | def __init__(self, task_name, data_dir, val_split="validation"): 14 | super().__init__(task_name, data_dir, val_split) 15 | 16 | def get_boost_decomp_examples(self, data_train, boost_id): 17 | lst = [ 18 | 'super_glue_wsc.fixed_GPT_3_Style', 19 | 'super_glue_wsc.fixed_Who_or_what_is_are', 20 | 'super_glue_wsc.fixed_p_is_are_r', 21 | 'super_glue_wsc.fixed_by_p_they_mean', 22 | 'super_glue_wsc.fixed_replaced_with', 23 | 'super_glue_wsc.fixed_in_other_words', 24 | 'super_glue_wsc.fixed_I_think_they_mean', 25 | 'super_glue_wsc.fixed_does_p_stand_for', 26 | 'super_glue_wsc.fixed_does_the_pronoun_refer_to', 27 | 'super_glue_wsc.fixed_the_pronoun_refers_to' 28 | ] 29 | file_path = lst[boost_id] 30 | print(f"FILE PATH: {file_path}") 31 | train_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/train.feather") 32 | val_data = pd.read_feather(f"{DATA_DIR}/P3/data_feather/{file_path}/validation.feather") 33 | return [ 34 | train_data, 35 | val_data 36 | ] 37 | 38 | def zero_few_baseline( 39 | self, 40 | test_data, 41 | few_shot_df, 42 | manifest, 43 | overwrite_manifest, 44 | do_few_shot=True, 45 | ): 46 | expt_log = {} 47 | preds = [] 48 | labels = [] 49 | 50 | labels_names = set(test_data["targets_pretokenized"]) 51 | labels_names = [l.lower().strip() for l in labels_names] 52 | 53 | for i, (ind, row) in tqdm( 54 | enumerate(test_data.iterrows()), total=len(test_data) 55 | ): 56 | if ind in expt_log: 57 | pred = entry["pred"] 58 | gold = entry["gold"] 59 | else: 60 | text = row["inputs_pretokenized"] 61 | gold = row["targets_pretokenized"] 62 | 63 | icl_str = "" 64 | if do_few_shot: 65 | for s_ind, s_row in few_shot_df.iterrows(): 66 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n" 67 | 68 | text = row["inputs_pretokenized"] 69 | text = text.replace("True, False, or Neither?", "") 70 | text = text + ". True, False, or Neither?" 71 | gold = row["targets_pretokenized"] 72 | prompt = f"{icl_str}{{text:}}" 73 | pmp = prompt.format(text=text) 74 | if i == 0: 75 | print(pmp) 76 | 77 | raw_answer = get_response( 78 | pmp, 79 | manifest, 80 | overwrite=bool(overwrite_manifest), 81 | max_toks=30, 82 | ) 83 | answer = raw_answer.strip().lower() 84 | answer = answer.split("\n") 85 | answer = [a for a in answer if a] 86 | answer = [ 87 | a 88 | for a in answer 89 | if any(l.lower() in a.lower() for l in labels_names) 90 | ] 91 | if answer: 92 | answer = answer[0] 93 | else: 94 | answer = "" 95 | answer = "".join( 96 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']] 97 | ) 98 | is_yes = "true" in answer.split() 99 | is_no = "false" in answer.split() 100 | is_maybe = "neither" in answer.split() or "unknown" in answer.split() or "maybe" in answer.split() 101 | pred = "Neither" 102 | if is_yes and (not is_maybe and not is_no): 103 | pred = "True" 104 | if is_no and (not is_maybe and not is_yes): 105 | pred = "False" 106 | if is_maybe and (not is_no and not is_yes): 107 | pred = "Neither" 108 | entry = { 109 | "ind": ind, 110 | "example": text, 111 | "base_prompt": pmp, 112 | "raw_answer": raw_answer, 113 | "pred": pred, 114 | "gold": gold, 115 | } 116 | expt_log[ind] = entry 117 | 118 | preds.append(pred) 119 | labels.append(gold) 120 | 121 | report = classification_report(labels, preds, output_dict=True) 122 | return expt_log, report["accuracy"] 123 | 124 | def run_decomposed_prompt( 125 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 126 | ): 127 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, do_train=0) 128 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, do_train=1) 129 | # Do WS 130 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 131 | # Get accuracies across all boost sets 132 | individual_accuracies = [] 133 | for i in range(len(all_boost_preds[0])): 134 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"]) 135 | report = classification_report(labels, preds, output_dict=True) 136 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies 137 | 138 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, do_train=-1): 139 | expt_log = {} 140 | all_boost_preds = [] 141 | labels = [] 142 | 143 | label_name_to_maping = { 144 | "always": "true", 145 | "never": "false", 146 | "sometimes": 'neither', 147 | "true": "true", 148 | "false": "false", 149 | "neither": 'neither', 150 | "no": "false", 151 | "yes": "true", 152 | "maybe": "neither", 153 | "unknown": "neither", 154 | "inconclusive": "neither", 155 | "impossible": "false", 156 | "possible": "neither", 157 | "guaranteed": "true", 158 | } 159 | 160 | prompts_across_boost = defaultdict(list) 161 | preds_across_boost = defaultdict(list) 162 | for boost_num, boost_examples in enumerate(boost_dfs): 163 | if do_train: 164 | data = boost_examples[0].iloc[:1000] 165 | elif not do_train: 166 | data = boost_examples[1] 167 | else: 168 | raise ValueError("Unsupported value for do train.") 169 | 170 | for i, (ind, row) in tqdm(enumerate(data.iterrows()), total=len(data)): 171 | input = row["inputs_pretokenized"] 172 | gold = row["targets_pretokenized"] 173 | all_prompts = [] 174 | raw_answer = get_response( 175 | input, 176 | manifest, 177 | overwrite=bool(overwrite_manifest), 178 | max_toks=30, 179 | ) 180 | all_prompts.append(input) 181 | answer = raw_answer.lower() 182 | if answer not in label_name_to_maping: 183 | pred = 'neither' 184 | print(answer) 185 | else: 186 | pred = label_name_to_maping[answer] 187 | prompts_across_boost[i].append(all_prompts) 188 | preds_across_boost[i].append(pred) 189 | if gold.lower() not in label_name_to_maping: 190 | import pdb; 191 | pdb.set_trace() 192 | 193 | for i, (ind, row) in enumerate(data.iterrows()): 194 | label = row["targets_pretokenized"].lower() 195 | entry = { 196 | "ind": ind, 197 | "prompts": prompts_across_boost[i], 198 | "preds_boost": preds_across_boost[i], 199 | "example": row['inputs_pretokenized'], 200 | "gold": label, 201 | } 202 | expt_log[ind] = entry 203 | all_boost_preds.append(preds_across_boost[ind]) 204 | labels.append(label_name_to_maping[label]) 205 | return expt_log, all_boost_preds, labels 206 | 207 | 208 | def main(): 209 | args = get_args() 210 | args.num_boost = 10 211 | task_name = "wsc_t0_variants" 212 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_wsc.fixed_GPT_3_Style/" 213 | decomp = RTEDecomp(task_name, data_dir) 214 | decomp.run(args) 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /ablations/T0_variants/run.sh: -------------------------------------------------------------------------------- 1 | 2 | for task in WSC_variants; do 3 | python3 ${task}.py \ 4 | --run_zeroshot 0 \ 5 | --run_fewshot 0 \ 6 | --run_decomp 1 \ 7 | --num_boost 10 \ 8 | --k_shot 3 \ 9 | --output_metrics_file ../ama_logs/metrics.json \ 10 | --save_dir ../ama_logs/ama_final_runs \ 11 | --overwrite_data 1 \ 12 | --overwrite_boost_exs 1 \ 13 | --num_run -1 \ 14 | --client_connection http://127.0.0.1:5000 15 | done; -------------------------------------------------------------------------------- /ablations/T0_variants/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from collections import Counter 3 | import json 4 | from datasets import load_dataset 5 | import re 6 | import pandas as pd 7 | from typing import Callable, List 8 | from manifest import Manifest 9 | 10 | 11 | class InputOutputPrompt: 12 | def __init__(self, 13 | input_formatter: Callable, 14 | output_formatter: Callable, 15 | required_keys: List, 16 | input_output_sep: str = "\n", 17 | example_sep: str = "\n\n", 18 | instruction: str = "" 19 | ): 20 | self.input_formatter = input_formatter 21 | self.output_formatter = output_formatter 22 | self.required_keys = required_keys 23 | self.input_output_sep = input_output_sep 24 | self.example_sep = example_sep 25 | self.instruction = instruction 26 | 27 | def __call__(self, input_output_pairs: pd.DataFrame): 28 | examples = [] 29 | for _, example in input_output_pairs.iterrows(): 30 | examples.append(f"{self.input_formatter(example)}{self.input_output_sep}{self.output_formatter(example)}") 31 | if examples: 32 | input_str = self.example_sep.join(examples) 33 | res = f"{self.instruction}{input_str}" 34 | else: 35 | res = f"{self.instruction}".rstrip() 36 | return res 37 | 38 | def __repr__(self): 39 | dummy_ex = pd.DataFrame([{k: f"<{k.upper()}>" for k in self.required_keys}]) 40 | st = self(dummy_ex) 41 | return st 42 | 43 | 44 | def prefix_formatter(ex_keys: List[str], prefix: str, error_on_empty: bool = True) -> str: 45 | def full_prefix_formatter(ex: pd.Series): 46 | for k in ex_keys: 47 | if k in ex: 48 | return f"{prefix} {getattr(ex, k)}" 49 | if error_on_empty: 50 | raise ValueError(f"Example {ex} has no value for any of the keys {ex_keys}") 51 | else: 52 | return f"{prefix}" 53 | return full_prefix_formatter 54 | 55 | 56 | def get_manifest_session( 57 | client_name="huggingface", 58 | client_engine=None, 59 | client_connection="http://127.0.0.1:5000", 60 | cache_connection=None, 61 | temperature=0, 62 | top_p=1.0, 63 | ): 64 | if client_name == "huggingface" and temperature == 0: 65 | params = { 66 | "temperature": 0.001, 67 | "do_sample": False, 68 | "top_p": top_p, 69 | } 70 | elif client_name in {"openai", "ai21"}: 71 | params = { 72 | "temperature": temperature, 73 | "top_p": top_p, 74 | "engine": client_engine, 75 | } 76 | else: 77 | raise ValueError(f"{client_name} is not a valid client name") 78 | manifest = Manifest( 79 | client_name=client_name, 80 | client_connection=client_connection, 81 | cache_name="sqlite", 82 | cache_connection=cache_connection, 83 | session_id=None, 84 | **params, 85 | ) 86 | params = manifest.client.get_model_params() 87 | model_name = params["model_name"] 88 | if "engine" in params: 89 | model_name += f"_{params['engine']}" 90 | return manifest, model_name 91 | 92 | 93 | def get_response( 94 | prompt, 95 | manifest, 96 | overwrite=False, 97 | max_toks=10, 98 | stop_token=None, 99 | gold_choices=[], 100 | verbose=False, 101 | ): 102 | prompt = prompt.strip() 103 | if gold_choices: 104 | gold_choices = [" " + g.strip() for g in gold_choices] 105 | response_obj = manifest.run( 106 | prompt, gold_choices=gold_choices, overwrite_cache=overwrite, return_response=True 107 | ) 108 | response_obj = response_obj.get_json_response()["choices"][0] 109 | log_prob = response_obj["text_logprob"] 110 | response = response_obj["text"] 111 | else: 112 | response = manifest.run( 113 | prompt, 114 | max_tokens=max_toks, 115 | stop_token=stop_token, 116 | overwrite_cache=overwrite, 117 | ) 118 | log_prob = None 119 | if verbose: 120 | print("\n***Prompt***\n", prompt) 121 | print("\n***Response***\n", response) 122 | if log_prob: 123 | return response, log_prob 124 | return response 125 | 126 | def load_hf_data(save_dir, task_name, val_split, hf_name, overwrite_data): 127 | save_data = Path(f"{save_dir}/{task_name}/data.feather") 128 | if not save_data.exists() or overwrite_data: 129 | dataset = load_dataset(hf_name) 130 | test_data = dataset[val_split].to_pandas() 131 | test_data.to_feather(f"{save_data}") 132 | else: 133 | print(f"Reading test data from {save_data}") 134 | test_data = pd.read_feather(f"{save_data}") 135 | 136 | save_data_train = Path(f"{save_dir}/{task_name}/train_data.feather") 137 | if not save_data_train.exists() or overwrite_data: 138 | dataset = load_dataset(hf_name) 139 | train_data = dataset["train"].to_pandas() 140 | train_data.to_feather(f"{save_data_train}") 141 | else: 142 | print(f"Reading train data from {save_data_train}") 143 | train_data = pd.read_feather(f"{save_data_train}") 144 | 145 | print(f"Test Data Size: {len(test_data)}") 146 | print(f"Train Data Size: {len(train_data)}") 147 | return test_data, train_data 148 | 149 | def save_log(task_name, expt_name, log, final_run_dir): 150 | final_run_dir = Path(final_run_dir) 151 | output_fpath = final_run_dir / task_name 152 | output_fpath.mkdir(parents=True, exist_ok=True) 153 | 154 | print("Saving to", output_fpath / f"{expt_name}.json") 155 | assert all(a in list(log.values())[0].keys() for a in ["ind","example","pred","gold"]) 156 | with open(output_fpath / f"{expt_name}.json", "w") as f: 157 | json.dump(log, f) 158 | 159 | def text_f1(preds, golds): 160 | """Compute average F1 of text spans. 161 | Taken from Squad without prob threshold for no answer. 162 | """ 163 | total_f1 = 0 164 | for pred, gold in zip(preds, golds): 165 | pred_toks = pred.split() 166 | gold_toks = gold.split() 167 | common = Counter(pred_toks) & Counter(gold_toks) 168 | num_same = sum(common.values()) 169 | if len(gold_toks) == 0 or len(pred_toks) == 0: 170 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 171 | total_f1 += int(gold_toks == pred_toks) 172 | elif num_same == 0: 173 | total_f1 += 0 174 | else: 175 | precision = 1.0 * num_same / len(pred_toks) 176 | recall = 1.0 * num_same / len(gold_toks) 177 | f1 = (2 * precision * recall) / (precision + recall) 178 | total_f1 += f1 179 | f1_avg = total_f1 / len(golds) 180 | return f1_avg 181 | 182 | def accuracy_span_overlap(preds, golds): 183 | correct = 0 184 | for pred, gold in zip(preds, golds): 185 | found = False 186 | for p in pred: 187 | for g in gold: 188 | if len(p) < len(g): 189 | if p.lower() in g.lower(): 190 | found = True 191 | break 192 | else: 193 | if g.lower() in p.lower(): 194 | found = True 195 | break 196 | if found: correct += 1 197 | return correct / len(preds) 198 | 199 | 200 | -------------------------------------------------------------------------------- /boosting/pgm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | import matplotlib.pyplot as plt 4 | import scipy.stats 5 | 6 | 7 | 8 | class Ising(): 9 | 10 | 11 | def __init__(self, m, potentials, thetas = None, vals = [-1, 1], ) -> None: 12 | self.m = m 13 | self.v = m + 1 # total number of vertices 14 | self.potentials = potentials 15 | 16 | self.vals = vals 17 | #TODO support values in 0, 1 18 | 19 | if thetas is not None: 20 | assert len(thetas) >= len(potentials), f"Need to specify at least {len(potentials)} theta parameters." 21 | self.thetas = thetas 22 | else: 23 | self.thetas = np.random.rand(len(potentials)) 24 | 25 | self.support = np.array(list(map(list, itertools.product(vals, repeat=self.v)))) 26 | 27 | self._make_pdf() 28 | self._make_cdf() 29 | 30 | self._get_means() 31 | self._get_balance() 32 | self._get_accs() 33 | 34 | def _exponential_family(self, labels): 35 | x = 0.0 36 | for i in range(len(self.potentials)): 37 | x += self.thetas[i] * labels[self.potentials[i]].prod() 38 | 39 | return np.exp(x) 40 | 41 | def _make_pdf(self): 42 | p = np.zeros(len(self.support)) 43 | for i, labels in enumerate(self.support): 44 | p[i] = self._exponential_family(labels) 45 | 46 | self.z = sum(p) 47 | 48 | self.pdf = p/self.z 49 | 50 | def _make_cdf(self): 51 | self.cdf = np.cumsum(self.pdf) 52 | 53 | 54 | 55 | 56 | def joint_p(self, C, values): 57 | p = 0.0 58 | for k, labels in enumerate(self.support): 59 | flag = True 60 | for i in range(len(C)): 61 | prod = labels[C[i]].prod() 62 | if prod != values[i]: 63 | flag = False 64 | 65 | if flag == True: 66 | p += self.pdf[k] 67 | 68 | return p 69 | 70 | def expectation(self, C): 71 | return self.vals[0] * self.joint_p(C, self.vals[0] * np.ones(len(C))) + self.vals[1] * self.joint_p(C, self.vals[1] * np.ones(len(C))) 72 | 73 | def _get_means(self): 74 | self.means = np.zeros(self.m) 75 | for k in range(self.m): 76 | self.means[k] = self.expectation([[k]]) 77 | 78 | 79 | def _get_balance(self): 80 | self.balance = self.joint_p([[self.m]], [1]) 81 | 82 | # def _get_covariance(self): 83 | 84 | def _get_accs(self): 85 | """ 86 | self.accs[k, i, j] = Pr(lf_k = j | y = i) (i, j scaled to -1, 1 if needed) 87 | """ 88 | self.accs = np.zeros((self.m, 2, 2)) 89 | for k in range(self.m): 90 | self.accs[k, 1, 1] = self.joint_p([[k], [self.m]], [self.vals[1], self.vals[1]]) / self.balance 91 | self.accs[k, 0, 0] = self.joint_p([[k], [self.m]], [self.vals[0], self.vals[0]]) / (1 - self.balance) 92 | self.accs[k, 1, 0] = 1 - self.accs[k, 1, 1] 93 | self.accs[k, 0, 1] = 1 - self.accs[k, 0, 0] 94 | 95 | 96 | def sample(self): 97 | r = np.random.random_sample() 98 | smaller = np.where(self.cdf < r)[0] 99 | if len(smaller) == 0: 100 | i = 0 101 | else: 102 | i = smaller.max() + 1 103 | 104 | return self.support[i] 105 | 106 | def make_data(self, n, has_label = True): 107 | L = np.zeros((n, self.m)) 108 | gold = np.zeros(n) 109 | for i in range(n): 110 | l = self.sample() 111 | L[i, :] = l[:self.m] 112 | 113 | if has_label: 114 | gold[i] = l[self.m] 115 | 116 | return L.astype(int), gold.astype(int) 117 | 118 | 119 | 120 | def est_accs(m, vote, gold): 121 | # compute pr(lf | y) accuracies. Each prompt has 4 values (2x2) 122 | # we need to do this on the train/dev set 123 | classes = [0, 1] 124 | gold_idxs = [np.where(gold == -1)[0], np.where(gold == 1)[0]] 125 | 126 | accs = np.zeros((m, 2, 2)) # [i, j, k] = Pr(prompt_i = j| y = k) 127 | for p in range(m): 128 | for i in classes: 129 | for j in classes: 130 | accs[p, i, j] = len(np.where(vote[gold_idxs[i], p] == 2*j-1)[0]) / len(gold_idxs[i]) 131 | 132 | return accs 133 | 134 | def est_balance(gold, n): 135 | return len(np.where(gold == 1)[0]) / n 136 | 137 | # Pr(lf votes, y) 138 | def get_cond_probs(m, votes, y, accs, balance): 139 | pr_y = balance if y == 1 else 1 - balance 140 | prod = pr_y 141 | for i in range(m): 142 | prod *= accs[i, y, int(0.5*(votes[i] + 1))] # this assumes everything is independent 143 | return prod 144 | 145 | # Pr(y = 1 | lf votes) 146 | def get_probs(m, votes, accs, balance): 147 | pos = get_cond_probs(m, votes, 1, accs, balance) 148 | neg = get_cond_probs(m, votes, 0, accs, balance) 149 | 150 | if pos == 0: 151 | return 0 152 | else: 153 | return pos / (pos + neg) 154 | 155 | 156 | def pick_best_prompt(m, vote, gold, n): 157 | # overall accuracies Pr(lf_p = y) on test (we don't know these) 158 | overall_train_acc = np.zeros(m) 159 | for i in range(m): 160 | overall_train_acc[i] = len(np.where((vote[:, i] == gold) == True)[0])/n 161 | 162 | return overall_train_acc.argmax() 163 | 164 | 165 | def main(): 166 | 167 | # number of weak labels 168 | m = 5 169 | 170 | # total number of vertices 171 | v = m + 1 172 | 173 | # randomly parametrize exponential family to determine accuracies and correlations 174 | #theta = np.random.rand() 175 | #theta_cliques = (np.random.randint(0, 2, 5)*2 - 1)*theta 176 | #theta = np.random.rand() 177 | #theta_cliques = [1, 1, 1, 1, 1, 1, 1] 178 | thetas = np.random.rand(30) 179 | 180 | # all conditionally independent 181 | potentials = [[5], [0], [1], [4], [0, 5], [1, 5], [2, 5], [3, 5], [4, 5]] 182 | 183 | pgm = Ising(m, potentials, thetas) 184 | 185 | n_train = 10000 186 | vote_train, gold_train = pgm.make_data(n_train) 187 | 188 | n_test = 1000 189 | vote_test, gold_test = pgm.make_data(n_test) 190 | 191 | accs = est_accs(m, vote_train, gold_train) 192 | balance = est_balance(gold_train, n_train) 193 | 194 | nb_output = np.zeros(n_test) # naive bayes 195 | mv_output = np.zeros(n_test) 196 | 197 | nb_err = 0 198 | mv_err = 0 199 | 200 | for i in range(n_test): 201 | nb_output[i] = 2*np.round(get_probs(m, vote_test[i], accs, balance))-1 202 | if nb_output[i] != gold_test[i]: 203 | nb_err += 1 204 | 205 | 206 | # note: play around with MV tie breaking strategy 207 | if len(np.where(vote_test[i] == 1)[0]) >= m / 2: 208 | mv_output[i] = 1 209 | elif len(np.where(vote_test[i] == 1)[0]) < m / 2: 210 | mv_output[i] = -1 211 | else: 212 | mv_output[i] = 2*np.random.randint(0, 2)-1 213 | 214 | if mv_output[i] != gold_test[i]: 215 | mv_err += 1 216 | 217 | nb_acc = 1 - (nb_err / n_test) 218 | mv_acc = 1 - (mv_err / n_test) 219 | #fs_acc = 1 - (fs_err / n_test) 220 | 221 | best_prompt = pick_best_prompt(m, vote_train, gold_train, n_train) 222 | 223 | best_prompt_acc = len(np.where((vote_test[:, best_prompt] == gold_test) == True)[0]) / n_test 224 | 225 | print(f"Naive bayes: {nb_acc}") 226 | print(f"Best prompt: {best_prompt_acc}") 227 | print(f"Majority vote: {mv_acc}") 228 | 229 | 230 | if __name__ == "__main__": 231 | main() -------------------------------------------------------------------------------- /boosting/run_ws.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import json 4 | import os 5 | import cvxpy as cp 6 | import scipy as sp 7 | import datetime 8 | from methods import Aggregator 9 | from metal.label_model import LabelModel 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--task_name", type=str, required=True) 15 | parser.add_argument("--data_dir", type=str, default="../ama_logs/ama_final_runs") 16 | parser.add_argument("--symmetric", action="store_true") 17 | parser.add_argument("--model_prefix", type=str, default="EleutherAI_gpt-j-6B") 18 | parser.add_argument("--override_date", type=str, default=None) 19 | 20 | args = parser.parse_args() 21 | return args 22 | 23 | def get_data(task_name, data_dir, model_prefix, override_date=None): 24 | """ 25 | Load in dataset from task_name depending on where files are saved. 26 | """ 27 | task_dir = os.path.join(data_dir, task_name) 28 | date = datetime.datetime.today().strftime("%m%d%Y") if not override_date else override_date 29 | print(f"Loading runs from {date}") 30 | dpath = os.path.join(task_dir, f"{model_prefix}_decomposed_{date}.json") 31 | train_dpath = os.path.join(task_dir, f"{model_prefix}_decomposed_{date}_train.json") 32 | 33 | print(dpath) 34 | print(train_dpath) 35 | 36 | if task_name in ["super_glue_rte"]: 37 | label_name_to_int = {"True": 1, "False": 0} 38 | elif task_name == "amazon": 39 | label_name_to_int = {'amazon instant video':0, 40 | 'books':1, 41 | 'clothing shoes and jewelry':2, 42 | 'electronics':3, 43 | 'kindle store':4, 44 | 'movies and tv':5, 45 | 'musical instruments':6, 46 | 'office products':7, 47 | 'tools and home improvement':8} 48 | 49 | elif task_name == 'wic': 50 | label_name_to_int = {"Yes": 1, "No": 0} 51 | elif task_name == 'super_glue_wsc': 52 | label_name_to_int = {"True": 1, "False": 0} 53 | elif task_name == "sst2": 54 | label_name_to_int = {"positive": 1, "negative": 0, "neutral": 0} 55 | elif task_name == "super_glue_boolq": 56 | label_name_to_int = {"true": 1, "false": 0} 57 | elif task_name in ["story_cloze", "story_cloze_v2", "story_cloze_v3"]: 58 | label_name_to_int = {"2": 1, "1": 0} 59 | elif task_name in ["anli_r1", "anli_r2", "anli_r3"]: 60 | label_name_to_int = {"true": 1, "false": 0, "neither": 2} 61 | elif task_name == "MR" or task_name == "mr": 62 | label_name_to_int = {"positive": 1, "negative": 0} 63 | elif task_name == "multirc": 64 | label_name_to_int = {"yes": 1, "no": 0} 65 | elif task_name == "super_glue_cb": 66 | label_name_to_int = {"true": 0, "false": 2, "neither": 1} 67 | elif task_name == "super_glue_copa": 68 | label_name_to_int = {"1": 1, "2": 0} 69 | elif task_name == "drop": 70 | label_name_to_int = {"true": 1, "false": 0} 71 | elif task_name == "super_glue_record": 72 | label_name_to_int = {"true": 1, "false": 0} 73 | elif task_name == "ag_news": 74 | label_name_to_int = {"world news": 0, "sports": 1, "business": 2, "technology and science": 3} 75 | elif task_name == "dbpedia": 76 | label_name_to_int = {"company": 0, "educational institution": 1, "artist": 2, "athlete": 3, 77 | "office holder": 4, "mean of transportation": 5, "building": 6, "natural place": 7, 78 | "village": 8, "animal": 9, "plant": 10, "album": 11, "film": 12, "written work": 13} 79 | else: 80 | raise ValueError("Unsupported task!") 81 | 82 | test_data = json.load(open(dpath)) 83 | train_data = json.load(open(train_dpath)) 84 | 85 | print(train_data['0']['example']) 86 | print(train_data['0']['gold']) 87 | 88 | n_test = len(test_data) 89 | n_train = len(train_data) 90 | 91 | m = len(test_data['0']['preds_boost']) 92 | 93 | test_votes = np.zeros((n_test, m)) 94 | test_gold = np.zeros(n_test) 95 | for i in range(n_test): 96 | 97 | test_votes[i] = np.array([label_name_to_int[ans] if ans in label_name_to_int else -1 for ans in test_data[str(i)]['preds_boost']]) 98 | test_gold[i] = label_name_to_int[str(test_data[str(i)]['gold'])] 99 | 100 | test_votes = test_votes.astype(int) 101 | test_gold = test_gold.astype(int) 102 | 103 | train_votes = np.zeros((n_train, m)) 104 | train_gold = np.zeros(n_train) 105 | for i in range(n_train): 106 | train_votes[i] = np.array([label_name_to_int[ans] if ans in label_name_to_int else -1 for ans in train_data[str(i)]['preds_boost']]) 107 | train_gold[i] = label_name_to_int[str(train_data[str(i)]['gold'])] 108 | 109 | train_votes = train_votes.astype(int) 110 | train_gold = train_gold.astype(int) 111 | 112 | return train_votes, train_gold, test_votes, test_gold 113 | 114 | 115 | 116 | def get_top_deps_from_inverse_sig(J, k): 117 | m = J.shape[0] 118 | deps = [] 119 | sorted_idxs = np.argsort(np.abs(J), axis=None) 120 | n = m*m 121 | idxs = sorted_idxs[-k:] 122 | for idx in idxs: 123 | i = int(np.floor(idx / m)) 124 | j = idx % m 125 | if (j, i) in deps: 126 | continue 127 | deps.append((i, j)) 128 | return deps 129 | 130 | def learn_structure(L): 131 | m = L.shape[1] 132 | n = float(np.shape(L)[0]) 133 | sigma_O = (np.dot(L.T,L))/(n-1) - np.outer(np.mean(L,axis=0), np.mean(L,axis=0)) 134 | 135 | #bad code 136 | O = 1/2*(sigma_O+sigma_O.T) 137 | O_root = np.real(sp.linalg.sqrtm(O)) 138 | 139 | # low-rank matrix 140 | L_cvx = cp.Variable([m,m], PSD=True) 141 | 142 | # sparse matrix 143 | S = cp.Variable([m,m], PSD=True) 144 | 145 | # S-L matrix 146 | R = cp.Variable([m,m], PSD=True) 147 | 148 | #reg params 149 | lam = 1/np.sqrt(m) 150 | gamma = 1e-8 151 | 152 | objective = cp.Minimize(0.5*(cp.norm(R @ O_root, 'fro')**2) - cp.trace(R) + lam*(gamma*cp.pnorm(S,1) + cp.norm(L_cvx, "nuc"))) 153 | constraints = [R == S - L_cvx, L_cvx>>0] 154 | 155 | prob = cp.Problem(objective, constraints) 156 | result = prob.solve(verbose=False, solver=cp.SCS) 157 | opt_error = prob.value 158 | 159 | #extract dependencies 160 | J_hat = S.value 161 | 162 | if J_hat is None: 163 | raise ValueError("CVXPY failed to solve the structured learning problem, use result without dependencies.") 164 | 165 | for i in range(m): 166 | J_hat[i, i] = 0 167 | return J_hat 168 | 169 | 170 | def learn_structure_multiclass(L, k): 171 | m = L.shape[1] 172 | J_hats = np.zeros((k, m, m)) 173 | for c in range(k): 174 | 175 | all_votes_c = np.where(L == c, 1, 0) 176 | J_hats[c] = learn_structure(all_votes_c) 177 | 178 | return J_hats 179 | 180 | def get_min_off_diagonal(J_hat): 181 | J_hat_copy = J_hat.copy() 182 | for i in range(len(J_hat_copy)): 183 | J_hat_copy[i, i] = np.inf 184 | return np.abs(J_hat_copy).min() 185 | 186 | def main(): 187 | args = get_args() 188 | task_name = args.task_name 189 | data_dir = args.data_dir 190 | symmetric = args.symmetric 191 | 192 | train_votes, train_gold, test_votes, test_gold = get_data(task_name, data_dir, args.model_prefix, args.override_date) 193 | classes = np.sort(np.unique(test_gold)) 194 | vote_classes = np.sort(np.unique(test_votes)) 195 | n_train, m = train_votes.shape 196 | n_test = len(test_votes) 197 | k = len(classes) 198 | abstains = len(vote_classes) == len(classes) + 1 199 | print(f"Abstains: {abstains}") 200 | 201 | m = test_votes.shape[1] 202 | 203 | all_votes= np.concatenate((train_votes, test_votes)) 204 | 205 | label_model = LabelModel(k=k, seed=123) 206 | 207 | # scale to 0, 1, 2 (0 is abstain) 208 | test_votes_scaled = (test_votes + np.ones((n_test, m))).astype(int) 209 | test_gold_scaled = (test_gold + np.ones(n_test)).astype(int) 210 | 211 | train_votes_scaled = (train_votes + np.ones((n_train, m))).astype(int) 212 | train_gold_scaled = (train_gold + np.ones(n_train)).astype(int) 213 | 214 | all_votes_scaled = np.concatenate((train_votes_scaled, test_votes_scaled)) 215 | 216 | label_model.train_model(all_votes_scaled, Y_dev=train_gold_scaled, abstains=abstains, symmetric=False, n_epochs=10000, log_train_every=50, lr=0.00001) 217 | 218 | 219 | print('Trained Label Model Metrics (No deps):') 220 | scores = label_model.score((test_votes_scaled, test_gold_scaled), metric=['accuracy','precision', 'recall', 'f1']) 221 | print(scores) 222 | 223 | all_votes_no_abstains = np.where(all_votes == -1, 0, all_votes) 224 | 225 | if len(classes) == 2: 226 | J_hat = learn_structure(all_votes_no_abstains) 227 | else: 228 | J_hats = learn_structure_multiclass(all_votes_no_abstains, len(classes)) 229 | J_hat = J_hats.mean(axis=0) 230 | 231 | # if values in J are all too large, then everything is connected / structure learning isn't learning the right thing. Don't model deps then 232 | min_entry = get_min_off_diagonal(J_hat) 233 | if min_entry < 1: 234 | deps = get_top_deps_from_inverse_sig(J_hat, 1) 235 | print("Recovered dependencies: ", deps) 236 | 237 | label_model.train_model(all_votes_scaled, Y_dev=train_gold_scaled, abstains=abstains, symmetric=symmetric, n_epochs=80000, log_train_every=50, lr=0.000001, deps=deps) 238 | print('Trained Label Model Metrics (with deps):') 239 | scores = label_model.score((test_votes_scaled, test_gold_scaled), metric=['accuracy', 'precision', 'recall', 'f1']) 240 | print(scores) 241 | 242 | try: 243 | lm_probs = label_model.predict_proba(test_votes_scaled) 244 | agg = Aggregator(test_votes, test_gold, test_votes, test_gold, abstains, classes=[0, 1]) # 245 | print("H(Y | WS output):") 246 | print(agg.conditional_entropy_singleton(lm_probs, test_gold)) 247 | except: 248 | print(f"Failed to produce conditional entropy value: H(Y | WS output).") 249 | 250 | 251 | 252 | if __name__ == "__main__": 253 | main() 254 | -------------------------------------------------------------------------------- /boosting/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | def get_probabilties(num_lfs, num_examples, predictions, label_name_to_int): 5 | 6 | lf_array = np.zeros((num_lfs, num_examples)) 7 | golds = [] 8 | 9 | # Collect golds and preds 10 | for i, (k, item) in enumerate(predictions.items()): 11 | preds = item['chosen_answers_lst'] 12 | preds_mapped = [] 13 | for p in preds: 14 | if p in label_name_to_int: 15 | preds_mapped.append(label_name_to_int[p]) 16 | else: 17 | preds_mapped.append(0) 18 | preds = preds_mapped.copy() 19 | for lf_num, p in zip(range(num_lfs), preds): 20 | lf_array[lf_num][i] = p 21 | gold = label_name_to_int[item['gold']] 22 | golds.append(gold) 23 | golds = np.array(golds) 24 | neg_indices, pos_indices = [np.where(golds == -1)[0], np.where(golds == 1)[0]] 25 | indices = { 26 | -1: neg_indices, 27 | 1: pos_indices 28 | } 29 | 30 | # [i, j, k] = Pr(prompt_i = j| y = k) 31 | # Accuracies 32 | lf_accuracies = [] 33 | for i in range(num_lfs): 34 | lf_accuracies.append(np.sum(golds == np.array(lf_array[i]))/num_examples) 35 | print(f"LF Accs: {lf_accuracies}") 36 | 37 | # [i, j, k] = Pr(prompt_i = j| y = k) 38 | classes = label_name_to_int.values() 39 | accs = np.zeros((num_lfs, len(classes), len(classes))) 40 | for p in range(num_lfs): 41 | for i in classes: 42 | for j in classes: 43 | j_idx = j 44 | if j == -1: 45 | j_idx = 0 46 | i_idx = i 47 | if i == -1: 48 | i_idx = 0 49 | accs[p, i_idx, j_idx] = len(np.where(lf_array[p, indices[i]] == j)[0]) / len(indices[i]) 50 | 51 | # Compute probabilities 52 | pos_probs = [] 53 | for i in range(num_lfs): 54 | sub_preds = lf_array[i][pos_indices] 55 | sub_golds = golds[pos_indices] 56 | pos_probs.append(np.sum(sub_golds == np.array(sub_preds))/len(pos_indices)) 57 | print(f"Pos Probs: {pos_probs}") 58 | 59 | neg_probs = [] 60 | for i in range(num_lfs): 61 | sub_preds = lf_array[i][neg_indices] 62 | sub_golds = golds[neg_indices] 63 | neg_probs.append(np.sum(sub_golds == np.array(sub_preds))/len(neg_indices)) 64 | print(f"Neg Probs: {neg_probs}\n\n") 65 | 66 | return lf_accuracies, accs, pos_probs, neg_probs, golds, indices 67 | 68 | 69 | """ Independence Assumption: take the product of probabilities as p(L1, L2, ..., LK | y) """ 70 | 71 | # Pr(y = 1 | lf votes) 72 | def get_cond_probs(votes, y, indices_train, golds_train, accs_train, num_lfs_test): 73 | prop_pos = len(indices_train[1])/len(golds_train) 74 | pr_y = prop_pos if y == 1 else 1 - prop_pos 75 | prod = pr_y 76 | for i in range(num_lfs_test): 77 | if y == -1: 78 | y = 0 79 | prod *= accs_train[i, y, votes[i]] 80 | return prod 81 | 82 | # Pr(y = 1 | lf votes) 83 | def get_probs(votes, indices_train, golds_train, acc_train, num_lfs_test): 84 | votes = [max(v, 0) for v in votes] 85 | numerator = get_cond_probs(votes, 1, indices_train, golds_train, acc_train, num_lfs_test) 86 | denominator = numerator + get_cond_probs(votes, -1, indices_train, golds_train, acc_train, num_lfs_test) 87 | return numerator / denominator 88 | 89 | 90 | def get_nb_accuracy(num_examples_test, num_lfs_test, predictions_test, label_name_to_int, golds_test, indices_train, golds_train, accs_train): 91 | output = np.zeros(num_examples_test) 92 | errors = 0 93 | for i, (k, item) in enumerate(predictions_test.items()): 94 | votes = item['chosen_answers_lst'] 95 | votes_mapped = [] 96 | for v in votes: 97 | if v in label_name_to_int: 98 | votes_mapped.append(label_name_to_int[v]) 99 | else: 100 | votes_mapped.append(0) 101 | votes = votes_mapped.copy() 102 | probs = np.round(get_probs(votes, indices_train, golds_train, accs_train, num_lfs_test)) 103 | output[i] = probs 104 | 105 | # Mean squared error 106 | g = golds_test[i] 107 | if golds_test[i] == -1: 108 | g = 0 109 | error = np.abs(output[i] - g)**2 110 | errors += error 111 | accuracy = 1 - (errors / num_examples_test) 112 | return accuracy, output 113 | 114 | 115 | def estimate_matrix(m, n, L): 116 | E_prod = np.zeros((m, m)) 117 | l_avg = np.zeros(m) 118 | for i in range(n): 119 | l = L[i, :] 120 | l_avg += l 121 | E_prod += np.outer(l, l) 122 | 123 | l_avg = l_avg/n 124 | E_prod = E_prod/n 125 | 126 | cov = E_prod - np.outer(l_avg, l_avg) 127 | 128 | return (E_prod, cov, l_avg) 129 | 130 | 131 | def get_vote_vectors(num_samples, num_lfs, predictions, label_name_to_int): 132 | vectors = np.zeros((num_samples, num_lfs+1), float) 133 | vectors_no_y = np.zeros((num_samples, num_lfs), float) 134 | labels_vector = np.zeros((num_samples, 1), float) 135 | for i, p in enumerate(predictions.values()): 136 | votes = p['chosen_answers_lst'] 137 | votes_mapped = [] 138 | for v in votes: 139 | if v in label_name_to_int: 140 | votes_mapped.append(label_name_to_int[v]) 141 | else: 142 | votes_mapped.append(0) 143 | votes = votes_mapped.copy() 144 | # votes = [max(v, 0) for v in votes] 145 | gold = p['gold'] 146 | gold = label_name_to_int[gold] 147 | vectors_no_y[i] = np.array(votes) 148 | vectors[i] = np.array(votes + [gold]) #- lf_accuracies_train 149 | labels_vector[i] = np.array([gold]) 150 | print(f"Shape: {vectors.shape}") 151 | print(f"Sample: {vectors[0]}") 152 | 153 | return vectors, vectors_no_y, labels_vector 154 | 155 | def get_feature_vector(vote_vectors, include_pairwise=False, include_singletons=True): 156 | feature_vectors = [] 157 | for votes in vote_vectors: 158 | if include_singletons: 159 | feature_vector = list(votes[:]) 160 | else: 161 | feature_vector = [] 162 | if include_pairwise: 163 | for subset in itertools.combinations(votes[:], 2): 164 | feature_vector.append(subset[0] * subset[1]) 165 | feature_vectors.append(feature_vector) 166 | X = np.matrix(feature_vectors) 167 | return X -------------------------------------------------------------------------------- /diagnostics/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the diagnostic tasks for Ask Me Antying (AMA) as introduced in the main paper. The diagnostic tasks are aligned with the types of reasoning steps involved in the end-to-end AMA strategy, including question generation, text generation, extraction (from context), and selection (from listed answer choices). Given a new language model, this diagnostic may help anticipate whether AMA will provide lift. 2 | 3 | 4 | With a running manifest session, you can run the diagnostic as follows: 5 | 6 | ``` 7 | unzip data.zip 8 | python run_diagnostics.py 9 | ``` -------------------------------------------------------------------------------- /diagnostics/data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/diagnostics/data.zip -------------------------------------------------------------------------------- /diagnostics/run_diagnostics.py: -------------------------------------------------------------------------------- 1 | """ Running and Scoring the AMA Diagnostics """ 2 | 3 | import os 4 | import json 5 | from collections import Counter 6 | from tqdm import tqdm 7 | import openai 8 | from manifest import Manifest 9 | openai.api_key = "" # Find this on the OpenAI Website 10 | 11 | from datasets import load_metric 12 | rouge = load_metric("rouge") 13 | 14 | ######################### HELPER FUNCTIONS ######################### 15 | 16 | def rougeL(preds, labels): 17 | return rouge.compute(predictions=preds, references=labels)['rougeL'].mid.fmeasure 18 | 19 | from collections import Counter 20 | def text_f1(preds=[], labels=[]): 21 | """Compute average F1 of text spans. 22 | 23 | Taken from Squad without prob threshold for no answer. 24 | """ 25 | total_f1 = 0 26 | for pred, gold in zip(preds, labels): 27 | pred_toks = pred.split() 28 | gold_toks = gold.split() 29 | common = Counter(pred_toks) & Counter(gold_toks) 30 | num_same = sum(common.values()) 31 | if len(gold_toks) == 0 or len(pred_toks) == 0: 32 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 33 | total_f1 += int(gold_toks == pred_toks) 34 | elif num_same == 0: 35 | total_f1 += 0 36 | else: 37 | precision = 1.0 * num_same / len(pred_toks) 38 | recall = 1.0 * num_same / len(gold_toks) 39 | f1 = (2 * precision * recall) / (precision + recall) 40 | total_f1 += f1 41 | f1_avg = total_f1 / len(labels) 42 | return f1_avg 43 | 44 | def get_response(manifest, prompt, max_toks=10, temperature = 0, gold_choices=[], model_name="manifest", engine="text-davinci-002", logit_bias={}): 45 | prompt = prompt.strip() 46 | if model_name == 'openai': 47 | if logit_bias: 48 | completion = openai.Completion.create(engine=engine, prompt=prompt, temperature=temperature, top_p=1, max_tokens=max_toks, logprobs=5, logit_bias=logit_bias) 49 | else: 50 | completion = openai.Completion.create(engine=engine, prompt=prompt, temperature=temperature, top_p=1, max_tokens=max_toks, logprobs=5) 51 | response = completion.choices[0].text 52 | 53 | if model_name == "manifest": 54 | if gold_choices: 55 | max_len = max([len(g.split()) for g in gold_choices]) 56 | else: 57 | max_len = 0 58 | max_token_args = ({"max_tokens": min(max_toks,8 * len(max_len),)} 59 | if gold_choices is None 60 | else {} 61 | ) 62 | if gold_choices: 63 | response = manifest.run(prompt, gold_choices=gold_choices,overwrite_cache=False,**max_token_args,) 64 | else: 65 | response = manifest.run(prompt, max_tokens=max_toks, overwrite_cache=False) 66 | 67 | return response 68 | 69 | def get_manifest_session( 70 | client_name="huggingface", 71 | client_connection="http://127.0.0.1:5000", 72 | cache_connection=None, 73 | temperature=0, 74 | top_p=1.0, 75 | ): 76 | if client_name == "huggingface" and temperature == 0: 77 | params = { 78 | "temperature": 0.001, 79 | "do_sample": False, 80 | "top_p": top_p, 81 | } 82 | elif client_name in {"openai", "ai21"}: 83 | params = { 84 | "temperature": temperature, 85 | "top_p": top_p, 86 | } 87 | else: 88 | raise ValueError(f"{client_name} is not a valid client name") 89 | manifest = Manifest( 90 | client_name=client_name, 91 | client_connection=client_connection, 92 | cache_name="sqlite", 93 | cache_connection=cache_connection, 94 | **params, 95 | ) 96 | model_name = manifest.client.get_model_params()["model_name"] 97 | return manifest, model_name 98 | 99 | ######################### SCORE FUNCTIONS ######################### 100 | 101 | 102 | """ A "blah" maps to category: blah """ 103 | def selection_easy(dataset, manifest): 104 | preds = [] 105 | for i, (ind, row) in tqdm(enumerate(dataset.items())): 106 | prefix = row['input'] 107 | pred = get_response(manifest, prefix, max_toks=50) 108 | pred = [p for p in pred.split("\n") if p][0].strip() 109 | preds.append(pred) 110 | 111 | if i == 0: 112 | print(prefix) 113 | print(f"PRED: {pred}") 114 | 115 | labels = [l.strip() for l in row['labels']] 116 | num_valid = [p for p in preds if p in labels] 117 | return len(num_valid)/len(dataset) 118 | 119 | 120 | """ Does the model pick one of the given choices? """ 121 | def selection_hard(dataset, manifest): 122 | preds = [] 123 | for i, (ind, row) in tqdm(enumerate(dataset.items())): 124 | prefix = row['input'] 125 | pred = get_response(manifest, prefix, max_toks=50) 126 | pred = [p for p in pred.split("\n")][0] 127 | preds.append(pred) 128 | 129 | if i == 0: 130 | print(prefix) 131 | print(f"PRED: {pred}") 132 | 133 | valid = 0 134 | for (ind, row), pred in zip(dataset.items(), preds): 135 | choices = row['output'] 136 | if pred.lower().strip(".") in [c.lower().strip(".") for c in choices]: 137 | valid += 1 138 | return valid/len(dataset) 139 | 140 | 141 | """ Does the model generate three choices? """ 142 | def text_generation(dataset, manifest): 143 | preds = [] 144 | for i, (ind, row) in tqdm(enumerate(dataset.items())): 145 | prefix = row['input'] 146 | pred = get_response(manifest, prefix, max_toks=50) 147 | pred = pred.split("\n\n")[0] 148 | pred = pred.split("\n") 149 | pred = list(set([a.replace("- ", "").strip() for a in pred])) 150 | preds.append(pred) 151 | 152 | if i == 0: 153 | print(prefix) 154 | print(f"PRED: {pred}") 155 | 156 | valid = 0 157 | for pred in preds: 158 | if len(pred) == 2: 159 | valid += 1 160 | return valid/len(dataset) 161 | 162 | 163 | """ Does the model faithfully transform the statement to a question? """ 164 | def question_generation(dataset, manifest): 165 | preds = [] 166 | for i, (ind, row) in tqdm(enumerate(dataset.items())): 167 | prefix = row['input'] 168 | pred = get_response(manifest, prefix, max_toks=50) 169 | pred = [p for p in pred.split("\n")][0] 170 | preds.append(pred) 171 | 172 | if i == 0: 173 | print(prefix) 174 | print(f"PRED: {pred}") 175 | 176 | outputs = [row['output'] for ind, row in dataset.items()] 177 | score = rougeL(preds=preds, labels = outputs) 178 | return score 179 | 180 | 181 | """ Does the model faithfully choose the sentence with the entity name? """ 182 | """ Does the model faithfully answer given a keyword for extraction? """ 183 | def extraction(dataset, manifest): 184 | preds = [] 185 | for i, (ind, row) in tqdm(enumerate(dataset.items())): 186 | prefix = row['input'] 187 | pred = get_response(manifest, prefix, max_toks=50) 188 | pred = [p for p in pred.split("\n")][0] 189 | preds.append(pred) 190 | 191 | if i == 0: 192 | print(prefix) 193 | print(f"PRED: {pred}") 194 | 195 | outputs = [row['output'] for ind, row in dataset.items()] 196 | score = text_f1(preds=preds, labels = outputs) 197 | return score 198 | 199 | 200 | def main(): 201 | data_dir = "data/" 202 | synthetics = { 203 | 'selection_easy':selection_easy, 204 | 'selection_hard':selection_hard, 205 | 'extraction':extraction, 206 | "text_generation": text_generation, 207 | "question_generation": question_generation 208 | } 209 | 210 | manifest, model_name = get_manifest_session() 211 | 212 | synthetic_scores = {} 213 | for synthetic, function in synthetics.items(): 214 | print(f"RUNNING {synthetic}") 215 | with open(f"{data_dir}/{synthetic}.json") as f: 216 | dataset = json.load(f) 217 | score = function(dataset, manifest) 218 | synthetic_scores[synthetic] = score 219 | print(f"SCORE: {score}") 220 | 221 | print(synthetic_scores) 222 | model_name = model_name.replace("/", "_") 223 | 224 | if not os.path.exists(f"results/"): 225 | os.makedirs(f"results/") 226 | 227 | with open(f"results/{model_name}_results.json", "w") as f: 228 | json.dump(synthetic_scores, f) 229 | 230 | print(f"Saved to: results/{model_name}_results.json") 231 | 232 | if __name__ == "__main__": 233 | main() 234 | -------------------------------------------------------------------------------- /download_p3.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import tensorflow as tf 3 | import pandas as pd 4 | from pathlib import Path 5 | import json 6 | from tqdm.auto import tqdm 7 | import os 8 | 9 | DATA_DIR = os.environ.get("AMA_DATA", "/home/data") 10 | 11 | # Download P3 github data from HF website 12 | ''' 13 | git lfs install 14 | git clone https://huggingface.co/datasets/bigscience/P3 15 | ''' 16 | 17 | import sys 18 | sys.path.append(f"{DATA_DIR}/P3") 19 | # From P3 github 20 | from tasks_splits_and_features import DATA_SPLITS_SIZES 21 | 22 | SPLITS = ["train", "test", "validation"] 23 | _FEAT_MAPPING_FUNCTIONS = { 24 | "answer_choices": lambda x: [choice.decode("utf-8") for choice in x], 25 | "inputs": lambda x: x.tolist(), 26 | "inputs_pretokenized": lambda x: x.decode("utf-8"), 27 | "targets": lambda x: x.tolist(), 28 | "targets_pretokenized": lambda x: x.decode("utf-8"), 29 | "idx": lambda x: x.tolist(), 30 | "weight": lambda x: float(x), 31 | "is_correct": lambda x: x, 32 | } 33 | def _feature_config(shape, dtype): 34 | if dtype in ("int32", "bool"): 35 | # int32 and bool are stored as int64 in the tf.train.Example protobuf. 36 | dtype = "int64" 37 | if shape and shape[0] is None: 38 | return tf.io.FixedLenSequenceFeature( 39 | shape[1:], dtype, allow_missing=True 40 | ) 41 | return tf.io.FixedLenFeature(shape, dtype) 42 | 43 | @tf.autograph.experimental.do_not_convert 44 | def extract_dataset(data_path, subfolder, sizes, processes=10, splits=SPLITS): 45 | datasets = {} 46 | for split in splits: 47 | if not (data_path / subfolder / f"info.{split}.json").exists(): 48 | continue 49 | features_dict = json.load(open(data_path / subfolder / f"info.{split}.json")) 50 | if "features" not in features_dict: 51 | features_dict = json.load(open(data_path / subfolder / f"info.train.json")) 52 | if "features" not in features_dict: 53 | continue 54 | features_dict = features_dict["features"] 55 | tfrecord = str(data_path / subfolder / f"{split}.tfrecord-00000-of-00001") 56 | feature_description = { 57 | feat: _feature_config(**desc) for feat, desc in features_dict.items() 58 | } 59 | ds = tf.data.TFRecordDataset(tf.io.gfile.glob([tfrecord])) # TODO -> handle multiple shards 60 | ds = ds.map( 61 | lambda pb: tf.io.parse_single_example(pb, feature_description), 62 | num_parallel_calls=processes 63 | ) 64 | # Cast features back to the types from the info JSON since some features 65 | # must be cast for storage (e.g., int32 is stored as int64). 66 | ds = ds.map( 67 | lambda x: {k: tf.cast(v, features_dict[k]["dtype"]) for k, v in x.items()}, 68 | num_parallel_calls=processes 69 | ) 70 | res = [] 71 | for ex in tqdm(ds.as_numpy_iterator(), total=sizes.get(split, 10000)): 72 | ex_dict = {} 73 | for feat_name, feat_value in ex.items(): 74 | ex_dict[feat_name] = _FEAT_MAPPING_FUNCTIONS[feat_name](feat_value) 75 | res.append(ex_dict) 76 | if len(res) > 0: 77 | df = pd.DataFrame.from_records(res) 78 | datasets[split] = df 79 | return datasets 80 | 81 | data_path = Path(f"{DATA_DIR}/P3/data") 82 | output_data_path = Path(f"{DATA_DIR}/P3/data_feather") 83 | for subfolder in data_path.iterdir(): 84 | print(subfolder) 85 | subfolder = subfolder.name 86 | out_path = output_data_path / subfolder 87 | out_path.mkdir(parents=True, exist_ok=True) 88 | splits_to_pass = [] 89 | for split in SPLITS: 90 | if not (out_path / f"{split}.feather").exists(): 91 | splits_to_pass.append(split) 92 | datasets = extract_dataset(data_path, subfolder, sizes=DATA_SPLITS_SIZES[str(subfolder)], splits=splits_to_pass) 93 | for split, df in datasets.items(): 94 | df.to_feather(out_path / f"{split}.feather") -------------------------------------------------------------------------------- /imgs/crfm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/imgs/crfm.png -------------------------------------------------------------------------------- /imgs/decomp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/imgs/decomp.png -------------------------------------------------------------------------------- /imgs/numbers_station.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/imgs/numbers_station.png -------------------------------------------------------------------------------- /imgs/snorkel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/ama_prompting/460843d93a9e4bf2115eb35e4a02e82b6b67feac/imgs/snorkel.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cvxpy==1.2.1 2 | datasets==1.4.0 3 | dill==0.3.5.1 4 | GPUtil==1.4.0 5 | ipdb==0.13.9 6 | ipython==8.5.0 7 | matplotlib==3.5.2 8 | more_itertools==8.14.0 9 | networkx==2.8.5 10 | nltk==3.7 11 | numpy==1.23.3 12 | pandas==1.4.2 13 | scikit_learn==1.1.2 14 | scipy==1.8.1 15 | setuptools==59.5.0 16 | snorkel==0.9.9 17 | tensorboardX==2.5.1 18 | tensorflow==2.9.1 19 | tools==0.1.9 20 | torch==1.12.1 21 | torchtext==0.13.1 22 | tqdm 23 | -------------------------------------------------------------------------------- /tasks/CB_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | from tqdm.auto import tqdm 4 | import pandas as pd 5 | 6 | from sklearn.metrics import classification_report 7 | from decomposition import Decomposition, get_args, DATA_DIR 8 | from utils import get_response, InputOutputPrompt 9 | 10 | questioner = InputOutputPrompt( 11 | input_formatter=lambda x: f"Statement: {x['statement']}", 12 | output_formatter=lambda x: f"Question: {x['question']}", 13 | required_keys=["statement", "question"], 14 | input_output_sep="\n", 15 | example_sep="\n\n", 16 | instruction="Rewrite the statement as a yes/no question.\n\n" 17 | ) 18 | 19 | questioner_examples = [ 20 | pd.DataFrame([ 21 | { 22 | "statement": "most of the light comes from the sun", 23 | "question": "Does most of the light come from the sun?" 24 | }, 25 | { 26 | "statement": "the test was not", 27 | "question": "Was the test hard?" 28 | }, 29 | { 30 | "statement": "it is a good idea to buy your parents gifts", 31 | "question": "Is it a good idea to buy your parents gifts?" 32 | }, 33 | { 34 | "statement": "the balloon popped", 35 | "question": "Did the balloon pop?" 36 | }, 37 | { 38 | "statement": "The father and son went camping to California.", 39 | "question": "Did the father and son go camping?" 40 | } 41 | ]), 42 | pd.DataFrame([ 43 | { 44 | "statement": "most of the light comes from the sun", 45 | "question": "Does most of the light come from the sun?" 46 | }, 47 | { 48 | "statement": "the test was not", 49 | "question": "Was the test hard?" 50 | }, 51 | { 52 | "statement": "it is a good idea to buy your parents gifts", 53 | "question": "Is it a good idea to buy your parents gifts?" 54 | }, 55 | { 56 | "statement": "the balloon popped", 57 | "question": "Did the balloon pop?" 58 | }, 59 | { 60 | "statement": "The father and son went camping to California.", 61 | "question": "Did the father and son go camping?" 62 | } 63 | ]), 64 | pd.DataFrame([ 65 | { 66 | "statement": "she prefers kittens over puppies.", 67 | "question": "Does she prefer kittens over puppies?", 68 | }, 69 | { 70 | "statement": "Max and his wife went on a trip to Europe", 71 | "question": "Did Max and his wife go on a trip to Europe?", 72 | }, 73 | { 74 | "statement": "jared was born during the war in 1942.", 75 | "question": "Was Jared born during a war in 1942?", 76 | }, 77 | { 78 | "statement": "it took jenna 7 attempts to solve the problem", 79 | "question": "Did it take Jenna 7 attempts to solve the problem?", 80 | }, 81 | ]), 82 | ] 83 | 84 | openended_qa = InputOutputPrompt( 85 | input_formatter=lambda x: f"Passage: {x['passage']}\nQuestion: {x['question']}", 86 | output_formatter=lambda x: f"Answer: {x['answer']}", 87 | required_keys=["passage", "question", "answer"], 88 | input_output_sep="\n", 89 | example_sep="\n\n", 90 | instruction="Provide the answer to the question from the passage.\n\n" 91 | ) 92 | 93 | openended_qa_examples = [ 94 | pd.DataFrame([ 95 | { 96 | "passage": "When Judy and Jack went to school, they got in trouble with their teacher for being late. I didn't think it was very fair.", 97 | "question": "Did she think it was fair?", 98 | "answer": "No" 99 | }, 100 | { 101 | "passage": "If inflation is occurring, leading to higher prices for basic necessities such as gas by 2 dollars. Do you think that inflation is good for society?", 102 | "question": "Is inflation good for society?", 103 | "answer": "Maybe" 104 | }, 105 | { 106 | "passage": "Put yourself out there. The more time you spend dating and socializing, the more likely you will find a boyfriend you like.", 107 | "question": "Does socializing help you find a boyfriend?", 108 | "answer": "Yes" 109 | }, 110 | ]), 111 | pd.DataFrame([ 112 | { 113 | "passage": "Jack recommends his least favorite books of the year to his followers. The least favorite book this year was Harry Potter and the 7 Rings.", 114 | "question": "What book does Jack dislike?", 115 | "answer": "Jack does not like Harry Potter and the 7 Rings." 116 | }, 117 | { 118 | "passage": "When Judy and Jack went to school, they got in trouble with their teacher for being late. I didn't think it was very fair.", 119 | "question": "Did she think it was fair?", 120 | "answer": "No, she didn't think it was very fair." 121 | }, 122 | { 123 | "passage": "If inflation is occurring, leading to higher prices for basic necessities such as gas by 2 dollars. Do you think that inflation is good for society?", 124 | "question": "Is inflation good for society?", 125 | "answer": "Hmm. Do you think so?" 126 | }, 127 | { 128 | "passage": "Put yourself out there. The more time you spend dating and socializing, the more likely you will find a boyfriend you like.", 129 | "question": "Does socializing help you find a boyfriend?", 130 | "answer": "Yes, it helps you find a boyfriend." 131 | } 132 | ]), 133 | pd.DataFrame([ 134 | { 135 | "passage": "Anna's mother always told her to be confident even if she feels nervous on the inside", 136 | "question": "Does Anna always feel nervous on the inside?", 137 | "answer": "Unknown" 138 | }, 139 | { 140 | "passage": "Max and Jeff were extremely competitive at soccer, but Max was a lot better.", 141 | "question": "Was Jeff better than Max at soccer?", 142 | "answer": "No, Max was a lot better" 143 | }, 144 | { 145 | "passage": "When Judy and Jack went to school, they got in trouble with their teacher for being late. I didn't think it was very fair.", 146 | "question": "Did she think it was fair?", 147 | "answer": "No, she didn't think it was very fair." 148 | }, 149 | { 150 | "passage": "The FSP conference took place last week in Spain and representatives from 21 countries attended.", 151 | "question": "Did representatives from more than 20 countries attend FSP?", 152 | "answer": "Yes" 153 | }, 154 | ]), 155 | ] 156 | 157 | class CBDecomp(Decomposition): 158 | def __init__(self, task_name, data_dir, val_split="validation"): 159 | super().__init__(task_name, data_dir, val_split) 160 | 161 | def get_boost_decomp_examples(self, data_train, boost_id): 162 | return [ 163 | questioner_examples[boost_id], 164 | openended_qa_examples[boost_id], 165 | ] 166 | 167 | def zero_few_baseline( 168 | self, 169 | test_data, 170 | few_shot_df, 171 | manifest, 172 | overwrite_manifest, 173 | do_few_shot=True, 174 | ): 175 | expt_log = {} 176 | preds = [] 177 | labels = [] 178 | 179 | labels_names = set(test_data["targets_pretokenized"]) 180 | labels_names = [l.lower().strip() for l in labels_names] 181 | 182 | for i, (ind, row) in tqdm( 183 | enumerate(test_data.iterrows()), total=len(test_data) 184 | ): 185 | if ind in expt_log: 186 | pred = entry["pred"] 187 | gold = entry["gold"] 188 | else: 189 | text = row["inputs_pretokenized"] 190 | gold = row["targets_pretokenized"] 191 | 192 | icl_str = "" 193 | if do_few_shot: 194 | for s_ind, s_row in few_shot_df.iterrows(): 195 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n\n" 196 | 197 | text = row["inputs_pretokenized"] 198 | text = text.replace("True, False, or Neither?", "") 199 | text = text + ". True, False, or Neither?" 200 | gold = row["targets_pretokenized"] 201 | prompt = f"{icl_str}{{text:}}" 202 | pmp = prompt.format(text=text) 203 | if i == 0: 204 | print(pmp) 205 | 206 | raw_answer = get_response( 207 | pmp, 208 | manifest, 209 | overwrite=bool(overwrite_manifest), 210 | max_toks=30, 211 | ) 212 | answer = raw_answer.strip().lower() 213 | answer = answer.split("\n") 214 | answer = [a for a in answer if a] 215 | answer = [ 216 | a 217 | for a in answer 218 | if any(l.lower() in a.lower() for l in labels_names) 219 | ] 220 | if answer: 221 | answer = answer[0] 222 | else: 223 | answer = "" 224 | answer = "".join( 225 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']] 226 | ) 227 | is_yes = "true" in answer.split() 228 | is_no = "false" in answer.split() 229 | is_maybe = "neither" in answer.split() or "unknown" in answer.split() or "maybe" in answer.split() 230 | pred = "Neither" 231 | if is_yes and (not is_maybe and not is_no): 232 | pred = "True" 233 | if is_no and (not is_maybe and not is_yes): 234 | pred = "False" 235 | if is_maybe and (not is_no and not is_yes): 236 | pred = "Neither" 237 | entry = { 238 | "ind": ind, 239 | "example": text, 240 | "base_prompt": pmp, 241 | "raw_answer": raw_answer, 242 | "pred": pred, 243 | "gold": gold, 244 | } 245 | expt_log[ind] = entry 246 | 247 | preds.append(pred) 248 | labels.append(gold) 249 | 250 | report = classification_report(labels, preds, output_dict=True) 251 | return expt_log, report["accuracy"] 252 | 253 | def get_question(self, statement, prompt, boost_ex, manifest, overwrite_manifest): 254 | prompt_suffix = prompt(boost_ex) 255 | quesiton_prompt = f"\n{prompt_suffix}\n\nStatement: {{statement:}}\nQuestion:" 256 | question_pmp = quesiton_prompt.format(statement=statement) 257 | answer = get_response( 258 | question_pmp, 259 | manifest, 260 | overwrite=bool(overwrite_manifest), 261 | max_toks=50, 262 | ) 263 | answer = answer.split("\n")[0] 264 | return answer, question_pmp 265 | 266 | def open_qa(self, question, passage, prompt, boost_ex, manifest, overwrite_manifest): 267 | prompt_suffix = prompt(boost_ex) 268 | qa_prompt = f"\n{prompt_suffix}\n\nPassage: {{passage:}}\nQuestion: {question}\nAnswer:" 269 | qa_pmp = qa_prompt.format(passage=passage) 270 | answer = get_response( 271 | qa_pmp, 272 | manifest, 273 | overwrite=bool(overwrite_manifest), 274 | max_toks=50, 275 | ) 276 | answer = answer.split("\n")[0] 277 | answer = ( 278 | answer.replace("A: ", "") 279 | .replace("B: ", "") 280 | .replace("Answer: ", "") 281 | .replace(", ", " ") 282 | ) 283 | return answer, qa_pmp 284 | 285 | def run_decomposed_prompt( 286 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 287 | ): 288 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest) 289 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=-1) 290 | # Do WS 291 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 292 | # Get accuracies across all boost sets 293 | individual_accuracies = [] 294 | for i in range(len(all_boost_preds[0])): 295 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"]) 296 | report = classification_report(labels, preds, output_dict=True) 297 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies 298 | 299 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1): 300 | expt_log = {} 301 | all_boost_preds = [] 302 | labels = [] 303 | 304 | for i, (ind, row) in tqdm( 305 | enumerate(test_data.iterrows()), total=len(test_data) 306 | ): 307 | suffix = "True, False, or Neither?" 308 | input = row["inputs_pretokenized"] 309 | passage = input.split("\nQuestion: ")[0] 310 | statement = ( 311 | input.split("\nQuestion: ")[-1] 312 | .replace(suffix, "") 313 | .replace('"', "") 314 | .strip() 315 | ) 316 | 317 | if i == run_limit: 318 | break 319 | 320 | gold = row["targets_pretokenized"] 321 | gold = gold.lower() 322 | prompts_across_boost = [] 323 | preds_across_boost = [] 324 | for boost_examples in boost_dfs: 325 | all_prompts = [] 326 | question, question_final_prompt = self.get_question( 327 | statement, questioner, boost_examples[0], manifest, overwrite_manifest 328 | ) 329 | open_answer, answer_final_prompt = self.open_qa( 330 | question, passage, openended_qa, boost_examples[1], manifest, overwrite_manifest 331 | ) 332 | all_prompts.append(question_final_prompt) 333 | all_prompts.append(answer_final_prompt) 334 | if i == 0: 335 | print("\n".join(all_prompts)) 336 | if "Yes" in open_answer.split(): 337 | answer = "True" 338 | elif "No" in open_answer.split(): 339 | answer = "False" 340 | else: 341 | answer = "Neither" 342 | 343 | answer = answer.lower() 344 | 345 | is_yes = "yes" in answer.split() or "true" in answer.split() 346 | is_no = "no" in answer.split() or "false" in answer.split() 347 | is_maybe = "neither" in answer.split() or "maybe" in answer.split() or "unknown" in answer.split() 348 | pred = "neither" 349 | if is_yes and (not is_maybe and not is_no): 350 | pred = "true" 351 | if is_no and (not is_maybe and not is_yes): 352 | pred = "false" 353 | if is_maybe and (not is_no and not is_yes): 354 | pred = "neither" 355 | prompts_across_boost.append(all_prompts) 356 | preds_across_boost.append(pred) 357 | 358 | entry = { 359 | "ind": ind, 360 | "prompts": prompts_across_boost, 361 | "preds_boost": preds_across_boost, 362 | "example": input, 363 | "gold": gold, 364 | } 365 | expt_log[ind] = entry 366 | all_boost_preds.append(preds_across_boost) 367 | labels.append(gold) 368 | return expt_log, all_boost_preds, labels 369 | 370 | 371 | def main(): 372 | args = get_args() 373 | args.num_boost = 3 374 | task_name = "super_glue_cb" 375 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_cb_GPT_3_style/" 376 | decomp = CBDecomp(task_name, data_dir) 377 | decomp.run(args) 378 | 379 | 380 | if __name__ == "__main__": 381 | main() 382 | -------------------------------------------------------------------------------- /tasks/DBPedia_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import pandas as pd 4 | import numpy as np 5 | 6 | from tqdm.auto import tqdm 7 | 8 | from sklearn.metrics import classification_report 9 | from decomposition import Decomposition, get_args, DATA_DIR 10 | from utils import get_response, InputOutputPrompt 11 | 12 | summarize = InputOutputPrompt( 13 | input_formatter=lambda x: f"Passage: {x['passage']}", 14 | output_formatter=lambda x: f"Summarize: the passage \"Passage\": {x['summary']}", 15 | required_keys=["passage", "summary"], 16 | input_output_sep="\n", 17 | example_sep="\n\n", 18 | instruction="Summarize the passage.\n\n\"Categories\":\n- company\n- educational institution\n- artist\n- athlete\n- office holder\n- mean of transportation\n- building\n- natural place\n- village\n- animal\n- plant\n- album\n- film \n- written work\n\n" 19 | ) 20 | 21 | summarize_examples = [ 22 | pd.DataFrame([ 23 | { 24 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.", 25 | "summary": "The passage is about a journal." 26 | }, 27 | { 28 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.", 29 | "summary": "The passage is about a lifeboat." 30 | }, 31 | { 32 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.", 33 | "summary": "The passage is about a album." 34 | } 35 | ]), 36 | pd.DataFrame([ 37 | { 38 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.", 39 | "summary": "The passage is about a journal." 40 | }, 41 | { 42 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.", 43 | "summary": "The passage is about a album." 44 | }, 45 | { 46 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.", 47 | "summary": "The passage is about a lifeboat." 48 | }, 49 | ]), 50 | pd.DataFrame([ 51 | { 52 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.", 53 | "summary": "The passage is about a album." 54 | }, 55 | { 56 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.", 57 | "summary": "The passage is about a journal." 58 | }, 59 | { 60 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.", 61 | "summary": "The passage is about a lifeboat." 62 | }, 63 | ]) 64 | ] 65 | 66 | categorize = InputOutputPrompt( 67 | input_formatter=lambda x: f"Passage: {x['passage']}\nSummary: {x['summary']}", 68 | output_formatter=lambda x: f"The summary \"Summary\" fits \"Category\": {x['category']}", 69 | required_keys=["passage", "summary", "category"], 70 | input_output_sep="\n", 71 | example_sep="\n\n", 72 | instruction="Pick one category for the following text.\n\n\"Categories\":\n- company\n- educational institution\n- artist\n- athlete\n- office holder\n- mean of transportation\n- building\n- natural place\n- village\n- animal\n- plant\n- album\n- film\n- written work\n\n" 73 | ) 74 | categorize_examples = [ 75 | pd.DataFrame([ 76 | { 77 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.", 78 | "summary": "The passage is about a journal.", 79 | "category": "written work" 80 | }, 81 | { 82 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.", 83 | "summary": "The passage is about a lifeboat.", 84 | "category": "mean of transportation" 85 | }, 86 | { 87 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.", 88 | "summary": "The passage is about a album.", 89 | "category": "album" 90 | } 91 | ]), 92 | pd.DataFrame([ 93 | { 94 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.", 95 | "summary": "The passage is about a journal.", 96 | "category": "written work" 97 | }, 98 | { 99 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.", 100 | "summary": "The passage is about a album.", 101 | "category": "album" 102 | }, 103 | { 104 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.", 105 | "summary": "The passage is about a lifeboat.", 106 | "category": "mean of transportation" 107 | }, 108 | ]), 109 | pd.DataFrame([ 110 | { 111 | "passage": "Sayonara mo Ienakatta Natsu - Sayonara mo Ienakatta Natsu (さよならも言えなかった夏) is an album by Mikuni Shimokawa released on July 4 2007 by Pony Canyon.This album consists of eleven songs; several new songs and some songs which were previously released as singles.", 112 | "summary": "The passage is about a album.", 113 | "category": "album" 114 | }, 115 | { 116 | "passage": "Personality and Mental Health - Personality and Mental Health: Multidisciplinary Studies from Personality Dysfunction to Criminal Behaviour is a quarterly peer-reviewed academic journal published by Wiley-Blackwell on behalf of the Centre for Health and Justice.", 117 | "summary": "The passage is about a journal.", 118 | "category": "written work" 119 | }, 120 | { 121 | "passage": "RNLB Mona (ON 775) - RNLB Mona (ON 775) was a Watson Class lifeboat based at Broughty Ferry in Scotland that capsized during a rescue attempt with the loss of her entire crew of eight men. The Mona was built in 1935 and in her time saved 118 lives.", 122 | "summary": "The passage is about a lifeboat.", 123 | "category": "mean of transportation" 124 | }, 125 | ]) 126 | ] 127 | description_zeroshot=""" 128 | Pick the correct category for the passage. 129 | Categories: 130 | - company 131 | - educational institution 132 | - artist 133 | - athlete 134 | - office holder 135 | - mean of transportation 136 | - building 137 | - natural place 138 | - village 139 | - animal 140 | - plant 141 | - album 142 | - film 143 | - written work""" 144 | 145 | class DBPediaDecomp(Decomposition): 146 | def __init__(self, task_name, data_dir, val_split="validation"): 147 | super().__init__(task_name, data_dir, val_split) 148 | 149 | def get_boost_decomp_examples(self, train_data, boost_id): 150 | return [ 151 | summarize_examples[boost_id], 152 | categorize_examples[boost_id], 153 | ] 154 | 155 | def get_few_shot_examples(self, train_data, k_shot): 156 | """Get few shot examples""" 157 | labels = sorted(set(train_data["targets_pretokenized"])) 158 | num_per_class = int(np.ceil(k_shot / len(labels))) 159 | print(f"Selecting {num_per_class} examples per class.") 160 | 161 | dfs = [] 162 | total_in_context = 0 163 | for label in labels: 164 | while num_per_class + total_in_context > k_shot: 165 | num_per_class -= 1 166 | sub_df = train_data[train_data["targets_pretokenized"] == label].sample( 167 | num_per_class 168 | ) 169 | dfs.append(sub_df) 170 | total_in_context += num_per_class 171 | if total_in_context == k_shot: 172 | break 173 | mini_df = pd.concat(dfs) 174 | return mini_df 175 | 176 | def zero_few_baseline( 177 | self, 178 | test_data, 179 | few_shot_df, 180 | manifest, 181 | overwrite_manifest, 182 | do_few_shot=True, 183 | ): 184 | expt_log = {} 185 | preds = [] 186 | labels = [] 187 | 188 | for i, (ind, row) in tqdm( 189 | enumerate(test_data.iterrows()), total=len(test_data) 190 | ): 191 | input = row["inputs_pretokenized"] 192 | body = input.split("written work. ")[-1] 193 | gold = row["targets_pretokenized"].strip().lower() 194 | icl_str = "" 195 | title = description_zeroshot 196 | if do_few_shot: 197 | for s_ind, s_row in few_shot_df.iterrows(): 198 | s_input = s_row.inputs_pretokenized 199 | s_body = s_input.split("written work. ")[-1] 200 | s_title = s_input.split("written work. ")[0] + "written work." 201 | s_output = s_row.targets_pretokenized.strip() 202 | icl_str += f"Passage: {s_body}\nCategory: {s_output}\n\n" 203 | 204 | icl_str = f"{title}\n\n{icl_str}" 205 | prompt = f"{icl_str}Passage: {{body:}}\nCategory:" 206 | pmp = prompt.format(body=body) 207 | if i == 0: 208 | print(pmp) 209 | 210 | output = get_response( 211 | pmp, 212 | manifest, 213 | overwrite=bool(overwrite_manifest), 214 | max_toks=10, 215 | stop_token="\n\n", 216 | ) 217 | pred = output.strip().lower() 218 | entry = { 219 | "ind": ind, 220 | "example": input, 221 | "base_prompt": pmp, 222 | "pred": pred, 223 | "gold": gold, 224 | } 225 | expt_log[ind] = entry 226 | 227 | preds.append(pred) 228 | labels.append(gold) 229 | 230 | report = classification_report(labels, preds, output_dict=True) 231 | return expt_log, report["accuracy"] 232 | 233 | def run_decomposed_prompt( 234 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 235 | ): 236 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest) 237 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1) 238 | # Do WS 239 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 240 | # Get accuracies across all boost sets 241 | individual_accuracies = [] 242 | for i in range(len(all_boost_preds[0])): 243 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"]) 244 | report = classification_report(labels, preds, output_dict=True) 245 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies 246 | 247 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1): 248 | expt_log = {} 249 | all_boost_preds = [] 250 | labels = [] 251 | 252 | for i, (ind, row) in tqdm( 253 | enumerate(test_data.iterrows()), total=len(test_data) 254 | ): 255 | if i == run_limit: 256 | break 257 | 258 | text = ( 259 | row.inputs_pretokenized.strip("\n").split("written work.")[-1].strip() 260 | ) 261 | gold = row.targets_pretokenized.strip().lower() 262 | 263 | prompts_across_boost = [] 264 | preds_across_boost = [] 265 | for boost_examples in boost_dfs: 266 | all_prompts = [] 267 | prompt_suffix = summarize(boost_examples[0]) 268 | summarize_prompt = f"{prompt_suffix}\n\nPassage: {{text:}}\nSummarize: the passage \"Passage\":" 269 | summarize_pmp = summarize_prompt.format(text=text) 270 | output = get_response( 271 | summarize_pmp, 272 | manifest, 273 | overwrite=bool(overwrite_manifest), 274 | max_toks=50, 275 | ) 276 | summary = output.split("\n")[0].split(":")[-1].strip("\n") 277 | prompt_suffix = categorize(boost_examples[1]) 278 | category_prompt = f"{prompt_suffix}\n\nPassage: {{text:}}\nSummary: {{summary:}}\nThe summary \"Summary\" fits \"Category\":" 279 | category_pmp = category_prompt.format(text=text, summary=summary) 280 | output = get_response( 281 | category_pmp, 282 | manifest, 283 | overwrite=bool(overwrite_manifest), 284 | max_toks=15, 285 | ) 286 | pred = output.split("\n")[0].strip().lower() 287 | 288 | all_prompts.append(summarize_pmp) 289 | all_prompts.append(category_pmp) 290 | if i == 0: 291 | print(summarize_pmp) 292 | print(category_pmp) 293 | prompts_across_boost.append(all_prompts) 294 | preds_across_boost.append(pred) 295 | 296 | entry = { 297 | "ind": ind, 298 | "example": text, 299 | "prompts": prompts_across_boost, 300 | "preds_boost": preds_across_boost, 301 | "gold": gold, 302 | } 303 | expt_log[ind] = entry 304 | all_boost_preds.append(preds_across_boost) 305 | labels.append(gold) 306 | return expt_log, all_boost_preds, labels 307 | 308 | 309 | def main(): 310 | args = get_args() 311 | task_name = "dbpedia" 312 | data_dir = ( 313 | f"{DATA_DIR}/P3/data_feather/dbpedia_14_pick_one_category_for_the_following_text" 314 | ) 315 | decomp = DBPediaDecomp(task_name, data_dir, val_split="test") 316 | decomp.run(args) 317 | 318 | 319 | if __name__ == "__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /tasks/NQ_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import pandas as pd 4 | import numpy as np 5 | import ast 6 | 7 | from tqdm.auto import tqdm 8 | from pathlib import Path 9 | from nltk.corpus import stopwords 10 | 11 | stops = set(stopwords.words("english")) 12 | from decomposition import Decomposition, get_args, DATA_DIR 13 | from utils import get_response, InputOutputPrompt, accuracy_span_overlap 14 | 15 | extract = InputOutputPrompt( 16 | input_formatter=lambda x: f"Question: {x['question']}", 17 | output_formatter=lambda x: f"Answer: {x['answer']}", 18 | required_keys=["question", "answer"], 19 | input_output_sep="\n", 20 | example_sep="\n\n", 21 | instruction="Produce distinct questions.\n\n" 22 | ) 23 | 24 | more_info_examples = [ 25 | pd.DataFrame([ 26 | { 27 | "question": "who plays Carrie Bradshaw in sex and the city?", 28 | "answer": "Caroline \"Carrie\" Bradshaw is a fictional character from the HBO franchise Sex and the City, portrayed by Sarah Jessica Parker." 29 | }, 30 | { 31 | "question": "what are the elements in air?", 32 | "answer": "By mole fraction (i.e., by number of molecules), dry air contains 78.08% nitrogen, 20.95% oxygen, 0.93% argon, 0.04% carbon dioxide, and small amounts of other gases" 33 | }, 34 | { 35 | "question": "what is HP company?", 36 | "answer": "HP Inc. is an American multinational information technology company headquartered in Palo Alto, California, that develops personal computers (PCs)" 37 | }, 38 | { 39 | "question": "when was the last season of FRIENDS released?", 40 | "answer": "The series finale aired on May 6, 2004, and was watched by around 52.5 million American viewers, making it the fifth-most-watched series finale in television history" 41 | } 42 | ]), 43 | 44 | ] 45 | 46 | answer = InputOutputPrompt( 47 | input_formatter=lambda x: f"Context: {x['context']}\nQuestion: {x['question']}", 48 | output_formatter=lambda x: f"Answer: {x['answer']}", 49 | required_keys=["context", "question", "answer"], 50 | input_output_sep="\n", 51 | example_sep="\n\n", 52 | instruction="Answer the question.\n\n" 53 | ) 54 | 55 | answer_question = [ 56 | pd.DataFrame([ 57 | { 58 | 'context': 'The nearest airport to Palm Springs is Indio/Palm Springs (PSP) Airport which is 2.1 miles away. ', 59 | 'question': 'what airport is closest to palm springs?', 60 | 'answer': 'Palm Springs International Airport' 61 | }, 62 | { 63 | 'context': 'Martin Luther King earned his Bachelor of Divinity degree from Crozer Theological Seminary, followed by a doctorate in Systematic Theology from Boston University.', 64 | 'question': 'what degree did martin luther king get?', 65 | 'answer': 'Bachelor of Divinity' 66 | }, 67 | { 68 | 'context': 'The Niger river runs in a crescent through Libya, Mali, Niger, on the border with Benin and then through Nigeria.', 69 | 'question': 'what countries does the niger river flow through?', 70 | 'answer': 'Libya' 71 | }, 72 | { 73 | 'context': 'Puerto Rico is a territory of the United States and uses the U.S. dollar. ', 74 | 'question': 'what type of currency is used in puerto rico?', 75 | 'answer': 'United States dollar' 76 | }, 77 | { 78 | 'context': 'kitt was voice most often by William daniels.', 79 | 'question': 'who played kitt in knight rider?', 80 | 'answer': 'William Daniels' 81 | } 82 | ]), 83 | pd.DataFrame([ 84 | { 85 | 'context': 'leonardo da vinci invented the parachute, the helicopter, double hull, an armored fighting vehicle,', 86 | 'question': 'what inventions did leonardo da vinci made?', 87 | 'answer': 'Double hull' 88 | }, 89 | { 90 | 'context': "The French franc (F) was the national currency of France prior to France's adoption of the euro (EUR) in January 2002.", 91 | 'question': 'what currency is used in france before euro?', 92 | 'answer': 'French franc' 93 | }, 94 | { 95 | 'context': 'The Isthmus of Panama, contains the country of Panama and the panama canal.', 96 | 'question': 'where is isthmus of panama located?', 97 | 'answer': 'Costa Rica' 98 | }, 99 | { 100 | 'context': 'Hurricane Irene was a large and destructive tropical cyclone which affected much of the Caribbean and East Coast', 101 | 'question': 'where did hurricane irene?', 102 | 'answer': 'Eastern United States' 103 | }, 104 | { 105 | 'context': 'Rihanna acted in This is the End and Battleship.', 106 | 'question': 'what movie did rihanna play in?', 107 | 'answer': 'This Is the End' 108 | } 109 | ]), 110 | pd.DataFrame([ 111 | { 112 | 'context': 'st vincent de paul is buried in the 6th arrondisment of Paris.', 113 | 'question': 'where is st vincent de paul buried?', 114 | 'answer': 'Paris' 115 | }, 116 | { 117 | 'context': 'Thomas Luther "Luke" Bryan (born July 17, 1976) is an American country singer and songwriter from Leesburg.', 118 | 'question': 'where is luke bryan from?', 119 | 'answer': 'Leesburg' 120 | }, 121 | { 122 | 'context': "Klum and Seal got married on 10 May 2005 on a beach in Mexico near Seal's home on Costa Careyes. ", 123 | 'question': 'where did heidi klum and seal get married?', 124 | 'answer': 'Mexico'}, 125 | { 126 | 'context': 'Tarantino starred in pulp fiction, grindhouse and others.', 127 | 'question': 'what movies did quentin tarantino star in?', 128 | 'answer': 'Grindhouse' 129 | }, 130 | { 131 | 'context': 'Countries that are sometimes considered to be entirely or partially part of the Balkans are Croatia, Serbia, Lake Prespa.', 132 | 'question': 'what country is located in the balkan peninsula?', 133 | 'answer': 'Lake Prespa' 134 | } 135 | ]) 136 | 137 | ] 138 | 139 | prefix_select_zeroshot = """Answer the question.\n\n""" 140 | 141 | 142 | class NQDecomp(Decomposition): 143 | def __init__(self, task_name, data_dir, val_split="validation"): 144 | super().__init__(task_name, data_dir, val_split) 145 | 146 | def get_boost_decomp_examples(self, train_data, boost_id): 147 | return [ 148 | more_info_examples[0], 149 | answer_question[boost_id], 150 | ] 151 | 152 | def read_data(self, save_dir, overwrite_data): 153 | save_data = Path(f"{save_dir}/{self.task_name}/data.feather") 154 | if not save_data.exists() or overwrite_data: 155 | test_data = pd.read_csv(f"{self.data_dir}/{self.task_name}/nq-test.qa.csv", sep="\t", header=None) 156 | test_data.to_feather(f"{save_data}") 157 | else: 158 | print(f"Reading test data from {save_data}") 159 | test_data = pd.read_feather(save_data) 160 | 161 | save_data = Path(f"{self.data_dir}/{self.task_name}/train_data.feather") 162 | if not save_data.exists() or overwrite_data: 163 | train_data = pd.read_json(f"{self.data_dir}/{self.task_name}/biencoder-nq-dev.json") 164 | train_data.to_feather(f"{save_data}") 165 | else: 166 | print(f"Reading train data from {save_data}") 167 | train_data = pd.read_feather(save_data) 168 | print(f"Test Data Size: {len(test_data)}") 169 | print(f"Train Data Size: {len(train_data)}") 170 | return test_data, train_data 171 | 172 | def get_few_shot_examples(self, train_data, k_shot): 173 | """Get few shot examples""" 174 | labels = train_data.question.tolist() 175 | 176 | num_per_class = int(np.ceil(k_shot / len(labels))) 177 | print(f"Selecting {num_per_class} examples per class.") 178 | 179 | dfs = [] 180 | total_in_context = 0 181 | for label in labels: 182 | while num_per_class + total_in_context > k_shot: 183 | num_per_class -= 1 184 | sub_df = train_data[train_data["question"] == label].sample(num_per_class) 185 | dfs.append(sub_df) 186 | total_in_context += num_per_class 187 | if total_in_context == k_shot: 188 | break 189 | mini_df = pd.concat(dfs) 190 | return mini_df 191 | 192 | def zero_few_baseline( 193 | self, 194 | test_data, 195 | few_shot_df, 196 | manifest, 197 | overwrite_manifest, 198 | prompt_suffix="", 199 | do_few_shot=True, 200 | ): 201 | expt_log = {} 202 | preds = [] 203 | labels = [] 204 | 205 | for i, (ind, row) in tqdm( 206 | enumerate(test_data.iterrows()), total=len(test_data) 207 | ): 208 | question = row.question 209 | if isinstance(row.answers, str): 210 | label = ast.literal_eval(row.answers) 211 | else: 212 | label = row.answers.tolist() 213 | gold = label 214 | 215 | icl_str = "" 216 | if do_few_shot: 217 | for s_ind, s_row in few_shot_df.iterrows(): 218 | input = s_row.question 219 | s_question = s_row.question 220 | if isinstance(s_row.answers, str): 221 | label = ast.literal_eval(s_row.answers) 222 | else: 223 | label = s_row.answers.tolist() 224 | icl_str += f"Question: {s_question}\nAnswer: {label[0]}\n\n" 225 | 226 | prompt = ( 227 | icl_str 228 | + "Question: {question:}" 229 | + prompt_suffix 230 | + "\nAnswer:" 231 | ) 232 | 233 | if i == 0: 234 | print(prompt.format(question=question)) 235 | prompt = prompt.format(question=question) 236 | raw_answer = get_response( 237 | prompt, #prompt.format(question=question), 238 | manifest, 239 | overwrite=bool(overwrite_manifest), 240 | max_toks=20, 241 | stop_token="\n\n", 242 | ) 243 | pred = raw_answer.strip("\n").strip().lower() 244 | entry = { 245 | "ind": ind, 246 | "example": question, 247 | "base_prompt": prompt, 248 | "raw_answer": raw_answer, 249 | "pred": pred, 250 | "gold": gold, 251 | } 252 | expt_log[ind] = entry 253 | 254 | preds.append([pred]) 255 | labels.append(gold) 256 | metric = accuracy_span_overlap(preds=preds, golds=labels) 257 | return expt_log, metric 258 | 259 | def run_decomposed_prompt( 260 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 261 | ): 262 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest) 263 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000) 264 | 265 | # Do WS 266 | boost_test, boost_train = [], [] 267 | for p in all_boost_preds: 268 | samples = [lf[1] for lf in p] 269 | boost_test.append(samples) 270 | for p in all_boost_train_preds: 271 | samples = [lf[1] for lf in p] 272 | boost_train.append(samples) 273 | 274 | preds = self.merge_boosted_preds(boost_test, boost_train, train_labels, expt_log, expt_log_train) 275 | preds = [(x,y) for x,y in zip([p[0][0] for p in all_boost_preds], preds)] 276 | 277 | # Get accuracies across all boost sets 278 | individual_accuracies = [] 279 | for i in range(len(all_boost_preds[0])): 280 | individual_accuracies.append(accuracy_span_overlap(preds=[p[i] for p in all_boost_preds], golds=labels)) 281 | 282 | metric = accuracy_span_overlap(preds=preds, golds=labels) 283 | return expt_log, expt_log_train, metric, individual_accuracies 284 | 285 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1): 286 | expt_log = {} 287 | all_boost_preds = [] 288 | labels = [] 289 | 290 | 291 | for i, (ind, row) in tqdm( 292 | enumerate(test_data.iterrows()), total=len(test_data) 293 | ): 294 | if i == run_limit: 295 | break 296 | 297 | question = row.question 298 | if isinstance(row.answers, str): 299 | label = ast.literal_eval(row.answers) 300 | else: 301 | label = row.answers.tolist() 302 | 303 | gold = label 304 | prompts_across_boost = [] 305 | preds_across_boost = [] 306 | 307 | # extract context 308 | prompt_suffix = extract(boost_dfs[0][0]) 309 | prompt = ( 310 | prompt_suffix + "\n\Question: {question:}\nAnswer:" 311 | ) 312 | more_info_answer = get_response( 313 | prompt.format(question=question), 314 | manifest, 315 | overwrite=bool(overwrite_manifest), 316 | max_toks=20, 317 | stop_token="\n\n", 318 | ) 319 | 320 | for boost_examples in boost_dfs: 321 | all_prompts = [] 322 | prompt_suffix = answer(boost_examples[1]) 323 | prompt = ( 324 | prompt_suffix + "\n\nContext: {text:}\nQuestion: {question:}\nAnswer:" 325 | ) 326 | all_prompts.append(prompt) 327 | if i == 0: 328 | print(prompt.format(text=more_info_answer, question=question)) 329 | 330 | raw_answer = get_response( 331 | prompt.format(text=more_info_answer, question=question), 332 | manifest, 333 | overwrite=bool(overwrite_manifest), 334 | max_toks=20, 335 | stop_token="\n\n", 336 | ) 337 | pred = raw_answer.split("\n")[0].strip().lower() 338 | prompts_across_boost.append(all_prompts) 339 | preds_across_boost.append((more_info_answer, pred)) 340 | 341 | entry = { 342 | "ind": ind, 343 | "example": question, 344 | "prompts": prompts_across_boost, 345 | "preds_boost": preds_across_boost, 346 | "gold": gold, 347 | } 348 | expt_log[ind] = entry 349 | all_boost_preds.append(preds_across_boost) 350 | labels.append(gold) 351 | return expt_log, all_boost_preds, labels 352 | 353 | 354 | def main(): 355 | args = get_args() 356 | task_name = "NQ" 357 | data_dir = f"{DATA_DIR}/NQ" 358 | webq = NQDecomp(task_name, data_dir) 359 | webq.run(args) 360 | 361 | 362 | if __name__ == "__main__": 363 | main() 364 | -------------------------------------------------------------------------------- /tasks/SST2_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | from tqdm.auto import tqdm 4 | from pathlib import Path 5 | import pandas as pd 6 | import numpy as np 7 | import random 8 | 9 | from sklearn.metrics import classification_report 10 | from decomposition import Decomposition, get_args, DATA_DIR 11 | from utils import get_response 12 | 13 | def format_data(lines): 14 | """from lines in dataset to two lists of sentences and labels respectively""" 15 | def process_raw_data_sst(lines): 16 | labels = [] 17 | sentences = [] 18 | for line in lines: 19 | labels.append(int(line[0])) 20 | sentences.append(line[2:].strip()) 21 | return sentences, labels 22 | 23 | train_sentences, train_labels = process_raw_data_sst(lines) 24 | return train_sentences, train_labels 25 | 26 | 27 | class SST2Decomp(Decomposition): 28 | def __init__(self, task_name, data_dir, val_split="validation"): 29 | super().__init__(task_name, data_dir, val_split) 30 | 31 | def get_few_shot_examples(self, train_data, k_shot): 32 | """Get few shot examples""" 33 | labels = [0, 1] 34 | num_per_class = int(np.ceil(k_shot / len(labels))) 35 | print(f"Selecting {num_per_class} examples per class.") 36 | 37 | dfs = [] 38 | total_in_context = 0 39 | for label in labels: 40 | while num_per_class + total_in_context > k_shot: 41 | num_per_class -= 1 42 | sub_df = train_data[train_data["label"] == label].sample(num_per_class) 43 | dfs.append(sub_df) 44 | total_in_context += num_per_class 45 | if total_in_context == k_shot: 46 | break 47 | mini_df = pd.concat(dfs) 48 | return mini_df 49 | 50 | def read_data(self, save_dir, overwrite_data): 51 | save_data = Path(f"{save_dir}/{self.task_name}/data.feather") 52 | if not save_data.exists() or overwrite_data: 53 | with open(f"{self.data_dir}/stsa.binary.test", "r") as f: 54 | test_lines = f.readlines() 55 | test_sentences, test_labels = format_data(test_lines) 56 | test_data = pd.DataFrame({ 57 | 'sentence': test_sentences, 58 | 'label': test_labels, 59 | }) 60 | test_data.to_feather(f"{save_data}") 61 | else: 62 | print(f"Reading test data from {save_data}") 63 | test_data = pd.read_feather(save_data) 64 | 65 | save_data = Path(f"{save_dir}/{self.task_name}/train_data.feather") 66 | if not save_data.exists() or overwrite_data: 67 | with open(f"{self.data_dir}/stsa.binary.train", "r") as f: 68 | train_lines = f.readlines() 69 | train_sentences, train_labels = format_data(train_lines) 70 | train_data = pd.DataFrame({ 71 | 'sentence': train_sentences, 72 | 'label': train_labels, 73 | }) 74 | train_data.to_feather(f"{save_data}") 75 | else: 76 | print(f"Reading train data from {save_data}") 77 | train_data = pd.read_feather(save_data) 78 | print(f"Test Data Size: {len(test_data)}") 79 | print(f"Train Data Size: {len(train_data)}") 80 | return test_data, train_data 81 | 82 | 83 | def get_boost_decomp_examples(self, train_data, boost_id): 84 | seed = [1, 2, 3][boost_id] 85 | k_shot = 16 86 | random.seed(seed) 87 | np.random.seed(seed) 88 | 89 | data_train = pd.DataFrame(train_data) 90 | labels = set(data_train["label"]) 91 | num_per_class = int(np.ceil(k_shot / len(labels))) 92 | 93 | dfs = [] 94 | total_in_context = 0 95 | for label in labels: 96 | while num_per_class + total_in_context > k_shot: 97 | num_per_class -= 1 98 | 99 | if seed % 2 == 1: 100 | sub_df = data_train[data_train["label"] == label].sample(num_per_class, random_state=seed) 101 | elif seed % 2 == 0: 102 | sub_df = data_train[data_train["label"] != label].sample(num_per_class, random_state=seed) 103 | dfs.append(sub_df) 104 | total_in_context += num_per_class 105 | if total_in_context == k_shot: 106 | break 107 | 108 | booster_df = pd.concat(dfs).sample(frac=1, random_state=0) 109 | print(f"Selected: {len(booster_df)} in context examples.") 110 | return [ 111 | booster_df 112 | ] 113 | 114 | def zero_few_baseline( 115 | self, 116 | test_data, 117 | few_shot_df, 118 | manifest, 119 | overwrite_manifest, 120 | do_few_shot=True, 121 | ): 122 | expt_log = {} 123 | preds = [] 124 | labels = [] 125 | 126 | for i, (ind, row) in tqdm( 127 | enumerate(test_data.iterrows()), total=len(test_data) 128 | ): 129 | if ind in expt_log: 130 | pred = entry["pred"] 131 | gold_str = entry["gold"] 132 | else: 133 | sentence = row["sentence"] 134 | label = row["label"] 135 | 136 | icl_str = "" 137 | if do_few_shot: 138 | for s_ind, s_row in few_shot_df.iterrows(): 139 | if s_row["label"] == 0: 140 | demo_label = "negative" 141 | else: 142 | demo_label = "positive" 143 | icl = f"Text: {s_row['sentence']}\nSentiment: {demo_label}" 144 | icl_str += f"{icl}\n\n" 145 | 146 | description = "For each snippet of text, label the sentiment of the text as positive or negative." 147 | prompt = f"{description}\n\n{icl_str}Text: {{sentence:}}\nSentiment:" 148 | pmp = prompt.format(sentence=sentence) 149 | if i == 0: 150 | print(pmp) 151 | 152 | pred = get_response( 153 | pmp, 154 | manifest, 155 | overwrite=bool(overwrite_manifest), 156 | max_toks=50, 157 | ) 158 | pred = ( 159 | pred.replace(".", "") 160 | .replace(",", "") 161 | .replace("Label: ", "") 162 | .replace("Sentiment: ", "") 163 | ) 164 | pred = [p for p in pred.split("\n") if p] 165 | is_pos = "positive" in pred 166 | is_neg = "negative" in pred 167 | if is_pos and not is_neg: 168 | pred = "positive" 169 | elif is_neg and not is_pos: 170 | pred = "negative" 171 | else: 172 | pred = "" 173 | 174 | if label == 1: 175 | gold_str = "positive" 176 | else: 177 | gold_str = "negative" 178 | 179 | entry = { 180 | "gold": gold_str, 181 | "pred": pred, 182 | "base_prompt": pmp, 183 | "ind": ind, 184 | "example": sentence, 185 | } 186 | expt_log[ind] = entry 187 | 188 | preds.append(pred) 189 | labels.append(gold_str) 190 | 191 | report = classification_report(labels, preds, output_dict=True) 192 | return expt_log, report["accuracy"] 193 | 194 | def run_decomposed_prompt( 195 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 196 | ): 197 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1) 198 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000) 199 | # Do WS 200 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 201 | # Get accuracies across all boost sets 202 | individual_accuracies = [] 203 | for i in range(len(all_boost_preds[0])): 204 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"]) 205 | report = classification_report(labels, preds, output_dict=True) 206 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies 207 | 208 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1): 209 | expt_log = {} 210 | all_boost_preds = [] 211 | labels = [] 212 | 213 | for i, (ind, row) in tqdm( 214 | enumerate(test_data.iterrows()), total=len(test_data) 215 | ): 216 | sentence = row["sentence"] 217 | label = row["label"] 218 | 219 | if i == run_limit: 220 | break 221 | 222 | prompts_across_boost = [] 223 | preds_across_boost = [] 224 | for boost_examples in boost_dfs: 225 | all_prompts = [] 226 | icl_str = "" 227 | for s_ind, s_row in boost_examples[0].iterrows(): 228 | if s_row["label"] == 0: 229 | demo_label = "negative" 230 | else: 231 | demo_label = "positive" 232 | icl = f"Text: {s_row['sentence']}\nSentiment: {demo_label}" 233 | icl_str += f"{icl}\n\n" 234 | 235 | description = "For each snippet of text, label the sentiment of the text as positive or negative." 236 | prompt = f"{description}\n\n{icl_str}Text: {{sentence:}}\nSentiment:" 237 | pmp = prompt.format(sentence=sentence) 238 | all_prompts.append(pmp) 239 | if i == 0: 240 | print(pmp) 241 | pred = get_response( 242 | pmp, 243 | manifest, 244 | overwrite=bool(overwrite_manifest), 245 | max_toks=5, 246 | ) 247 | pred = pred.replace(".", "").replace(",", "").replace("Label: ", "") 248 | pred = [p for p in pred.split("\n") if p] 249 | if pred: 250 | pred = pred[0] 251 | else: 252 | pred = "" 253 | prompts_across_boost.append(all_prompts) 254 | preds_across_boost.append(pred) 255 | 256 | if label == 1: 257 | gold_str = "positive" 258 | else: 259 | gold_str = "negative" 260 | 261 | entry = { 262 | "gold": gold_str, 263 | "prompts": prompts_across_boost, 264 | "preds_boost": preds_across_boost, 265 | "example": sentence, 266 | "ind": i, 267 | } 268 | expt_log[i] = entry 269 | all_boost_preds.append(preds_across_boost) 270 | labels.append(gold_str) 271 | 272 | return expt_log, all_boost_preds, labels 273 | 274 | 275 | def main(): 276 | args = get_args() 277 | args.num_boost = 3 278 | task_name = "sst2" 279 | data_dir = f"{DATA_DIR}/sst2/" 280 | if not Path(data_dir).exists(): 281 | raise ValueError( 282 | f"Data directory {data_dir} does not exist. Download AGNews from https://github.com/tonyzhaozh/few-shot-learning.") 283 | decomp = SST2Decomp(task_name, data_dir) 284 | decomp.run(args) 285 | 286 | 287 | if __name__ == "__main__": 288 | main() 289 | -------------------------------------------------------------------------------- /tasks/WIC_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | from tqdm.auto import tqdm 4 | import pandas as pd 5 | 6 | from sklearn.metrics import classification_report 7 | from decomposition import Decomposition, get_args, DATA_DIR 8 | from utils import get_response, InputOutputPrompt 9 | 10 | synonym = InputOutputPrompt( 11 | input_formatter=lambda x: f"{x['passage']}", 12 | output_formatter=lambda x: f"- {x['answer']}", 13 | required_keys=["passage", "answer"], 14 | input_output_sep="\n", 15 | example_sep="\n\n", 16 | instruction="Give synonyms of the word in the sentence.\n\n" 17 | ) 18 | synonym_examples = [ 19 | pd.DataFrame([ 20 | { 21 | "passage": "In \"She heard the sound of voices in the hall.\", synonyms for the word \"sound\" are:", 22 | "answer": "noise", 23 | }, 24 | { 25 | "passage": "In \"Enter the secret code.\", synonyms for the word \"code\" are:", 26 | "answer": "password", 27 | }, 28 | { 29 | "passage": "In \"She acted in a play on Broadway\", synonyms for the word \"play\" are:", 30 | "answer": "show", 31 | }, 32 | ]), 33 | pd.DataFrame([ 34 | { 35 | "passage": "In \"She rode around the park on her cycle.\", synonyms for the word \"cycle\" are:", 36 | "answer": "bicycle", 37 | }, 38 | { 39 | "passage": "In \"Don't keep everything bottled up.\", synonyms for the word \"bottled\" are:", 40 | "answer": "trapped inside", 41 | }, 42 | { 43 | "passage": "In \"The present is like no other time.\", synonyms for the word \"present\" are:", 44 | "answer": "current moment", 45 | }, 46 | ]), 47 | pd.DataFrame([ 48 | { 49 | "passage": "In \"The movie was awful.\", synonyms for the word \"aweful\" are:", 50 | "answer": "bad and terrible", 51 | }, 52 | { 53 | "passage": "In \"She is so beautiful.\", synonyms for the word \"beautiful\" are:", 54 | "answer": "pretty and gorgeous", 55 | }, 56 | { 57 | "passage": "In \"It was quite cool out so she wore a jacket\", synonyms for the word \"cool\" are:", 58 | "answer": "cold and chilly", 59 | }, 60 | ]), 61 | pd.DataFrame([ 62 | { 63 | "passage": "In \"There are so many flies near the food.\", synonyms for the word \"flies\" are:", 64 | "answer": "bugs", 65 | }, 66 | { 67 | "passage": "In \"Eat your noodles with a fork.\", synonyms for the word \"fork\" are:", 68 | "answer": "utensils", 69 | }, 70 | { 71 | "passage": "In \"She and her husband went on a trip.\", synonyms for the word \"trip\" are:", 72 | "answer": "vacation", 73 | }, 74 | ]), 75 | pd.DataFrame([ 76 | { 77 | "passage": "In \"It was definitely a cry for help.\", synonyms for the word \"cry\" are:", 78 | "answer": "call", 79 | }, 80 | { 81 | "passage": "In \"I watch all my students as they take their exams.\", synonyms for the word \"watch\" are:", 82 | "answer": "look at", 83 | }, 84 | { 85 | "passage": "In \"The beginning of the book was fine, but the end was terrible.\", synonyms for the word \"beginning\" are:", 86 | "answer": "start", 87 | }, 88 | ]) 89 | ] 90 | 91 | description = InputOutputPrompt( 92 | input_formatter=lambda x: f"Choices\n{x['choices']}\nFill the [MASK] with the correct \"Choice\": {x['sentence']}", 93 | output_formatter=lambda x: f"[MASK] is \"Choice\": {x['answer']}\n", 94 | required_keys=["choices", "sentence", "answer"], 95 | input_output_sep="\n", 96 | example_sep="\n\n", 97 | instruction="Select the correct choice for the blank.\n\n" 98 | ) 99 | description_examples = [ 100 | pd.DataFrame([ 101 | { 102 | "choices": "1: noise\n2. in good condition", 103 | "sentence": "She heard the [MASK] of voices in the hall.", 104 | "answer": "noise", 105 | }, 106 | { 107 | "choices": "1. not heavy\n2. sun rays", 108 | "sentence": "The [MASK] shined through the window.", 109 | "answer": "sun rays", 110 | }, 111 | { 112 | "choices": "1. widespread\n2. commander of an army", 113 | "sentence": "The book is of [MASK] interest.", 114 | "answer": "widespread", 115 | }, 116 | ]) 117 | ] 118 | 119 | class WICDecomp(Decomposition): 120 | def __init__(self, task_name, data_dir, val_split="validation"): 121 | super().__init__(task_name, data_dir, val_split) 122 | 123 | def get_boost_decomp_examples(self, train_data, boost_id): 124 | return [ 125 | synonym_examples[boost_id], 126 | ] 127 | 128 | def zero_few_baseline( 129 | self, 130 | test_data, 131 | few_shot_df, 132 | manifest, 133 | overwrite_manifest, 134 | do_few_shot=True, 135 | ): 136 | expt_log = {} 137 | preds = [] 138 | labels = [] 139 | 140 | labels_names = set(test_data["targets_pretokenized"]) 141 | labels_names = [l.lower().strip() for l in labels_names] 142 | 143 | for i, (ind, row) in tqdm( 144 | enumerate(test_data.iterrows()), total=len(test_data) 145 | ): 146 | if ind in expt_log: 147 | pred = entry["pred"] 148 | gold = entry["gold"] 149 | else: 150 | text = row["inputs_pretokenized"] 151 | gold = row["targets_pretokenized"] 152 | 153 | icl_str = "" 154 | if do_few_shot: 155 | for s_ind, s_row in few_shot_df.iterrows(): 156 | icl_str += f"{s_row['inputs_pretokenized']} {s_row['targets_pretokenized']}\n\n" 157 | 158 | prompt = f"{icl_str}{{text:}}" 159 | pmp = prompt.format(text=text) 160 | if i == 0: 161 | print(pmp) 162 | 163 | raw_answer = get_response( 164 | pmp, 165 | manifest, 166 | overwrite=bool(overwrite_manifest), 167 | max_toks=30, 168 | ) 169 | answer = raw_answer.strip().lower() 170 | answer = answer.split("\n") 171 | answer = [a for a in answer if a] 172 | answer = [ 173 | a 174 | for a in answer 175 | if any(l.lower() in a.lower() for l in labels_names) 176 | ] 177 | if answer: 178 | answer = answer[0] 179 | answer = "".join( 180 | [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']] 181 | ) 182 | 183 | is_yes = "yes" in answer.split() 184 | is_no = "no" in answer.split() 185 | pred = "" 186 | if is_yes and (not is_no): 187 | pred = "yes" 188 | if is_no and (not is_yes): 189 | pred = "no" 190 | 191 | gold = gold.strip().lower() 192 | pred = pred.strip().lower() 193 | entry = { 194 | "ind": ind, 195 | "example": text, 196 | "base_prompt": pmp, 197 | "raw_answer": raw_answer, 198 | "pred": pred, 199 | "gold": gold, 200 | } 201 | expt_log[ind] = entry 202 | 203 | preds.append(pred) 204 | labels.append(gold) 205 | 206 | report = classification_report(labels, preds, output_dict=True) 207 | return expt_log, report["accuracy"] 208 | 209 | def get_parts(self, text): 210 | parts = text.split("\n") 211 | sent1 = parts[0] 212 | sent2 = parts[1] 213 | word = parts[2].split("Question: Is the word ")[-1].split()[0] 214 | return word, sent1, sent2 215 | 216 | def clean_sentence(self, sentence): 217 | sentence = sentence.replace("2. ", "") 218 | sentence = sentence.replace("3. ", "") 219 | sentence = sentence.replace("\n\n", "") 220 | sentence = sentence.replace("A:", "") 221 | sentence = sentence.strip() 222 | sentence = sentence.split(".")[0] 223 | sentence = sentence.split("\n")[0] 224 | return sentence 225 | 226 | def get_sentences( 227 | self, all_constructors, all_boost_exs, sent, word, manifest, overwrite_manifest 228 | ): 229 | synonym = all_constructors[0] 230 | 231 | all_prompts = [] 232 | # synonyms 233 | prompt_suffix = synonym(all_boost_exs[0]) 234 | prompt_combined = f'{prompt_suffix}\n\nIn "{{sent:}}", synonyms for the word \"{{word:}}\" are: ' 235 | all_prompts.append(prompt_combined.format(sent=sent, word=word)) 236 | synonyms = get_response( 237 | prompt_combined.format(sent=sent, word=word), 238 | manifest, 239 | overwrite=bool(overwrite_manifest), 240 | max_toks=10, 241 | ) 242 | synonyms = synonyms.replace("- ", "").split("\n") 243 | synonyms = ", ".join([a for a in synonyms if a][0:1]) 244 | 245 | # generate sentences 246 | quoted_sent = sent.split() 247 | sent = [] 248 | for tok in quoted_sent: 249 | if tok.lower() == word.strip('"').lower(): 250 | sent.append(f'"{tok}"') 251 | else: 252 | sent.append(tok) 253 | if sent: 254 | sent = " ".join(sent) 255 | else: 256 | sent = " ".join(quoted_sent) 257 | 258 | combined_definition = f"{synonyms}" 259 | sentences = [] 260 | return combined_definition, sentences, all_prompts 261 | 262 | def pairwise_comparisons( 263 | self, 264 | description_constructor, 265 | boost_exs, 266 | def1, 267 | sentences_lst1, 268 | def2, 269 | sentences_lst2, 270 | word, 271 | manifest, 272 | overwrite_manifest, 273 | ): 274 | all_prompts = [] 275 | # reconcile the result 276 | answer = "" 277 | if def1.strip() != def2.strip(): 278 | answer = "No" 279 | else: 280 | answer = "Yes" 281 | return answer, all_prompts 282 | 283 | def run_decomposed_prompt( 284 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 285 | ): 286 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest) 287 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000) 288 | # Do WS 289 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 290 | # Get accuracies across all boost sets 291 | individual_accuracies = [] 292 | for i in range(len(all_boost_preds[0])): 293 | individual_accuracies.append(classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)["accuracy"]) 294 | report = classification_report(labels, preds, output_dict=True) 295 | return expt_log, expt_log_train, report["accuracy"], individual_accuracies 296 | 297 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1): 298 | expt_log = {} 299 | all_boost_preds = [] 300 | labels = [] 301 | 302 | for i, (ind, row) in tqdm( 303 | enumerate(test_data.iterrows()), total=len(test_data) 304 | ): 305 | text = row["inputs_pretokenized"] 306 | gold = row["targets_pretokenized"] 307 | 308 | word, sent1, sent2 = self.get_parts(text) 309 | 310 | if i == run_limit: 311 | break 312 | 313 | prompts_across_boost = [] 314 | preds_across_boost = [] 315 | for boost_examples in boost_dfs: 316 | all_prompts = [] 317 | 318 | def1, sentences_lst1, lst1_prompts = self.get_sentences( 319 | [synonym], [boost_examples[0]], sent1, word, manifest, overwrite_manifest 320 | ) 321 | def2, sentences_lst2, lst2_prompts = self.get_sentences( 322 | [synonym], [boost_examples[0]], sent2, word, manifest, overwrite_manifest 323 | ) 324 | 325 | pred, pred_prompts = self.pairwise_comparisons( 326 | description, 327 | boost_examples[-1], 328 | def1, 329 | sentences_lst1, 330 | def2, 331 | sentences_lst2, 332 | word, 333 | manifest, 334 | overwrite_manifest, 335 | ) 336 | all_prompts = lst1_prompts + lst2_prompts + pred_prompts 337 | if i == 0: 338 | print("\n".join(all_prompts)) 339 | prompts_across_boost.append(all_prompts) 340 | preds_across_boost.append(pred) 341 | entry = { 342 | "ind": ind, 343 | "example": text, 344 | "word": word, 345 | "prompts": prompts_across_boost, 346 | "preds_boost": preds_across_boost, 347 | "sent1": sent1, 348 | "sent2": sent2, 349 | "def1": def1, 350 | "def2": def2, 351 | "gold": gold, 352 | "sentences_lst1": sentences_lst1, 353 | "sentences_lst2": sentences_lst2, 354 | } 355 | expt_log[ind] = entry 356 | all_boost_preds.append(preds_across_boost) 357 | labels.append(gold) 358 | return expt_log, all_boost_preds, labels 359 | 360 | 361 | def main(): 362 | args = get_args() 363 | args.num_boost = 5 364 | task_name = "super_glue_wic" 365 | data_dir = f"{DATA_DIR}/P3/data_feather/super_glue_wic_GPT_3_prompt/" 366 | decomp = WICDecomp(task_name, data_dir) 367 | decomp.run(args) 368 | 369 | 370 | if __name__ == "__main__": 371 | main() 372 | -------------------------------------------------------------------------------- /tasks/decomposition.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | from pathlib import Path 4 | 5 | import argparse 6 | from typing import Counter 7 | import pandas as pd 8 | import json 9 | import numpy as np 10 | import datetime 11 | import os 12 | import random 13 | 14 | from utils import save_log, get_manifest_session 15 | 16 | DATA_DIR = os.environ.get("AMA_DATA", "/home/data") 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--seed", type=int, default=3) 21 | parser.add_argument( 22 | "--num_run", type=int, default=-1, help="Number of rows of test data to run" 23 | ) 24 | parser.add_argument( 25 | "--k_shot", type=int, default=3, help="Number of few shot" 26 | ) 27 | parser.add_argument( 28 | "--num_boost", type=int, default=3, help="Number of few shot sets to boost over") 29 | parser.add_argument( 30 | "--boost_train_examples", type=int, default=1000, help="Number of training examples to run through for boosting" 31 | ) 32 | parser.add_argument( 33 | "--output_metrics_file", type=str, default="decomposition_metrics.json", help="Output file for all metrics." 34 | ) 35 | parser.add_argument( 36 | "--save_dir", type=str, default="/home/final_runs/", help="Data directory" 37 | ) 38 | parser.add_argument( 39 | "--run_decomp", 40 | type=int, 41 | default=1, 42 | help="Run decomp", 43 | choices=[0, 1], 44 | ) 45 | parser.add_argument( 46 | "--run_zeroshot", 47 | type=int, 48 | default=1, 49 | help="Run zeroshot", 50 | choices=[0, 1], 51 | ) 52 | parser.add_argument( 53 | "--run_fewshot", 54 | type=int, 55 | default=1, 56 | help="Run fewshot", 57 | choices=[0, 1], 58 | ) 59 | parser.add_argument( 60 | "--run_zeroshot_decomp", 61 | type=int, 62 | default=0, 63 | help="Run zero shot decomp", 64 | choices=[0, 1], 65 | ) 66 | parser.add_argument( 67 | "--overwrite_boost_exs", 68 | type=int, 69 | default=0, 70 | help="Overwrite boost examples", 71 | choices=[0, 1], 72 | ) 73 | parser.add_argument( 74 | "--overwrite_data", 75 | type=int, 76 | default=0, 77 | help="Overwrite saved data examples", 78 | choices=[0, 1], 79 | ) 80 | # Manifest 81 | parser.add_argument( 82 | "--client_name", 83 | type=str, 84 | default="huggingface", 85 | help="Client name manifest", 86 | choices=["huggingface", "openai", "ai21"], 87 | ) 88 | parser.add_argument( 89 | "--client_engine", 90 | type=str, 91 | default=None, 92 | help="Client engine manifest. Only used for openai/ai21", 93 | choices=["davinci"], 94 | ) 95 | parser.add_argument( 96 | "--client_connection", 97 | type=str, 98 | default="http://127.0.0.1:5000", 99 | help="Client connection str", 100 | ) 101 | parser.add_argument( 102 | "--cache_connection", 103 | type=str, 104 | default="/home/manifest/final_runs.sqlite", 105 | help="Cache connection str", 106 | ) 107 | parser.add_argument( 108 | "--overwrite_manifest", 109 | type=int, 110 | default=0, 111 | help="Overwrite manifest", 112 | choices=[0, 1], 113 | ) 114 | return parser.parse_args() 115 | 116 | class Decomposition: 117 | def __init__(self, task_name, data_dir, val_split="validation"): 118 | self.task_name = task_name 119 | self.data_dir = data_dir 120 | self.val_split = val_split 121 | 122 | def read_data(self, save_dir, overwrite_data): 123 | save_data = Path(f"{save_dir}/{self.task_name}/data.feather") 124 | if not save_data.exists() or overwrite_data: 125 | test_data = pd.read_feather(f"{self.data_dir}/{self.val_split}.feather") 126 | test_data.to_feather(f"{save_data}") 127 | else: 128 | print(f"Reading test data from {save_data}") 129 | test_data = pd.read_feather(save_data) 130 | 131 | save_data = Path(f"{save_dir}/{self.task_name}/train_data.feather") 132 | if not save_data.exists() or overwrite_data: 133 | train_data = pd.read_feather(f"{self.data_dir}/train.feather") 134 | else: 135 | print(f"Reading train data from {save_data}") 136 | train_data = pd.read_feather(save_data) 137 | print(f"Test Data Size: {len(test_data)}") 138 | print(f"Train Data Size: {len(train_data)}") 139 | return test_data, train_data 140 | 141 | def get_few_shot_examples(self, train_data, k_shot): 142 | """Get few shot examples""" 143 | return train_data.sample(k_shot) 144 | 145 | def get_boost_decomp_examples(self, train_data, boost_i=0): 146 | """Get boost examples""" 147 | raise NotImplementedError() 148 | 149 | def zero_few_baseline( 150 | self, test_data, few_shot_df, manifest, overwrite_manifest, do_few_shot=True 151 | ): 152 | """Zero and few shot baseline""" 153 | raise NotImplementedError() 154 | 155 | def run_decomposed_prompt( 156 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 157 | ): 158 | """Decomposition run""" 159 | raise NotImplementedError() 160 | 161 | def merge_boosted_preds(self, boosted_preds, all_boost_train_preds, train_labels, exp_log, expt_log_train, indecisive_ans=None): 162 | """Merge boosted preds""" 163 | if isinstance(boosted_preds, list): 164 | boosted_preds = np.array(boosted_preds) 165 | if isinstance(all_boost_train_preds, list): 166 | all_boost_train_preds = np.array(all_boost_train_preds) 167 | if isinstance(train_labels, list): 168 | train_labels = np.array(train_labels) 169 | 170 | uniq = np.unique(boosted_preds) 171 | pred_map = {} 172 | if "yes" in uniq: 173 | pred_map = {"yes": 1, "no": -1, "neither": 0} 174 | elif "true" in uniq: 175 | pred_map = {"true": 1, "false": -1, "neither": 0} 176 | elif "positive" in uniq: 177 | pred_map = {"positive": 1, "negative": -1, "neutral": 0} 178 | pred_map_inv = {v:k for k,v in pred_map.items()} 179 | use_pred_map = False 180 | if all(p.lower() in pred_map for p in uniq): 181 | use_pred_map = True 182 | if use_pred_map: 183 | # Cast to integers 184 | boosted_preds = np.array([[pred_map[p.lower()] for p in preds] for preds in boosted_preds]) 185 | all_boost_train_preds = np.array( 186 | [[pred_map[p.lower()] for p in preds] for preds in all_boost_train_preds] 187 | ) 188 | train_labels = np.array([pred_map[p.lower()] for p in train_labels]) 189 | if indecisive_ans: 190 | indecisive_ans = pred_map[indecisive_ans.lower()] 191 | 192 | # Take majority vote 193 | preds_test = [] 194 | for i, voter_preds in enumerate(boosted_preds): 195 | most_common = Counter(voter_preds).most_common(1)[0] 196 | if indecisive_ans and len(voter_preds) > 1 and most_common[1] == 1: 197 | majority_vote_pred = indecisive_ans 198 | else: 199 | majority_vote_pred = most_common[0] 200 | if use_pred_map: 201 | majority_vote_pred = pred_map_inv[majority_vote_pred] 202 | preds_test.append(majority_vote_pred) 203 | exp_log[i]["pred"] = majority_vote_pred 204 | 205 | # Take majority vote 206 | preds_train = [] 207 | for i, voter_preds in enumerate(all_boost_train_preds): 208 | most_common = Counter(voter_preds).most_common(1)[0] 209 | if indecisive_ans and len(voter_preds) > 1 and most_common[1] == 1: 210 | majority_vote_pred = indecisive_ans 211 | else: 212 | majority_vote_pred = most_common[0] 213 | if use_pred_map: 214 | majority_vote_pred = pred_map_inv[majority_vote_pred] 215 | preds_train.append(majority_vote_pred) 216 | expt_log_train[i]["pred"] = majority_vote_pred 217 | return preds_test 218 | 219 | def run(self, args): 220 | print(json.dumps(vars(args), indent=4)) 221 | 222 | random.seed(args.seed) 223 | np.random.seed(args.seed) 224 | save_path = Path(f"{args.save_dir}/{self.task_name}") 225 | save_path.mkdir(parents=True, exist_ok=True) 226 | data_test, data_train = self.read_data(args.save_dir, bool(args.overwrite_data)) 227 | # Subsample train for boost exps 228 | if args.boost_train_examples >= 0: 229 | boost_data_train = data_train.head(min(len(data_train), args.boost_train_examples)) 230 | else: 231 | boost_data_train = data_train 232 | # Reset indexes for enumerations 233 | boost_data_train = boost_data_train.reset_index(drop=True) 234 | data_test = data_test.reset_index(drop=True) 235 | data_train = data_train.reset_index(drop=True) 236 | num_run = ( 237 | min(args.num_run, len(data_test)) if args.num_run > 0 else len(data_test) 238 | ) 239 | save_results = True 240 | if num_run != len(data_test): 241 | print("Using {} rows".format(num_run)) 242 | data_test = data_test.iloc[:num_run] 243 | save_results = False 244 | 245 | runner, model_name = get_manifest_session( 246 | client_name=args.client_name, 247 | client_engine=args.client_engine, 248 | client_connection=args.client_connection, 249 | cache_connection=args.cache_connection, 250 | ) 251 | 252 | model_name = model_name.replace("/", "_") 253 | print("Model name:", model_name) 254 | 255 | # Read in few shot examples 256 | few_shot_path = save_path /f"{args.k_shot}_shot_examples.feather" 257 | if bool(args.overwrite_data) or not few_shot_path.exists(): 258 | mini_df = self.get_few_shot_examples(data_train, args.k_shot) 259 | mini_df.reset_index().to_feather(few_shot_path) 260 | else: 261 | print(f"Reading few show examples from {few_shot_path}") 262 | mini_df = pd.read_feather(few_shot_path) 263 | 264 | # Read in few shot decomp examples - one data frame per decomp step 265 | boost_examples = [] 266 | for i in range(args.num_boost): 267 | boost_examples_per_step = [] 268 | # Get all steps 269 | boost_examples_paths = list(save_path.glob(f"boost_examples_{i}_step*.feather")) 270 | if bool(args.overwrite_boost_exs) or not boost_examples_paths or not all(p.exists() for p in boost_examples_paths): 271 | boost_df_steps = self.get_boost_decomp_examples(data_train, boost_id=i) 272 | if not isinstance(boost_df_steps, list) or not isinstance( 273 | boost_df_steps[0], pd.DataFrame 274 | ): 275 | raise ValueError("Must return list of dataframes, one per step") 276 | for step, boost_df in enumerate(boost_df_steps): 277 | boost_df.reset_index().to_feather(save_path / f"boost_examples_{i}_step{step}.feather") 278 | print(f"Saving boost examples to", save_path / f"boost_examples_{i}_step{step}.feather") 279 | boost_examples_per_step.append(boost_df) 280 | else: 281 | for boost_examples_p in sorted(boost_examples_paths): 282 | print(f"Reading boost examples from {boost_examples_p}") 283 | boost_examples_per_step.append(pd.read_feather(boost_examples_p)) 284 | boost_examples.append(boost_examples_per_step) 285 | 286 | today = datetime.datetime.today().strftime("%m%d%Y") 287 | 288 | # Default metrics 289 | metric_zero = -1.0 290 | metric_few = -1.0 291 | metric_decomposed = -1.0 292 | metric_decomposed_by_boost = [] 293 | metric_zeroshot_decomposed = -1.0 294 | 295 | if bool(args.run_zeroshot): 296 | # Zero Shot 297 | run_name = f"{model_name}_0shot" 298 | exp_zero, metric_zero = self.zero_few_baseline( 299 | test_data=data_test, 300 | few_shot_df=mini_df, 301 | manifest=runner, 302 | overwrite_manifest=args.overwrite_manifest, 303 | do_few_shot=False, 304 | ) 305 | if save_results: 306 | save_log(self.task_name, run_name, exp_zero, args.save_dir) 307 | 308 | if bool(args.run_fewshot): 309 | # Few Shot 310 | run_name = f"{model_name}_{args.k_shot}shot" 311 | exp_few, metric_few = self.zero_few_baseline( 312 | test_data=data_test, 313 | few_shot_df=mini_df, 314 | manifest=runner, 315 | overwrite_manifest=args.overwrite_manifest, 316 | do_few_shot=True, 317 | ) 318 | if save_results: 319 | save_log(self.task_name, run_name, exp_few, args.save_dir) 320 | 321 | if bool(args.run_decomp): 322 | # Decomp 323 | run_name = f"{model_name}_decomposed_{today}" 324 | exp_decomposed, exp_decomposed_train, metric_decomposed, metric_decomposed_by_boost = self.run_decomposed_prompt( 325 | test_data=data_test, boost_data_train=boost_data_train, boost_dfs=boost_examples, manifest=runner, overwrite_manifest=args.overwrite_manifest 326 | ) 327 | if save_results: 328 | save_log( 329 | self.task_name, 330 | run_name, 331 | exp_decomposed, 332 | args.save_dir 333 | ) 334 | if exp_decomposed_train: 335 | save_log( 336 | self.task_name, 337 | f"{run_name}_train", 338 | exp_decomposed_train, 339 | args.save_dir 340 | ) 341 | 342 | # Zero shot decomp 343 | exp_zeroshot_decomposed = [] 344 | if bool(args.run_zeroshot_decomp): 345 | run_name = f"{model_name}_decomposed_0shot_{today}" 346 | ( 347 | exp_zeroshot_decomposed, 348 | exp_zeroshot_decomposed_train, 349 | metric_zeroshot_decomposed, 350 | _, 351 | ) = self.run_decomposed_prompt( 352 | test_data=data_test, boost_data_train=boost_data_train, boost_dfs=[[pd.DataFrame() for _ in range(len(boost_examples[0]))]], manifest=runner, overwrite_manifest=args.overwrite_manifest, 353 | ) 354 | if save_results and len(exp_zeroshot_decomposed) > 0: 355 | save_log( 356 | self.task_name, 357 | run_name, 358 | exp_zeroshot_decomposed, 359 | args.save_dir, 360 | ) 361 | if exp_zeroshot_decomposed_train: 362 | save_log( 363 | self.task_name, 364 | f"{run_name}_train", 365 | exp_zeroshot_decomposed_train, 366 | args.save_dir, 367 | ) 368 | print("Accuracy Zero Shot", metric_zero) 369 | print("Accuracy Few Shot", metric_few) 370 | if len(metric_decomposed_by_boost) > 0: 371 | print("Accuracy by Boost Set Decomposed", metric_decomposed_by_boost) 372 | print("Accuracy by Boost Set Decomposed Average", np.mean(metric_decomposed_by_boost)) 373 | print("Accuracy Boost Decomposed", metric_decomposed) 374 | if len(exp_zeroshot_decomposed) > 0: 375 | print("Accuracy Zero Shot Decomposed", metric_zeroshot_decomposed) 376 | 377 | metrics = { 378 | "model_name": model_name, 379 | "task_name": self.task_name, 380 | "today": today, 381 | "zero_shot": metric_zero, 382 | "few_shot": metric_few, 383 | "decomposed": metric_decomposed, 384 | "decomposed_by_boost": metric_decomposed_by_boost, 385 | "decomposed_by_boost_avg": np.mean(metric_decomposed_by_boost), 386 | "zero_shot_decomposed": metric_zeroshot_decomposed, 387 | } 388 | output_metrics = Path(args.output_metrics_file) 389 | output_metrics.parent.mkdir(parents=True, exist_ok=True) 390 | with open(output_metrics, "a") as f: 391 | f.write(json.dumps(metrics) + "\n") 392 | print(f"Saved metrics to {output_metrics}") 393 | print(f"Saved final data to", Path(args.save_dir) / self.task_name) 394 | -------------------------------------------------------------------------------- /tasks/drop_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | from pathlib import Path 4 | import pandas as pd 5 | import numpy as np 6 | from tqdm.auto import tqdm 7 | from datasets import load_dataset 8 | 9 | from decomposition import Decomposition, get_args 10 | from utils import get_response, text_f1, InputOutputPrompt, load_hf_data 11 | 12 | extract = InputOutputPrompt( 13 | input_formatter=lambda x: f"Context: {x['context']}\nQuestion: {x['question']}", 14 | output_formatter=lambda x: f"Answer: {x['answer']}", 15 | required_keys=["context", "question", "answer"], 16 | input_output_sep="\n", 17 | example_sep="\n\n", 18 | instruction="Answer the question. If there is no evidence in the context, return \"Unknown\".\n\n" 19 | ) 20 | 21 | extract_examples = [ 22 | pd.DataFrame([ 23 | { 24 | "context": "According to Biraben, the plague was present somewhere in Europe in every year between 1346 and 1671", 25 | "question": "Where was the plague present?", 26 | "answer": "somewhere in Europe" 27 | }, 28 | { 29 | "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.", 30 | "question": "What's one factor in increasing self-esteem?", 31 | "answer": "Unknown" 32 | }, 33 | { 34 | "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.", 35 | "question": "What is another name for anti-matter?", 36 | "answer": "Unknown" 37 | } 38 | ]), 39 | pd.DataFrame([ 40 | { 41 | "context": "According to Biraben, the plague was present somewhere in Europe in every year between 1346 and 1671", 42 | "question": "Where was the plague present?", 43 | "answer": "somewhere in Europe" 44 | }, 45 | { 46 | "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.", 47 | "question": "What is another name for anti-matter?", 48 | "answer": "Unknown" 49 | }, 50 | { 51 | "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.", 52 | "question": "What's one factor in increasing self-esteem?", 53 | "answer": "Unknown" 54 | }, 55 | ]), 56 | pd.DataFrame([ 57 | { 58 | "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.", 59 | "question": "What's one factor in increasing self-esteem?", 60 | "answer": "Unknown" 61 | }, 62 | { 63 | "context": "According to Biraben, the plague was present somewhere in Europe in every year between 1346 and 1671", 64 | "question": "Where was the plague present?", 65 | "answer": "somewhere in Europe" 66 | }, 67 | { 68 | "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.", 69 | "question": "What is another name for anti-matter?", 70 | "answer": "Unknown" 71 | } 72 | ]), 73 | 74 | ] 75 | 76 | prefix_select_zeroshot = """Answer the question. If there is no evidence in the context, return "Unknown".\n\n""" 77 | 78 | 79 | class DropDecomp(Decomposition): 80 | def __init__(self, task_name, data_dir, val_split="validation"): 81 | super().__init__(task_name, data_dir, val_split) 82 | 83 | def read_data(self, save_dir, overwrite_data): 84 | return load_hf_data(save_dir, self.task_name, self.val_split, "drop", overwrite_data) 85 | 86 | def get_boost_decomp_examples(self, train_data, boost_id): 87 | return [ 88 | extract_examples[boost_id], 89 | ] 90 | 91 | def get_few_shot_examples(self, train_data, k_shot): 92 | """Get few shot examples""" 93 | labels = [] 94 | for x in train_data.answers_spans: 95 | if len(x['spans']) == 0: 96 | labels.append("unknown") 97 | else: 98 | labels.append(x['spans'][0]) 99 | train_data['expanded_labels'] = labels 100 | labels = ["unknown"] + list(sorted(set(labels))) 101 | 102 | num_per_class = int(np.ceil(k_shot / len(labels))) 103 | print(f"Selecting {num_per_class} examples per class.") 104 | 105 | dfs = [] 106 | total_in_context = 0 107 | for label in labels: 108 | while num_per_class + total_in_context > k_shot: 109 | num_per_class -= 1 110 | sub_df = train_data[train_data["expanded_labels"] == label].sample(num_per_class) 111 | dfs.append(sub_df) 112 | total_in_context += num_per_class 113 | if total_in_context == k_shot: 114 | break 115 | mini_df = pd.concat(dfs) 116 | return mini_df 117 | 118 | def zero_few_baseline( 119 | self, 120 | test_data, 121 | few_shot_df, 122 | manifest, 123 | overwrite_manifest, 124 | do_few_shot=True, 125 | ): 126 | expt_log = {} 127 | preds = [] 128 | labels = [] 129 | 130 | for i, (ind, row) in tqdm( 131 | enumerate(test_data.iterrows()), total=len(test_data) 132 | ): 133 | if ind in expt_log: 134 | entry = expt_log[ind] 135 | pred = entry["pred"] 136 | gold = entry["gold"] 137 | 138 | else: 139 | text = row.passage 140 | question = row.question 141 | if len(row.answers_spans["spans"]) == 0: 142 | label = "unknown" 143 | else: 144 | label = row.answers_spans["spans"][0] 145 | gold = label.lower() 146 | 147 | icl_str = "" 148 | if do_few_shot: 149 | for s_ind, s_row in few_shot_df.iterrows(): 150 | input = s_row.passage 151 | s_question = s_row.question 152 | if len(s_row.answers_spans["spans"]) == 0: 153 | label = "unknown" 154 | else: 155 | label = s_row.answers_spans["spans"][0] 156 | icl_str += f"Passage: {input}\nQuestion: {s_question}\nAnswer: {label}\n\n" 157 | 158 | prompt = ( 159 | icl_str 160 | + "Passage: {text:}\nQuestion: {question:}" 161 | + "\nAnswer:" 162 | ) 163 | pmp = prompt.format(text=text, question=question) 164 | if i == 0: 165 | print(prompt.format(text=text, question=question)) 166 | 167 | raw_answer = get_response( 168 | pmp, 169 | manifest, 170 | overwrite=bool(overwrite_manifest), 171 | max_toks=10, 172 | stop_token="\n\n", 173 | ) 174 | pred = raw_answer.strip("\n").strip().lower() 175 | entry = { 176 | "ind": ind, 177 | "example": text, 178 | "base_prompt": pmp, 179 | "raw_answer": raw_answer, 180 | "pred": pred, 181 | "gold": gold, 182 | } 183 | expt_log[ind] = entry 184 | 185 | preds.append(pred) 186 | labels.append(gold) 187 | metric = text_f1(preds=preds, golds=labels) 188 | return expt_log, metric 189 | 190 | def run_decomposed_prompt( 191 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 192 | ): 193 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest) 194 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1) 195 | # Do WS 196 | preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train) 197 | # Get accuracies across all boost sets 198 | individual_accuracies = [] 199 | for i in range(len(all_boost_preds[0])): 200 | individual_accuracies.append(text_f1(preds=[p[i] for p in all_boost_preds], golds=labels)) 201 | metric = text_f1(preds=preds, golds=labels) 202 | return expt_log, expt_log_train, metric, individual_accuracies 203 | 204 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1): 205 | expt_log = {} 206 | all_boost_preds = [] 207 | labels = [] 208 | 209 | for i, (ind, row) in tqdm( 210 | enumerate(test_data.iterrows()), total=len(test_data) 211 | ): 212 | if i == run_limit: 213 | break 214 | 215 | text = row.passage 216 | question = row.question 217 | if len(row.answers_spans["spans"]) == 0: 218 | label = "unknown" 219 | else: 220 | label = row.answers_spans["spans"][0] 221 | gold = label.lower() 222 | prompts_across_boost = [] 223 | preds_across_boost = [] 224 | for boost_examples in boost_dfs: 225 | prompt_suffix = extract(boost_examples[0]) 226 | prompt = ( 227 | prompt_suffix + "\n\nContext: {text:}\nQuestion: {question:}\nAnswer:" 228 | ) 229 | pmp = prompt.format(text=text, question=question) 230 | if i == 0: 231 | print(pmp) 232 | 233 | raw_answer = get_response( 234 | pmp, 235 | manifest, 236 | overwrite=bool(overwrite_manifest), 237 | max_toks=50, 238 | stop_token="\n\n", 239 | ) 240 | pred = raw_answer.split("\n")[0].replace('"', "").strip().lower() 241 | # Single list pmp for one step decomp 242 | prompts_across_boost.append([pmp]) 243 | preds_across_boost.append(pred) 244 | 245 | entry = { 246 | "ind": ind, 247 | "example": text, 248 | "prompts": prompts_across_boost, 249 | "preds_boost": preds_across_boost, 250 | "gold": gold, 251 | } 252 | expt_log[ind] = entry 253 | all_boost_preds.append(preds_across_boost) 254 | labels.append(gold) 255 | return expt_log, all_boost_preds, labels 256 | 257 | 258 | def main(): 259 | args = get_args() 260 | task_name = "drop" 261 | data_dir = "drop" 262 | wic = DropDecomp(task_name, data_dir) 263 | wic.run(args) 264 | 265 | 266 | if __name__ == "__main__": 267 | main() 268 | -------------------------------------------------------------------------------- /tasks/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from collections import Counter 3 | import json 4 | from datasets import load_dataset 5 | import re 6 | import pandas as pd 7 | from typing import Callable, List 8 | from manifest import Manifest 9 | 10 | 11 | class InputOutputPrompt: 12 | def __init__(self, 13 | input_formatter: Callable, 14 | output_formatter: Callable, 15 | required_keys: List, 16 | input_output_sep: str = "\n", 17 | example_sep: str = "\n\n", 18 | instruction: str = "" 19 | ): 20 | self.input_formatter = input_formatter 21 | self.output_formatter = output_formatter 22 | self.required_keys = required_keys 23 | self.input_output_sep = input_output_sep 24 | self.example_sep = example_sep 25 | self.instruction = instruction 26 | 27 | def __call__(self, input_output_pairs: pd.DataFrame): 28 | examples = [] 29 | for _, example in input_output_pairs.iterrows(): 30 | examples.append(f"{self.input_formatter(example)}{self.input_output_sep}{self.output_formatter(example)}") 31 | if examples: 32 | input_str = self.example_sep.join(examples) 33 | res = f"{self.instruction}{input_str}" 34 | else: 35 | res = f"{self.instruction}".rstrip() 36 | return res 37 | 38 | def __repr__(self): 39 | dummy_ex = pd.DataFrame([{k: f"<{k.upper()}>" for k in self.required_keys}]) 40 | st = self(dummy_ex) 41 | return st 42 | 43 | 44 | def prefix_formatter(ex_keys: List[str], prefix: str, error_on_empty: bool = True) -> str: 45 | def full_prefix_formatter(ex: pd.Series): 46 | for k in ex_keys: 47 | if k in ex: 48 | return f"{prefix} {getattr(ex, k)}" 49 | if error_on_empty: 50 | raise ValueError(f"Example {ex} has no value for any of the keys {ex_keys}") 51 | else: 52 | return f"{prefix}" 53 | return full_prefix_formatter 54 | 55 | 56 | def get_manifest_session( 57 | client_name="huggingface", 58 | client_engine=None, 59 | client_connection="http://127.0.0.1:5000", 60 | cache_connection=None, 61 | temperature=0, 62 | top_p=1.0, 63 | ): 64 | if client_name == "huggingface" and temperature == 0: 65 | params = { 66 | "temperature": 0.001, 67 | "do_sample": False, 68 | "top_p": top_p, 69 | } 70 | elif client_name in {"openai", "ai21"}: 71 | params = { 72 | "temperature": temperature, 73 | "top_p": top_p, 74 | "engine": client_engine, 75 | } 76 | else: 77 | raise ValueError(f"{client_name} is not a valid client name") 78 | manifest = Manifest( 79 | client_name=client_name, 80 | client_connection=client_connection, 81 | cache_name="sqlite", 82 | cache_connection=cache_connection, 83 | session_id=None, 84 | **params, 85 | ) 86 | params = manifest.client.get_model_params() 87 | model_name = params["model_name"] 88 | if "engine" in params: 89 | model_name += f"_{params['engine']}" 90 | return manifest, model_name 91 | 92 | 93 | def get_response( 94 | prompt, 95 | manifest, 96 | overwrite=False, 97 | max_toks=10, 98 | stop_token=None, 99 | gold_choices=[], 100 | verbose=False, 101 | ): 102 | prompt = prompt.strip() 103 | if gold_choices: 104 | gold_choices = [" " + g.strip() for g in gold_choices] 105 | response_obj = manifest.run( 106 | prompt, gold_choices=gold_choices, overwrite_cache=overwrite, return_response=True 107 | ) 108 | response_obj = response_obj.get_json_response()["choices"][0] 109 | log_prob = response_obj["text_logprob"] 110 | response = response_obj["text"] 111 | else: 112 | response = manifest.run( 113 | prompt, 114 | max_tokens=max_toks, 115 | stop_token=stop_token, 116 | overwrite_cache=overwrite, 117 | ) 118 | log_prob = None 119 | if verbose: 120 | print("\n***Prompt***\n", prompt) 121 | print("\n***Response***\n", response) 122 | if log_prob: 123 | return response, log_prob 124 | return response 125 | 126 | def load_hf_data(save_dir, task_name, val_split, hf_name, overwrite_data): 127 | save_data = Path(f"{save_dir}/{task_name}/data.feather") 128 | if not save_data.exists() or overwrite_data: 129 | dataset = load_dataset(hf_name) 130 | test_data = dataset[val_split].to_pandas() 131 | test_data.to_feather(f"{save_data}") 132 | else: 133 | print(f"Reading test data from {save_data}") 134 | test_data = pd.read_feather(f"{save_data}") 135 | 136 | save_data_train = Path(f"{save_dir}/{task_name}/train_data.feather") 137 | if not save_data_train.exists() or overwrite_data: 138 | dataset = load_dataset(hf_name) 139 | train_data = dataset["train"].to_pandas() 140 | train_data.to_feather(f"{save_data_train}") 141 | else: 142 | print(f"Reading train data from {save_data_train}") 143 | train_data = pd.read_feather(f"{save_data_train}") 144 | 145 | print(f"Test Data Size: {len(test_data)}") 146 | print(f"Train Data Size: {len(train_data)}") 147 | return test_data, train_data 148 | 149 | def save_log(task_name, expt_name, log, final_run_dir): 150 | final_run_dir = Path(final_run_dir) 151 | output_fpath = final_run_dir / task_name 152 | output_fpath.mkdir(parents=True, exist_ok=True) 153 | 154 | print("Saving to", output_fpath / f"{expt_name}.json") 155 | assert all(a in list(log.values())[0].keys() for a in ["ind","example","pred","gold"]) 156 | with open(output_fpath / f"{expt_name}.json", "w") as f: 157 | json.dump(log, f) 158 | 159 | def text_f1(preds, golds): 160 | """Compute average F1 of text spans. 161 | Taken from Squad without prob threshold for no answer. 162 | """ 163 | total_f1 = 0 164 | for pred, gold in zip(preds, golds): 165 | pred_toks = pred.split() 166 | gold_toks = gold.split() 167 | common = Counter(pred_toks) & Counter(gold_toks) 168 | num_same = sum(common.values()) 169 | if len(gold_toks) == 0 or len(pred_toks) == 0: 170 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 171 | total_f1 += int(gold_toks == pred_toks) 172 | elif num_same == 0: 173 | total_f1 += 0 174 | else: 175 | precision = 1.0 * num_same / len(pred_toks) 176 | recall = 1.0 * num_same / len(gold_toks) 177 | f1 = (2 * precision * recall) / (precision + recall) 178 | total_f1 += f1 179 | f1_avg = total_f1 / len(golds) 180 | return f1_avg 181 | 182 | def accuracy_span_overlap(preds, golds): 183 | correct = 0 184 | for pred, gold in zip(preds, golds): 185 | found = False 186 | for p in pred: 187 | for g in gold: 188 | if len(p) < len(g): 189 | if p.lower() in g.lower(): 190 | found = True 191 | break 192 | else: 193 | if g.lower() in p.lower(): 194 | found = True 195 | break 196 | if found: correct += 1 197 | return correct / len(preds) 198 | 199 | 200 | -------------------------------------------------------------------------------- /tasks/webq_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import pandas as pd 4 | import numpy as np 5 | import ast 6 | from tqdm.auto import tqdm 7 | from decomposition import Decomposition, get_args 8 | from utils import get_response, InputOutputPrompt, accuracy_span_overlap, load_hf_data 9 | 10 | extract = InputOutputPrompt( 11 | input_formatter=lambda x: f"Question: {x['question']}", 12 | output_formatter=lambda x: f"Answer: {x['answer']}", 13 | required_keys=["question", "answer"], 14 | input_output_sep="\n", 15 | example_sep="\n\n", 16 | instruction="Produce distinct questions.\n\n" 17 | ) 18 | 19 | more_info_examples = [ 20 | pd.DataFrame([ 21 | { 22 | "question": "who plays Carrie Bradshaw in sex and the city?", 23 | "answer": "Caroline \"Carrie\" Bradshaw is a fictional character from the HBO franchise Sex and the City, portrayed by Sarah Jessica Parker." 24 | }, 25 | { 26 | "question": "what are the elements in air?", 27 | "answer": "By mole fraction (i.e., by number of molecules), dry air contains 78.08% nitrogen, 20.95% oxygen, 0.93% argon, 0.04% carbon dioxide, and small amounts of other gases" 28 | }, 29 | { 30 | "question": "what is HP company?", 31 | "answer": "HP Inc. is an American multinational information technology company headquartered in Palo Alto, California, that develops personal computers (PCs)" 32 | }, 33 | { 34 | "question": "when was the last season of FRIENDS released?", 35 | "answer": "The series finale aired on May 6, 2004, and was watched by around 52.5 million American viewers, making it the fifth-most-watched series finale in television history" 36 | } 37 | ]), 38 | 39 | ] 40 | 41 | answer = InputOutputPrompt( 42 | input_formatter=lambda x: f"Context: {x['context']}\nQuestion: {x['question']}", 43 | output_formatter=lambda x: f"Answer: {x['answer']}", 44 | required_keys=["context", "question", "answer"], 45 | input_output_sep="\n", 46 | example_sep="\n\n", 47 | instruction="Answer the question.\n\n" 48 | ) 49 | 50 | answer_question = [ 51 | pd.DataFrame([ 52 | { 53 | 'context': 'The nearest airport to Palm Springs is Indio/Palm Springs (PSP) Airport which is 2.1 miles away. ', 54 | 'question': 'what airport is closest to palm springs?', 55 | 'answer': 'Palm Springs International Airport' 56 | }, 57 | { 58 | 'context': 'Martin Luther King earned his Bachelor of Divinity degree from Crozer Theological Seminary, followed by a doctorate in Systematic Theology from Boston University.', 59 | 'question': 'what degree did martin luther king get?', 60 | 'answer': 'Bachelor of Divinity' 61 | }, 62 | { 63 | 'context': 'The Niger river runs in a crescent through Libya, Mali, Niger, on the border with Benin and then through Nigeria.', 64 | 'question': 'what countries does the niger river flow through?', 65 | 'answer': 'Libya' 66 | }, 67 | { 68 | 'context': 'Puerto Rico is a territory of the United States and uses the U.S. dollar. ', 69 | 'question': 'what type of currency is used in puerto rico?', 70 | 'answer': 'United States dollar' 71 | }, 72 | { 73 | 'context': 'kitt was voice most often by William daniels.', 74 | 'question': 'who played kitt in knight rider?', 75 | 'answer': 'William Daniels' 76 | } 77 | ]), 78 | pd.DataFrame([ 79 | { 80 | 'context': 'leonardo da vinci invented the parachute, the helicopter, double hull, an armored fighting vehicle,', 81 | 'question': 'what inventions did leonardo da vinci made?', 82 | 'answer': 'Double hull' 83 | }, 84 | { 85 | 'context': "The French franc (F) was the national currency of France prior to France's adoption of the euro (EUR) in January 2002.", 86 | 'question': 'what currency is used in france before euro?', 87 | 'answer': 'French franc' 88 | }, 89 | { 90 | 'context': 'The Isthmus of Panama, contains the country of Panama and the panama canal.', 91 | 'question': 'where is isthmus of panama located?', 92 | 'answer': 'Costa Rica' 93 | }, 94 | { 95 | 'context': 'Hurricane Irene was a large and destructive tropical cyclone which affected much of the Caribbean and East Coast', 96 | 'question': 'where did hurricane irene?', 97 | 'answer': 'Eastern United States' 98 | }, 99 | { 100 | 'context': 'Rihanna acted in This is the End and Battleship.', 101 | 'question': 'what movie did rihanna play in?', 102 | 'answer': 'This Is the End' 103 | } 104 | ]), 105 | pd.DataFrame([ 106 | { 107 | 'context': 'st vincent de paul is buried in the 6th arrondisment of Paris.', 108 | 'question': 'where is st vincent de paul buried?', 109 | 'answer': 'Paris' 110 | }, 111 | { 112 | 'context': 'Thomas Luther "Luke" Bryan (born July 17, 1976) is an American country singer and songwriter from Leesburg.', 113 | 'question': 'where is luke bryan from?', 114 | 'answer': 'Leesburg' 115 | }, 116 | { 117 | 'context': "Klum and Seal got married on 10 May 2005 on a beach in Mexico near Seal's home on Costa Careyes. ", 118 | 'question': 'where did heidi klum and seal get married?', 119 | 'answer': 'Mexico'}, 120 | { 121 | 'context': 'Tarantino starred in pulp fiction, grindhouse and others.', 122 | 'question': 'what movies did quentin tarantino star in?', 123 | 'answer': 'Grindhouse' 124 | }, 125 | { 126 | 'context': 'Countries that are sometimes considered to be entirely or partially part of the Balkans are Croatia, Serbia, Lake Prespa.', 127 | 'question': 'what country is located in the balkan peninsula?', 128 | 'answer': 'Lake Prespa' 129 | } 130 | ]) 131 | 132 | ] 133 | 134 | prefix_select_zeroshot = """Answer the question.\n\n""" 135 | 136 | 137 | class WebQDecomp(Decomposition): 138 | def __init__(self, task_name, data_dir, val_split="validation"): 139 | super().__init__(task_name, data_dir, val_split) 140 | 141 | def read_data(self, save_dir, overwrite_data): 142 | return load_hf_data(save_dir, self.task_name, self.val_split, "web_questions", overwrite_data) 143 | 144 | def get_boost_decomp_examples(self, train_data, boost_id): 145 | return [ 146 | more_info_examples[0], 147 | answer_question[boost_id], 148 | ] 149 | 150 | def zero_few_baseline( 151 | self, 152 | test_data, 153 | few_shot_df, 154 | manifest, 155 | overwrite_manifest, 156 | prompt_suffix="", 157 | do_few_shot=True, 158 | ): 159 | expt_log = {} 160 | preds = [] 161 | labels = [] 162 | 163 | for i, (ind, row) in tqdm( 164 | enumerate(test_data.iterrows()), total=len(test_data) 165 | ): 166 | if ind in expt_log: 167 | entry = expt_log[ind] 168 | pred = entry["pred"] 169 | gold = entry["gold"] 170 | 171 | else: 172 | question = row.question 173 | if isinstance(row.answers, str): 174 | label = ast.literal_eval(row.answers) 175 | else: 176 | label = row.answers.tolist() 177 | gold = label 178 | 179 | icl_str = "" 180 | if do_few_shot: 181 | for s_ind, s_row in few_shot_df.iterrows(): 182 | s_question = s_row.question 183 | if isinstance(s_row.answers, str): 184 | label = ast.literal_eval(s_row.answers) 185 | else: 186 | label = s_row.answers.tolist() 187 | icl_str += f"Question: {s_question}\nAnswer: {label[0]}\n\n" 188 | 189 | prompt = ( 190 | icl_str 191 | + "Question: {question:}" 192 | + prompt_suffix 193 | + "\nAnswer:" 194 | ) 195 | 196 | if i == 0: 197 | print(prompt.format(question=question)) 198 | prompt = prompt.format(question=question) 199 | raw_answer = get_response( 200 | prompt, #prompt.format(question=question), 201 | manifest, 202 | overwrite=bool(overwrite_manifest), 203 | max_toks=20, 204 | stop_token="\n\n", 205 | ) 206 | pred = raw_answer.strip("\n").strip().lower() 207 | entry = { 208 | "ind": ind, 209 | "example": question, 210 | "base_prompt": prompt, 211 | "raw_answer": raw_answer, 212 | "pred": pred, 213 | "gold": gold, 214 | } 215 | expt_log[ind] = entry 216 | 217 | preds.append([pred]) 218 | labels.append(gold) 219 | metric = accuracy_span_overlap(preds=preds, golds=labels) 220 | return expt_log, metric 221 | 222 | def run_decomposed_prompt( 223 | self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest 224 | ): 225 | expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest) 226 | expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000) 227 | 228 | # Do WS 229 | boost_test, boost_train = [], [] 230 | for p in all_boost_preds: 231 | samples = [lf[1] for lf in p] 232 | boost_test.append(samples) 233 | for p in all_boost_train_preds: 234 | samples = [lf[1] for lf in p] 235 | boost_train.append(samples) 236 | 237 | preds = self.merge_boosted_preds(boost_test, boost_train, train_labels, expt_log, expt_log_train) 238 | preds = [(x,y) for x,y in zip([p[0][0] for p in all_boost_preds], preds)] 239 | 240 | # Get accuracies across all boost sets 241 | individual_accuracies = [] 242 | for i in range(len(all_boost_preds[0])): 243 | individual_accuracies.append(accuracy_span_overlap(preds=[p[i] for p in all_boost_preds], golds=labels)) 244 | 245 | metric = accuracy_span_overlap(preds=preds, golds=labels) 246 | return expt_log, expt_log_train, metric, individual_accuracies 247 | 248 | def _run_decomp_single_data(self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1): 249 | expt_log = {} 250 | all_boost_preds = [] 251 | labels = [] 252 | 253 | for i, (ind, row) in tqdm( 254 | enumerate(test_data.iterrows()), total=len(test_data) 255 | ): 256 | if i == run_limit: 257 | break 258 | 259 | question = row.question 260 | if isinstance(row.answers, str): 261 | label = ast.literal_eval(row.answers) 262 | else: 263 | label = row.answers.tolist() 264 | 265 | gold = label 266 | prompts_across_boost = [] 267 | preds_across_boost = [] 268 | 269 | # extract context 270 | prompt_suffix = extract(boost_dfs[0][0]) 271 | prompt = ( 272 | prompt_suffix + "\n\Question: {question:}\nAnswer:" 273 | ) 274 | more_info_answer = get_response( 275 | prompt.format(question=question), 276 | manifest, 277 | overwrite=bool(overwrite_manifest), 278 | max_toks=20, 279 | stop_token="\n\n", 280 | ) 281 | 282 | 283 | for boost_examples in boost_dfs: 284 | all_prompts = [] 285 | prompt_suffix = answer(boost_examples[1]) 286 | prompt = ( 287 | prompt_suffix + "\n\nContext: {text:}\nQuestion: {question:}\nAnswer:" 288 | ) 289 | if i == 0: 290 | print(prompt.format(text=more_info_answer, question=question)) 291 | 292 | all_prompts.append(prompt.format(text=more_info_answer, question=question)) 293 | raw_answer = get_response( 294 | prompt.format(text=more_info_answer, question=question), 295 | manifest, 296 | overwrite=bool(overwrite_manifest), 297 | max_toks=20, 298 | stop_token="\n\n", 299 | ) 300 | pred = raw_answer.split("\n")[0].strip().lower() 301 | prompts_across_boost.append(all_prompts) 302 | preds_across_boost.append((more_info_answer, pred)) 303 | 304 | entry = { 305 | "ind": ind, 306 | "example": question, 307 | "prompts": prompts_across_boost, 308 | "preds_boost": preds_across_boost, 309 | "gold": gold, 310 | } 311 | expt_log[ind] = entry 312 | all_boost_preds.append(preds_across_boost) 313 | labels.append(gold) 314 | return expt_log, all_boost_preds, labels 315 | 316 | 317 | def main(): 318 | args = get_args() 319 | task_name = "webq" 320 | data_dir = "webq" 321 | webq = WebQDecomp(task_name, data_dir, val_split="test") 322 | webq.run(args) 323 | 324 | 325 | if __name__ == "__main__": 326 | main() 327 | --------------------------------------------------------------------------------