├── .gitignore ├── .havenignore ├── LICENSE ├── NOTICE ├── README.md ├── configs └── exp_configs │ ├── __init__.py │ └── fs_exps.py ├── models ├── __init__.py └── backbones │ └── __init__.py ├── prepare_dataset.py ├── requirements.txt ├── runners ├── compile_results.py ├── infer.py ├── modules │ └── bert.py ├── oracle_relabel.py └── train.py ├── scripts ├── 3way_human_eval.py ├── __init__.py ├── compute_ttr.py ├── gpt3_analysis.py ├── make_spreadsheet.py └── openai_sandbox.py └── utils ├── __init__.py ├── data_utils ├── augment_slices.py ├── banking77_utils.py ├── clinc_utils.py ├── data_loader.py ├── eda_utils.py ├── hwu64_utils.py ├── main.py ├── sample_few_shot.py └── snips_utils.py └── metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | spreadsheets/* 2 | 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | results/ 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | .idea 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | .vscode 136 | 137 | data 138 | logs 139 | wandb 140 | results 141 | figures 142 | 143 | # data files 144 | *.pkl 145 | *.yml 146 | *.json 147 | 148 | wandb 149 | -------------------------------------------------------------------------------- /.havenignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | results/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 ServiceNow 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2022 ServiceNow, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the official implementation of the following paper:
2 | [Gaurav Sahu](https://github.com/demfier), Pau Rodriguez, Issam Laradji, Parmida Atighehchian, David Vazquez, and Dzmitry Bahdanau. [Data Augmentation for Intent Classification with Off-the-shelf Large Language Models](https://aclanthology.org/2022.nlp4convai-1.5.pdf). *Proceedings of the 4th Workshop on NLP for Conversational AI, ACL 2022.* 3 | 4 | If you find this code useful, please cite: 5 | ```bibtex 6 | @inproceedings{sahu-etal-2022-data, 7 | title = "Data Augmentation for Intent Classification with Off-the-shelf Large Language Models", 8 | author = "Sahu, Gaurav and 9 | Rodriguez, Pau and 10 | Laradji, Issam and 11 | Atighehchian, Parmida and 12 | Vazquez, David and 13 | Bahdanau, Dzmitry", 14 | booktitle = "Proceedings of the 4th Workshop on NLP for Conversational AI", 15 | month = may, 16 | year = "2022", 17 | address = "Dublin, Ireland", 18 | publisher = "Association for Computational Linguistics", 19 | url = "https://aclanthology.org/2022.nlp4convai-1.5", 20 | pages = "47--57", 21 | } 22 | ``` 23 | 24 | # Running experiments 25 | 26 | ### Datasets 27 | 28 | #### Preparing data 29 | This step is only required if there you are starting from scratch, i.e., *NO* data has been prepared at all. 30 | Note that you are still required to create the symbolic link as suggested in the previous step. 31 | To get started with data preparation, run the following: 32 | ``` 33 | python prepare_dataset.py --name --data_root './data/' 34 | ``` 35 | 36 | This will generate samples for all the supported modes (upsample, gpt3, gptj, eda). 37 | You can enable top-k and top-p sampling by specifying appropriate values for `--top_k` [0,) and `--top_p` [0, 1] flags. 38 | **NOTE**: If generating for GPTJ, make sure there's enough GPU memory (recommended >=32G). 39 | 40 | This will also setup the data directory structure for ``. 41 | It will prepare a `dataset.pkl` AND `data_full_suite.pkl`. 42 | It will also generate the corresponding label maps (name2id, id2name). 43 | Make sure you have `wget` installed in your local machine. 44 | 45 | **Note:** 46 | - HWU64 was downloaded from "https://github.com/alexa/dialoglue" 47 | - Banking77, and CLINC150 were downloaded using the HuggingFace "datasets" library 48 | - SNIPS was downloaded from "https://github.com/MiuLab/SlotGated-SLU" 49 | 50 | Refer to the `fewshot_baseline_clinc` configuration in `configs/exp_configs/fs_exps.py` for full few-shot experiment config. 51 | 52 | #### Running experiments: 53 | To run baseline experiments following the original CLINC setting: 54 | 1. Edit the `baselines` variable inside `configs/exp_configs/fs_exps.py`. Here's an example for running `small` and `plus` baselines together: 55 | 56 | ```python 57 | baselines = hu.cartesian_exp_group( 58 | { 59 | # do multiple runs to account for stochasticity in metrics 60 | "run#": list(range(10)), 61 | "dataset": [ 62 | {"name": "clinc_oos", "num_labels": 151, "config": c, "oos_id": 150} 63 | for c in ["plus", "small"] 64 | ], # config: small/plus/full/few_pure 65 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 66 | "exp_type": "baseline", # intrinsic/baseline 67 | "lr": 4e-5, 68 | "batch_size": 32, 69 | "epochs": 6, 70 | "warmup_ratio": 0.1, 71 | "weight_decay": 0.01, 72 | # metrics to compute 73 | "metrics": [["accuracy", "f1", "precision", "recall"]], 74 | "metric_best": "accuracy", 75 | "ngpu": 1, 76 | "eval_accumulation_steps": 30, 77 | } 78 | ) 79 | ``` 80 | **Note:** For CLINC (`name='clinc_oos'`), `oos_id=42` for `small/plus/imbalanced` and `oos_id=150` for `full/full_*`. 81 | It also supports no-OOS classifiers. Set `oos_id=None` and `num_labels=150`. 82 | For SNIPS (`name='snips_official'`) `oos_id=None` and it only supports `full*` settings. Make sure that `oos_id` is set correctly. 83 | Refer to other config variables inside `configs/exp_configs.py` for partial few-shot (ex2 setup) and full few-shot configs. 84 | 85 | 3. Run experiments: 86 | ``` 87 | $(which python) -m runners.train --savedir_base /path/to/save/dir/ --exp_group_list baselines -j 1 -v results.ipynb --python_binary $(which python) 88 | ``` 89 | Setting `-j 0` will run it locally. `--exp_group_list ex2_setup` will run the EX2 experiments (make sure that the dataset preperation is complete) 90 | 91 | #### For Oracle relabeling experiments: 92 | * Relabel generated examples using an oracle: 93 | ``` 94 | $(which python) -m runners.oracle_relabel -md /path/to/oracle/ -e fewshot_baseline_clinc 95 | ``` 96 | * Train classifiers on the relabeled data: 97 | ``` 98 | $(which python) -m runners.train --savedir_base /path/to/save/dir/ --exp_group_list fewshot_oracle_clinc -j 1 -v results.ipynb --python_binary $(which python) 99 | ``` 100 | 4. To compile results, correctly set the "knobs" in [runners.compile_results](runners/compile_results.py#L9-L30) and then run `python -m runners.compile_results` from root. 101 | 102 | 103 | #### Adding a new dataset 104 | To add a new dataset, follow these steps: 105 | * Create a utils file for your dataset under `utils/data_utils/`. Let's call it `{dataset}_utils.py`. 106 | All the dataset-specific processing needs to be added in there. In the end, your `{dataset}_utils.py` needs 107 | to have a `parse_and_load_{dataset}` function. Refer to the documentation of `clinc_utils.parse_and_load_clinc()` to understand more. 108 | * Add your dataset to `parse_and_load()` and `get_ds_config()` in `utils/data_utils/main.py`. 109 | * [Running prepare_dataset.py](#preparing-data) for your dataset name now should create the required files for ex2 and non-ex2 setup. 110 | * Finally, refer to [this](#for-oracle-relabling-experiments) to also generate dataset for the oracle relabling experiments in the full few-shot setup. 111 | 112 | 113 | #### List of configs exp_configs.py for different experiments 114 | 1. Reproducing CLINC150 results: 115 | 116 | ```python 117 | baselines = hu.cartesian_exp_group({ 118 | # do multiple runs to account for stochasticity in metrics 119 | 'run#': list(range(10)), 120 | 'dataset': {'name': 'clinc_oos', 'num_labels': 151, 'oos_id': 150, 'config': 'full'}, # config: small/plus/full 121 | 'model': { 122 | 'name': 'intent_classification', 123 | 'backbone': 'bert-large-uncased' 124 | }, 125 | 'exp_type': 'baseline', # intrinsic/baseline 126 | 'lr': 4e-5, 127 | 'batch_size': 32, 128 | 'epochs': 6, 129 | 'warmup_ratio': 0.1, 130 | 'weight_decay': 0.01, 131 | # metrics to compute 132 | 'metrics': [['accuracy', 'f1', 'precision', 'recall']], 133 | 'metric_best': 'accuracy', 134 | 'ngpu': 1, 135 | 'eval_accumulation_steps': 30 136 | }) 137 | ``` 138 | 139 | 2. Running Partial few-shot baseline/upsample experiments: 140 | 141 | ```python 142 | ex2_setup = hu.cartesian_exp_group({ 143 | # do multiple runs to account for stochasticity in metrics 144 | 'run#': list(range(10)), 145 | 'dataset': [{ 146 | 'name': 'clinc_oos', 'num_labels': 151, 'oos_id': 150, 147 | 'config': 'full_'+v} for v in DOMAINS], # config -> small/imbalanced/plus/small_aug/full 148 | 'model': { 149 | 'name': 'intent_classification', 150 | 'backbone': 'bert-large-uncased' 151 | }, 152 | 'exp_type': ['baseline', 'upsample'], # gpt3/upsample/baseline 153 | 'lr': 5e-5, 154 | 'batch_size': 64, 155 | 'epochs': 10, 156 | 'warmup_ratio': 0.1, 157 | 'weight_decay': 0.01, 158 | # metrics to compute. if oos_id is not None, 159 | # compute inscope_accuracy and oos_recall as well 160 | 'metrics': [['accuracy', 'f1', 'precision', 'recall']], 161 | 'metric_best': 'f1', 162 | 'eval_accumulation_steps': 30 163 | }) 164 | ``` 165 | 166 | 3. Partial few-shot augmented (GPT3) experiments: 167 | 168 | ```python 169 | ex2_setup = hu.cartesian_exp_group({ 170 | # do multiple runs to account for stochasticity in metrics 171 | 'run#': list(range(10)), 172 | 'dataset': [{ 173 | 'name': 'clinc_oos', 'num_labels': 151, 'oos_id': 150, 174 | 'config': 'full_'+v} for v in DOMAINS], # config -> small/imbalanced/plus/small_aug/full 175 | 'model': { 176 | 'name': 'intent_classification', 177 | 'backbone': 'bert-large-uncased' 178 | }, 179 | 'exp_type': ['gpt3'], # gpt3/upsample/baseline 180 | 'lr': 5e-5, 181 | 'batch_size': 64, 182 | 'epochs': 10, 183 | 'warmup_ratio': 0.1, 184 | 'weight_decay': 0.01, 185 | # metrics to compute. if oos_id is not None, 186 | # compute inscope_accuracy and oos_recall as well 187 | 'metrics': [['accuracy', 'f1', 'precision', 'recall']], 188 | 'metric_best': 'f1', 189 | # 'gpt3_engine': 'ada', # ada/babbage/curie/davinci 190 | 'gpt3_engine': ['ada', 'babbage', 'curie', 'davinci'], # ada/babbage/curie/davinci 191 | # 'gpt3_temp': 1.0, # 0.5/0.6/0.7/0.8/0.9/1.0/1.5/2.0 192 | 'gpt3_temp': [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.5, 2.0], # 0.5-2.0 193 | 'eval_accumulation_steps': 30 194 | }) 195 | ``` 196 | 197 | 4. Training the oracle: 198 | 199 | ```python 200 | baselines = hu.cartesian_exp_group({ 201 | # do multiple runs to account for stochasticity in metrics 202 | 'run#': list(range(1)), 203 | 'dataset': {'name': 'clinc_oos', 'num_labels': 151, 'oos_id': None, 'config': 'full'}, # config: small/plus/full 204 | 'model': { 205 | 'name': 'intent_classification', 206 | 'backbone': 'bert-large-uncased' 207 | }, 208 | 'exp_type': 'intrinsic', # intrinsic/baseline 209 | 'lr': 4e-5, 210 | 'batch_size': 32, 211 | 'epochs': 6, 212 | 'warmup_ratio': 0.1, 213 | 'weight_decay': 0.01, 214 | # metrics to compute 215 | 'metrics': [['accuracy', 'f1', 'precision', 'recall']], 216 | 'metric_best': 'accuracy', 217 | 'ngpu': 1, 218 | 'eval_accumulation_steps': 30 219 | }) 220 | ``` 221 | -------------------------------------------------------------------------------- /configs/exp_configs/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fs_exps 2 | 3 | EXP_GROUPS = {} 4 | EXP_GROUPS.update(fs_exps.EXP_GROUPS) 5 | -------------------------------------------------------------------------------- /configs/exp_configs/fs_exps.py: -------------------------------------------------------------------------------- 1 | from haven import haven_utils as hu 2 | 3 | EXP_GROUPS = {} 4 | 5 | # CLINC 6 | DOMAINS = [ 7 | "banking", 8 | "credit_card", 9 | "dining", 10 | "home", 11 | "auto", 12 | "travel", 13 | "utility", 14 | "work", 15 | "small_talk", 16 | "meta", 17 | ] 18 | 19 | 20 | baselines = hu.cartesian_exp_group( 21 | { 22 | # do multiple runs to account for stochasticity in metrics 23 | "run#": list(range(10)), 24 | "dataset": [ 25 | {"name": "clinc_oos", "num_labels": 151, "config": c, "oos_id": 150} 26 | for c in ["few_pure"] 27 | ], # config: small/plus/full/few_pure 28 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 29 | "exp_type": "baseline", # intrinsic/baseline 30 | "lr": 4e-5, 31 | "batch_size": 32, 32 | "epochs": 6, 33 | "warmup_ratio": 0.1, 34 | "weight_decay": 0.01, 35 | # metrics to compute 36 | "metrics": [["accuracy", "f1", "precision", "recall"]], 37 | "metric_best": "accuracy", 38 | "ngpu": 1, 39 | # 'gpt3_engine': ['ada', 'babbage', 'curie', 'davinci'], 40 | # 'gpt3_temp': 1.0, 41 | "eval_accumulation_steps": 30, 42 | } 43 | ) 44 | 45 | ex2_setup = hu.cartesian_exp_group( 46 | { 47 | # do multiple runs to account for stochasticity in metrics 48 | "run#": list(range(10)), 49 | "dataset": [ 50 | { 51 | "name": "clinc_oos", 52 | "num_labels": 151, 53 | "oos_id": 150, 54 | "config": "full_" + v, 55 | } 56 | for v in DOMAINS 57 | ], # config -> small/imbalanced/plus/small_aug/full 58 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 59 | "exp_type": ["gpt3"], # gpt3/upsample/baseline 60 | "lr": 5e-5, 61 | "batch_size": 64, 62 | "epochs": 10, 63 | "warmup_ratio": 0.1, 64 | "weight_decay": 0.01, 65 | # metrics to compute. if oos_id is not None, 66 | # compute inscope_accuracy and oos_recall as well 67 | "metrics": [["accuracy", "f1", "precision", "recall"]], 68 | "metric_best": "accuracy", 69 | # 'gpt3_engine': 'ada', # ada/babbage/curie/davinci 70 | "gpt3_engine": [ 71 | "ada", 72 | "babbage", 73 | "curie", 74 | "davinci", 75 | "gptj", 76 | ], # ada/babbage/curie/davinci 77 | # 'gpt3_temp': 1.0, # 0.5/0.6/0.7/0.8/0.9/1.0/1.5/2.0 78 | "gpt3_temp": 1.0, # 0.5-2.0 79 | "eval_accumulation_steps": 30, 80 | } 81 | ) 82 | 83 | fewshot_oracle_clinc = hu.cartesian_exp_group( 84 | { 85 | "run#": list(range(10)), # for extrinsic evaluation 86 | "dataset": { 87 | "name": "clinc_oos", 88 | "num_labels": 151, 89 | "oos_id": 150, 90 | "config": "few_pure", 91 | }, 92 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 93 | "lr": 5e-5, 94 | "batch_size": 64, 95 | "epochs": 6, 96 | "warmup_ratio": 0.1, 97 | "weight_decay": 0.01, 98 | "metrics": [["accuracy", "f1", "precision", "recall"]], 99 | "metric_best": "accuracy", # accuracy/f1 100 | "exp_type": "gpt3_oracle", 101 | "gpt3_engine": [ 102 | "ada", 103 | "babbage", 104 | "curie", 105 | "davinci", 106 | "gptj", 107 | ], # ada/babbage/curie/davinci/gptj 108 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 109 | "eval_accumulation_steps": 30, 110 | } 111 | ) 112 | 113 | fewshot_gpt3_clinc = hu.cartesian_exp_group( 114 | { 115 | "run#": list(range(10)), # for extrinsic evaluation 116 | "dataset": { 117 | "name": "clinc_oos", 118 | "num_labels": 151, 119 | "oos_id": 150, 120 | "config": "few_pure", 121 | }, 122 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 123 | "lr": 5e-5, 124 | "batch_size": 64, 125 | "epochs": 6, 126 | "warmup_ratio": 0.1, 127 | "weight_decay": 0.01, 128 | "metrics": [["accuracy", "f1", "precision", "recall"]], 129 | "metric_best": "accuracy", # accuracy/f1 130 | "exp_type": "gpt3", 131 | "gpt3_engine": [ 132 | "ada", 133 | "babbage", 134 | "curie", 135 | "davinci", 136 | "gptj", 137 | ], # ada/babbage/curie/davinci/gptj 138 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 139 | "eval_accumulation_steps": 30, 140 | } 141 | ) 142 | 143 | fewshot_gpt3mix_clinc = hu.cartesian_exp_group( 144 | { 145 | "run#": list(range(10)), # for extrinsic evaluation 146 | "dataset": { 147 | "name": "clinc_oos", 148 | "num_labels": 151, 149 | "oos_id": 150, 150 | "config": "few_pure", 151 | }, 152 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 153 | "lr": 5e-5, 154 | "batch_size": 64, 155 | "epochs": 6, 156 | "warmup_ratio": 0.1, 157 | "weight_decay": 0.01, 158 | "metrics": [["accuracy", "f1", "precision", "recall"]], 159 | "metric_best": "accuracy", # accuracy/f1 160 | "exp_type": ["gpt3mix", "gpt3mix_oracle"], 161 | "soft_label": True, 162 | "gpt3_engine": "curie", # only generated for curie 163 | "gpt3_temp": 1.0, # only generated for 1.0 164 | "eval_accumulation_steps": 30, 165 | } 166 | ) 167 | 168 | fewshot_eda_clinc = hu.cartesian_exp_group( 169 | { 170 | "run#": list(range(10)), # for extrinsic evaluation 171 | "dataset": { 172 | "name": "clinc_oos", 173 | "num_labels": 151, 174 | "oos_id": 150, 175 | "config": "few_pure", 176 | }, 177 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 178 | "lr": 5e-5, 179 | "batch_size": 64, 180 | "epochs": 6, 181 | "warmup_ratio": 0.1, 182 | "weight_decay": 0.01, 183 | "metrics": [["accuracy", "f1", "precision", "recall"]], 184 | "metric_best": "accuracy", # accuracy/f1 185 | "exp_type": ["eda", "eda_oracle"], # eda/eda_oracle 186 | "eval_accumulation_steps": 30, 187 | } 188 | ) 189 | 190 | 191 | fewshot_baseline_clinc = hu.cartesian_exp_group( 192 | { 193 | "run#": list(range(10)), # for extrinsic evaluation 194 | "dataset": { 195 | "name": "clinc_oos", 196 | "num_labels": 151, 197 | "oos_id": 150, 198 | "config": "few_pure", 199 | }, 200 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 201 | "lr": 3e-5, # 1e-5 for k=3, 5 202 | "batch_size": 8, # 2 for k=3, 5 203 | "epochs": 6, # 20 for k=3, 5 204 | "warmup_ratio": 0.1, 205 | "weight_decay": 0.001, # 0.0001 for k=3, 5 206 | "metrics": [["accuracy", "f1", "precision", "recall"]], 207 | "metric_best": "accuracy", # accuracy/f1 208 | "exp_type": "baseline", 209 | "eval_accumulation_steps": 30, 210 | } 211 | ) 212 | 213 | banking77_baselines = hu.cartesian_exp_group( 214 | { 215 | # do multiple runs to account for stochasticity in metrics 216 | "run#": list(range(10)), 217 | "dataset": { 218 | "name": "banking77", 219 | "num_labels": 77, 220 | "config": "full", 221 | "oos_id": None, 222 | }, # config: small/plus/full/few_pure 223 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 224 | "exp_type": "baseline", # intrinsic/baseline 225 | "lr": 5e-5, 226 | "batch_size": 32, 227 | "epochs": 10, 228 | "warmup_ratio": 0.1, 229 | "weight_decay": 0.01, 230 | # metrics to compute 231 | "metrics": [["accuracy", "f1", "precision", "recall"]], 232 | "metric_best": "accuracy", 233 | "eval_accumulation_steps": 30, 234 | } 235 | ) 236 | 237 | fewshot_baseline_banking77 = hu.cartesian_exp_group( 238 | { 239 | "run#": list(range(10)), # for extrinsic evaluation 240 | "dataset": { 241 | "name": "banking77", 242 | "num_labels": 77, 243 | "oos_id": None, 244 | "config": "few_pure", 245 | }, 246 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 247 | "lr": 3e-5, 248 | "batch_size": 8, 249 | "epochs": 20, 250 | "warmup_ratio": 0.1, 251 | "weight_decay": 0.001, 252 | "metrics": [["accuracy", "f1", "precision", "recall"]], 253 | "metric_best": "accuracy", # accuracy/f1 254 | "exp_type": "baseline", 255 | "eval_accumulation_steps": 30, 256 | } 257 | ) 258 | 259 | fewshot_oracle_banking77 = hu.cartesian_exp_group( 260 | { 261 | "run#": list(range(10)), # for extrinsic evaluation 262 | "dataset": { 263 | "name": "banking77", 264 | "num_labels": 77, 265 | "oos_id": None, 266 | "config": "few_pure", 267 | }, 268 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 269 | "lr": 5e-5, 270 | "batch_size": 32, 271 | "epochs": 10, 272 | "warmup_ratio": 0.1, 273 | "weight_decay": 0.01, 274 | "metrics": [["accuracy", "f1", "precision", "recall"]], 275 | "metric_best": "accuracy", # accuracy/f1 276 | "exp_type": "gpt3_oracle", 277 | "gpt3_engine": [ 278 | "ada", 279 | "babbage", 280 | "curie", 281 | "davinci", 282 | "gptj", 283 | ], # ada/babbage/curie/davinci/gptj 284 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 285 | "eval_accumulation_steps": 30, 286 | } 287 | ) 288 | 289 | fewshot_gpt3_banking77 = hu.cartesian_exp_group( 290 | { 291 | "run#": list(range(10)), # for extrinsic evaluation 292 | "dataset": { 293 | "name": "banking77", 294 | "num_labels": 77, 295 | "oos_id": None, 296 | "config": "few_pure", 297 | }, 298 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 299 | "lr": 5e-5, 300 | "batch_size": 32, 301 | "epochs": 10, 302 | "warmup_ratio": 0.1, 303 | "weight_decay": 0.01, 304 | "metrics": [["accuracy", "f1", "precision", "recall"]], 305 | "metric_best": "accuracy", # accuracy/f1 306 | "exp_type": "gpt3", 307 | "gpt3_engine": [ 308 | "ada", 309 | "babbage", 310 | "curie", 311 | "davinci", 312 | "gptj", 313 | ], # ada/babbage/curie/davinci/gptj 314 | # "gpt3_engine": "curie", # ada/babbage/curie/davinci/gptj 315 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 316 | "eval_accumulation_steps": 30, 317 | } 318 | ) 319 | 320 | 321 | fewshot_eda_banking77 = hu.cartesian_exp_group( 322 | { 323 | "run#": list(range(10)), # for extrinsic evaluation 324 | "dataset": { 325 | "name": "banking77", 326 | "num_labels": 77, 327 | "oos_id": None, 328 | "config": "few_pure", 329 | }, 330 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 331 | "lr": 5e-5, 332 | "batch_size": 32, 333 | "epochs": 10, 334 | "warmup_ratio": 0.1, 335 | "weight_decay": 0.01, 336 | "metrics": [["accuracy", "f1", "precision", "recall"]], 337 | "metric_best": "accuracy", # accuracy/f1 338 | "exp_type": ["eda", "eda_oracle"], 339 | "eval_accumulation_steps": 30, 340 | } 341 | ) 342 | 343 | hwu64_baselines = hu.cartesian_exp_group( 344 | { 345 | # do multiple runs to account for stochasticity in metrics 346 | "run#": list(range(10)), 347 | "dataset": { 348 | "name": "hwu64", 349 | "num_labels": 64, 350 | "config": "full", 351 | "oos_id": None, 352 | }, # config: small/plus/full/few_pure 353 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 354 | "exp_type": "intrinsic", # intrinsic/baseline 355 | "lr": 5e-5, 356 | "batch_size": 32, 357 | "epochs": 6, 358 | "warmup_ratio": 0.1, 359 | "weight_decay": 0.01, 360 | # metrics to compute 361 | "metrics": [["accuracy", "f1", "precision", "recall"]], 362 | "metric_best": "accuracy", 363 | "eval_accumulation_steps": 30, 364 | } 365 | ) 366 | 367 | fewshot_baseline_hwu64 = hu.cartesian_exp_group( 368 | { 369 | "run#": list(range(10)), # for extrinsic evaluation 370 | "dataset": { 371 | "name": "hwu64", 372 | "num_labels": 64, 373 | "oos_id": None, 374 | "config": "few_pure", 375 | }, 376 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 377 | "lr": 3e-5, 378 | "batch_size": 8, 379 | "epochs": 20, 380 | "warmup_ratio": 0.1, 381 | "weight_decay": 0.001, 382 | "metrics": [["accuracy", "f1", "precision", "recall"]], 383 | "metric_best": "accuracy", # accuracy/f1 384 | "exp_type": "baseline", 385 | "eval_accumulation_steps": 30, 386 | } 387 | ) 388 | 389 | fewshot_oracle_hwu64 = hu.cartesian_exp_group( 390 | { 391 | "run#": list(range(10)), # for extrinsic evaluation 392 | "dataset": { 393 | "name": "hwu64", 394 | "num_labels": 64, 395 | "oos_id": None, 396 | "config": "few_pure", 397 | }, 398 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 399 | "lr": 3e-5, 400 | "batch_size": 32, 401 | "epochs": 10, 402 | "warmup_ratio": 0.1, 403 | "weight_decay": 0.01, 404 | "metrics": [["accuracy", "f1", "precision", "recall"]], 405 | "metric_best": "accuracy", # accuracy/f1 406 | "exp_type": "gpt3_oracle", 407 | "gpt3_engine": [ 408 | "ada", 409 | "babbage", 410 | "curie", 411 | "davinci", 412 | "gptj", 413 | ], # ada/babbage/curie/davinci/gptj 414 | # "gpt3_engine": "babbage", # ada/babbage/curie/davinci/gptj 415 | "gpt3_temp": 1.0, 416 | "eval_accumulation_steps": 30, 417 | } 418 | ) 419 | 420 | fewshot_gpt3_hwu64 = hu.cartesian_exp_group( 421 | { 422 | "run#": list(range(10)), # for extrinsic evaluation 423 | "dataset": { 424 | "name": "hwu64", 425 | "num_labels": 64, 426 | "oos_id": None, 427 | "config": "few_pure", 428 | }, 429 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 430 | "lr": 3e-5, 431 | "batch_size": 32, 432 | "epochs": 10, 433 | "warmup_ratio": 0.1, 434 | "weight_decay": 0.01, 435 | "metrics": [["accuracy", "f1", "precision", "recall"]], 436 | "metric_best": "accuracy", # accuracy/f1 437 | "exp_type": "gpt3", 438 | "gpt3_engine": [ 439 | "ada", 440 | "babbage", 441 | "curie", 442 | "davinci", 443 | "gptj", 444 | ], # ada/babbage/curie/davinci/gptj 445 | # "gpt3_engine": "babbage", # ada/babbage/curie/davinci/gptj 446 | # "gpt3_temp": [round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))], 447 | "gpt3_temp": [1.0], 448 | "eval_accumulation_steps": 30, 449 | } 450 | ) 451 | 452 | fewshot_eda_hwu64 = hu.cartesian_exp_group( 453 | { 454 | "run#": list(range(10)), # for extrinsic evaluation 455 | "dataset": { 456 | "name": "hwu64", 457 | "num_labels": 64, 458 | "oos_id": None, 459 | "config": "few_pure", 460 | }, 461 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 462 | "lr": 3e-5, 463 | "batch_size": 32, 464 | "epochs": 10, 465 | "warmup_ratio": 0.1, 466 | "weight_decay": 0.01, 467 | "metrics": [["accuracy", "f1", "precision", "recall"]], 468 | "metric_best": "accuracy", # accuracy/f1 469 | "exp_type": ["eda", "eda_oracle"], 470 | "eval_accumulation_steps": 30, 471 | } 472 | ) 473 | 474 | # SNIPS 475 | SNIPS_DOMAINS = [ 476 | "AddToPlaylist", 477 | "BookRestaurant", 478 | "GetWeather", 479 | "PlayMusic", 480 | "RateBook", 481 | "SearchCreativeWork", 482 | "SearchScreeningEvent", 483 | ] 484 | 485 | # SNIPS_DOMAINS = ["RateBook"] 486 | 487 | snips_baselines = hu.cartesian_exp_group( 488 | { 489 | "dataset": { 490 | "name": "snips_official", 491 | "num_labels": 7, 492 | "oos_id": None, 493 | "config": "full", 494 | }, # config -> small/imbalanced/plus/small_aug/full/intrinsic 495 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 496 | "exp_type": "baseline", # upsample/baseline 497 | "lr": 4e-5, 498 | "batch_size": 32, 499 | "epochs": 6, 500 | "warmup_ratio": 0.1, 501 | "weight_decay": 0.01, 502 | "metrics": [["accuracy"]], 503 | "metric_best": "accuracy", 504 | "ngpu": 1, 505 | "gpt3_temp": 1.0, 506 | "eval_accumulation_steps": 30, 507 | } 508 | ) 509 | 510 | snips_ex2_setup = hu.cartesian_exp_group( 511 | { 512 | # do multiple runs to account for stochasticity in metrics 513 | "run#": list(range(10)), 514 | "dataset": [ 515 | { 516 | "name": "snips_official", 517 | "num_labels": 7, 518 | "oos_id": None, 519 | "config": "full_" + v, 520 | } 521 | for v in SNIPS_DOMAINS 522 | ], # config -> small/imbalanced/plus/small_aug/full 523 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 524 | "exp_type": ["gpt3"], # gpt3/upsample/baseline 525 | "lr": 5e-5, 526 | "batch_size": 64, 527 | "epochs": 10, 528 | "warmup_ratio": 0.1, 529 | "weight_decay": 0.01, 530 | # metrics to compute. if oos_id is not None, 531 | # compute inscope_accuracy and oos_recall as well 532 | "metrics": [["accuracy", "f1", "precision", "recall"]], 533 | "metric_best": "accuracy", 534 | # 'gpt3_engine': 'davinci', # ada/babbage/curie/davinci/gptj 535 | "gpt3_engine": [ 536 | "ada", 537 | "babbage", 538 | "curie", 539 | "davinci", 540 | "gptj", 541 | ], # ada/babbage/curie/davinci/gptj 542 | "gpt3_temp": 1.0, # 0.5/0.6/0.7/0.8/0.9/1.0/1.5 543 | "eval_accumulation_steps": 30, 544 | } 545 | ) 546 | 547 | 548 | fewshot_baseline_snips = hu.cartesian_exp_group( 549 | { 550 | "run#": list(range(10)), # for extrinsic evaluation 551 | "dataset": { 552 | "name": "snips_official", 553 | "num_labels": 7, 554 | "oos_id": None, 555 | "config": "few_pure", 556 | }, 557 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 558 | "lr": 3e-5, 559 | "batch_size": 8, 560 | "epochs": 20, 561 | "warmup_ratio": 0.1, 562 | "weight_decay": 0.001, 563 | "metrics": [["accuracy", "f1", "precision", "recall"]], 564 | "metric_best": "accuracy", # accuracy/f1 565 | "exp_type": "baseline", 566 | "eval_accumulation_steps": 30, 567 | } 568 | ) 569 | 570 | fewshot_oracle_snips = hu.cartesian_exp_group( 571 | { 572 | "run#": list(range(10)), # for extrinsic evaluation 573 | "dataset": { 574 | "name": "snips_official", 575 | "num_labels": 7, 576 | "oos_id": None, 577 | "config": "few_pure", 578 | }, 579 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 580 | "lr": 5e-5, 581 | "batch_size": 64, 582 | "epochs": 10, 583 | "warmup_ratio": 0.1, 584 | "weight_decay": 0.01, 585 | "metrics": [["accuracy", "f1", "precision", "recall"]], 586 | "metric_best": "accuracy", # accuracy/f1 587 | "exp_type": "gpt3_oracle", 588 | "gpt3_engine": [ 589 | "ada", 590 | "babbage", 591 | "curie", 592 | "davinci", 593 | "gptj", 594 | ], # ada/babbage/curie/davinci/gptj 595 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 596 | "eval_accumulation_steps": 30, 597 | } 598 | ) 599 | 600 | fewshot_gpt3_snips = hu.cartesian_exp_group( 601 | { 602 | "run#": list(range(10)), # for extrinsic evaluation 603 | "dataset": { 604 | "name": "snips_official", 605 | "num_labels": 7, 606 | "oos_id": None, 607 | "config": "few_pure", 608 | }, 609 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 610 | "lr": 5e-5, 611 | "batch_size": 64, 612 | "epochs": 10, 613 | "warmup_ratio": 0.1, 614 | "weight_decay": 0.01, 615 | "metrics": [["accuracy", "f1", "precision", "recall"]], 616 | "metric_best": "accuracy", # accuracy/f1 617 | "exp_type": "gpt3", 618 | "gpt3_engine": [ 619 | "ada", 620 | "babbage", 621 | "curie", 622 | "davinci", 623 | "gptj", 624 | ], # ada/babbage/curie/davinci/gptj 625 | # "gpt3_engine": "babbage", # ada/babbage/curie/davinci/gptj 626 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 627 | "eval_accumulation_steps": 30, 628 | } 629 | ) 630 | 631 | fewshot_eda_snips = hu.cartesian_exp_group( 632 | { 633 | "run#": list(range(10)), # for extrinsic evaluation 634 | "dataset": { 635 | "name": "snips_official", 636 | "num_labels": 7, 637 | "oos_id": None, 638 | "config": "few_pure", 639 | }, 640 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"}, 641 | "lr": 5e-5, 642 | "batch_size": 64, 643 | "epochs": 10, 644 | "warmup_ratio": 0.1, 645 | "weight_decay": 0.01, 646 | "metrics": [["accuracy", "f1", "precision", "recall"]], 647 | "metric_best": "accuracy", # accuracy/f1 648 | "exp_type": ["eda", "eda_oracle"], 649 | "eval_accumulation_steps": 30, 650 | } 651 | ) 652 | 653 | # CLINC 654 | EXP_GROUPS["baselines"] = baselines 655 | EXP_GROUPS["ex2_setup"] = ex2_setup 656 | EXP_GROUPS["fewshot_baseline_clinc"] = fewshot_baseline_clinc 657 | EXP_GROUPS["fewshot_oracle_clinc"] = fewshot_oracle_clinc 658 | EXP_GROUPS["fewshot_gpt3_clinc"] = fewshot_gpt3_clinc 659 | EXP_GROUPS["fewshot_gpt3mix_clinc"] = fewshot_gpt3mix_clinc 660 | EXP_GROUPS["fewshot_eda_clinc"] = fewshot_eda_clinc 661 | 662 | # Banking77 663 | EXP_GROUPS["banking77_baselines"] = banking77_baselines 664 | EXP_GROUPS["fewshot_baseline_banking77"] = fewshot_baseline_banking77 665 | EXP_GROUPS["fewshot_gpt3_banking77"] = fewshot_gpt3_banking77 666 | EXP_GROUPS["fewshot_oracle_banking77"] = fewshot_oracle_banking77 667 | EXP_GROUPS["fewshot_eda_banking77"] = fewshot_eda_banking77 668 | 669 | # HWU64 670 | EXP_GROUPS["hwu64_baselines"] = hwu64_baselines 671 | EXP_GROUPS["fewshot_baseline_hwu64"] = fewshot_baseline_hwu64 672 | EXP_GROUPS["fewshot_gpt3_hwu64"] = fewshot_gpt3_hwu64 673 | EXP_GROUPS["fewshot_oracle_hwu64"] = fewshot_oracle_hwu64 674 | EXP_GROUPS["fewshot_eda_hwu64"] = fewshot_eda_hwu64 675 | 676 | # SNIPS 677 | EXP_GROUPS["snips_baselines"] = snips_baselines 678 | EXP_GROUPS["snips_ex2_setup"] = snips_ex2_setup 679 | EXP_GROUPS["fewshot_baseline_snips"] = fewshot_baseline_snips 680 | EXP_GROUPS["fewshot_gpt3_snips"] = fewshot_gpt3_snips 681 | EXP_GROUPS["fewshot_oracle_snips"] = fewshot_oracle_snips 682 | EXP_GROUPS["fewshot_eda_snips"] = fewshot_eda_snips 683 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/data-augmentation-with-llms/77c740df2aaf4bfc1b2449114613f98b37a1faae/models/__init__.py -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoModelForSequenceClassification, 3 | AutoModelForSeq2SeqLM, 4 | AutoConfig, 5 | BertModel, 6 | ) 7 | 8 | import torch.nn as nn 9 | from transformers.modeling_outputs import SequenceClassifierOutput 10 | 11 | 12 | class BertModelWithCustomLossFunction(nn.Module): 13 | def __init__(self, exp_dict): 14 | super(BertModelWithCustomLossFunction, self).__init__() 15 | self.num_labels = exp_dict["dataset"]["num_labels"] 16 | self.bert = BertModel.from_pretrained( 17 | exp_dict["model"]["backbone"], num_labels=self.num_labels 18 | ) 19 | self.dropout = nn.Dropout(0.1) 20 | self.classifier = nn.Linear(1024, self.num_labels) 21 | 22 | def forward(self, input_ids, attention_mask, token_type_ids, labels=None): 23 | outputs = self.bert( 24 | input_ids=input_ids, 25 | attention_mask=attention_mask, 26 | token_type_ids=token_type_ids, 27 | ) 28 | 29 | output = self.dropout(outputs.pooler_output) 30 | logits = self.classifier(output) 31 | 32 | loss = None 33 | if labels is not None: 34 | # you can define any loss function here yourself 35 | # see https://pytorch.org/docs/stable/nn.html#loss-functions for an overview 36 | loss_fct = nn.CrossEntropyLoss() 37 | # next, compute the loss based on logits + ground-truth labels 38 | loss = loss_fct(logits.view(-1, self.num_labels), labels) 39 | 40 | return SequenceClassifierOutput( 41 | loss=loss, 42 | logits=logits, 43 | hidden_states=outputs.hidden_states, 44 | attentions=outputs.attentions, 45 | ) 46 | 47 | 48 | def get_backbone(exp_dict): 49 | if exp_dict["exp_type"] == "gpt3mix": 50 | backbone = BertModelWithCustomLossFunction(exp_dict) 51 | return backbone 52 | 53 | if exp_dict["model"]["backbone"] in [ 54 | "distilbert-base-uncased", 55 | "bert-large-uncased", 56 | "bert-base-uncased", 57 | ]: 58 | backbone = AutoModelForSequenceClassification.from_pretrained( 59 | exp_dict["model"]["backbone"], num_labels=exp_dict["dataset"]["num_labels"] 60 | ) 61 | return backbone 62 | raise ValueError(f"backbone: {exp_dict['model']['backbone']} not supported") 63 | -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset preparation step. It's mostly a one-time script.""" 2 | import os 3 | 4 | os.environ["LOCAL_RANK"] = "-1" 5 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 6 | os.environ["TRANSFORMERS_CACHE"] = "/mnt/home/" 7 | import argparse 8 | from utils import main_data_utils, sample_few_shot, augment_slices 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--data_root", default="./data") 14 | parser.add_argument("--name", default="clinc_oos") 15 | parser.add_argument( 16 | "--modes", nargs="+", default=["upsample", "gptj", "gpt3", "eda"] 17 | ) 18 | parser.add_argument("--top_k", default=0, type=int) 19 | parser.add_argument("--top_p", default=1.0, type=float) 20 | return parser.parse_args() 21 | 22 | 23 | if __name__ == "__main__": 24 | args = parse_args() 25 | if args.top_k < 0 or args.top_p > 1.0: 26 | print("args.top_k >= 0 | 0.0 <= args.top_p <= 1.0") 27 | import sys 28 | 29 | sys.exit() 30 | ds_config = main_data_utils.get_ds_config(args.name) 31 | print(f"Loaded dataset config for {args.name}") 32 | sample_few_shot(args.data_root, ds_config) 33 | print(f"augmenting for modes: {args.modes}") 34 | data_slices = augment_slices( 35 | args.data_root, 36 | ds_config, 37 | modes=args.modes, 38 | top_k=args.top_k, 39 | top_p=args.top_p, 40 | ) 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==1.18.4 2 | GPUtil==1.4.0 3 | haven-ai @ git+https://github.com/haven-ai/haven-ai 4 | huggingface-hub==0.0.19 5 | matplotlib==3.4.3 6 | nltk 7 | numpy 8 | openai==0.15.0 9 | pandas==1.3.3 10 | scikit-image==0.18.3 11 | scikit-learn==1.0 12 | seaborn==0.11.2 13 | sentencepiece==0.1.96 14 | torch==1.9.1 15 | transformers==4.18.0 16 | wandb==0.12.4 17 | pydantic==1.8.2 -------------------------------------------------------------------------------- /runners/compile_results.py: -------------------------------------------------------------------------------- 1 | import os, json, copy, pickle 2 | import numpy as np 3 | import pandas as pd 4 | from collections import defaultdict 5 | from scipy import stats 6 | 7 | from utils.metrics import Metrics 8 | 9 | # this dir contains CLINC and SNIPS partial fewshot experiments 10 | # EXPERIMENTS_DIR = "/path/to/partial-fewshot/savedir/" 11 | 12 | # this dir contains CLINC, SNIPS, Banking77, and HWU64 full fewshot experiments 13 | EXPERIMENTS_DIR = "/path/to/full-fewshot/savedir/" 14 | 15 | # FILTERS (NOTE: any of these filters can be disabled by setting to None) 16 | datasets = None # clinc_oos, snips_official, banking77, hwu64 17 | augmentation = None # eda (in FS), ada, babbage, curie, davinci, gptj 18 | temperatures = None # 0.5-2.0 19 | skip_oracle = False # True/False 20 | 21 | # exp_type-wise significance tests are always computed w.r.t davinci models 22 | TTEST = False # True/False 23 | 24 | # print results as a latex table 25 | TO_LATEX = False 26 | 27 | # save metrics for plotting purposes 28 | GEN_FOR_PLOT = False 29 | # name of the results file (NOTE: only GPTJ for full fewshot is supported) 30 | FNAME = "results/gptj_val_fidelity_et_al.pkl" 31 | 32 | # The below vars are used by to_latex() 33 | # ALL_METRICS and ALL_DATASETS are in the order or reporting in the paper 34 | # defining the order here so that we don't need to apply any convoluted 35 | # ops like sorting to ensure consistent performance reporting. 36 | # NOTE: to_latex() auto-ignores a dataset if it's filtered out in this script 37 | ALL_DATASETS = ["clinc_oos", "hwu64", "banking77", "snips_official"] 38 | BACKBONE = "BERT" if "bert" in EXPERIMENTS_DIR else "T5" 39 | SCOPES = ["overall", "few_shot"] if "ex2" in EXPERIMENTS_DIR else ["test"] 40 | ALL_METRICS = ["IA", "OR"] 41 | 42 | # No need to touch anything else below this line. 43 | FILTER = { 44 | "datasets": datasets, # set None to fetch for all datasets 45 | "augmentation": augmentation, # set None to fetch for all augmentors 46 | "temps": temperatures, # set None to fetch for all temperatures 47 | "skip_oracle?": skip_oracle, # set False/None to fetch oracle results 48 | } 49 | 50 | pjoin = os.path.join 51 | 52 | 53 | def fmt(xs): 54 | return f"{np.mean(xs):.2f} ({np.std(xs):.2f})" 55 | 56 | 57 | def compile_results(folder_list, exp_dir): 58 | # fetch experiment type, and populate total_results accordingly 59 | exp_dict_path = pjoin(exp_dir, folder_list[0], "exp_dict.json") 60 | exp_dict = json.load(open(exp_dict_path, "r")) 61 | dataset = exp_dict["dataset"]["name"] 62 | dconfig = exp_dict["dataset"]["config"] 63 | oos_id = exp_dict["dataset"]["oos_id"] 64 | exp_type = exp_dict["exp_type"] 65 | ex2_setup = True if dconfig.startswith("full_") else False 66 | full_fewshot = True if dconfig == "few_pure" else False 67 | 68 | if dconfig != "few_pure" and not ex2_setup and exp_type == "baseline": 69 | org_baseline = True 70 | else: 71 | org_baseline = False 72 | 73 | # declare total_results template based on exp config 74 | if full_fewshot or org_baseline: # no need for overall, fewshot keys 75 | total_results = { 76 | "test": {"IA": [], "OR": [], "A": []}, 77 | "val": {"IA": [], "OR": [], "A": []}, 78 | } 79 | if oos_id is None: 80 | total_results = { 81 | "test": {"A": []}, 82 | "val": {"A": []}, 83 | } 84 | else: # ex2 setup 85 | total_results = { 86 | "few_shot": {"IA": []}, 87 | "few_shot_val": {"IA": []}, 88 | "overall": {"IA": [], "OR": []}, 89 | "overall_val": {"IA": [], "OR": []}, 90 | } 91 | 92 | # read all the json files 93 | for folder in folder_list: 94 | sub_results = json.load(open(pjoin(exp_dir, folder, "code/results.json"))) 95 | key = list(sub_results.keys())[0] 96 | sub_results = sub_results[key] 97 | 98 | if dataset == "snips_official" and ex2_setup: # snips has no OR 99 | _overall = sub_results["overall"] 100 | total_results["overall"]["IA"].append(_overall["test_accuracy"]) 101 | total_results["overall_val"]["IA"].append(_overall["valid_accuracy"]) 102 | 103 | elif full_fewshot or org_baseline: 104 | total_results["test"]["A"].append(sub_results["test_accuracy"]) 105 | total_results["val"]["A"].append(sub_results["valid_accuracy"]) 106 | if not oos_id: 107 | continue 108 | total_results["test"]["IA"].append(sub_results["test_inscope_accuracy"]) 109 | total_results["val"]["IA"].append(sub_results["valid_inscope_accuracy"]) 110 | total_results["test"]["OR"].append(sub_results["test_oos_recall"]) 111 | total_results["val"]["OR"].append(sub_results["valid_oos_recall"]) 112 | continue 113 | 114 | elif ex2_setup and dataset == "clinc_oos": 115 | # not handling oos_id as ex2_setup ALWAYS has an oos_id 116 | overall = sub_results["overall"] 117 | total_results["overall_val"]["IA"].append(overall["valid_inscope_accuracy"]) 118 | total_results["overall"]["IA"].append(overall["test_inscope_accuracy"]) 119 | total_results["overall_val"]["OR"].append(overall["valid_oos_recall"]) 120 | total_results["overall"]["OR"].append(overall["test_oos_recall"]) 121 | 122 | fs = sub_results["few_shot"] 123 | total_results["few_shot"]["IA"].append(fs["test_accuracy"]) 124 | total_results["few_shot_val"]["IA"].append(fs["valid_accuracy"]) 125 | 126 | for k in total_results: 127 | for m in total_results[k]: 128 | results = [100 * v for v in total_results[k][m]] 129 | # the following is averging across different fewshot domains. 130 | # For EX2, it's nth run's avg. across full_banking, full_meta, etc. 131 | # For Full fewshot, it's computing the mean for one domain, i.e, 132 | # the value is going to remain the same except it won't be a list. 133 | total_results[k][m] = np.mean(results) 134 | return total_results 135 | 136 | 137 | def segregate_sub_folders(exp_dir): 138 | sub_folder_dict = {} 139 | 140 | for folder in os.listdir(exp_dir): 141 | exp_dict_path = pjoin(exp_dir, folder, "exp_dict.json") 142 | exp_dict = json.load(open(exp_dict_path)) 143 | dname = exp_dict["dataset"]["name"] # aggregate on the dataset 144 | 145 | # dataset filter 146 | if FILTER["datasets"] and dname not in FILTER["datasets"]: 147 | continue 148 | exp_dict["gpt3_temp"] = 1.0 # TODO: remove 149 | if "gpt" in exp_dict["exp_type"]: 150 | engine, temp = exp_dict["gpt3_engine"], exp_dict["gpt3_temp"] 151 | # engine and temperature filter 152 | if FILTER["augmentation"] and engine not in FILTER["augmentation"]: 153 | continue 154 | if FILTER["temps"] and temp not in FILTER["temps"]: 155 | continue 156 | exp_type = f"{exp_dict['exp_type']}_{engine}_{temp}" 157 | else: 158 | # NOTE that for non-GPT experiments exp_type is the augmentation mode 159 | exp_type = exp_dict["exp_type"] # eda/eda_oracle 160 | _aug = exp_type.replace("_oracle", "") if "oracle" in exp_type else exp_type 161 | if FILTER["augmentation"] and _aug not in FILTER["augmentation"]: 162 | continue 163 | 164 | # oracle filter 165 | if "oracle" in exp_type and FILTER["skip_oracle?"]: 166 | continue 167 | 168 | if exp_type not in sub_folder_dict: 169 | sub_folder_dict[exp_type] = {} 170 | if dname not in sub_folder_dict[exp_type]: 171 | sub_folder_dict[exp_type][dname] = defaultdict(list) 172 | sub_folder_dict[exp_type][dname][exp_dict["run#"]].append(folder) 173 | 174 | # a sanity check line, prints the number of experiments per config. 175 | folders_per_exp = [] 176 | for exp_type in sub_folder_dict: 177 | for dname in sub_folder_dict[exp_type]: 178 | num_runs = len(sub_folder_dict[exp_type][dname]) 179 | folders_per_exp.append((exp_type, dname, num_runs)) 180 | 181 | print(folders_per_exp, len(folders_per_exp)) 182 | return sub_folder_dict 183 | 184 | 185 | def final_compile(sub_results_dicts): 186 | final_result = {} 187 | for s in sub_results_dicts: 188 | for scope in s.keys(): 189 | if scope not in final_result: 190 | final_result[scope] = {} 191 | for metric in s[scope]: 192 | if np.isnan(s[scope][metric]): 193 | continue 194 | if metric not in final_result[scope]: 195 | final_result[scope][metric] = [] 196 | final_result[scope][metric].append(s[scope][metric]) 197 | return final_result 198 | 199 | 200 | def get_performance(exp_dir): 201 | """ 202 | returns a results dictionary which is not aggregated by runs 203 | 204 | An example of hierarchy: 205 | gpt3_ada_1.0: 206 | clinc_oos: 207 | test: 208 | IA[90.32, 90.1...90.3] 209 | OR[40.23, 39.12...38.1] 210 | val: 211 | IA[92.32, 91.1...93.3] 212 | OR[45.23, 41.12...40.1] 213 | banking77: 214 | test: 215 | A[82.3...80.1] 216 | val: 217 | A[83.2...79.2] 218 | snips_official: 219 | . 220 | . 221 | . 222 | gpt3_babbage_1.0: 223 | . 224 | . 225 | . 226 | 227 | It's not aggregated so that other functions may use to for: 228 | - mean and std computation across multiple runs 229 | - significace testing 230 | """ 231 | sub_folder_dict = segregate_sub_folders(exp_dir) 232 | performance = {} 233 | for exp_type in sorted(list(sub_folder_dict.keys())): 234 | for dname in sorted(list(sub_folder_dict[exp_type].keys())): 235 | config_results = [] 236 | for config in sub_folder_dict[exp_type][dname]: 237 | folderlist = sub_folder_dict[exp_type][dname][config] 238 | config_results.append(compile_results(folderlist, exp_dir)) 239 | if exp_type not in performance: 240 | performance[exp_type] = {} 241 | performance[exp_type][dname] = final_compile(config_results) 242 | return performance 243 | 244 | 245 | def to_latex(performance): 246 | """ 247 | Generates latex table code for aug, aug+relabel settings 248 | """ 249 | table_latex = "" 250 | # backbone mode (aug/aug.+relabel) 251 | template = "{} {} (Ours) &" # line template 252 | 253 | # num of columns to report will be same for all exp. settings 254 | _etype = list(performance.keys())[0] 255 | n_cols = 0 256 | for dname in performance[_etype]: 257 | for s in SCOPES: 258 | curr_metrics = performance[_etype][dname][s] 259 | for _m in curr_metrics: 260 | if _m == "A" and "IA" in curr_metrics: 261 | continue 262 | n_cols += 1 263 | 264 | template += " {} &" * (n_cols - 1) 265 | template += " {} \\\\\n" 266 | 267 | for etype in performance: 268 | dscores = [] 269 | for dname in ALL_DATASETS: 270 | # print(dname) 271 | if dname not in performance[etype]: 272 | continue 273 | # print(dname) 274 | for s in SCOPES: 275 | # print(s) 276 | curr_metrics = list(performance[etype][dname][s]) 277 | for _m in ALL_METRICS: 278 | if _m not in curr_metrics: 279 | # a dataset without IA means no OOS. in that case, 280 | # A is the same as IA. 281 | if _m == "IA": 282 | _m = "A" 283 | else: 284 | continue 285 | dscores.append(fmt(performance[etype][dname][s][_m])) 286 | # print(_m) 287 | # print("===") 288 | 289 | table_latex += template.format( 290 | BACKBONE, etype.replace("_1.0", "").replace("_", "\_"), *dscores 291 | ) 292 | print(table_latex) 293 | 294 | 295 | def perform_ttest(performance): 296 | """ 297 | receives a performance for datasets and performs two statistical 298 | t-tests w.r.t davinci model at 1.0 temp. for that experiment type 299 | """ 300 | if performance == {}: 301 | print("Nothing to show here") 302 | return 303 | 304 | bigger_model = None 305 | for e in performance: 306 | if "davinci" in e: 307 | bigger_model = e 308 | bigger_results = performance[bigger_model] 309 | 310 | # Gather model-wise metrics 311 | for dname in bigger_results.keys(): 312 | print(f"Dataset: {dname.upper()}") 313 | print("-" * 30) 314 | for s in SCOPES: 315 | for m in bigger_results[dname][s]: 316 | _bresults = bigger_results[dname][s][m] 317 | print(f"--- {s} {m} test ---") 318 | print(f"{bigger_model.upper()}: ({fmt(_bresults)})") 319 | for model, results in performance.items(): 320 | if model == bigger_model: 321 | continue 322 | _sresults = results[dname][s][m] 323 | test_result = stats.ttest_ind(_bresults, _sresults) 324 | print(f" vs {model.upper()} ({fmt(_sresults)}) {test_result}") 325 | print() 326 | 327 | 328 | def display_results(performance): 329 | for etype in performance: 330 | for dname in performance[etype]: 331 | for scope in performance[etype][dname]: 332 | for metric in performance[etype][dname][scope]: 333 | results = performance[etype][dname][scope][metric] 334 | performance[etype][dname][scope][metric] = fmt(results) 335 | 336 | for etype in performance: 337 | for dname in performance[etype]: 338 | print(f"Setting: {etype} | {dname}") 339 | print("-" * 20) 340 | print(pd.DataFrame().from_dict(performance[etype][dname])) 341 | print("=" * 30) 342 | print("\n") 343 | 344 | 345 | def gen_for_plot(performance): 346 | """ 347 | Will save fidelity and fs accuries for all the datasets in a file 348 | NOTE: this doesn't save metrics for partial fewshot temp. profiling nor 349 | does it support any engine other than GPTJ right now. 350 | """ 351 | if FILTER["augmentation"] != ["gptj"]: 352 | raise NotImplementedError( 353 | "Metrics generation for plotting only supported for GPTJ!" 354 | ) 355 | if "fs" not in EXPERIMENTS_DIR: 356 | raise NotImplementedError( 357 | "Metrics generation for plotting only supported for Full Fewshot!" 358 | ) 359 | 360 | if os.path.exists(FNAME): 361 | print(f"{FNAME} already exists!! Loading...") 362 | print("Delete/Rename it to recompute fidelities.") 363 | return pickle.load(open(FNAME, "rb")) 364 | print("Compiling plotting metrics in a file...") 365 | 366 | # init df 367 | df = pd.DataFrame(columns=["temp", "ds", "val_acc_mean", "val_acc_std", "fidelity"]) 368 | 369 | # compute fidelities 370 | fidelities = {ds: Metrics().compute_fidelities(ds) for ds in ALL_DATASETS} 371 | for etype, results in performance.items(): 372 | # etype: gpt3_gptj_1.0 373 | _, temp = etype.rsplit("_", 1) 374 | for ds in results: 375 | acc_key = "IA" if ds == "clinc_oos" else "A" 376 | val_accs = results[ds]["val"][acc_key] 377 | val_acc_mean, val_acc_std = np.mean(val_accs), np.std(val_accs) 378 | _fid = fidelities[ds][f"gptj_{temp}"] 379 | # create a new entry in the dataframe 380 | df.loc[len(df.index)] = [float(temp), ds, val_acc_mean, val_acc_std, _fid] 381 | 382 | # will be used to plot the threshold lines in the fidelity plots 383 | thresholds = {ds: fidelities[ds]["threshold"] for ds in ALL_DATASETS} 384 | metrics = {"metrics": df, "thresholds": thresholds} 385 | print(f"Saving fidelity metrics for plotting {FNAME}") 386 | with open(FNAME, "wb") as f: 387 | pickle.dump(metrics, f) 388 | return metrics 389 | 390 | 391 | def main(): 392 | """ 393 | Computes mean and std of metrics obtained by get_performance 394 | """ 395 | # remove the "deleted" folder if it exists 396 | if os.path.exists(pjoin(EXPERIMENTS_DIR, "deleted")): 397 | print(f"Removing {pjoin(EXPERIMENTS_DIR, 'deleted')}...") 398 | os.system(f"rm -rf {pjoin(EXPERIMENTS_DIR, 'deleted')}") 399 | print("Removed.") 400 | 401 | performance = get_performance(EXPERIMENTS_DIR) 402 | 403 | # display results (deepcopy needed as display_results permutes its input) 404 | display_results(copy.deepcopy(performance)) 405 | 406 | if TTEST: 407 | # non-oracle etype 408 | print("T-Test no oracle") 409 | perform_ttest({k: v for k, v in performance.items() if "oracle" not in k}) 410 | 411 | # oracle etype 412 | print("T-Test with oracle") 413 | perform_ttest({k: v for k, v in performance.items() if "oracle" in k}) 414 | 415 | if TO_LATEX: 416 | to_latex(performance) 417 | 418 | if GEN_FOR_PLOT: 419 | gen_for_plot(performance) 420 | 421 | 422 | if __name__ == "__main__": 423 | main() 424 | -------------------------------------------------------------------------------- /runners/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["WANDB_DISABLED"] = "true" 4 | 5 | from collections import defaultdict 6 | import argparse 7 | from haven import haven_wizard as hw 8 | import gc 9 | import json 10 | import GPUtil 11 | import torch 12 | from transformers import ( 13 | AutoTokenizer, 14 | AutoModelForSequenceClassification, 15 | TrainingArguments, 16 | Trainer, 17 | ) 18 | 19 | from utils.metrics import Metrics 20 | from utils.data_utils.data_loader import DatasetLoader 21 | 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | NUM_LABELS = 7 26 | DATASET_NAME = "snips_official" 27 | MODEL_DIR = ( 28 | f"/mnt/colab_public/results/few_shot_nlp/model/{DATASET_NAME}/oracle_checkpoint/" 29 | ) 30 | 31 | RESULTS_PATH = f"results/gpt_fidelity_{DATASET_NAME}.json" 32 | 33 | 34 | def write_results(results): 35 | with open(RESULTS_PATH, "w") as f: 36 | json.dump(results, f) 37 | 38 | 39 | def trainval(exp_dict, savedir, args): 40 | """Main.""" 41 | # ========================== 42 | # load datasets 43 | # ========================== 44 | dataset_loader = DatasetLoader(args.datadir, exp_dict) 45 | 46 | # ========================== 47 | # create model and trainer 48 | # ========================== 49 | backbone = AutoModelForSequenceClassification.from_pretrained( 50 | MODEL_DIR, num_labels=NUM_LABELS 51 | ) 52 | 53 | tokenizer = AutoTokenizer.from_pretrained( 54 | exp_dict["model"]["backbone"], use_fast=True 55 | ) 56 | 57 | def preprocess(example): 58 | # note that setting max_length and truncation will not 59 | # have any effect for the vanilla baseline experiments 60 | results = tokenizer(example["text"], max_length=50, truncation=True) 61 | results["label"] = example["intent"] 62 | return results 63 | 64 | encoded_dataset = dataset_loader.dataset.map(preprocess, batched=True) 65 | 66 | if args.job_scheduler == "1": 67 | from configs import job_configs 68 | 69 | n_gpu = job_configs.JOB_CONFIG["resources"]["gpu"] 70 | else: 71 | n_gpu = 1 72 | 73 | args = TrainingArguments( 74 | savedir, 75 | evaluation_strategy="epoch", 76 | save_strategy="epoch", 77 | learning_rate=exp_dict["lr"], 78 | per_device_train_batch_size=exp_dict["batch_size"] // n_gpu, 79 | per_device_eval_batch_size=exp_dict["batch_size"] // n_gpu, 80 | num_train_epochs=exp_dict["epochs"], 81 | warmup_ratio=exp_dict["warmup_ratio"], 82 | weight_decay=exp_dict["weight_decay"], 83 | load_best_model_at_end=True, 84 | metric_for_best_model=exp_dict["metric_best"], 85 | # push_to_hub=True, 86 | # push_to_hub_model_id=f"{model_name}-finetuned-{task}", 87 | ) 88 | 89 | if "full_validation" in encoded_dataset: 90 | # for ex2 setup experiments 91 | eval_split = "full_validation" 92 | else: 93 | # for clinc setup experiments 94 | eval_split = "validation" 95 | 96 | trainer = Trainer( 97 | backbone, 98 | args, 99 | train_dataset=encoded_dataset["train"], 100 | eval_dataset=encoded_dataset[eval_split], 101 | tokenizer=tokenizer, 102 | compute_metrics=Metrics(exp_dict).compute_metrics(), 103 | ) 104 | # trainer.train() 105 | print(GPUtil.showUtilization()) 106 | 107 | print(f"n_gpus: {trainer.args.n_gpu}, local rank: {trainer.args.local_rank}") 108 | print("Emptying cache...") # crucial! 109 | gc.collect() 110 | torch.cuda.empty_cache() 111 | 112 | # compute metrics 113 | trainer.args.eval_accumulation_steps = exp_dict["eval_accumulation_steps"] 114 | if "full_test" in encoded_dataset: 115 | metrics = {"overall": {}, "few_shot": {}} 116 | else: 117 | metrics = defaultdict(dict) 118 | 119 | for split in encoded_dataset: 120 | if "train" in split: 121 | continue 122 | 123 | # split test, full_test, validation, full_validation 124 | prefix = "test" if "test" in split else "valid" 125 | if "overall" in metrics: 126 | _type = "overall" if "full" in split else "few_shot" 127 | metrics[_type].update( 128 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix) 129 | ) 130 | elif exp_dict["exp_type"] == "intrinsic": 131 | metrics[split].update( 132 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix) 133 | ) 134 | else: 135 | metrics.update( 136 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix) 137 | ) 138 | # print results 139 | print(metrics) 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument( 145 | "-e", 146 | "--exp_group_list", 147 | nargs="+", 148 | default="resnet", 149 | help="name of an experiment in exp_configs.py", 150 | ) 151 | parser.add_argument( 152 | "-sb", 153 | "--savedir_base", 154 | default="/mnt/home/haven_output", 155 | help="folder where logs will be saved", 156 | ) 157 | parser.add_argument("-nw", "--num_workers", type=int, default=4) 158 | parser.add_argument("-d", "--datadir", type=str, default="./data") 159 | parser.add_argument( 160 | "-r", "--reset", default=0, type=int, help="Overwrite previous results" 161 | ) 162 | parser.add_argument("-ei", "--exp_id", default=None) 163 | parser.add_argument( 164 | "-j", 165 | "--job_scheduler", 166 | type=str, 167 | default=None, 168 | help="If 1, runs in toolkit in parallel", 169 | ) 170 | parser.add_argument("-v", default="results.ipynb", help="orkestrator") 171 | parser.add_argument( 172 | "--python_binary", default="python", help="path to your python executable" 173 | ) 174 | 175 | args, unknown = parser.parse_known_args() 176 | from configs import exp_configs 177 | 178 | if args.job_scheduler == "1": 179 | from configs import job_configs 180 | 181 | job_config = job_configs.JOB_CONFIG 182 | else: 183 | job_config = None 184 | 185 | file_name = os.path.basename(__file__)[:-3] # remove .py 186 | hw.run_wizard( 187 | func=trainval, 188 | exp_groups=exp_configs.EXP_GROUPS, 189 | job_config=job_config, 190 | python_binary_path=args.python_binary, 191 | python_file_path=f"-m runners.{file_name}", 192 | use_threads=True, 193 | args=args, 194 | ) 195 | -------------------------------------------------------------------------------- /runners/modules/bert.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import gc, os, GPUtil, torch 3 | from transformers import AutoTokenizer, TrainingArguments, Trainer 4 | 5 | from utils.metrics import Metrics 6 | from utils import main_data_utils as mdu 7 | from utils.data_utils.data_loader import DatasetLoader 8 | from models.backbones import get_backbone 9 | 10 | 11 | torch.backends.cudnn.benchmark = True 12 | 13 | 14 | def trainval(exp_dict, savedir, cl_args): 15 | """Main.""" 16 | # ========================== 17 | # load datasets 18 | # ========================== 19 | dataset_loader = DatasetLoader(cl_args.datadir, exp_dict) 20 | 21 | # ========================== 22 | # create model and trainer 23 | # ========================== 24 | backbone = get_backbone(exp_dict) 25 | 26 | tokenizer = AutoTokenizer.from_pretrained( 27 | exp_dict["model"]["backbone"], use_fast=True 28 | ) 29 | 30 | def preprocess(example): 31 | # note that setting max_length and truncation will not 32 | # have any effect for the vanilla baseline experiments 33 | results = tokenizer(example["text"], max_length=50, truncation=True) 34 | results["label"] = example["intent"] 35 | return results 36 | 37 | encoded_dataset = dataset_loader.dataset.map(preprocess, batched=True) 38 | 39 | if cl_args.job_scheduler == "1": 40 | from configs import job_configs 41 | 42 | n_gpu = job_configs.JOB_CONFIG["resources"]["gpu"] 43 | else: 44 | n_gpu = 1 45 | 46 | args = TrainingArguments( 47 | savedir, 48 | evaluation_strategy="epoch", 49 | save_strategy="epoch", 50 | learning_rate=exp_dict["lr"], 51 | per_device_train_batch_size=exp_dict["batch_size"] // n_gpu, 52 | per_device_eval_batch_size=exp_dict["batch_size"] // n_gpu, 53 | num_train_epochs=exp_dict["epochs"], 54 | warmup_ratio=exp_dict["warmup_ratio"], 55 | weight_decay=exp_dict["weight_decay"], 56 | load_best_model_at_end=True, 57 | metric_for_best_model=exp_dict["metric_best"], 58 | save_total_limit=1, 59 | # push_to_hub=True, 60 | # push_to_hub_model_id=f"{model_name}-finetuned-{task}", 61 | ) 62 | 63 | if "full_validation" in encoded_dataset: 64 | # for ex2 setup experiments 65 | eval_split = "full_validation" 66 | else: 67 | # for clinc setup experiments 68 | eval_split = "validation" 69 | 70 | trainer = Trainer( 71 | backbone, 72 | args, 73 | train_dataset=encoded_dataset["train"], 74 | eval_dataset=encoded_dataset[eval_split], 75 | tokenizer=tokenizer, 76 | compute_metrics=Metrics(exp_dict).compute_metrics(), 77 | ) 78 | trainer.train() 79 | print(GPUtil.showUtilization()) 80 | 81 | print(f"n_gpus: {trainer.args.n_gpu}, local rank: {trainer.args.local_rank}") 82 | print("Emptying cache...") # crucial! 83 | gc.collect() 84 | torch.cuda.empty_cache() 85 | 86 | # compute metrics 87 | trainer.args.eval_accumulation_steps = exp_dict["eval_accumulation_steps"] 88 | if "full_test" in encoded_dataset: 89 | metrics = {"overall": {}, "few_shot": {}} 90 | else: 91 | metrics = defaultdict(dict) 92 | 93 | for split in encoded_dataset: 94 | if "train" in split: 95 | continue 96 | 97 | # split test, full_test, validation, full_validation 98 | prefix = "test" if "test" in split else "valid" 99 | if "overall" in metrics: 100 | _type = "overall" if "full" in split else "few_shot" 101 | metrics[_type].update( 102 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix) 103 | ) 104 | elif exp_dict["exp_type"] == "intrinsic": 105 | metrics[split].update( 106 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix) 107 | ) 108 | else: 109 | metrics.update( 110 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix) 111 | ) 112 | # write results 113 | results_path = os.path.join(savedir, "code/results.json") 114 | # this will happen if not using scheduler 115 | if not os.path.exists(results_path): 116 | os.makedirs(os.path.dirname(results_path)) 117 | mdu.write_json({exp_dict["dataset"]["config"]: metrics}, results_path) 118 | 119 | if not cl_args.retain_checkpoints: 120 | # delete HF checkpoints to save space on toolkit 121 | print(f"DELETING! {os.path.join(savedir, 'checkpoint-*')}") 122 | os.system(f"rm -rf {os.path.join(savedir, 'checkpoint-*')}") 123 | -------------------------------------------------------------------------------- /runners/oracle_relabel.py: -------------------------------------------------------------------------------- 1 | from configs import exp_configs 2 | import os, gc, argparse, pickle, torch, numpy as np 3 | from utils import main_data_utils 4 | 5 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 6 | from datasets import DatasetDict, Dataset 7 | 8 | from haven import haven_wizard as hw 9 | 10 | torch.backends.cudnn.benchmark = True 11 | 12 | pjoin = os.path.join 13 | write_pickle = lambda obj, path: pickle.dump(obj, open(path, "wb")) 14 | read_pickle = lambda path: pickle.load(open(path, "rb")) 15 | 16 | 17 | def oracle_correction(exp_dict, savedir, args): 18 | """Main.""" 19 | # ========================== 20 | # load full dataset (containing all generated examples) 21 | # ========================== 22 | ds_name = exp_dict["dataset"]["name"] 23 | ds_config = main_data_utils.get_ds_config(ds_name) 24 | DOMAINS = list(ds_config.domain_to_intent.keys()) 25 | if "oos" in DOMAINS: 26 | print(f"ignoring OOS domain from {ds_name.upper()}") 27 | DOMAINS.pop(DOMAINS.index("oos")) # remove oos 28 | 29 | print(f"{ds_name} domains:") 30 | print(DOMAINS) 31 | engines = ["ada", "babbage", "curie", "davinci", "gptj", "eda", "val", "test"] 32 | temp = [1.0] 33 | 34 | base_data_path = pjoin(args.datadir, ds_name, "full", "dataset.pkl") 35 | exp_data_path = pjoin(args.datadir, ds_name, "full", "data_full_suite.pkl") 36 | 37 | tokenizer = AutoTokenizer.from_pretrained( 38 | exp_dict["model"]["backbone"], use_fast=True 39 | ) 40 | 41 | generated_dataset = DatasetDict(read_pickle(exp_data_path)) 42 | 43 | # remove the domains and make a HF Dataset 44 | for e in engines: 45 | # no need to remove domains, will fetch from the base dataset 46 | if e in ["val", "test"]: 47 | continue 48 | elif e == "gptj": 49 | temp = [round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))] 50 | else: 51 | temp = [1.0] 52 | for t in temp: 53 | _lines = [] 54 | _intents = [] 55 | for d in DOMAINS: 56 | attr = e if e in ["eda", "bt"] else f"{e}_{t}" 57 | _lines.extend(generated_dataset[d]["F"][attr]["text"]) 58 | _intents.extend(generated_dataset[d]["F"][attr]["intent"]) 59 | generated_dataset[attr] = Dataset.from_dict( 60 | {"text": _lines, "intent": _intents} 61 | ) 62 | 63 | base_dataset = read_pickle(base_data_path) 64 | # Add validation samples (for computing thresholds in fidelity plots) 65 | generated_dataset["val"] = Dataset.from_dict(base_dataset["val"]) 66 | generated_dataset["test"] = Dataset.from_dict(base_dataset["test"]) 67 | 68 | # ========================== 69 | # create model 70 | # ========================== 71 | print(f"Oracle path: {args.modeldir}") 72 | oracle = AutoModelForSequenceClassification.from_pretrained( 73 | args.modeldir, num_labels=exp_dict["dataset"]["num_labels"] 74 | ) 75 | oracle.cuda() 76 | oracle.eval() 77 | 78 | # Init AL dataset 79 | al_path = pjoin(args.datadir, ds_name, "full", "al_dataset.pkl") 80 | if os.path.exists(al_path): 81 | print(f"Loading existing data from {al_path}") 82 | al_ds = read_pickle(al_path)["generated"] 83 | else: 84 | print(f"Initializing al_dataset.pkl for {ds_name.upper()}") 85 | al_ds = {} 86 | 87 | with torch.no_grad(): 88 | for e in engines: 89 | if e == "gptj": 90 | temp = [ 91 | round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1)) 92 | ] 93 | else: 94 | temp = [1.0] 95 | for t in temp: 96 | attr = e if e in ["eda", "bt", "val", "test"] else f"{e}_{t}" 97 | if attr in al_ds: 98 | print(f"{attr} already exists in {ds_name}'s AL dataset") 99 | continue 100 | print(f"relabeling for {attr}") 101 | al_ds[attr] = {} 102 | al_ds[attr]["text"] = [] 103 | al_ds[attr]["label"] = [] 104 | 105 | encodings = tokenizer( 106 | generated_dataset[attr]["text"], 107 | max_length=50, 108 | padding=True, 109 | truncation=True, 110 | return_tensors="pt", 111 | ) 112 | total_num = len(encodings["input_ids"]) 113 | batch_size = 100 114 | 115 | print("total_num", total_num, "and batch_size", batch_size) 116 | lbls = [] 117 | for i in range(0, total_num, batch_size): 118 | outputs = oracle( 119 | encodings["input_ids"][i : i + 100].cuda(), 120 | attention_mask=encodings["attention_mask"][i : i + 100].cuda(), 121 | ) 122 | probs = torch.softmax(outputs.logits, dim=1) 123 | lbls.extend(torch.argmax(probs, dim=1).cpu().tolist()) 124 | 125 | al_ds[attr]["text"] = generated_dataset[attr]["text"] 126 | al_ds[attr]["intent"] = lbls 127 | al_ds[attr]["old_intent"] = generated_dataset[attr]["intent"] 128 | 129 | al_dataset = dict( 130 | train=base_dataset["train"], 131 | val=base_dataset["val"], 132 | test=base_dataset["test"], 133 | generated=al_ds, 134 | ) 135 | 136 | with open(al_path, "wb") as f: 137 | pickle.dump(al_dataset, f) 138 | 139 | print("Emptying cache...") # crucial! 140 | gc.collect() 141 | torch.cuda.empty_cache() 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument( 147 | "-e", 148 | "--exp_group_list", 149 | nargs="+", 150 | default="resnet", 151 | help="name of an experiment in exp_configs.py", 152 | ) 153 | parser.add_argument( 154 | "-sb", 155 | "--savedir_base", 156 | default="/mnt/home/haven_output", 157 | help="folder where logs will be saved", 158 | ) 159 | parser.add_argument("-nw", "--num_workers", type=int, default=4) 160 | parser.add_argument("-d", "--datadir", type=str, default="./data") 161 | parser.add_argument("-md", "--modeldir", type=str, default=None) 162 | parser.add_argument("-ei", "--exp_id", default=None) 163 | parser.add_argument( 164 | "-j", 165 | "--job_scheduler", 166 | type=str, 167 | default=None, 168 | help="If 1, runs in toolkit in parallel", 169 | ) 170 | parser.add_argument( 171 | "--python_binary", default="python", help="path to your python executable" 172 | ) 173 | 174 | args, unknown = parser.parse_known_args() 175 | if args.job_scheduler == "1": 176 | from configs import job_configs 177 | 178 | job_config = job_configs.JOB_CONFIG 179 | else: 180 | job_config = None 181 | file_name = os.path.basename(__file__)[:-3] # remove .py 182 | hw.run_wizard( 183 | func=oracle_correction, 184 | exp_groups=exp_configs.EXP_GROUPS, 185 | job_config=job_config, 186 | python_binary_path=args.python_binary, 187 | python_file_path=f"-m runners.{file_name}", 188 | use_threads=True, 189 | args=args, 190 | ) 191 | -------------------------------------------------------------------------------- /runners/train.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.sharedctypes import Value 2 | from configs import exp_configs 3 | import os, argparse, torch, wandb 4 | from haven import haven_wizard as hw 5 | 6 | from runners.modules import bert 7 | 8 | torch.backends.cudnn.benchmark = True 9 | 10 | 11 | def init_wandb(exp_dict): 12 | exp_name = f"{exp_dict['exp_type']}_oosID={exp_dict['dataset']['oos_id']}_" 13 | exp_name += f'{exp_dict["dataset"]["name"]}_{exp_dict["dataset"]["config"]}_' 14 | exp_name += f'{exp_dict["lr"]}_{exp_dict["epochs"]}_{exp_dict["batch_size"]}' 15 | wandb.init(project="few-shot-nlp", name=exp_name) 16 | 17 | 18 | def trainval(exp_dict, savedir, args): 19 | """Main.""" 20 | # ========================== 21 | # init wandb 22 | # ========================== 23 | if not args.disable_wandb: 24 | init_wandb(exp_dict) 25 | else: 26 | os.environ["WANDB_DISABLED"] = "true" 27 | 28 | # ========================== 29 | # Load appropriate trainer 30 | # ========================== 31 | if "bert" in exp_dict["model"]["backbone"]: 32 | return bert.trainval(exp_dict, savedir, args) 33 | return ValueError("backend not recognized") 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument( 39 | "-e", 40 | "--exp_group_list", 41 | nargs="+", 42 | default="resnet", 43 | help="name of an experiment in exp_configs.py", 44 | ) 45 | parser.add_argument( 46 | "-sb", 47 | "--savedir_base", 48 | default="/mnt/home/haven_output", 49 | help="folder where logs will be saved", 50 | ) 51 | parser.add_argument("-nw", "--num_workers", type=int, default=4) 52 | parser.add_argument("-d", "--datadir", type=str, default="./data") 53 | parser.add_argument("-md", "--modeldir", type=str, default=None) 54 | parser.add_argument( 55 | "-r", "--reset", default=0, type=int, help="Overwrite previous results" 56 | ) 57 | parser.add_argument("-ei", "--exp_id", default=None) 58 | parser.add_argument( 59 | "-j", 60 | "--job_scheduler", 61 | type=str, 62 | default=None, 63 | help="If 1, runs in toolkit in parallel", 64 | ) 65 | parser.add_argument("-v", default="results.ipynb", help="orkestrator") 66 | parser.add_argument( 67 | "--python_binary", default="python", help="path to your python executable" 68 | ) 69 | parser.add_argument("--disable_wandb", default=1, type=int) 70 | parser.add_argument("--retain_checkpoints", action="store_true") 71 | 72 | args, unknown = parser.parse_known_args() 73 | if args.job_scheduler == "1": 74 | from configs import job_configs 75 | 76 | job_config = job_configs.JOB_CONFIG 77 | else: 78 | job_config = None 79 | 80 | file_name = os.path.basename(__file__)[:-3] # remove .py 81 | hw.run_wizard( 82 | func=trainval, 83 | exp_groups=exp_configs.EXP_GROUPS, 84 | job_config=job_config, 85 | python_binary_path=args.python_binary, 86 | python_file_path=f"-m runners.{file_name}", 87 | use_threads=True, 88 | args=args, 89 | results_fname="notebooks/fsn.ipynb", 90 | ) 91 | -------------------------------------------------------------------------------- /scripts/3way_human_eval.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from utils import main_data_utils as mdu 3 | import argparse, numpy as np 4 | from openpyxl import Workbook 5 | from openpyxl.styles import Font 6 | 7 | 8 | def create_content(ds, id2name, args, run_code): 9 | samples = [] 10 | for (intent_id, sentence) in zip(ds[run_code]["intent"], ds[run_code]["text"]): 11 | if id2name[str(intent_id)] not in args.intent_triplet: 12 | continue 13 | samples.append((sentence, id2name[str(intent_id)])) 14 | 15 | np.random.shuffle(samples) 16 | # start creating an excel sheet 17 | wb = Workbook() 18 | wb.remove(wb.active) 19 | 20 | # create a fresh worksheet with meaningful name 21 | ws = wb.create_sheet("Human eval.") 22 | ws.append(["Sentence", "Your Prediction"]) 23 | labels = "" 24 | max_sent_length = 0 25 | for sent, intent in samples: 26 | if len(sent) > max_sent_length: 27 | max_sent_length = len(sent) 28 | ws.append([sent]) 29 | labels += intent + "\n" 30 | 31 | # increase font size of sentences 32 | for i in range(1, len(samples) + 2): 33 | ws[f"A{i}"].font = Font(size=14) 34 | 35 | # max width of column A 36 | ws.column_dimensions["A"].width = max_sent_length 37 | ws["A1"].font = Font(bold=True, size=14) 38 | ws["B1"].font = Font(bold=True, size=14) 39 | wb.active = 0 40 | wb.save(f"spreadsheets/{args.dataset_name}/{run_code}_human_eval.xlsx") 41 | mdu.write_file(labels, f"spreadsheets/{args.dataset_name}/{run_code}_labels.txt") 42 | 43 | 44 | def evaluate(args, run_code): 45 | labels_path = f"spreadsheets/{args.dataset_name}/{run_code}_labels.txt" 46 | labels = mdu.read_file(labels_path).splitlines() 47 | if args.dataset_name == "hwu64": 48 | predcode2name = { 49 | "1": "music_likeness", 50 | "2": "music_settings", 51 | "3": "play_music", 52 | } 53 | elif args.dataset_name == "banking77": 54 | predcode2name = { 55 | "1": "topping_up_by_card", 56 | "2": "top_up_failed", 57 | "3": "pending_top_up", 58 | } 59 | preds_path = f"spreadsheets/{args.dataset_name}/sahu_{run_code}_preds.txt" 60 | preds = [predcode2name[p] for p in mdu.read_file(preds_path).splitlines()] 61 | assert len(preds) == len(labels) 62 | class_wise_preds = defaultdict(list) 63 | for _pred, _label in zip(preds, labels): 64 | class_wise_preds[_label].append(_pred) 65 | print( 66 | "Overall Val Acc of sahu =", 67 | f"{np.mean([1 if p == l else 0 for p, l in zip(preds, labels)])*100:.2f}", 68 | ) 69 | 70 | for intent, preds in class_wise_preds.items(): 71 | print(f"sahu Acc on {intent}: {preds.count(intent)/len(preds)*100:.2f}") 72 | 73 | 74 | def main(): 75 | # read base data 76 | args = parse_args() 77 | if args.eval: 78 | evaluate(args, "val") 79 | ds = mdu.read_pickle(f"data/{args.dataset_name}/full/dataset.pkl") 80 | id2name = mdu.read_json(f"data/{args.dataset_name}/id2name.json") 81 | 82 | create_content(ds, id2name, args, "val") # 3-way val set 83 | create_content(ds, id2name, args, "test") # 3-way test set 84 | 85 | 86 | def parse_args(): 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("-d", "--dataset-name", default="hwu64") 89 | # parser.add_argument("-d", "--dataset-name", default="banking77") 90 | parser.add_argument( 91 | "-it", 92 | "--intent-triplet", 93 | nargs="+", 94 | default=["music_likeness", "play_music", "music_settings"], 95 | # default=["topping_up_by_card", "top_up_failed", "pending_top_up"], 96 | ) 97 | parser.add_argument("--eval", action="store_true") 98 | return parser.parse_args() 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/data-augmentation-with-llms/77c740df2aaf4bfc1b2449114613f98b37a1faae/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/compute_ttr.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import nltk 3 | 4 | 5 | def ttr(sentences, n): 6 | """ 7 | Computes Type-Token Ratio for a given 8 | set of sentences 9 | Params: 10 | ====== 11 | sentences (list): list of sentences in the corpus 12 | n (list): list of n values to consider when computing ngrams 13 | """ 14 | ngrams = [[] for _ in n] 15 | for s in sentences: 16 | tokens = nltk.tokenize.word_tokenize(s.lower()) 17 | for i in range(len(n)): 18 | ngrams[i].extend(list(nltk.ngrams(tokens, n[i]))) 19 | return [round(len(set(v)) / len(v), 2) for v in ngrams] 20 | 21 | 22 | def main(): 23 | for dataset in ["clinc_oos", "snips_official", "hwu64", "banking77"]: 24 | dpath = f"./data/{dataset}/full/data_full_suite.pkl" 25 | data = pickle.load(open(dpath, "rb")) 26 | domains = list(data.keys()) 27 | print(f"Diversity metrics for {dataset}") 28 | for e in ["eda", "ada", "babbage", "curie", "davinci"]: 29 | for t in [1.0]: 30 | sentences = [] 31 | for d in domains: 32 | # skip OOS for now 33 | if d == "oos": 34 | continue 35 | attr = e if e == "eda" else f"{e}_{t}" 36 | sentences.extend(data[d]["F"][attr]["text"]) 37 | print(f"{e}_{t}: {ttr(sentences, n=[1, 2, 3, 4])}") 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /scripts/gpt3_analysis.py: -------------------------------------------------------------------------------- 1 | import argparse, numpy as np 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | from utils import data_utils as du 5 | from utils import main_data_utils as mdu 6 | 7 | 8 | def compute_acc(preds, labels): 9 | return f"{np.mean([1 if p == l else 0 for p, l in zip(preds, labels)])*100:.2f}" 10 | 11 | 12 | def get_seed_tuples(ds, seed_intent, ds_config, name2id): 13 | """ 14 | returns seed sentences for a given intent_id in (full suite) dataset object 15 | ds as a list of tuples with (intent_id, sentence). 16 | seed_intent is a string 17 | """ 18 | # identify the domain for the seed intent 19 | intent_id = int(name2id[seed_intent]) 20 | seed_domain = None 21 | for _domain, _intents in ds_config.domain_to_intent.items(): 22 | if seed_intent in _intents: 23 | seed_domain = _domain 24 | break 25 | 26 | # fetch the seed sentences for seed intent 27 | _start = ds[seed_domain]["F"]["train"]["intent"].index(intent_id) 28 | num_examples = ds[seed_domain]["F"]["train"]["intent"].count(intent_id) 29 | sents = ds[seed_domain]["F"]["train"]["text"][_start : _start + num_examples] 30 | return [(seed_intent, s) for s in sents] 31 | 32 | 33 | def filter_via_gpt(sentences, seed_sentences, args): 34 | """ 35 | `sentences` is a list of tuples (old intent, new intent, sentence) 36 | """ 37 | # construct a prompt 38 | ltemplate = "sentence: {} ; category:{}" # line template 39 | prompt = f"Each example in the following list contains a sentence that belongs to a category. A category is one of the following: {', '.join(args.intent_triplet)}" 40 | prompt += "\n\n" 41 | prompt += "\n".join([ltemplate.format(s, " " + i) for (i, s) in seed_sentences]) 42 | prompt += "\n" 43 | 44 | print(f"FILTERING generations using {args.gpt_engine.upper()}...") 45 | retained_sentences = [] 46 | # NOTE: ignoring the new intent (mid element of the tuple) 47 | for (old_intent, new_intent, sent) in tqdm(sentences): 48 | # query gpt 49 | input_prompt = prompt + ltemplate.format(sent, "") 50 | responses = du.augment_slices.openai_complete( 51 | prompt=input_prompt, 52 | n=10, 53 | engine=args.gpt_engine, 54 | temp=1.0, 55 | top_p=1.0, 56 | ) 57 | responses = [r.text.strip() for r in responses] 58 | # print(input_prompt) 59 | # print(responses) 60 | # breakpoint() 61 | # NOTE: not handling ties. if there is a tie, it just means that 62 | # the sentence is not a good enough one, and we shouldn't include it. 63 | response = max(args.intent_triplet, key=responses.count) 64 | if response == old_intent: 65 | retained_sentences.append((old_intent, new_intent, sent)) 66 | 67 | # num retained, num input 68 | n_r, n_i = len(retained_sentences), len(sentences) 69 | # percentage retained 70 | ret_per = f"{(n_r/n_i)*100:.2f}%" 71 | print(f"{args.gpt_engine.upper()} retained {n_r}/{n_i} sentences ({ret_per}).") 72 | return retained_sentences 73 | 74 | 75 | def run_gpt_eval(eval_sentences, seed_sentences, args): 76 | """ 77 | `eval_sentences` is a list of tuples (gt label in the dataset, text) 78 | `seed_sentences` is a list of tupels (seed intent, text) 79 | """ 80 | # construct a prompt 81 | ltemplate = "sentence: {} ; category:{}" 82 | prompt = f"Each example in the following list contains a sentence that belongs to a category. A category is one of the following: {', '.join(args.intent_triplet)}" 83 | prompt += "\n\n" 84 | prompt += "\n".join([ltemplate.format(s, " " + i) for (i, s) in seed_sentences]) 85 | prompt += "\n" 86 | 87 | preds, labels = [], [] 88 | class_wise_pred_labels = defaultdict(list) 89 | print("Running evaluation...") 90 | for (intent, sent) in tqdm(eval_sentences): 91 | # query gpt 92 | input_prompt = prompt + ltemplate.format(sent, "") 93 | responses = du.augment_slices.openai_complete( 94 | prompt=input_prompt, n=10, engine=args.gpt_engine, temp=1.0, top_p=1.0 95 | ) 96 | responses = [r.text.strip() for r in responses] 97 | # NOTE: not handling ties here. 98 | response = max(args.intent_triplet, key=responses.count) 99 | # for overall preds and labels 100 | preds.append(response) 101 | labels.append(intent) 102 | 103 | # for class-wise performance 104 | class_wise_pred_labels[intent].append(response) 105 | 106 | # Evaluation 107 | print(f"{args.gpt_engine.upper()} performance on {len(eval_sentences)} examples:") 108 | print(f"Overall accuracy = {compute_acc(preds, labels)}") 109 | print(f"Class-wise accuracies:") 110 | for intent, preds in class_wise_pred_labels.items(): 111 | _acc = preds.count(intent) / len(preds) 112 | print(f"Acc. for intent {intent}: {_acc*100:.2f}") 113 | 114 | 115 | def compute_fidelity(sentence_tuples): 116 | """ 117 | `sentence_tuples` is a list of tuples (old intent, new intent, sentence) 118 | and old intent and new intent are intent names (not ids) 119 | """ 120 | overall_3way_fidelity = [] 121 | class_wise_fidelity = defaultdict(list) 122 | for (old_intent, new_intent, _) in sentence_tuples: 123 | class_wise_fidelity[old_intent].append(new_intent) 124 | overall_3way_fidelity.append(1 if new_intent == old_intent else 0) 125 | 126 | print(f"Overall 3-way fidelity = {np.mean(overall_3way_fidelity)*100:.2f}") 127 | for intent, preds in class_wise_fidelity.items(): 128 | _fid = preds.count(intent) / len(preds) 129 | print(f"Fidelity for {intent}: {_fid*100:.2f}") 130 | 131 | 132 | def oracle_eval(eval_sentences, id2name, name2id, args, size): 133 | """ 134 | This will be used to evaluate: 135 | - small oracle (a.k.a 10-shot baseline) 136 | - bigger oracle (a.k.a the oracle) 137 | """ 138 | 139 | intent_ids_of_interest = [name2id[i] for i in args.intent_triplet] 140 | intent_ids_of_interest.sort() 141 | 142 | # segregate sentences and labels 143 | labels, sentences = [], [] 144 | for l, s in eval_sentences: 145 | labels.append(l) 146 | sentences.append(s) 147 | 148 | if args.dataset_name not in ["hwu64", "banking77"]: 149 | raise NotImplementedError(f"Dataset {args.dataset_name} not supported") 150 | 151 | print("Running N-shot baseline evaluation on reduced val+test set...") 152 | model_dir = f"/mnt/colab_public/results/few_shot_nlp/model/{args.dataset_name}/" 153 | model_dir += "10_shot_baseline/" if size == "small" else "oracle_checkpoint/" 154 | 155 | if args.dataset_name == "hwu64": 156 | num_labels = 64 157 | else: 158 | num_labels = 77 159 | 160 | import torch 161 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 162 | 163 | tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased", use_fast=True) 164 | model = AutoModelForSequenceClassification.from_pretrained( 165 | model_dir, num_labels=num_labels 166 | ) 167 | device = "cuda" if torch.cuda.is_available() else "cpu" 168 | model.to(device) 169 | model.eval() 170 | 171 | with torch.no_grad(): 172 | encodings = tokenizer( 173 | sentences, 174 | max_length=50, 175 | padding=True, 176 | truncation=True, 177 | return_tensors="pt", 178 | ) 179 | input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] 180 | logits = model(input_ids.to(device), attention_mask.to(device)).logits 181 | # only consider the 3 intents of interest for prediction 182 | _temp_iids = torch.tensor(intent_ids_of_interest).to(device) 183 | logits = logits.index_select(index=_temp_iids, dim=1) 184 | preds = torch.argmax(logits, dim=1).cpu().tolist() 185 | # remap pred ids to dataset intent ids 186 | preds = [id2name[str(intent_ids_of_interest[p])] for p in preds] 187 | 188 | class_wise_preds = defaultdict(list) 189 | for p, l in zip(preds, labels): 190 | class_wise_preds[l].append(p) 191 | 192 | # Evaluation 193 | print(f"N-shot baseline performance on {len(eval_sentences)} examples:") 194 | print(f"Overall accuracy = {compute_acc(preds, labels)}") 195 | print(f"Class-wise accuracies:") 196 | for intent, preds in class_wise_preds.items(): 197 | _acc = preds.count(intent) / len(preds) 198 | print(f"Acc. for intent {intent}: {_acc*100:.2f}") 199 | 200 | 201 | def main(): 202 | """ 203 | generated sentences 204 | filtered generated sentences 205 | 206 | val+test sentences for the dataset 207 | """ 208 | args = parse_args() 209 | print(args) 210 | ds_config = mdu.get_ds_config(args.dataset_name) 211 | name2id = mdu.read_json(f"data/{args.dataset_name}/name2id.json") 212 | id2name = mdu.read_json(f"data/{args.dataset_name}/id2name.json") 213 | 214 | ############# START: Fetch all required set of sentences ################ 215 | 216 | # fetch seed sentences for all the intents in the triplet using the 217 | # dataset prepared for partial fewshot experiment 218 | ds = mdu.read_pickle(f"data/{args.dataset_name}/full/data_full_suite.pkl") 219 | all_seed_sentences = [] 220 | for intent in args.intent_triplet: 221 | seed_sents = get_seed_tuples(ds, intent, ds_config, name2id) 222 | all_seed_sentences.extend(seed_sents) 223 | # shuffle shuffle 224 | np.random.shuffle(all_seed_sentences) 225 | 226 | ds = mdu.read_pickle(f"data/{args.dataset_name}/full/al_dataset.pkl") 227 | # Fetch generated sentences as a list of tuples 228 | # each tuple --> (old intent, new intent, sentence) 229 | all_gen_sentences = [] 230 | relabled_pool = ds["generated"][f"{args.gpt_engine}_1.0"] 231 | for idx, intent in enumerate(relabled_pool["old_intent"]): 232 | if id2name[str(intent)] not in args.intent_triplet: 233 | continue 234 | old_intent = id2name[str(intent)] 235 | new_intent = id2name[str(relabled_pool["intent"][idx])] 236 | sentence = relabled_pool["text"][idx] 237 | all_gen_sentences.append((old_intent, new_intent, sentence)) 238 | 239 | # fetch val+test sentences (all accuracies are computed on these sentences) 240 | # NOTE: since it's val and test, old_intent is the ground truth 241 | eval_sentences = [] 242 | for part in ["val", "test"]: 243 | sent_pool = ds["generated"][part] 244 | for (text, label) in zip(sent_pool["text"], sent_pool["old_intent"]): 245 | if id2name[str(label)] not in args.intent_triplet: 246 | continue 247 | eval_sentences.append((id2name[str(label)], text)) 248 | 249 | # filtered examples using GPT as the classifier 250 | retained_generations = filter_via_gpt(all_gen_sentences, all_seed_sentences, args) 251 | 252 | ############## END: Fetched all required set of sentencs ################# 253 | 254 | ############# COMPUTE DIFFERENT ################## 255 | print(f"\nFIDELITY for ALL {args.gpt_engine.upper()} generations") 256 | compute_fidelity(all_gen_sentences) 257 | print(f"\nFIDELITY after FILTERING {args.gpt_engine.upper()} rejections") 258 | compute_fidelity(retained_generations) 259 | 260 | print(f"\n{args.gpt_engine.upper()}'s 3-way classification performance") 261 | run_gpt_eval(eval_sentences, all_seed_sentences, args) 262 | print(f"\n10-SHOT-BASELINE's 3-way classification performance") 263 | oracle_eval(eval_sentences, id2name, name2id, args, "small") 264 | print(f"\nORACLE's 3-way classification performance") 265 | oracle_eval(eval_sentences, id2name, name2id, args, "big") 266 | 267 | 268 | def parse_args(): 269 | parser = argparse.ArgumentParser() 270 | # parser.add_argument("-dn", "--dataset-name", default="hwu64") 271 | parser.add_argument("-dn", "--dataset-name", default="banking77") 272 | parser.add_argument("-e", "--gpt_engine", default="davinci") 273 | parser.add_argument( 274 | "-it", 275 | "--intent-triplet", 276 | nargs="+", 277 | # default=["music_likeness", "play_music", "music_settings"], 278 | default=["topping_up_by_card", "top_up_failed", "pending_top_up"], 279 | ) 280 | return parser.parse_args() 281 | 282 | 283 | if __name__ == "__main__": 284 | main() 285 | -------------------------------------------------------------------------------- /scripts/make_spreadsheet.py: -------------------------------------------------------------------------------- 1 | import os, json, pickle 2 | from openpyxl import Workbook 3 | from openpyxl.styles import Font 4 | 5 | DATASET_NAME = "banking77" 6 | 7 | 8 | def load_data(): 9 | data_dir = f"./data/{DATASET_NAME}" 10 | opj = os.path.join 11 | ds_full_suite = pickle.load(open(opj(data_dir, "full/data_full_suite.pkl"), "rb")) 12 | generated_samples = pickle.load(open(opj(data_dir, "full/al_dataset.pkl"), "rb"))[ 13 | "generated" 14 | ] 15 | id2name = json.load(open(opj(data_dir, "id2name.json"))) 16 | return ds_full_suite, generated_samples, id2name 17 | 18 | 19 | def massage_data(ds_full_suite, generated_samples, id2name): 20 | workbooks = {} 21 | for engine in generated_samples: 22 | workbooks[engine] = gather_workbook_data( 23 | ds_full_suite, generated_samples[engine], id2name 24 | ) 25 | return workbooks 26 | 27 | 28 | def gather_workbook_data(ds_full_suite, generated_samples, id2name): 29 | workbook_data = {} 30 | for domain in ds_full_suite: 31 | # generate prompt column 32 | prompt_data = ds_full_suite[domain]["F"]["train"] 33 | for text, intent_id in zip(prompt_data["text"], prompt_data["intent"]): 34 | intent_name = id2name[str(intent_id)].replace("?", "") 35 | sheet_name = f"{domain}<>{intent_name}" 36 | if sheet_name not in workbook_data: 37 | workbook_data[sheet_name] = { 38 | "prompt": [], 39 | "generated": [], 40 | "oracle_prediction": [], 41 | } 42 | workbook_data[sheet_name]["prompt"].append(text) 43 | 44 | # add generated data, and oracle prediction data 45 | for text, oracle_intent_id, org_intent_id in zip( 46 | generated_samples["text"], 47 | generated_samples["intent"], 48 | generated_samples["old_intent"], 49 | ): 50 | oracle_pred = id2name[str(oracle_intent_id)].replace("?", "") 51 | org_intent_name = id2name[str(org_intent_id)].replace("?", "") 52 | sheet_name = f"{domain}<>{org_intent_name}" 53 | if sheet_name not in workbook_data: 54 | # print(f"sheet {sheet_name} doesn't exist") 55 | continue 56 | workbook_data[sheet_name]["generated"].append(text) 57 | workbook_data[sheet_name]["oracle_prediction"].append(oracle_pred) 58 | return workbook_data 59 | 60 | 61 | def create_excel_sheet(name, data): 62 | wb = Workbook() 63 | wb.remove(wb.active) # remove the empty "Sheet" 64 | # create different sheets 65 | for sheet_name in data: 66 | org_intent = sheet_name.split("<>", 1)[1] 67 | ws = wb.create_sheet(sheet_name) 68 | prompts = data[sheet_name]["prompt"] 69 | generated = data[sheet_name]["generated"] 70 | oracle_predictions = data[sheet_name]["oracle_prediction"] 71 | 72 | ############# compute some quantities for formatting ############## 73 | # max width of column A 74 | max_sent_length = max(map(len, prompts + generated)) 75 | # max width of column B 76 | max_pred_length = max(map(len, oracle_predictions)) 77 | total_faithful_samples = oracle_predictions.count(org_intent) 78 | ############# compute end ################# 79 | 80 | # add the first column 81 | ws.append(["Sentences", "Oracle Predictions"]) 82 | # add the sentences column 83 | for irow in range(len(prompts + generated)): 84 | if irow < len(prompts): 85 | ws.append([prompts[irow]]) 86 | else: 87 | new_irow = irow - len(prompts) 88 | ws.append([generated[new_irow], oracle_predictions[new_irow]]) 89 | 90 | # some analysis 91 | ws["C1"] = "Total faithful samples" 92 | ws["C2"] = f"{total_faithful_samples}/{len(generated)}" 93 | ws["C3"] = f"{total_faithful_samples/len(generated)*100:.2f}%" 94 | 95 | # adjust column widths 96 | ws.column_dimensions["A"].width = max_sent_length 97 | ws.column_dimensions["B"].width = max_pred_length 98 | ws.column_dimensions["C"].width = len("Total faithful samples") 99 | 100 | # increase font size 101 | n_rows = len(prompts + generated) 102 | for col, n_rows in [("A", n_rows), ("B", n_rows), ("C", 3)]: 103 | for i in range(1, n_rows + 2): 104 | ws[f"{col}{i}"].font = Font(size=14) 105 | 106 | # bold the first row 107 | ws["A1"].font = Font(bold=True, size=14) 108 | ws["B1"].font = Font(bold=True, size=14) 109 | ws["C1"].font = Font(bold=True, size=14) 110 | 111 | # delete this useless sheet 112 | # sort sheets based on fidelity (ws['C3'] is the fidelity) 113 | wb._sheets.sort(key=lambda ws: float(ws["C3"].value[:-1])) 114 | wb.active = 0 115 | 116 | save_folder = f"spreadsheets/{DATASET_NAME}" 117 | if not os.path.exists(save_folder): 118 | os.mkdir(save_folder) 119 | wb.save(os.path.join(save_folder, f"{name}.xlsx")) 120 | 121 | 122 | if __name__ == "__main__": 123 | workbooks = massage_data(*load_data()) 124 | for engine_temp, data in workbooks.items(): 125 | if "_" in engine_temp and engine_temp.split("_")[1] != "1.0": 126 | continue 127 | create_excel_sheet(engine_temp, data) 128 | -------------------------------------------------------------------------------- /scripts/openai_sandbox.py: -------------------------------------------------------------------------------- 1 | from utils.data_utils.augment_slices import openai_complete 2 | import argparse 3 | 4 | 5 | def main(args): 6 | if args.examples is None: 7 | raise ValueError("No seed examples provided.") 8 | if args.intent_name is None: 9 | raise ValueError("Please provide the name of a seed intent") 10 | if args.gpt_engine is None: 11 | print("No engine provided. Using ada...") 12 | ENGINE = "ada" 13 | else: 14 | ENGINE = args.gpt_engine 15 | 16 | intent = args.intent_name 17 | lines = args.examples 18 | k = len(args.examples) 19 | 20 | print("----Default method----") 21 | prompt = f"The following sentences belong to the same category {intent}:\n" 22 | prompt += "\n".join([f"Example {i+1}: {l}" for i, l in enumerate(lines)]) 23 | prompt += "\n" 24 | prompt += f"Example {k+1}:" 25 | print(prompt) 26 | from pprint import pprint 27 | 28 | pprint( 29 | [ 30 | r.text.strip() 31 | for r in openai_complete( 32 | prompt=prompt, n=20, engine=ENGINE, temp=1.0, top_p=1.0 33 | ) 34 | ] 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser() 40 | 41 | parser.add_argument( 42 | "-e", 43 | "--examples", 44 | nargs="+", 45 | default=None, 46 | help="Seed examples for prompting GPT", 47 | ) 48 | parser.add_argument( 49 | "-i", "--intent_name", default=None, type=str, help="Name of the seed intent" 50 | ) 51 | 52 | parser.add_argument( 53 | "-g", 54 | "--gpt_engine", 55 | default=None, 56 | help="GPT engine to use for augmentation (ada/babbage/curie/davinci)", 57 | ) 58 | args, unknown = parser.parse_known_args() 59 | 60 | main(args) 61 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils.sample_few_shot import sample_few_shot 2 | from .data_utils.augment_slices import augment_slices 3 | from .data_utils import main as main_data_utils -------------------------------------------------------------------------------- /utils/data_utils/augment_slices.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script uses openai to augment the dataset with 3 | samples for different few-shot domains 4 | """ 5 | import os, gc, torch, openai, pickle, json 6 | import numpy as np 7 | from . import eda_utils 8 | from collections import Counter 9 | 10 | pjoin = os.path.join 11 | 12 | 13 | class GPTJChoice: 14 | def __init__(self, text): 15 | self.text = text 16 | 17 | 18 | def load_dataset_slices(data_root, data_name): 19 | with open(pjoin(data_root, data_name, "full", "data_full_suite.pkl"), "rb") as f: 20 | return pickle.load(f) 21 | 22 | 23 | def openai_complete( 24 | prompt, 25 | n, 26 | engine, 27 | temp, 28 | top_p, 29 | max_tokens=32, 30 | stop="\n", 31 | frequency_penalty=0, 32 | logprobs=None, 33 | ): 34 | completion = openai.Completion.create( 35 | engine=engine, 36 | prompt=prompt, 37 | max_tokens=max_tokens, 38 | n=n, 39 | stop=stop, 40 | temperature=temp, 41 | top_p=1 if not top_p else top_p, 42 | frequency_penalty=frequency_penalty, 43 | logprobs=logprobs, 44 | ) 45 | return completion.choices 46 | # return [c.text.strip() for c in completion.choices] 47 | 48 | 49 | def gptj_complete(prompt, n, temp, model, tokenizer, top_k, top_p): 50 | """ 51 | Parameters: 52 | =========== 53 | prompt: Str 54 | Text to be fed as prompt 55 | n: Int 56 | Number of sentences to be generated 57 | temp: Float 58 | Sampling temperature for GPTJ 59 | model: GPTJ model instance 60 | GPTJ model loaded for inference 61 | tokenizer: GPTJ tokenizer instance 62 | GPTJ tokenizer loaded (from Huggingface currenlty) 63 | top_k: Anyof False, Int 64 | top k tokens to consider when sampling 65 | top_p: Float 66 | p value for top-p sampling (nucleus sampling) 67 | """ 68 | # k is the line where the predicted/generated sample resides 69 | # compensate for the last (incomplete) line "Example {num_seed+1}:" 70 | k = len(prompt.splitlines()) - 1 71 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 72 | input_ids = input_ids.to("cuda") 73 | stop_token = tokenizer.encode("\n")[0] 74 | sentences = [] 75 | while len(sentences) != n: 76 | # generate multiple sentences at a time 77 | # NOTE: .generate already sets torch.no_grad() 78 | gen_tokens = model.generate( 79 | input_ids, 80 | do_sample=True, 81 | max_length=2 * input_ids.shape[1], 82 | temperature=temp, 83 | eos_token_id=stop_token, 84 | # 30 is the max we can go on a 32G GPU 85 | # ^^that's a lie 86 | num_return_sequences=min(n, 30), 87 | # to suppress open-end generation warning 88 | pad_token_id=stop_token, 89 | top_k=0 if not top_k else top_k, 90 | top_p=1 if not top_p else top_p, 91 | ) 92 | generations = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) 93 | del gen_tokens 94 | # remove the first k lines as they belong to the prompt 95 | for i in range(min(n, 30)): 96 | # intentionally using that space after : 97 | # the model should be predicting that 98 | s = generations[i].splitlines()[k:][0][len(f"Example {k-1}: ") :].strip() 99 | # don't add if empty or if already generated n sentences 100 | if s and len(sentences) < n: 101 | sentences.append(s) 102 | del s 103 | del generations 104 | del input_ids 105 | del model 106 | gc.collect() 107 | torch.cuda.empty_cache() 108 | return [GPTJChoice(s) for s in sentences] 109 | 110 | 111 | def upsample_domain(prompt, n): 112 | lines = prompt.strip().splitlines() 113 | upsampled = lines * (n // len(lines)) 114 | upsampled.extend(lines[: (n % len(lines))]) 115 | return upsampled 116 | 117 | 118 | def eda_domain(prompt, n): 119 | lines = prompt.strip().splitlines() 120 | k = len(lines) 121 | augmented = [] 122 | for line in lines: 123 | if not line: 124 | continue 125 | # augment for a line 126 | # NOTE: num_aug in the EDA paper is #new lines per training sample 127 | # using alpha = 0.05 as per recommendation in the paper 128 | generated = eda_utils.eda( 129 | sentence=line, 130 | alpha_sr=0.05, 131 | alpha_ri=0.05, 132 | alpha_rs=0.05, 133 | p_rd=0.05, 134 | num_aug=n // k, 135 | ) 136 | augmented.extend(generated) 137 | return augmented 138 | 139 | 140 | def regenerate(input_prompt, n_empty, engine, temp, top_p): 141 | new_lines = [] 142 | while n_empty > 0: 143 | print(f"Saw {n_empty} empty line(s). GPT3ing again...") 144 | curr_lines = openai_complete( 145 | prompt=input_prompt, 146 | n=n_empty, 147 | engine=engine, 148 | temp=temp, 149 | top_p=top_p, 150 | ) 151 | curr_lines = [r.text.strip() for r in curr_lines] 152 | n_empty = curr_lines.count("") 153 | new_lines.extend([t for t in curr_lines if t]) 154 | if n_empty == 0: 155 | return new_lines 156 | 157 | 158 | def augment_domain( 159 | dataset_slices, 160 | val_domain, 161 | data_save_path, 162 | id2name, 163 | num_ex=10, 164 | n_max=128, 165 | engine=None, 166 | temp=None, 167 | model=None, 168 | tokenizer=None, 169 | top_k=False, 170 | top_p=False, 171 | mode="upsample", 172 | mt_dict=None, 173 | ): 174 | """ 175 | Augments a given domain in dataset_slices AND updates the pickle file 176 | """ 177 | if len(dataset_slices[val_domain]["M"]["train"]["intent"]) == 0: 178 | # no many-domain available for this dataset 179 | data_path = os.path.join(os.path.dirname(data_save_path), "dataset.pkl") 180 | with open(data_path, "rb") as f: 181 | dataset = pickle.load(f) 182 | num_synthetic = int( 183 | np.median(list(Counter(dataset["train"]["intent"]).values())) 184 | ) 185 | else: 186 | # when many-domain is available 187 | counts = Counter(dataset_slices[val_domain]["M"]["train"]["intent"]) 188 | num_synthetic = int(np.median(list(counts.values()))) 189 | 190 | f_train_lines = dataset_slices[val_domain]["F"]["train"]["text"] 191 | f_train_labels = dataset_slices[val_domain]["F"]["train"]["intent"] 192 | 193 | f_gen_lines, f_gen_labels = [], [] 194 | for idx in range(0, len(f_train_lines), num_ex): 195 | prompt_lines = f_train_lines[idx : idx + num_ex] 196 | 197 | # simple prompt format (prepend 'Example i: ') 198 | input_prompt = "\n".join( 199 | [f"Example {i+1}: {t}" for i, t in enumerate(prompt_lines)] 200 | ) 201 | input_prompt += f"\nExample {num_ex+1}:" 202 | 203 | # prompting with label addition as well 204 | seed_intent = id2name[f"{f_train_labels[idx : idx + num_ex][0]}"] 205 | print(f"Seed intent: {seed_intent}") 206 | input_prompt = ( 207 | f"The following sentences belong to the same category '{seed_intent}':\n" 208 | + input_prompt 209 | ) 210 | 211 | if mode == "upsample": 212 | print("Upsampling...") 213 | generated_lines = upsample_domain(input_prompt, num_synthetic) 214 | elif mode == "eda": 215 | print("EDAing...") 216 | generated_lines = eda_domain(input_prompt, num_synthetic) 217 | elif mode == "gptj": 218 | engine = "gptj" 219 | print("GPTJing...") 220 | generated_lines = gptj_complete( 221 | prompt=input_prompt, 222 | n=num_synthetic, 223 | temp=temp, 224 | model=model, 225 | tokenizer=tokenizer, 226 | top_k=top_k, 227 | top_p=top_p, 228 | ) 229 | generated_lines = [r.text.strip() for r in generated_lines] 230 | else: 231 | print("GPT3ing...") 232 | if num_synthetic <= n_max: 233 | generated_lines = openai_complete( 234 | prompt=input_prompt, 235 | n=num_synthetic, 236 | engine=engine, 237 | temp=temp, 238 | top_p=top_p, 239 | ) 240 | generated_lines = [r.text.strip() for r in generated_lines] 241 | else: 242 | generated_lines = [] 243 | for _ in range(num_synthetic // n_max): 244 | _c = openai_complete( 245 | prompt=input_prompt, 246 | n=n_max, 247 | engine=engine, 248 | temp=temp, 249 | top_p=top_p, 250 | ) 251 | generated_lines.extend([r.text.strip() for r in _c]) 252 | # rest of the lines 253 | _c = openai_complete( 254 | prompt=input_prompt, 255 | n=num_synthetic % n_max, 256 | engine=engine, 257 | temp=temp, 258 | top_p=top_p, 259 | ) 260 | generated_lines.extend([r.text.strip() for r in _c]) 261 | 262 | # sometimes there are empty strings generated by GPT3, try again 263 | n_empty = generated_lines.count("") 264 | if n_empty > 0: 265 | generated_lines = [t for t in generated_lines if t] 266 | if n_empty <= n_max: 267 | generated_lines.extend( 268 | regenerate( 269 | input_prompt, 270 | n_empty, 271 | engine, 272 | temp, 273 | top_p, 274 | ) 275 | ) 276 | else: 277 | for _ in range(n_empty // n_max): 278 | generated_lines.extend( 279 | regenerate( 280 | input_prompt, 281 | n_max, 282 | engine, 283 | temp, 284 | top_p, 285 | ) 286 | ) 287 | # rest of the lines 288 | generated_lines.extend( 289 | regenerate( 290 | input_prompt, 291 | n_empty % n_max, 292 | engine, 293 | temp, 294 | top_p, 295 | ) 296 | ) 297 | 298 | assert len(generated_lines) == num_synthetic 299 | 300 | f_gen_lines.extend(generated_lines) 301 | # using len(generated_lines) to make sure #lines == #labels 302 | # as, for imbalanced datasets, there can be slightly more sentences 303 | # generated by EDA as it augments per sentence. 304 | f_gen_labels.extend([f_train_labels[idx]] * len(generated_lines)) 305 | 306 | attr_name = mode if engine is None else f"{engine}_{temp}" 307 | 308 | dataset_slices[val_domain]["F"][attr_name] = { 309 | "text": f_gen_lines, 310 | "intent": f_gen_labels, 311 | } 312 | write_pickle(data_save_path, dataset_slices) 313 | 314 | 315 | def write_pickle(path, data): 316 | with open(path, "wb") as f: 317 | pickle.dump(data, f) 318 | 319 | 320 | def upsample_loop(dataset_slices, domains, data_save_path, id2name): 321 | # Upsample loop 322 | for val_domain in domains: 323 | if "upsample" in dataset_slices[val_domain]["F"].keys(): 324 | print(f"upsample for {val_domain} already exists") 325 | continue 326 | print(f"Augmenting for domain: {val_domain}") 327 | augment_domain( 328 | dataset_slices=dataset_slices, 329 | val_domain=val_domain, 330 | data_save_path=data_save_path, 331 | id2name=id2name, 332 | ) 333 | 334 | 335 | def eda_loop(dataset_slices, domains, data_save_path, id2name): 336 | """ 337 | Easy data augmenatation baseline by Wei and Zhou (EMNLP, 2019) 338 | """ 339 | # EDA loop: 340 | for val_domain in domains: 341 | if "eda" in dataset_slices[val_domain]["F"].keys(): 342 | print(f"eda for {val_domain} already exists") 343 | continue 344 | print(f"Augmenting for domain: {val_domain}") 345 | augment_domain( 346 | dataset_slices=dataset_slices, 347 | val_domain=val_domain, 348 | data_save_path=data_save_path, 349 | id2name=id2name, 350 | mode="eda", 351 | ) 352 | 353 | 354 | def load_gptj(): 355 | """returns GPTJ and its tokenizer""" 356 | from transformers import GPTJForCausalLM, AutoTokenizer 357 | 358 | print("Loading GPT-J...") 359 | model = GPTJForCausalLM.from_pretrained( 360 | "EleutherAI/gpt-j-6B", 361 | revision="float16", 362 | torch_dtype=torch.float16, 363 | low_cpu_mem_usage=True, 364 | ) 365 | model = model.to("cuda") 366 | model.eval() 367 | print("Loaded GPT-J.") 368 | 369 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 370 | return model, tokenizer 371 | 372 | 373 | def gptj_loop( 374 | dataset_slices, 375 | domains, 376 | ds_config, 377 | data_save_path, 378 | id2name, 379 | top_k, 380 | top_p, 381 | ): 382 | # GPTJ loop 383 | model, tokenizer = None, None 384 | # for temp in [1.0]: 385 | # round the temperatures to avoid floating points as 1.200003.. 386 | for temp in [round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))]: 387 | print(f"Engine: GPT-J | Temp: {temp}") 388 | for val_domain in domains: 389 | # comment the following two lines to *update* existing lines 390 | if f"gptj_{temp}" in dataset_slices[val_domain]["F"].keys(): 391 | print(f"gptj_{temp} for {val_domain} already exists") 392 | continue 393 | 394 | if model is None and tokenizer is None: 395 | model, tokenizer = load_gptj() 396 | 397 | print(f"Augmenting for domain: {val_domain}") 398 | augment_domain( 399 | dataset_slices=dataset_slices, 400 | val_domain=val_domain, 401 | data_save_path=data_save_path, 402 | id2name=id2name, 403 | num_ex=ds_config.num_examples, 404 | engine="gptj", 405 | temp=temp, 406 | model=model, 407 | tokenizer=tokenizer, 408 | top_k=top_k, 409 | top_p=top_p, 410 | mode="gptj", 411 | ) 412 | 413 | 414 | def gpt3_loop( 415 | dataset_slices, 416 | domains, 417 | ds_config, 418 | data_save_path, 419 | id2name, 420 | top_p, 421 | ): 422 | # GPT3 loop 423 | for engine in ["davinci", "curie", "babbage", "ada"]: 424 | for temp in [1.0]: 425 | # round the temperatures to avoid floating points as 1.200003.. 426 | # for temp in [round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))]: 427 | print(f"Engine: {engine} | Temp: {temp}") 428 | for val_domain in domains: 429 | # comment the following two lines to *update* existing lines 430 | if f"{engine}_{temp}" in dataset_slices[val_domain]["F"].keys(): 431 | print(f"{engine}_{temp} for {val_domain} already exists") 432 | continue 433 | print(f"Augmenting for domain: {val_domain}") 434 | augment_domain( 435 | dataset_slices=dataset_slices, 436 | val_domain=val_domain, 437 | data_save_path=data_save_path, 438 | id2name=id2name, 439 | num_ex=ds_config.num_examples, 440 | n_max=ds_config.gpt3_batch_size, 441 | engine=engine, 442 | temp=temp, 443 | top_p=top_p, 444 | mode="gpt3", 445 | ) 446 | # davinci quickly reaches the token/min limit, so we must sleep 447 | if engine == "davinci": 448 | print("sleeping, for openai won't let me GPT3 no more...") 449 | import time 450 | 451 | time.sleep(60) 452 | 453 | 454 | def augment_slices( 455 | data_root, 456 | ds_config, 457 | modes=["upsample", "gptj", "gpt3", "eda"], 458 | top_k=False, 459 | top_p=False, 460 | ): 461 | dataset_slices = load_dataset_slices(data_root, ds_config.data_name) 462 | DOMAINS = ds_config.domain_to_intent.keys() 463 | 464 | data_save_path = pjoin( 465 | data_root, 466 | ds_config.data_name, 467 | "full", 468 | "data_full_suite.pkl", 469 | ) 470 | id2name = json.load(open(pjoin(data_root, ds_config.data_name, "id2name.json"))) 471 | 472 | for mode in modes: 473 | if mode == "upsample": 474 | upsample_loop(dataset_slices, DOMAINS, data_save_path, id2name) 475 | elif mode == "gptj": 476 | gptj_loop( 477 | dataset_slices, 478 | DOMAINS, 479 | ds_config, 480 | data_save_path, 481 | id2name, 482 | top_k=top_k, 483 | top_p=top_p, 484 | ) 485 | elif mode == "gpt3": 486 | if top_k: 487 | print("NOTE: ignoring top_k for gpt3 as openai doesn't support it yet") 488 | gpt3_loop( 489 | dataset_slices, 490 | DOMAINS, 491 | ds_config, 492 | data_save_path, 493 | id2name, 494 | top_p=top_p, 495 | ) 496 | elif mode == "eda": 497 | eda_loop(dataset_slices, DOMAINS, data_save_path, id2name) 498 | return dataset_slices 499 | -------------------------------------------------------------------------------- /utils/data_utils/banking77_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | All banking77 specific utilities are implemented here. 3 | """ 4 | from datasets import load_dataset 5 | from sklearn.model_selection import train_test_split 6 | from haven import haven_utils as hu 7 | import os 8 | 9 | INTENTS = [ 10 | "activate_my_card", 11 | "age_limit", 12 | "apple_pay_or_google_pay", 13 | "atm_support", 14 | "automatic_top_up", 15 | "balance_not_updated_after_bank_transfer", 16 | "balance_not_updated_after_cheque_or_cash_deposit", 17 | "beneficiary_not_allowed", 18 | "cancel_transfer", 19 | "card_about_to_expire", 20 | "card_acceptance", 21 | "card_arrival", 22 | "card_delivery_estimate", 23 | "card_linking", 24 | "card_not_working", 25 | "card_payment_fee_charged", 26 | "card_payment_not_recognised", 27 | "card_payment_wrong_exchange_rate", 28 | "card_swallowed", 29 | "cash_withdrawal_charge", 30 | "cash_withdrawal_not_recognised", 31 | "change_pin", 32 | "compromised_card", 33 | "contactless_not_working", 34 | "country_support", 35 | "declined_card_payment", 36 | "declined_cash_withdrawal", 37 | "declined_transfer", 38 | "direct_debit_payment_not_recognised", 39 | "disposable_card_limits", 40 | "edit_personal_details", 41 | "exchange_charge", 42 | "exchange_rate", 43 | "exchange_via_app", 44 | "extra_charge_on_statement", 45 | "failed_transfer", 46 | "fiat_currency_support", 47 | "get_disposable_virtual_card", 48 | "get_physical_card", 49 | "getting_spare_card", 50 | "getting_virtual_card", 51 | "lost_or_stolen_card", 52 | "lost_or_stolen_phone", 53 | "order_physical_card", 54 | "passcode_forgotten", 55 | "pending_card_payment", 56 | "pending_cash_withdrawal", 57 | "pending_top_up", 58 | "pending_transfer", 59 | "pin_blocked", 60 | "receiving_money", 61 | "Refund_not_showing_up", 62 | "request_refund", 63 | "reverted_card_payment?", 64 | "supported_cards_and_currencies", 65 | "terminate_account", 66 | "top_up_by_bank_transfer_charge", 67 | "top_up_by_card_charge", 68 | "top_up_by_cash_or_cheque", 69 | "top_up_failed", 70 | "top_up_limits", 71 | "top_up_reverted", 72 | "topping_up_by_card", 73 | "transaction_charged_twice", 74 | "transfer_fee_charged", 75 | "transfer_into_account", 76 | "transfer_not_received_by_recipient", 77 | "transfer_timing", 78 | "unable_to_verify_identity", 79 | "verify_my_identity", 80 | "verify_source_of_funds", 81 | "verify_top_up", 82 | "virtual_card_not_working", 83 | "visa_or_mastercard", 84 | "why_verify_identity", 85 | "wrong_amount_of_cash_received", 86 | "wrong_exchange_rate_for_cash_withdrawal", 87 | ] 88 | 89 | 90 | class Banking77: 91 | def __init__(self, name): 92 | self.data_name = name 93 | self.full_path = f"./data/{name}/full/dataset.pkl" 94 | self.num_examples = 10 95 | 96 | # 1. save dataset.pkl 97 | if not os.path.exists(self.full_path): 98 | self.dataset = load_dataset(name).rename_column("label", "intent") 99 | # get the splits 100 | train, validation = train_test_split( 101 | self.dataset["train"], 102 | train_size=0.9, 103 | stratify=self.dataset["train"]["intent"], 104 | ) 105 | 106 | # get data dict 107 | data_dict = { 108 | "train": train, 109 | "val": validation, 110 | "test": self.dataset["test"].to_dict(), 111 | } 112 | hu.save_pkl(self.full_path, data_dict) 113 | else: 114 | data_dict = hu.load_pkl(self.full_path) 115 | 116 | # 2. Group by Intent 117 | name2id = {k: i for i, k in enumerate(INTENTS)} 118 | hu.save_json(f"./data/{name}/name2id.json", name2id) 119 | id2name = {i: k for i, k in enumerate(INTENTS)} 120 | hu.save_json(f"./data/{name}/id2name.json", id2name) 121 | 122 | # intents = list(name2id.keys()) 123 | 124 | self.dataset_by_intent = {} 125 | for split in ["train", "val", "test"]: 126 | self.dataset_by_intent[split] = {} 127 | text_intent_dict = data_dict[split] 128 | text_list, intent_list = ( 129 | text_intent_dict["text"], 130 | map(lambda x: id2name[x], text_intent_dict["intent"]), 131 | ) 132 | 133 | # get texts from intent 134 | intent2texts = {} 135 | for t, i in zip(text_list, intent_list): 136 | if i not in intent2texts: 137 | intent2texts[i] = [] 138 | intent2texts[i] += [t] 139 | self.dataset_by_intent[split] = intent2texts 140 | 141 | self.domain_to_intent = {name: INTENTS} 142 | self.gpt3_batch_size: int = 128 143 | 144 | def parse_and_load(self): 145 | return self.dataset_by_intent 146 | -------------------------------------------------------------------------------- /utils/data_utils/clinc_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | All CLINC specific utilities must be implemented here. 3 | It defines a DS_CONFIG for the dataset. 4 | The main function that will be used by any other submodule in the repo is 5 | `parse_and_load_clinc`. In a nutshell, it prepares a `dataset.pkl` file for 6 | CLINC if not already prepared AND returns the dataset grouped by intent names. 7 | Refer to the function's documentation to know more details. 8 | """ 9 | 10 | import pickle 11 | import os 12 | import json 13 | import collections 14 | from typing import Dict 15 | from pydantic import BaseModel 16 | 17 | 18 | class DS_CONFIG(BaseModel): 19 | data_name: str = "clinc_oos" 20 | full_path: str = "./data/clinc_oos/full/dataset.pkl" 21 | num_examples: int = 10 22 | # See https://github.com/clinc/oos-eval/blob/master/supplementary.pdf. 23 | domain_to_intent: Dict = { 24 | "banking": [ 25 | "transfer", 26 | "transactions", 27 | "balance", 28 | "freeze_account", 29 | "pay_bill", 30 | "bill_balance", 31 | "bill_due", 32 | "interest_rate", 33 | "routing", 34 | "min_payment", 35 | "order_checks", 36 | "pin_change", 37 | "report_fraud", 38 | "account_blocked", 39 | "spending_history", 40 | ], 41 | "credit_card": [ 42 | "credit_score", 43 | "report_lost_card", 44 | "credit_limit", 45 | "rewards_balance", 46 | "new_card", 47 | "application_status", 48 | "card_declined", 49 | "international_fees", 50 | "apr", 51 | "redeem_rewards", 52 | "credit_limit_change", 53 | "damaged_card", 54 | "replacement_card_duration", 55 | "improve_credit_score", 56 | "expiration_date", 57 | ], 58 | "dining": [ 59 | "recipe", 60 | "restaurant_reviews", 61 | "calories", 62 | "nutrition_info", 63 | "restaurant_suggestion", 64 | "ingredients_list", 65 | "ingredient_substitution", 66 | "cook_time", 67 | "food_last", 68 | "meal_suggestion", 69 | "restaurant_reservation", 70 | "confirm_reservation", 71 | "how_busy", 72 | "cancel_reservation", 73 | "accept_reservations", 74 | ], 75 | "home": [ 76 | "shopping_list", 77 | "shopping_list_update", 78 | "next_song", 79 | "play_music", 80 | "update_playlist", 81 | "todo_list", 82 | "todo_list_update", 83 | "calendar", 84 | "calendar_update", 85 | "what_song", 86 | "order", 87 | "order_status", 88 | "reminder", 89 | "reminder_update", 90 | "smart_home", 91 | ], 92 | "auto": [ 93 | "traffic", 94 | "directions", 95 | "gas", 96 | "gas_type", 97 | "distance", 98 | "current_location", 99 | "mpg", 100 | "oil_change_when", 101 | "oil_change_how", 102 | "jump_start", 103 | "uber", 104 | "schedule_maintenance", 105 | "last_maintenance", 106 | "tire_pressure", 107 | "tire_change", 108 | ], 109 | "travel": [ 110 | "book_flight", 111 | "book_hotel", 112 | "car_rental", 113 | "travel_suggestion", 114 | "travel_alert", 115 | "travel_notification", 116 | "carry_on", 117 | "timezone", 118 | "vaccines", 119 | "translate", 120 | "flight_status", 121 | "international_visa", 122 | "lost_luggage", 123 | "plug_type", 124 | "exchange_rate", 125 | ], 126 | "utility": [ 127 | "time", 128 | "alarm", 129 | "share_location", 130 | "find_phone", 131 | "weather", 132 | "text", 133 | "spelling", 134 | "make_call", 135 | "timer", 136 | "date", 137 | "calculator", 138 | "measurement_conversion", 139 | "flip_coin", 140 | "roll_dice", 141 | "definition", 142 | ], 143 | "work": [ 144 | "direct_deposit", 145 | "pto_request", 146 | "taxes", 147 | "payday", 148 | "w2", 149 | "pto_balance", 150 | "pto_request_status", 151 | "next_holiday", 152 | "insurance", 153 | "insurance_change", 154 | "schedule_meeting", 155 | "pto_used", 156 | "meeting_schedule", 157 | "rollover_401k", 158 | "income", 159 | ], 160 | "small_talk": [ 161 | "greeting", 162 | "goodbye", 163 | "tell_joke", 164 | "where_are_you_from", 165 | "how_old_are_you", 166 | "what_is_your_name", 167 | "who_made_you", 168 | "thank_you", 169 | "what_can_i_ask_you", 170 | "what_are_your_hobbies", 171 | "do_you_have_pets", 172 | "are_you_a_bot", 173 | "meaning_of_life", 174 | "who_do_you_work_for", 175 | "fun_fact", 176 | ], 177 | "meta": [ 178 | "change_ai_name", 179 | "change_user_name", 180 | "cancel", 181 | "user_name", 182 | "reset_settings", 183 | "whisper_mode", 184 | "repeat", 185 | "no", 186 | "yes", 187 | "maybe", 188 | "change_language", 189 | "change_accent", 190 | "change_volume", 191 | "change_speed", 192 | "sync_device", 193 | ], 194 | "oos": ["oos"], 195 | } 196 | gpt3_batch_size: int = 128 197 | 198 | 199 | DOMAIN_TO_INTENT = DS_CONFIG().domain_to_intent 200 | CLINC_FULL_PATH = "./data/clinc_oos/full/data_full.json" 201 | 202 | 203 | def make_label_maps(): 204 | """ 205 | returns a mapping of intent name to intent id and vice versa 206 | """ 207 | intent_names = [] 208 | for intent_list in DOMAIN_TO_INTENT.values(): 209 | if intent_list == ["oos"]: 210 | continue 211 | intent_names.extend(intent_list) 212 | 213 | name2id, id2name = {}, {} 214 | for idx, name in enumerate(sorted(set(intent_names))): 215 | name2id[name] = idx 216 | id2name[idx] = name 217 | 218 | # explicitly assign oos an id of 150 219 | name2id["oos"] = 150 220 | id2name[150] = "oos" 221 | 222 | # storing in JSON for readability later on 223 | # NOTE: keys in id2name.json will be str and not int! 224 | with open("./data/clinc_oos/name2id.json", "w") as f: 225 | json.dump(name2id, f) 226 | with open("./data/clinc_oos/id2name.json", "w") as f: 227 | json.dump(id2name, f) 228 | return name2id, id2name 229 | 230 | 231 | def download_clinc_full(): 232 | download_dir = "./data/clinc_oos/full/" 233 | # create dir if doesn't exist 234 | if not os.path.exists(download_dir): 235 | os.makedirs(download_dir) 236 | clinc_data_url = ( 237 | "https://raw.githubusercontent.com/clinc/oos-eval/master/data/data_full.json" 238 | ) 239 | os.system(f"wget {clinc_data_url} -P {download_dir}") 240 | return json.load(open(os.path.join(download_dir, "data_full.json"))) 241 | 242 | 243 | def load_or_download_clinc(): 244 | try: 245 | full_data = json.load(open(CLINC_FULL_PATH)) 246 | except FileNotFoundError: 247 | full_data = download_clinc_full() 248 | return full_data 249 | 250 | 251 | def get_label_maps(): 252 | try: 253 | intentname2id = json.load(open("./data/clinc_oos/name2id.json")) 254 | intentid2name = json.load(open("./data/clinc_oos/id2name.json")) 255 | except FileNotFoundError: 256 | intentname2id, intentid2name = make_label_maps() 257 | return intentname2id, intentid2name 258 | 259 | 260 | def prepare_clinc(): 261 | full_data = load_or_download_clinc() 262 | intentname2id, intentid2name = get_label_maps() 263 | 264 | data_dict = { 265 | "train": {"text": [], "intent": []}, 266 | "val": {"text": [], "intent": []}, 267 | "test": {"text": [], "intent": []}, 268 | } 269 | 270 | for key, data in full_data.items(): 271 | for sample in data: 272 | line, label = sample 273 | label = intentname2id[label] 274 | # maybe there's a better way to prepare this... 275 | if "train" in key: 276 | data_dict["train"]["text"].append(line) 277 | data_dict["train"]["intent"].append(label) 278 | elif "val" in key: 279 | data_dict["val"]["text"].append(line) 280 | data_dict["val"]["intent"].append(label) 281 | else: 282 | data_dict["test"]["text"].append(line) 283 | data_dict["test"]["intent"].append(label) 284 | 285 | print("Data details:") 286 | print( 287 | f'Train #lines: {len(data_dict["train"]["text"])} #labels: {len(data_dict["train"]["intent"])}' 288 | ) 289 | print( 290 | f'Val #lines: {len(data_dict["val"]["text"])} #labels: {len(data_dict["val"]["intent"])}' 291 | ) 292 | print( 293 | f'Test #lines: {len(data_dict["test"]["text"])} #labels: {len(data_dict["test"]["intent"])}' 294 | ) 295 | 296 | with open("./data/clinc_oos/full/dataset.pkl", "wb") as f: 297 | pickle.dump(data_dict, f) 298 | print("Base dataset.pkl prepared for CLINC!") 299 | 300 | 301 | def load_clinc(): 302 | data = collections.defaultdict(lambda: collections.defaultdict(list)) 303 | with open(CLINC_FULL_PATH, "r") as data_file: 304 | for split_name, split_data in json.load(data_file).items(): 305 | for query, intent in split_data: 306 | data[split_name][intent].append(query) 307 | return data 308 | 309 | 310 | def parse_and_load_clinc(): 311 | """ 312 | This functions has two primary roles: 313 | 1) create `dataset.pkl` for CLINC if it doesn't exist already 314 | 2) return the CLINC dataset grouped by intent names 315 | 316 | Secondary role: 317 | This function also creates name2id.json and id2name.json files for the 318 | dataset in the respective dataset folder (e.g. ./data/clinc_oos/ for 319 | CLINC. Note that the name of the dataset folder matches 320 | DS_CONFIG.data_name). If the dataset contains an OOS class, make it the 321 | last class for that dataset (to easen execution of non-OOS experiments) 322 | 323 | name2id.json is a Dict[Str: Int] and id2name.json is a Dict[Str: Str] 324 | 325 | parse_and_load_clinc() is the only function that will interact with the 326 | outside world, and any dataset specific utility required to accomplish the 327 | described roles above (like downloading, parsing, etc.) may be implemented 328 | in this file. 329 | 330 | NOTE that we store intent ids in dataset.pkl, but the grouped dataset 331 | returned by the function stores intent names! 332 | 333 | `dataset.pkl` contains a Dict object with the following schema: 334 | 335 | { 336 | 'train': {'text': listofStr, 'intent': listofInt}, 337 | 'val': {'text': listofStr, 'intent': listofInt}, 338 | 'test': {'text': listofStr, 'intent': listofInt} 339 | } 340 | 341 | Format of the returned dataset grouped by intent names: 342 | 343 | collections.defaultdict(None, { 344 | 'train': { 345 | Str: listofStr, 346 | Str: listofStr, 347 | . 348 | . 349 | . 350 | Str: listofStr # n_intents 351 | }, 352 | 353 | 'val': { 354 | Str: listofStr, 355 | Str: listofStr, 356 | . 357 | . 358 | . 359 | Str: listofStr # n_intents 360 | }, 361 | 362 | 'test': { 363 | Str: listofStr, 364 | Str: listofStr, 365 | . 366 | . 367 | . 368 | Str: listofStr # n_intents 369 | } 370 | }) 371 | 372 | parse_and_load_clinc: None -> collections.defaultdict 373 | """ 374 | if not os.path.exists(os.path.join("./data/clinc_oos/full/dataset.pkl")): 375 | prepare_clinc() 376 | return load_clinc() 377 | -------------------------------------------------------------------------------- /utils/data_utils/data_loader.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from utils import main_data_utils as mdu 3 | from datasets import load_dataset, Dataset, DatasetDict, Dataset 4 | import os, warnings, regex, numpy as np, collections, math, time 5 | 6 | from utils.data_utils.augment_slices import openai_complete 7 | 8 | 9 | def _normalize(probs): 10 | """ 11 | Given label probs Dict[str: list] 12 | first, normalizes probablities of tokens predicted multiple times 13 | then, normalizes across all the predicted labels 14 | """ 15 | # NOTE: not using assertion because davinci and curie give differet probs 16 | # for the same prediction sometimes 17 | # for k, v in probs.items(): 18 | # prob should be the same for all the multiple predictions of the label 19 | # assert len(set(v)) == 1 20 | # probs[k] = v[0] 21 | probs = {k: np.mean(v) for k, v in probs.items()} 22 | return {k: v / sum(probs.values()) for k, v in probs.items()} 23 | 24 | 25 | def gpt3mix_complete(prompt, n, labels_list, exp_dict, name2id): 26 | """ 27 | Given a seed_text and its corresponding seed_intent (name, not id), 28 | 29 | 1. generate x_dash (n augmentations per seed_text) 30 | 2. generate y_dash (soft label using name2id) 31 | """ 32 | pattern = regex.compile(rf"(?r)Sentence: (.*)\(intent: (.*)\)") 33 | # gpt prompting to generate x_dashes 34 | completions = openai_complete( 35 | engine=exp_dict["gpt3_engine"], 36 | prompt=prompt, 37 | temp=exp_dict["gpt3_temp"], 38 | top_p=1.0, 39 | n=n, 40 | stop="\n", 41 | max_tokens=50, 42 | frequency_penalty=0.02, 43 | ) 44 | 45 | augmentations = {"text": [], "intent": []} 46 | for c in completions: 47 | match = pattern.search("Sentence:" + c.text) 48 | if match is None: # invalid prediction 49 | continue 50 | _txt = match.group(1).strip().lower() 51 | if not _txt: 52 | continue 53 | 54 | # prompt GPT3 again to create soft label 55 | label_completions = openai_complete( 56 | engine=exp_dict["gpt3_engine"], 57 | prompt=prompt + f" {_txt}. (intent:", 58 | temp=exp_dict["gpt3_temp"], 59 | n=100, 60 | top_p=1.0, 61 | max_tokens=20, 62 | stop="\n", 63 | logprobs=1, 64 | ) 65 | 66 | # construct probabilities for all the predicted labels 67 | label_probs = collections.defaultdict(list) 68 | for _lc in label_completions: 69 | _log_probs = _lc.logprobs 70 | _match = pattern.search(f"Sentence: {_txt}. (intent:" + _lc.text) 71 | if _match is None: # incomplete prediction 72 | continue 73 | 74 | _pred = _match.group(2).strip().lower() 75 | if _pred not in labels_list: # invalid label 76 | continue 77 | 78 | # NOTE: we are looking at token_logprobs v/s top_logprobs used 79 | # by the GPT3Mix paper because we are sampling to compute 80 | # p(y_dash| x_dash) as opposed to looking at logprobs of top 100 81 | # most likely tokens. default value limits us to just 5 now. 82 | _curr_log_prob = 0 83 | for t, p in zip(_log_probs["tokens"], _log_probs["token_logprobs"]): 84 | # if the code reaches here, ) is guaranteed to be present 85 | # as regex check earlier would trigger a `continue` otherwise 86 | if t == ")": 87 | label_probs[_pred].append(math.exp(_curr_log_prob)) 88 | break 89 | 90 | # add logprobs (multiply probs) for sub words of _pred as 91 | # class names are not single tokens 92 | _curr_log_prob += p 93 | 94 | # normalize label_probs 95 | label_probs = _normalize(label_probs) 96 | # create soft label 97 | soft_label = [0] * exp_dict["dataset"]["num_labels"] 98 | for k, v in label_probs.items(): 99 | soft_label[name2id[k]] = v 100 | 101 | augmentations["text"].append(_txt) 102 | augmentations["intent"].append(soft_label) 103 | return augmentations 104 | 105 | 106 | def generate_for_gpt3mix(base_ds, ex2_ds, exp_dict, interim_save_path): 107 | num_labels = exp_dict["dataset"]["num_labels"] 108 | ds_name = exp_dict["dataset"]["name"] 109 | id2name = mdu.read_json(f"data/{ds_name}/id2name.json") 110 | name2id = mdu.read_json(f"data/{ds_name}/name2id.json") 111 | ds_config = mdu.get_ds_config(ds_name) 112 | k = ds_config.num_examples 113 | 114 | labels_list = list(name2id.keys()) 115 | if "oos" in labels_list: 116 | labels_list.remove("oos") 117 | 118 | train_lines, train_labels = [], [] 119 | if os.path.exists(interim_save_path): 120 | interim_copy = mdu.read_pickle(interim_save_path) 121 | else: 122 | interim_copy = {} 123 | 124 | for domain in ex2_ds: 125 | if domain in interim_copy: 126 | print(f"Domain: {domain} already GPT3Mix augmented. Moving on...") 127 | continue 128 | 129 | print(f"Augmenting domain: {domain}") 130 | texts = ex2_ds[domain]["F"]["train"]["text"] 131 | hard_labels = ex2_ds[domain]["F"]["train"]["intent"] 132 | _lines, _labels = [], [] 133 | # NOTE: this loop will never be executed for oos--both lists will be [] 134 | for text, intent in tqdm(zip(texts, hard_labels), total=len(texts)): 135 | # add gold example to training set 136 | one_hot = [0.0] * num_labels 137 | one_hot[intent] = 1.0 138 | # for interim copy 139 | _lines.append(text) 140 | _labels.append(one_hot) 141 | 142 | # construct prompt header 143 | prompt = "Each item in the following list contains a sentence and the respective intent." 144 | label_enum_str = [f"'{l.lower()}'" for l in labels_list] 145 | prompt += f" Intent is one of {', or '.join(label_enum_str)}" 146 | prompt += ".\n" 147 | prompt += f"Sentence: {text}. (intent: {id2name[str(intent)]})\n" 148 | 149 | # remove current intent from candidates to sample from 150 | _lbl_list = [l for l in labels_list if l != id2name[str(intent)]] 151 | # sample k-1 random intents from the label_set (k=9) 152 | other_lbls = np.random.choice(_lbl_list, k - 1, replace=False) 153 | # fetch a sample for each of these new intents and add to the prompt 154 | for lbl in other_lbls: 155 | # find the domain of lbl 156 | _domain, _domain_found = None, False 157 | for _d, _i_l in ds_config.domain_to_intent.items(): 158 | if not _domain_found and lbl in _i_l: 159 | _domain_found = True 160 | _domain = _d 161 | 162 | gt_txts = ex2_ds[_domain]["F"]["train"]["text"] 163 | gt_lbls = ex2_ds[_domain]["F"]["train"]["intent"] 164 | _start = gt_lbls.index(name2id[lbl]) 165 | # select a random sentence for lbl 166 | _text = np.random.choice(gt_txts[_start : _start + k], 1)[0] 167 | # add the _text, lbl pair to prompt 168 | prompt += f"Sentence: {_text}. (intent: {lbl})\n" 169 | prompt += "Sentence:" 170 | 171 | # generated examples with soft labels 172 | augs = gpt3mix_complete(prompt, 10, labels_list, exp_dict, name2id) 173 | _lines.extend(augs["text"]) 174 | _labels.extend(augs["intent"]) 175 | 176 | train_lines.extend(_lines) 177 | train_labels.extend(_labels) 178 | 179 | # save an interim copy now 180 | interim_copy[domain] = {"text": _lines, "intent": _labels} 181 | mdu.write_pickle(interim_copy, interim_save_path) 182 | 183 | print("Sleeping...for a minute") 184 | time.sleep(60) 185 | 186 | # Add OOS samples 187 | oos_texts, oos_labels = extract_oos(base_ds["train"], exp_dict["dataset"]["oos_id"]) 188 | for text, intent in tqdm(zip(oos_texts, oos_labels), total=len(oos_texts)): 189 | # add gold example to training set 190 | one_hot = [0.0] * num_labels 191 | one_hot[intent] = 1.0 192 | train_lines.append(text) 193 | train_labels.append(one_hot) 194 | 195 | # delete interim copy 196 | del interim_copy 197 | return {"text": train_lines, "intent": train_labels} 198 | 199 | 200 | def prepare_for_seq2seq(dataset, id2name_path): 201 | """ 202 | dataset: Dict[str]: 203 | """ 204 | id2name = mdu.read_json(id2name_path) 205 | return { 206 | "text": [t + " " for t in dataset["text"]], 207 | # intents are class ids here, not names 208 | "intent": [id2name[str(i)] + " " for i in dataset["intent"]], 209 | } 210 | 211 | 212 | def filter_oos(data_dict, oos_id, soft_label=False): 213 | """Removes oos samples from the data dict""" 214 | lines, labels = data_dict["text"], data_dict["intent"] 215 | # some datasets (like SNIPS) don't have an OOS class 216 | if oos_id is None: 217 | return lines, labels 218 | _lines, _labels = [], [] 219 | for idx, intent_id in enumerate(labels): 220 | if soft_label and np.array(intent_id).argmax(-1) == oos_id: 221 | continue 222 | if not soft_label and intent_id == oos_id: 223 | continue 224 | _lines.append(lines[idx]) 225 | _labels.append(labels[idx]) 226 | # print(len(_lines), len(_labels)) 227 | return _lines, _labels 228 | 229 | 230 | def extract_oos(data_dict, oos_id): 231 | """Extract the OOS samples from the data dict. It is the 232 | opposite of filter_oos""" 233 | lines, labels = data_dict["text"], data_dict["intent"] 234 | # some datasets (like SNIPS) don't have an OOS class 235 | _lines, _labels = [], [] 236 | for idx, intent_id in enumerate(labels): 237 | if intent_id != oos_id: 238 | continue 239 | _lines.append(lines[idx]) 240 | _labels.append(labels[idx]) 241 | return _lines, _labels 242 | 243 | 244 | class DatasetLoader: 245 | """ 246 | Available datasets: 247 | - Clinc original: We can define whether to get the `full` version or the `small` version. 248 | - Pure Fewshot Clinc: 249 | baseline: Contains 10 example per class (except the OOS) which is randomly sampled from the original full clinc. 250 | 251 | """ 252 | 253 | def __init__(self, data_root, exp_dict): 254 | dataset_config = exp_dict["dataset"]["config"] 255 | 256 | var_path = "full" if dataset_config.startswith("f") else "small" 257 | ds_name = exp_dict["dataset"]["name"] 258 | basic_data_path = os.path.join(data_root, ds_name, var_path, "dataset.pkl") 259 | ex2_data_path = os.path.join( 260 | data_root, ds_name, var_path, "data_full_suite.pkl" 261 | ) 262 | 263 | if dataset_config == "few_pure": 264 | base_ds = mdu.read_pickle(basic_data_path) 265 | data_set = mdu.read_pickle(ex2_data_path) 266 | oos_id = exp_dict["dataset"]["oos_id"] 267 | train_lines, train_labels = [], [] 268 | if exp_dict["exp_type"] == "baseline": 269 | print("Loading dataset for full few-shot baseline") 270 | for domain in data_set: 271 | train_lines.extend(data_set[domain]["F"]["train"]["text"]) 272 | train_labels.extend(data_set[domain]["F"]["train"]["intent"]) 273 | elif exp_dict["exp_type"] in ["eda"]: 274 | exp_type = exp_dict["exp_type"] 275 | print(f"Loading dataset for full few-shot {exp_type.upper()}") 276 | # lump in EDA examples with all few-shot samples 277 | for domain in data_set: 278 | train_lines.extend( 279 | data_set[domain]["F"]["train"]["text"] 280 | + data_set[domain]["F"][exp_type]["text"] 281 | ) 282 | train_labels.extend( 283 | data_set[domain]["F"]["train"]["intent"] 284 | + data_set[domain]["F"][exp_type]["intent"] 285 | ) 286 | elif exp_dict["exp_type"] in ["gpt3", "eda"]: 287 | print(f"Loading dataset for full few-shot {exp_dict['exp_type']}") 288 | # set correct attribute to fetch from the dataset 289 | if exp_dict["exp_type"] == "gpt3": 290 | engine, temp = exp_dict["gpt3_engine"], exp_dict["gpt3_temp"] 291 | attr = f"{engine}_{temp}" 292 | else: # eda 293 | attr = exp_dict["exp_type"] 294 | 295 | # lump in the fetched examples with all few-shot samples 296 | for domain in data_set: 297 | train_lines.extend( 298 | data_set[domain]["F"]["train"]["text"] 299 | + data_set[domain]["F"][attr]["text"] 300 | ) 301 | train_labels.extend( 302 | data_set[domain]["F"]["train"]["intent"] 303 | + data_set[domain]["F"][attr]["intent"] 304 | ) 305 | elif exp_dict["exp_type"] in [ 306 | "gpt3_oracle", 307 | "eda_oracle", 308 | "gpt3mix_oracle", 309 | ]: 310 | # the few shot sentences are taken from the ex2 setup data 311 | # and the relabeled samples are taken from al_dataset.pkl 312 | print(f"Loading dataset for full few-shot {exp_dict['exp_type']}") 313 | # Use relabeled dataset as base 314 | al_path = os.path.join(data_root, ds_name, "full", "al_dataset.pkl") 315 | al_ds = mdu.read_pickle(al_path) 316 | 317 | # set correct attribute to fetch from the dataset 318 | if exp_dict["exp_type"] == "gpt3_oracle": 319 | engine, temp = exp_dict["gpt3_engine"], exp_dict["gpt3_temp"] 320 | attr = f"{engine}_{temp}" 321 | elif exp_dict["exp_type"] == "gpt3mix_oracle": 322 | attr = f"gpt3mix_{exp_dict['gpt3_engine']}" 323 | else: # eda_oracle 324 | attr = exp_dict["exp_type"].split("_")[0] # just eda 325 | 326 | for domain in data_set: 327 | train_lines.extend(data_set[domain]["F"]["train"]["text"]) 328 | train_labels.extend(data_set[domain]["F"]["train"]["intent"]) 329 | train_lines.extend(al_ds["generated"][attr]["text"]) 330 | train_labels.extend(al_ds["generated"][attr]["intent"]) 331 | 332 | elif exp_dict["exp_type"] == "gpt3mix": 333 | print("Loading labelled pool for full few-shot gpt3mix") 334 | engine = exp_dict["gpt3_engine"] 335 | gpt3mix_path = f"data/{ds_name}/full/gpt3mix_{engine}.pkl" 336 | 337 | # augs will also contain the seed samples 338 | if os.path.exists(gpt3mix_path): # load from existing pkl 339 | print(f"Loading existing GPT3Mix data for {engine.upper()}") 340 | augs = mdu.read_pickle(gpt3mix_path) 341 | else: # otherwise, generate gpt3mix pickle 342 | print(f"Generating GPT3Mix data with {engine.upper()}") 343 | interim_save_path = gpt3mix_path[:-4] + "_interim.pkl" 344 | augs = generate_for_gpt3mix( 345 | base_ds, data_set, exp_dict, interim_save_path 346 | ) 347 | # save complete augmented data 348 | mdu.write_pickle(augs, gpt3mix_path) 349 | 350 | train_lines, train_labels = augs["text"], augs["intent"] 351 | 352 | val_lines, val_labels = base_ds["val"]["text"], base_ds["val"]["intent"] 353 | test_lines, test_labels = ( 354 | base_ds["test"]["text"], 355 | base_ds["test"]["intent"], 356 | ) 357 | 358 | # add oos samples to train set (gpt3mix setting already adds) 359 | if oos_id is not None and exp_dict["exp_type"] != "gpt3mix": 360 | # add oos samples to the dataset 361 | oos_lines, oos_labels = extract_oos(base_ds["train"], oos_id) 362 | train_lines.extend(oos_lines) 363 | train_labels.extend(oos_labels) 364 | 365 | # remove oos samples appropriately 366 | if oos_id is None: 367 | name2id_path = os.path.join(data_root, ds_name, "name2id.json") 368 | temp_oos_id = mdu.read_json(name2id_path).get("oos", None) 369 | if exp_dict["exp_type"] == "gpt3mix": 370 | train_set = {"text": train_lines, "intent": train_labels} 371 | # remove oos samples add to train set here by default 372 | train_lines, train_labels = filter_oos( 373 | train_set, oos_id, soft_label=True 374 | ) 375 | # remove oos samples from the val set added by default 376 | val_lines, val_labels = filter_oos(base_ds["val"], temp_oos_id) 377 | test_lines, test_labels = filter_oos(base_ds["test"], temp_oos_id) 378 | 379 | print(len(train_lines), len(train_labels)) 380 | self.dataset = DatasetDict( 381 | train=Dataset.from_dict({"text": train_lines, "intent": train_labels}), 382 | validation=Dataset.from_dict({"text": val_lines, "intent": val_labels}), 383 | test=Dataset.from_dict({"text": test_lines, "intent": test_labels}), 384 | ) 385 | elif dataset_config == "full": 386 | # read the original FULL version of the dataset 387 | data_set = mdu.read_pickle(basic_data_path) 388 | 389 | if exp_dict["exp_type"] == "intrinsic": 390 | print("Loading utils for intrinsic evaluation") 391 | 392 | oos_id = exp_dict["dataset"]["oos_id"] 393 | train_lines, train_labels = filter_oos(data_set["train"], oos_id) 394 | val_lines, val_labels = filter_oos(data_set["val"], oos_id) 395 | test_lines, test_labels = filter_oos(data_set["test"], oos_id) 396 | 397 | self.dataset = DatasetDict( 398 | train=Dataset.from_dict( 399 | {"text": train_lines, "intent": train_labels} 400 | ), 401 | validation=Dataset.from_dict( 402 | {"text": val_lines, "intent": val_labels} 403 | ), 404 | test=Dataset.from_dict({"text": test_lines, "intent": test_labels}), 405 | ) 406 | # add different set of generated lines as test set 407 | augmented_data = mdu.mdu.read_pickle(ex2_data_path) 408 | domains = list(augmented_data.keys()) 409 | for e in ["ada", "babbage", "curie", "davinci", "gptj"]: 410 | # for t in np.linspace(0.5, 2, int((2.1-.5)/.1)): 411 | for t in [1.0]: 412 | _lines, _intents = [], [] 413 | for d in domains: 414 | if d == "oos": 415 | continue 416 | _lines.extend(augmented_data[d]["F"][f"{e}_{t}"]["text"]) 417 | _intents.extend( 418 | augmented_data[d]["F"][f"{e}_{t}"]["intent"] 419 | ) 420 | self.dataset[f"{e}_{t}"] = Dataset.from_dict( 421 | {"text": _lines, "intent": _intents} 422 | ) 423 | elif exp_dict["exp_type"] == "baseline": 424 | print("Loading utils for baseline version") 425 | self.dataset = DatasetDict( 426 | train=Dataset.from_dict(data_set["train"]), 427 | validation=Dataset.from_dict(data_set["val"]), 428 | test=Dataset.from_dict(data_set["test"]), 429 | ) 430 | 431 | elif dataset_config.startswith("full_"): 432 | print(f"Loading utils for {dataset_config}") 433 | # read the augmented version of the dataset 434 | data_set = mdu.read_pickle(ex2_data_path) 435 | # the few-shot domain 436 | val_domain = dataset_config.split("_", 1)[1] 437 | # train set = D_{M, train} + D_{F, train} 438 | train_lines = ( 439 | data_set[val_domain]["M"]["train"]["text"] 440 | + data_set[val_domain]["F"]["train"]["text"] 441 | ) 442 | 443 | train_labels = ( 444 | data_set[val_domain]["M"]["train"]["intent"] 445 | + data_set[val_domain]["F"]["train"]["intent"] 446 | ) 447 | 448 | if exp_dict["exp_type"] == "upsample": 449 | train_lines.extend(data_set[val_domain]["F"]["upsample"]["text"]) 450 | train_labels.extend(data_set[val_domain]["F"]["upsample"]["intent"]) 451 | elif exp_dict["exp_type"] == "gpt3": 452 | engine = exp_dict["gpt3_engine"] 453 | temp = exp_dict["gpt3_temp"] 454 | train_lines.extend( 455 | data_set[val_domain]["F"][f"{engine}_{temp}"]["text"] 456 | ) 457 | train_labels.extend( 458 | data_set[val_domain]["F"][f"{engine}_{temp}"]["intent"] 459 | ) 460 | 461 | full_val_lines = ( 462 | data_set[val_domain]["M"]["val"]["text"] 463 | + data_set[val_domain]["F"]["val"]["text"] 464 | ) 465 | 466 | full_val_labels = ( 467 | data_set[val_domain]["M"]["val"]["intent"] 468 | + data_set[val_domain]["F"]["val"]["intent"] 469 | ) 470 | 471 | full_test_lines = ( 472 | data_set[val_domain]["M"]["test"]["text"] 473 | + data_set[val_domain]["F"]["test"]["text"] 474 | ) 475 | 476 | full_test_labels = ( 477 | data_set[val_domain]["M"]["test"]["intent"] 478 | + data_set[val_domain]["F"]["test"]["intent"] 479 | ) 480 | 481 | # add oos samples to the dataset for oos-aware classifiers 482 | if exp_dict["dataset"]["oos_id"] is not None: 483 | print("adding OOS samples to the dataset") 484 | base_ds = mdu.mdu.read_pickle(basic_data_path) 485 | oos_id = exp_dict["dataset"]["oos_id"] 486 | 487 | # augment training set 488 | oos_train_lines, oos_train_labels = extract_oos( 489 | base_ds["train"], oos_id 490 | ) 491 | train_lines.extend(oos_train_lines) 492 | train_labels.extend(oos_train_labels) 493 | 494 | # augment validation set 495 | oos_val_lines, oos_val_labels = extract_oos(base_ds["val"], oos_id) 496 | full_val_lines.extend(oos_val_lines) 497 | full_val_labels.extend(oos_val_labels) 498 | 499 | # augment test set 500 | oos_test_lines, oos_test_labels = extract_oos(base_ds["test"], oos_id) 501 | full_test_lines.extend(oos_test_lines) 502 | full_test_labels.extend(oos_test_labels) 503 | 504 | self.dataset = DatasetDict( 505 | train=Dataset.from_dict({"text": train_lines, "intent": train_labels}), 506 | validation=Dataset.from_dict(data_set[val_domain]["F"]["val"]), 507 | test=Dataset.from_dict(data_set[val_domain]["F"]["test"]), 508 | full_test=Dataset.from_dict( 509 | {"text": full_test_lines, "intent": full_test_labels} 510 | ), 511 | full_validation=Dataset.from_dict( 512 | {"text": full_val_lines, "intent": full_val_labels} 513 | ), 514 | ) 515 | else: 516 | warnings.warn("At the moment we can only load clinc_oos") 517 | self.dataset = load_dataset(ds_name, dataset_config, cache_dir=data_root) 518 | 519 | def get_split(self, split): 520 | return self.dataset[split] 521 | -------------------------------------------------------------------------------- /utils/data_utils/eda_utils.py: -------------------------------------------------------------------------------- 1 | # Easy data augmentation techniques for text classification 2 | # Jason Wei and Kai Zou from 3 | # https://github.com/jasonwei20/eda_nlp/blob/04ab29c5b18d2d72f9fa5b304322aaf4793acea0/code/eda.py 4 | 5 | import random 6 | from random import shuffle 7 | 8 | random.seed(1) 9 | 10 | # stop words list 11 | stop_words = [ 12 | "i", 13 | "me", 14 | "my", 15 | "myself", 16 | "we", 17 | "our", 18 | "ours", 19 | "ourselves", 20 | "you", 21 | "your", 22 | "yours", 23 | "yourself", 24 | "yourselves", 25 | "he", 26 | "him", 27 | "his", 28 | "himself", 29 | "she", 30 | "her", 31 | "hers", 32 | "herself", 33 | "it", 34 | "its", 35 | "itself", 36 | "they", 37 | "them", 38 | "their", 39 | "theirs", 40 | "themselves", 41 | "what", 42 | "which", 43 | "who", 44 | "whom", 45 | "this", 46 | "that", 47 | "these", 48 | "those", 49 | "am", 50 | "is", 51 | "are", 52 | "was", 53 | "were", 54 | "be", 55 | "been", 56 | "being", 57 | "have", 58 | "has", 59 | "had", 60 | "having", 61 | "do", 62 | "does", 63 | "did", 64 | "doing", 65 | "a", 66 | "an", 67 | "the", 68 | "and", 69 | "but", 70 | "if", 71 | "or", 72 | "because", 73 | "as", 74 | "until", 75 | "while", 76 | "of", 77 | "at", 78 | "by", 79 | "for", 80 | "with", 81 | "about", 82 | "against", 83 | "between", 84 | "into", 85 | "through", 86 | "during", 87 | "before", 88 | "after", 89 | "above", 90 | "below", 91 | "to", 92 | "from", 93 | "up", 94 | "down", 95 | "in", 96 | "out", 97 | "on", 98 | "off", 99 | "over", 100 | "under", 101 | "again", 102 | "further", 103 | "then", 104 | "once", 105 | "here", 106 | "there", 107 | "when", 108 | "where", 109 | "why", 110 | "how", 111 | "all", 112 | "any", 113 | "both", 114 | "each", 115 | "few", 116 | "more", 117 | "most", 118 | "other", 119 | "some", 120 | "such", 121 | "no", 122 | "nor", 123 | "not", 124 | "only", 125 | "own", 126 | "same", 127 | "so", 128 | "than", 129 | "too", 130 | "very", 131 | "s", 132 | "t", 133 | "can", 134 | "will", 135 | "just", 136 | "don", 137 | "should", 138 | "now", 139 | "", 140 | ] 141 | 142 | # cleaning up text 143 | import re 144 | 145 | 146 | def get_only_chars(line): 147 | 148 | clean_line = "" 149 | 150 | line = line.replace("’", "") 151 | line = line.replace("'", "") 152 | line = line.replace("-", " ") # replace hyphens with spaces 153 | line = line.replace("\t", " ") 154 | line = line.replace("\n", " ") 155 | line = line.lower() 156 | 157 | for char in line: 158 | if char in "qwertyuiopasdfghjklzxcvbnm ": 159 | clean_line += char 160 | else: 161 | clean_line += " " 162 | 163 | clean_line = re.sub(" +", " ", clean_line) # delete extra spaces 164 | if clean_line[0] == " ": 165 | clean_line = clean_line[1:] 166 | return clean_line 167 | 168 | 169 | ######################################################################## 170 | # Synonym replacement 171 | # Replace n words in the sentence with synonyms from wordnet 172 | ######################################################################## 173 | 174 | 175 | # import nltk 176 | # nltk.download('wordnet') 177 | from nltk.corpus import wordnet 178 | 179 | 180 | def synonym_replacement(words, n): 181 | new_words = words.copy() 182 | random_word_list = list(set([word for word in words if word not in stop_words])) 183 | random.shuffle(random_word_list) 184 | num_replaced = 0 185 | for random_word in random_word_list: 186 | synonyms = get_synonyms(random_word) 187 | if len(synonyms) >= 1: 188 | synonym = random.choice(list(synonyms)) 189 | new_words = [synonym if word == random_word else word for word in new_words] 190 | # print("replaced", random_word, "with", synonym) 191 | num_replaced += 1 192 | if num_replaced >= n: # only replace up to n words 193 | break 194 | 195 | # this is stupid but we need it, trust me 196 | sentence = " ".join(new_words) 197 | new_words = sentence.split(" ") 198 | 199 | return new_words 200 | 201 | 202 | def get_synonyms(word): 203 | synonyms = set() 204 | for syn in wordnet.synsets(word): 205 | for l in syn.lemmas(): 206 | synonym = l.name().replace("_", " ").replace("-", " ").lower() 207 | synonym = "".join( 208 | [char for char in synonym if char in " qwertyuiopasdfghjklzxcvbnm"] 209 | ) 210 | synonyms.add(synonym) 211 | if word in synonyms: 212 | synonyms.remove(word) 213 | return list(synonyms) 214 | 215 | 216 | ######################################################################## 217 | # Random deletion 218 | # Randomly delete words from the sentence with probability p 219 | ######################################################################## 220 | 221 | 222 | def random_deletion(words, p): 223 | 224 | # obviously, if there's only one word, don't delete it 225 | if len(words) == 1: 226 | return words 227 | 228 | # randomly delete words with probability p 229 | new_words = [] 230 | for word in words: 231 | r = random.uniform(0, 1) 232 | if r > p: 233 | new_words.append(word) 234 | 235 | # if you end up deleting all words, just return a random word 236 | if len(new_words) == 0: 237 | rand_int = random.randint(0, len(words) - 1) 238 | return [words[rand_int]] 239 | 240 | return new_words 241 | 242 | 243 | ######################################################################## 244 | # Random swap 245 | # Randomly swap two words in the sentence n times 246 | ######################################################################## 247 | 248 | 249 | def random_swap(words, n): 250 | new_words = words.copy() 251 | for _ in range(n): 252 | new_words = swap_word(new_words) 253 | return new_words 254 | 255 | 256 | def swap_word(new_words): 257 | random_idx_1 = random.randint(0, len(new_words) - 1) 258 | random_idx_2 = random_idx_1 259 | counter = 0 260 | while random_idx_2 == random_idx_1: 261 | random_idx_2 = random.randint(0, len(new_words) - 1) 262 | counter += 1 263 | if counter > 3: 264 | return new_words 265 | new_words[random_idx_1], new_words[random_idx_2] = ( 266 | new_words[random_idx_2], 267 | new_words[random_idx_1], 268 | ) 269 | return new_words 270 | 271 | 272 | ######################################################################## 273 | # Random insertion 274 | # Randomly insert n words into the sentence 275 | ######################################################################## 276 | 277 | 278 | def random_insertion(words, n): 279 | new_words = words.copy() 280 | for _ in range(n): 281 | add_word(new_words) 282 | return new_words 283 | 284 | 285 | def add_word(new_words): 286 | synonyms = [] 287 | counter = 0 288 | while len(synonyms) < 1: 289 | random_word = new_words[random.randint(0, len(new_words) - 1)] 290 | synonyms = get_synonyms(random_word) 291 | counter += 1 292 | if counter >= 10: 293 | return 294 | random_synonym = synonyms[0] 295 | random_idx = random.randint(0, len(new_words) - 1) 296 | new_words.insert(random_idx, random_synonym) 297 | 298 | 299 | ######################################################################## 300 | # main data augmentation function 301 | ######################################################################## 302 | 303 | 304 | def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9): 305 | sentence = get_only_chars(sentence) 306 | words = sentence.split(" ") 307 | words = [word for word in words if word != ""] 308 | num_words = len(words) 309 | 310 | augmented_sentences = [] 311 | num_new_per_technique = int(num_aug / 4) + 1 312 | 313 | # sr 314 | if alpha_sr > 0: 315 | n_sr = max(1, int(alpha_sr * num_words)) 316 | for _ in range(num_new_per_technique): 317 | a_words = synonym_replacement(words, n_sr) 318 | augmented_sentences.append(" ".join(a_words)) 319 | 320 | # ri 321 | if alpha_ri > 0: 322 | n_ri = max(1, int(alpha_ri * num_words)) 323 | for _ in range(num_new_per_technique): 324 | a_words = random_insertion(words, n_ri) 325 | augmented_sentences.append(" ".join(a_words)) 326 | 327 | # rs 328 | if alpha_rs > 0: 329 | n_rs = max(1, int(alpha_rs * num_words)) 330 | for _ in range(num_new_per_technique): 331 | a_words = random_swap(words, n_rs) 332 | augmented_sentences.append(" ".join(a_words)) 333 | 334 | # rd 335 | if p_rd > 0: 336 | for _ in range(num_new_per_technique): 337 | a_words = random_deletion(words, p_rd) 338 | augmented_sentences.append(" ".join(a_words)) 339 | 340 | augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences] 341 | shuffle(augmented_sentences) 342 | 343 | # trim so that we have the desired number of augmented sentences 344 | if num_aug >= 1: 345 | augmented_sentences = augmented_sentences[:num_aug] 346 | else: 347 | keep_prob = num_aug / len(augmented_sentences) 348 | augmented_sentences = [ 349 | s for s in augmented_sentences if random.uniform(0, 1) < keep_prob 350 | ] 351 | 352 | # DO NOT append the original sentence (we'll fetch it from another part) 353 | # during data loading 354 | # augmented_sentences.append(sentence) 355 | 356 | return augmented_sentences 357 | -------------------------------------------------------------------------------- /utils/data_utils/hwu64_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | All HWU64 specific utilities are implemented here. 3 | """ 4 | from haven import haven_utils as hu 5 | import os 6 | import pandas as pd 7 | 8 | INTENTS = [ 9 | "alarm_query", 10 | "alarm_remove", 11 | "alarm_set", 12 | "audio_volume_down", 13 | "audio_volume_mute", 14 | "audio_volume_up", 15 | "calendar_query", 16 | "calendar_remove", 17 | "calendar_set", 18 | "cooking_recipe", 19 | "datetime_convert", 20 | "datetime_query", 21 | "email_addcontact", 22 | "email_query", 23 | "email_querycontact", 24 | "email_sendemail", 25 | "general_affirm", 26 | "general_commandstop", 27 | "general_confirm", 28 | "general_dontcare", 29 | "general_explain", 30 | "general_joke", 31 | "general_negate", 32 | "general_praise", 33 | "general_quirky", 34 | "general_repeat", 35 | "iot_cleaning", 36 | "iot_coffee", 37 | "iot_hue_lightchange", 38 | "iot_hue_lightdim", 39 | "iot_hue_lightoff", 40 | "iot_hue_lighton", 41 | "iot_hue_lightup", 42 | "iot_wemo_off", 43 | "iot_wemo_on", 44 | "lists_createoradd", 45 | "lists_query", 46 | "lists_remove", 47 | "music_likeness", 48 | "music_query", 49 | "music_settings", 50 | "news_query", 51 | "play_audiobook", 52 | "play_game", 53 | "play_music", 54 | "play_podcasts", 55 | "play_radio", 56 | "qa_currency", 57 | "qa_definition", 58 | "qa_factoid", 59 | "qa_maths", 60 | "qa_stock", 61 | "recommendation_events", 62 | "recommendation_locations", 63 | "recommendation_movies", 64 | "social_post", 65 | "social_query", 66 | "takeaway_order", 67 | "takeaway_query", 68 | "transport_query", 69 | "transport_taxi", 70 | "transport_ticket", 71 | "transport_traffic", 72 | "weather_query", 73 | ] 74 | 75 | 76 | class Hwu64: 77 | def __init__(self, name): 78 | path_base = "./data/dialoglue/data_utils/dialoglue/hwu" 79 | 80 | self.data_name = name 81 | self.full_path = f"./data/{name}/full/dataset.pkl" 82 | self.num_examples = 10 83 | 84 | # 1. Get Intent Mappings 85 | name2id = {k: i for i, k in enumerate(INTENTS)} 86 | hu.save_json(f"./data/{name}/name2id.json", name2id) 87 | id2name = {i: k for i, k in enumerate(INTENTS)} 88 | hu.save_json(f"./data/{name}/id2name.json", id2name) 89 | 90 | # 2. save dataset.pkl 91 | if not os.path.exists(self.full_path): 92 | # get data dict 93 | data_dict = { 94 | "train": get_data_dict(path_base, "train", name2id), 95 | "val": get_data_dict(path_base, "val", name2id), 96 | "test": get_data_dict(path_base, "test", name2id), 97 | } 98 | hu.save_pkl(self.full_path, data_dict) 99 | else: 100 | data_dict = hu.load_pkl(self.full_path) 101 | 102 | # intents = list(name2id.keys()) 103 | # Groupp by intent 104 | self.dataset_by_intent = {} 105 | for split in ["train", "val", "test"]: 106 | self.dataset_by_intent[split] = {} 107 | text_intent_dict = data_dict[split] 108 | text_list, intent_list = ( 109 | text_intent_dict["text"], 110 | map(lambda x: id2name[x], text_intent_dict["intent"]), 111 | ) 112 | 113 | # get texts from intent 114 | intent2texts = {} 115 | for t, i in zip(text_list, intent_list): 116 | if i not in intent2texts: 117 | intent2texts[i] = [] 118 | intent2texts[i] += [t] 119 | self.dataset_by_intent[split] = intent2texts 120 | 121 | self.domain_to_intent = self.generate_domain_to_intent_map() 122 | self.gpt3_batch_size: int = 128 123 | 124 | def parse_and_load(self): 125 | return self.dataset_by_intent 126 | 127 | def generate_domain_to_intent_map(self): 128 | mapping = {} 129 | for intent in INTENTS: 130 | # the intents are structured as {domain}_{intent} 131 | # E.g. the domain is alarm and intent is query for alarm_query 132 | domain = intent.split("_", 1)[0] 133 | if domain not in mapping: 134 | mapping[domain] = [] 135 | mapping[domain].append(intent) 136 | return mapping 137 | 138 | 139 | def get_data_dict(path_base, split, name2id): 140 | tmp_dict = pd.read_csv(os.path.join(path_base, f"{split}.csv")).to_dict() 141 | 142 | data_dict = {} 143 | data_dict["text"] = list(tmp_dict["text"].values()) 144 | data_dict["intent"] = [int(name2id[c]) for c in tmp_dict["category"].values()] 145 | return data_dict 146 | -------------------------------------------------------------------------------- /utils/data_utils/main.py: -------------------------------------------------------------------------------- 1 | import os, json, hashlib, pickle 2 | from utils.data_utils import clinc_utils, snips_utils, banking77_utils, hwu64_utils 3 | 4 | # some handy aliases 5 | pjoin = os.path.join 6 | 7 | write_json = lambda obj, path: json.dump(obj, open(path, "w")) 8 | read_json = lambda path: json.load(open(path, "r")) 9 | 10 | write_pickle = lambda obj, path: pickle.dump(obj, open(path, "wb")) 11 | read_pickle = lambda path: pickle.load(open(path, "rb")) 12 | 13 | write_file = lambda content, path: open(path, "w").write(content) 14 | read_file = lambda path: open(path, "r").read() 15 | 16 | 17 | def get_hash(x): 18 | return int(hashlib.sha1(x.encode("utf-8")).hexdigest(), 16) 19 | 20 | 21 | def get_label_maps(data_root, data_name): 22 | with open(pjoin(data_root, data_name, "name2id.json"), "r") as f: 23 | name2id = json.load(f) 24 | with open(pjoin(data_root, data_name, "id2name.json"), "r") as f: 25 | id2name = json.load(f) 26 | return name2id, id2name 27 | 28 | 29 | def get_ds_config(ds_name): 30 | # This code is used for "prepare_dataset.py" 31 | if ds_name == "clinc_oos": 32 | return clinc_utils.DS_CONFIG() 33 | elif ds_name == "snips_official": 34 | return snips_utils.DS_CONFIG() 35 | elif ds_name == "banking77": 36 | return banking77_utils.Banking77(ds_name) 37 | elif ds_name == "hwu64": 38 | return hwu64_utils.Hwu64(ds_name) 39 | else: 40 | raise NotImplementedError(f"Dataset {ds_name} not supported") 41 | 42 | 43 | def truncate_data(data, num_examples): 44 | """Used to simulate the few-shot setting for validation intents.""" 45 | return sorted(data, key=get_hash)[:num_examples] 46 | 47 | 48 | def parse_and_load(dataset_name): 49 | # This code is used for "prepare_dataset.py" 50 | if dataset_name == "clinc_oos": 51 | return clinc_utils.parse_and_load_clinc() 52 | elif dataset_name == "snips_official": 53 | return snips_utils.parse_and_load_snips() 54 | elif dataset_name == "banking77": 55 | ds = banking77_utils.Banking77(dataset_name) 56 | return ds.parse_and_load() 57 | elif dataset_name == "hwu64": 58 | ds = hwu64_utils.Hwu64(dataset_name) 59 | return ds.parse_and_load() 60 | else: 61 | raise Exception(f"Dataset {dataset_name} not supported") 62 | -------------------------------------------------------------------------------- /utils/data_utils/sample_few_shot.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script samples K(default=10) exemplers per intent class in the few-shot 3 | slice using the heuristic used in the EX2 paper: 4 | https://arxiv.org/pdf/2102.01335.pdf 5 | 6 | Since we will do cross-validation, every domain in the SNIPS dataset will be 7 | treated as a few-shot slice once. 8 | See https://github.com/clinc/oos-eval/blob/master/supplementary.pdf. 9 | """ 10 | import os 11 | import pickle 12 | 13 | from utils.data_utils.main import get_label_maps, truncate_data, parse_and_load 14 | 15 | pjoin = os.path.join 16 | 17 | 18 | def add_fs_slice(data_splits, val_domain, data_root, ds_config): 19 | DOMAIN_TO_INTENT = ds_config.domain_to_intent 20 | few_shot_intents = set(DOMAIN_TO_INTENT[val_domain]) 21 | 22 | # load data grouped by intent, will also prepare the basic dataset.pkl 23 | data = parse_and_load(ds_config.data_name) 24 | name2id, id2name = get_label_maps(data_root, ds_config.data_name) 25 | 26 | # NOTE: `data` has oos_train, oos_val, and oos_test keys as well. 27 | # ignoring them here. since we don't perform a ex2-setup run for the OOS 28 | # domain, OOS samples are added when required in data_loader.py 29 | train_data = data["train"] 30 | val_data = data["val"] 31 | test_data = data["test"] 32 | 33 | # construct D_M and D_F (with respective, train, val, test sets) 34 | # MANY SHOT INTENTS 35 | # ----------------- 36 | m_train_lines, m_train_intents = [], [] 37 | for intent, lines in train_data.items(): 38 | if intent not in few_shot_intents: 39 | m_train_lines.extend(lines) 40 | m_train_intents.extend([name2id[intent]] * len(lines)) 41 | 42 | m_val_lines, m_val_intents = [], [] 43 | for intent, lines in val_data.items(): 44 | if intent not in few_shot_intents: 45 | m_val_lines.extend(lines) 46 | m_val_intents.extend([name2id[intent]] * len(lines)) 47 | 48 | m_test_lines, m_test_intents = [], [] 49 | for intent, lines in test_data.items(): 50 | if intent not in few_shot_intents: 51 | m_test_lines.extend(lines) 52 | m_test_intents.extend([name2id[intent]] * len(lines)) 53 | 54 | # FEW SHOT INTENTS 55 | # ---------------- 56 | f_train_lines, f_train_intents = [], [] 57 | for intent, lines in train_data.items(): 58 | if intent in few_shot_intents: 59 | f_train_lines.extend(truncate_data(lines, ds_config.num_examples)) 60 | f_train_intents.extend([name2id[intent]] * ds_config.num_examples) 61 | 62 | f_val_lines, f_val_intents = [], [] 63 | for intent, lines in val_data.items(): 64 | if intent in few_shot_intents: 65 | f_val_lines.extend(lines) 66 | f_val_intents.extend([name2id[intent]] * len(lines)) 67 | 68 | f_test_lines, f_test_intents = [], [] 69 | for intent, lines in test_data.items(): 70 | if intent in few_shot_intents: 71 | f_test_lines.extend(lines) 72 | f_test_intents.extend([name2id[intent]] * len(lines)) 73 | 74 | data_splits[val_domain] = {} 75 | 76 | # add the many-shot split 77 | data_splits[val_domain]["M"] = { 78 | "train": { 79 | "text": m_train_lines, 80 | "intent": m_train_intents, 81 | }, 82 | "val": { 83 | "text": m_val_lines, 84 | "intent": m_val_intents, 85 | }, 86 | "test": { 87 | "text": m_test_lines, 88 | "intent": m_test_intents, 89 | }, 90 | } 91 | # add the few-shot split 92 | data_splits[val_domain]["F"] = { 93 | "train": { 94 | "text": f_train_lines, 95 | "intent": f_train_intents, 96 | }, 97 | "val": { 98 | "text": f_val_lines, 99 | "intent": f_val_intents, 100 | }, 101 | "test": { 102 | "text": f_test_lines, 103 | "intent": f_test_intents, 104 | }, 105 | } 106 | 107 | 108 | def sample_few_shot(data_root, ds_config): 109 | save_path = pjoin(data_root, ds_config.data_name, "full", "data_full_suite.pkl") 110 | if os.path.exists(save_path): 111 | print(f"Sample few shot has already been called for {ds_config.data_name}") 112 | return 113 | data_splits = {} 114 | DOMAIN_TO_INTENT = ds_config.domain_to_intent.keys() 115 | 116 | for val_domain in DOMAIN_TO_INTENT: 117 | # in-place modifies data_splits 118 | add_fs_slice(data_splits, val_domain, data_root, ds_config) 119 | 120 | # note that this pkl file will be updated 121 | with open(save_path, "wb") as f: 122 | pickle.dump(data_splits, f) 123 | return data_splits 124 | -------------------------------------------------------------------------------- /utils/data_utils/snips_utils.py: -------------------------------------------------------------------------------- 1 | # prepare SNIPS dataset for intent classification 2 | import os 3 | import json 4 | import pickle 5 | import collections 6 | from typing import Dict 7 | from pydantic import BaseModel 8 | 9 | 10 | class DS_CONFIG(BaseModel): 11 | data_name: str = "snips_official" 12 | full_path: str = "./data/snips_official/full/dataset.pkl" 13 | num_examples: int = 10 14 | domain_to_intent: Dict = { 15 | "AddToPlaylist": ["AddToPlaylist"], 16 | "BookRestaurant": ["BookRestaurant"], 17 | "GetWeather": ["GetWeather"], 18 | "PlayMusic": ["PlayMusic"], 19 | "RateBook": ["RateBook"], 20 | "SearchCreativeWork": ["SearchCreativeWork"], 21 | "SearchScreeningEvent": ["SearchScreeningEvent"], 22 | } 23 | gpt3_batch_size: int = 128 24 | 25 | 26 | SNIPS_DIR = './data/snips_official/' 27 | 28 | 29 | def make_label_maps(): 30 | with open(os.path.join(SNIPS_DIR, 'full', 'train', 'label')) as f: 31 | labels = sorted(set([x.strip() for x in f.readlines()])) 32 | intent_name2id = {label: idx for idx, label in enumerate(labels)} 33 | intent_id2name = {idx: label for idx, label in enumerate(labels)} 34 | with open(os.path.join(SNIPS_DIR, 'name2id.json'), 'w') as f: 35 | json.dump(intent_name2id, f) 36 | with open(os.path.join(SNIPS_DIR, 'id2name.json'), 'w') as f: 37 | json.dump(intent_id2name, f) 38 | 39 | 40 | def download_snips_full(): 41 | # NOTE: the official SNIPS repo uses valid, train, and test as split names 42 | for split in ['train', 'valid', 'test']: 43 | download_snips_files(split) 44 | 45 | def download_snips_files(split): 46 | print(f'Downloading SNIPS {split} files') 47 | snips_base_url = f'https://raw.githubusercontent.com/MiuLab/SlotGated-SLU/master/data/snips/{split}/' 48 | download_dir = os.path.join(SNIPS_DIR, 'full', split) 49 | if not os.path.exists(download_dir): 50 | os.makedirs(download_dir) 51 | os.system(f'wget {snips_base_url + "seq.in"} -P {download_dir}') 52 | os.system(f'wget {snips_base_url + "label"} -P {download_dir}') 53 | 54 | 55 | def parse_snips_split(dataset, split, name2id_map): 56 | with open(os.path.join(SNIPS_DIR, 'full', split, 'seq.in')) as f: 57 | sentences = [x.strip() for x in f.readlines()] 58 | with open(os.path.join(SNIPS_DIR, 'full', split, 'label')) as f: 59 | labels = [name2id_map[x.strip()] for x in f.readlines()] 60 | 61 | # just trying to be consistent with naming splits across diff datasets 62 | dataset['val' if split == 'valid' else split] = { 63 | 'text': sentences, 64 | 'intent': labels 65 | } 66 | 67 | 68 | def prepare_snips(): 69 | download_snips_full() 70 | make_label_maps() 71 | 72 | dataset = {} 73 | name2id_map = json.load(open(os.path.join(SNIPS_DIR, 'name2id.json'))) 74 | for split in ['train', 'valid', 'test']: 75 | parse_snips_split(dataset, split, name2id_map) 76 | 77 | with open(os.path.join(SNIPS_DIR, 'full', 'dataset.pkl'), 'wb') as f: 78 | pickle.dump(dataset, f) 79 | print('Base dataset.pkl prepared for SNIPS') 80 | 81 | 82 | def check_snips(): 83 | with open(os.path.join(SNIPS_DIR, 'full', 'dataset.pkl'), 'rb') as f: 84 | dataset = pickle.load(f) 85 | print(len(dataset['train']['text'])) 86 | print(len(dataset['val']['text'])) 87 | print(len(dataset['test']['text'])) 88 | 89 | 90 | def load_snips(): 91 | print('Loading SNIPS dataset') 92 | id2name = json.load(open(os.path.join(SNIPS_DIR, 'id2name.json'))) 93 | data_full_path = DS_CONFIG().full_path 94 | data = collections.defaultdict(lambda: collections.defaultdict(list)) 95 | with open(data_full_path, 'rb') as data_file: 96 | for split_name, split_data in pickle.load(data_file).items(): 97 | for query, intent in zip(split_data['text'], split_data['intent']): 98 | intent = id2name[str(intent)] 99 | data[split_name][intent].append(query) 100 | return data 101 | 102 | 103 | def parse_and_load_snips(): 104 | if not os.path.exists(os.path.join(SNIPS_DIR, 'full', 'dataset.pkl')): 105 | prepare_snips() 106 | check_snips() 107 | return load_snips() 108 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import json, re, os, numpy as np 2 | 3 | from datasets import load_metric 4 | from utils import main_data_utils as mdu 5 | from sklearn.metrics import confusion_matrix 6 | 7 | 8 | class Metrics: 9 | def __init__(self, exp_dict=None, tokenizer=None): 10 | self.exp_dict = exp_dict 11 | self.tokenizer = tokenizer 12 | 13 | def compute_metrics(self): 14 | """ 15 | Will choose the appropriate metric computer based on the config 16 | """ 17 | if "bert" in self.exp_dict["model"]["backbone"]: 18 | return self.compute_metrics_bert 19 | raise ValueError(f"Incompatible backbone {self.exp_dict['model']['backbone']}.") 20 | 21 | def compute_metrics_bert(self, eval_pred): 22 | predictions, labels = eval_pred 23 | predictions = np.argmax(predictions, axis=1) 24 | metrics = {} 25 | # sort, so accuracy is evaluated first, always 26 | for metric in sorted(self.exp_dict["metrics"]): 27 | _metric = eval(f"self.{metric}") 28 | metrics.update(_metric(predictions, labels)) 29 | return metrics 30 | 31 | def accuracy(self, predictions, labels): 32 | accuracies = {} 33 | acc = load_metric("accuracy") 34 | accuracies.update(acc.compute(predictions=predictions, references=labels)) 35 | oos_id = self.exp_dict["dataset"]["oos_id"] 36 | if oos_id is not None: 37 | # compute in_scope accuracy as well 38 | inscope_preds, inscope_labels = [], [] 39 | for idx in range(len(labels)): 40 | if labels[idx] == oos_id: 41 | continue 42 | inscope_labels.append(labels[idx]) 43 | inscope_preds.append(predictions[idx]) 44 | self.inscope_preds, self.inscope_labels = inscope_preds, inscope_labels 45 | accuracies["inscope_accuracy"] = acc.compute( 46 | predictions=inscope_preds, references=inscope_labels 47 | )["accuracy"] 48 | return accuracies 49 | 50 | def f1(self, predictions, labels): 51 | f1s = {} 52 | f1 = load_metric("f1") 53 | f1s.update( 54 | f1.compute(predictions=predictions, references=labels, average="macro") 55 | ) 56 | if self.exp_dict["dataset"]["oos_id"] is not None: 57 | f1s["inscope_f1"] = f1.compute( 58 | predictions=self.inscope_preds, 59 | references=self.inscope_labels, 60 | average="macro", 61 | )["f1"] 62 | return f1s 63 | 64 | def precision(self, predictions, labels): 65 | precision = load_metric("precision") 66 | return precision.compute( 67 | predictions=predictions, references=labels, average="macro" 68 | ) 69 | 70 | def recall(self, predictions, labels): 71 | recalls = {} 72 | recall = load_metric("recall") 73 | recalls.update( 74 | recall.compute(predictions=predictions, references=labels, average="macro") 75 | ) 76 | oos_id = self.exp_dict["dataset"]["oos_id"] 77 | if oos_id is not None and oos_id in labels: 78 | # compute OOS recall 79 | outscope_preds = [] 80 | for idx in range(len(labels)): 81 | if labels[idx] == oos_id: 82 | outscope_preds.append(1 if predictions[idx] == oos_id else -1) 83 | recalls["oos_recall"] = outscope_preds.count(1) / len(outscope_preds) 84 | return recalls 85 | 86 | def confusion_matrix(self, predictions, references): 87 | return { 88 | "confusion_matrix": confusion_matrix( 89 | y_true=references, 90 | y_pred=predictions, 91 | labels=list(range(self.exp_dict["dataset"]["num_labels"])), 92 | ) 93 | } 94 | 95 | def compute_fidelities(self, ds): 96 | """ 97 | Returns and saves fidelities for all engines for given a barebone exp_dict 98 | containing dataset name, num_labels, few_pure setting 99 | its number of classes 100 | Example schema of the returned result 101 | eda: Float 102 | ada_1.0: Float 103 | . 104 | . 105 | . 106 | gptj_2.0: Float 107 | NOTE: al_dataset.pkl must've been generated already for the dataset 108 | """ 109 | al_ds_path = mdu.pjoin("data", ds, "full", "al_dataset.pkl") 110 | if not os.path.exists(al_ds_path): 111 | print("Oracle relabelling hasn't been done on the generated samples yet") 112 | print("Go run runners.oracle_relabel first") 113 | return 114 | generated_ds = mdu.read_pickle(al_ds_path)["generated"] 115 | # path to save this dataset's fidelity results 116 | results_path = f"results/{ds}_fidelity.json" 117 | oos_id = mdu.read_json(f"data/{ds}/name2id.json").get("oos") 118 | if os.path.exists(results_path): 119 | print(f"Loading {results_path} that already exists!") 120 | print(f"Delete/Rename it to compute fidelity for {ds} again") 121 | return mdu.read_json(results_path) 122 | 123 | print(f"Computing fidelity numbers for {ds}") 124 | results = {} 125 | for engine, samples in generated_ds.items(): 126 | if engine == "val": 127 | engine = "threshold" 128 | _a = [ 129 | 1 if old == new else 0 130 | for old, new in zip(samples["old_intent"], samples["intent"]) 131 | ] 132 | results[engine] = np.mean(_a) 133 | 134 | print(f"Saving fidelity numbers for {ds}") 135 | mdu.write_json(results, results_path) 136 | return results 137 | --------------------------------------------------------------------------------