├── .gitignore
├── .havenignore
├── LICENSE
├── NOTICE
├── README.md
├── configs
└── exp_configs
│ ├── __init__.py
│ └── fs_exps.py
├── models
├── __init__.py
└── backbones
│ └── __init__.py
├── prepare_dataset.py
├── requirements.txt
├── runners
├── compile_results.py
├── infer.py
├── modules
│ └── bert.py
├── oracle_relabel.py
└── train.py
├── scripts
├── 3way_human_eval.py
├── __init__.py
├── compute_ttr.py
├── gpt3_analysis.py
├── make_spreadsheet.py
└── openai_sandbox.py
└── utils
├── __init__.py
├── data_utils
├── augment_slices.py
├── banking77_utils.py
├── clinc_utils.py
├── data_loader.py
├── eda_utils.py
├── hwu64_utils.py
├── main.py
├── sample_few_shot.py
└── snips_utils.py
└── metrics.py
/.gitignore:
--------------------------------------------------------------------------------
1 | spreadsheets/*
2 |
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 | results/
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 | .idea
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | .vscode
136 |
137 | data
138 | logs
139 | wandb
140 | results
141 | figures
142 |
143 | # data files
144 | *.pkl
145 | *.yml
146 | *.json
147 |
148 | wandb
149 |
--------------------------------------------------------------------------------
/.havenignore:
--------------------------------------------------------------------------------
1 | wandb/
2 | results/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 ServiceNow
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2022 ServiceNow, Inc.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is the official implementation of the following paper:
2 | [Gaurav Sahu](https://github.com/demfier), Pau Rodriguez, Issam Laradji, Parmida Atighehchian, David Vazquez, and Dzmitry Bahdanau. [Data Augmentation for Intent Classification with Off-the-shelf Large Language Models](https://aclanthology.org/2022.nlp4convai-1.5.pdf). *Proceedings of the 4th Workshop on NLP for Conversational AI, ACL 2022.*
3 |
4 | If you find this code useful, please cite:
5 | ```bibtex
6 | @inproceedings{sahu-etal-2022-data,
7 | title = "Data Augmentation for Intent Classification with Off-the-shelf Large Language Models",
8 | author = "Sahu, Gaurav and
9 | Rodriguez, Pau and
10 | Laradji, Issam and
11 | Atighehchian, Parmida and
12 | Vazquez, David and
13 | Bahdanau, Dzmitry",
14 | booktitle = "Proceedings of the 4th Workshop on NLP for Conversational AI",
15 | month = may,
16 | year = "2022",
17 | address = "Dublin, Ireland",
18 | publisher = "Association for Computational Linguistics",
19 | url = "https://aclanthology.org/2022.nlp4convai-1.5",
20 | pages = "47--57",
21 | }
22 | ```
23 |
24 | # Running experiments
25 |
26 | ### Datasets
27 |
28 | #### Preparing data
29 | This step is only required if there you are starting from scratch, i.e., *NO* data has been prepared at all.
30 | Note that you are still required to create the symbolic link as suggested in the previous step.
31 | To get started with data preparation, run the following:
32 | ```
33 | python prepare_dataset.py --name --data_root './data/'
34 | ```
35 |
36 | This will generate samples for all the supported modes (upsample, gpt3, gptj, eda).
37 | You can enable top-k and top-p sampling by specifying appropriate values for `--top_k` [0,) and `--top_p` [0, 1] flags.
38 | **NOTE**: If generating for GPTJ, make sure there's enough GPU memory (recommended >=32G).
39 |
40 | This will also setup the data directory structure for ``.
41 | It will prepare a `dataset.pkl` AND `data_full_suite.pkl`.
42 | It will also generate the corresponding label maps (name2id, id2name).
43 | Make sure you have `wget` installed in your local machine.
44 |
45 | **Note:**
46 | - HWU64 was downloaded from "https://github.com/alexa/dialoglue"
47 | - Banking77, and CLINC150 were downloaded using the HuggingFace "datasets" library
48 | - SNIPS was downloaded from "https://github.com/MiuLab/SlotGated-SLU"
49 |
50 | Refer to the `fewshot_baseline_clinc` configuration in `configs/exp_configs/fs_exps.py` for full few-shot experiment config.
51 |
52 | #### Running experiments:
53 | To run baseline experiments following the original CLINC setting:
54 | 1. Edit the `baselines` variable inside `configs/exp_configs/fs_exps.py`. Here's an example for running `small` and `plus` baselines together:
55 |
56 | ```python
57 | baselines = hu.cartesian_exp_group(
58 | {
59 | # do multiple runs to account for stochasticity in metrics
60 | "run#": list(range(10)),
61 | "dataset": [
62 | {"name": "clinc_oos", "num_labels": 151, "config": c, "oos_id": 150}
63 | for c in ["plus", "small"]
64 | ], # config: small/plus/full/few_pure
65 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
66 | "exp_type": "baseline", # intrinsic/baseline
67 | "lr": 4e-5,
68 | "batch_size": 32,
69 | "epochs": 6,
70 | "warmup_ratio": 0.1,
71 | "weight_decay": 0.01,
72 | # metrics to compute
73 | "metrics": [["accuracy", "f1", "precision", "recall"]],
74 | "metric_best": "accuracy",
75 | "ngpu": 1,
76 | "eval_accumulation_steps": 30,
77 | }
78 | )
79 | ```
80 | **Note:** For CLINC (`name='clinc_oos'`), `oos_id=42` for `small/plus/imbalanced` and `oos_id=150` for `full/full_*`.
81 | It also supports no-OOS classifiers. Set `oos_id=None` and `num_labels=150`.
82 | For SNIPS (`name='snips_official'`) `oos_id=None` and it only supports `full*` settings. Make sure that `oos_id` is set correctly.
83 | Refer to other config variables inside `configs/exp_configs.py` for partial few-shot (ex2 setup) and full few-shot configs.
84 |
85 | 3. Run experiments:
86 | ```
87 | $(which python) -m runners.train --savedir_base /path/to/save/dir/ --exp_group_list baselines -j 1 -v results.ipynb --python_binary $(which python)
88 | ```
89 | Setting `-j 0` will run it locally. `--exp_group_list ex2_setup` will run the EX2 experiments (make sure that the dataset preperation is complete)
90 |
91 | #### For Oracle relabeling experiments:
92 | * Relabel generated examples using an oracle:
93 | ```
94 | $(which python) -m runners.oracle_relabel -md /path/to/oracle/ -e fewshot_baseline_clinc
95 | ```
96 | * Train classifiers on the relabeled data:
97 | ```
98 | $(which python) -m runners.train --savedir_base /path/to/save/dir/ --exp_group_list fewshot_oracle_clinc -j 1 -v results.ipynb --python_binary $(which python)
99 | ```
100 | 4. To compile results, correctly set the "knobs" in [runners.compile_results](runners/compile_results.py#L9-L30) and then run `python -m runners.compile_results` from root.
101 |
102 |
103 | #### Adding a new dataset
104 | To add a new dataset, follow these steps:
105 | * Create a utils file for your dataset under `utils/data_utils/`. Let's call it `{dataset}_utils.py`.
106 | All the dataset-specific processing needs to be added in there. In the end, your `{dataset}_utils.py` needs
107 | to have a `parse_and_load_{dataset}` function. Refer to the documentation of `clinc_utils.parse_and_load_clinc()` to understand more.
108 | * Add your dataset to `parse_and_load()` and `get_ds_config()` in `utils/data_utils/main.py`.
109 | * [Running prepare_dataset.py](#preparing-data) for your dataset name now should create the required files for ex2 and non-ex2 setup.
110 | * Finally, refer to [this](#for-oracle-relabling-experiments) to also generate dataset for the oracle relabling experiments in the full few-shot setup.
111 |
112 |
113 | #### List of configs exp_configs.py for different experiments
114 | 1. Reproducing CLINC150 results:
115 |
116 | ```python
117 | baselines = hu.cartesian_exp_group({
118 | # do multiple runs to account for stochasticity in metrics
119 | 'run#': list(range(10)),
120 | 'dataset': {'name': 'clinc_oos', 'num_labels': 151, 'oos_id': 150, 'config': 'full'}, # config: small/plus/full
121 | 'model': {
122 | 'name': 'intent_classification',
123 | 'backbone': 'bert-large-uncased'
124 | },
125 | 'exp_type': 'baseline', # intrinsic/baseline
126 | 'lr': 4e-5,
127 | 'batch_size': 32,
128 | 'epochs': 6,
129 | 'warmup_ratio': 0.1,
130 | 'weight_decay': 0.01,
131 | # metrics to compute
132 | 'metrics': [['accuracy', 'f1', 'precision', 'recall']],
133 | 'metric_best': 'accuracy',
134 | 'ngpu': 1,
135 | 'eval_accumulation_steps': 30
136 | })
137 | ```
138 |
139 | 2. Running Partial few-shot baseline/upsample experiments:
140 |
141 | ```python
142 | ex2_setup = hu.cartesian_exp_group({
143 | # do multiple runs to account for stochasticity in metrics
144 | 'run#': list(range(10)),
145 | 'dataset': [{
146 | 'name': 'clinc_oos', 'num_labels': 151, 'oos_id': 150,
147 | 'config': 'full_'+v} for v in DOMAINS], # config -> small/imbalanced/plus/small_aug/full
148 | 'model': {
149 | 'name': 'intent_classification',
150 | 'backbone': 'bert-large-uncased'
151 | },
152 | 'exp_type': ['baseline', 'upsample'], # gpt3/upsample/baseline
153 | 'lr': 5e-5,
154 | 'batch_size': 64,
155 | 'epochs': 10,
156 | 'warmup_ratio': 0.1,
157 | 'weight_decay': 0.01,
158 | # metrics to compute. if oos_id is not None,
159 | # compute inscope_accuracy and oos_recall as well
160 | 'metrics': [['accuracy', 'f1', 'precision', 'recall']],
161 | 'metric_best': 'f1',
162 | 'eval_accumulation_steps': 30
163 | })
164 | ```
165 |
166 | 3. Partial few-shot augmented (GPT3) experiments:
167 |
168 | ```python
169 | ex2_setup = hu.cartesian_exp_group({
170 | # do multiple runs to account for stochasticity in metrics
171 | 'run#': list(range(10)),
172 | 'dataset': [{
173 | 'name': 'clinc_oos', 'num_labels': 151, 'oos_id': 150,
174 | 'config': 'full_'+v} for v in DOMAINS], # config -> small/imbalanced/plus/small_aug/full
175 | 'model': {
176 | 'name': 'intent_classification',
177 | 'backbone': 'bert-large-uncased'
178 | },
179 | 'exp_type': ['gpt3'], # gpt3/upsample/baseline
180 | 'lr': 5e-5,
181 | 'batch_size': 64,
182 | 'epochs': 10,
183 | 'warmup_ratio': 0.1,
184 | 'weight_decay': 0.01,
185 | # metrics to compute. if oos_id is not None,
186 | # compute inscope_accuracy and oos_recall as well
187 | 'metrics': [['accuracy', 'f1', 'precision', 'recall']],
188 | 'metric_best': 'f1',
189 | # 'gpt3_engine': 'ada', # ada/babbage/curie/davinci
190 | 'gpt3_engine': ['ada', 'babbage', 'curie', 'davinci'], # ada/babbage/curie/davinci
191 | # 'gpt3_temp': 1.0, # 0.5/0.6/0.7/0.8/0.9/1.0/1.5/2.0
192 | 'gpt3_temp': [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.5, 2.0], # 0.5-2.0
193 | 'eval_accumulation_steps': 30
194 | })
195 | ```
196 |
197 | 4. Training the oracle:
198 |
199 | ```python
200 | baselines = hu.cartesian_exp_group({
201 | # do multiple runs to account for stochasticity in metrics
202 | 'run#': list(range(1)),
203 | 'dataset': {'name': 'clinc_oos', 'num_labels': 151, 'oos_id': None, 'config': 'full'}, # config: small/plus/full
204 | 'model': {
205 | 'name': 'intent_classification',
206 | 'backbone': 'bert-large-uncased'
207 | },
208 | 'exp_type': 'intrinsic', # intrinsic/baseline
209 | 'lr': 4e-5,
210 | 'batch_size': 32,
211 | 'epochs': 6,
212 | 'warmup_ratio': 0.1,
213 | 'weight_decay': 0.01,
214 | # metrics to compute
215 | 'metrics': [['accuracy', 'f1', 'precision', 'recall']],
216 | 'metric_best': 'accuracy',
217 | 'ngpu': 1,
218 | 'eval_accumulation_steps': 30
219 | })
220 | ```
221 |
--------------------------------------------------------------------------------
/configs/exp_configs/__init__.py:
--------------------------------------------------------------------------------
1 | from . import fs_exps
2 |
3 | EXP_GROUPS = {}
4 | EXP_GROUPS.update(fs_exps.EXP_GROUPS)
5 |
--------------------------------------------------------------------------------
/configs/exp_configs/fs_exps.py:
--------------------------------------------------------------------------------
1 | from haven import haven_utils as hu
2 |
3 | EXP_GROUPS = {}
4 |
5 | # CLINC
6 | DOMAINS = [
7 | "banking",
8 | "credit_card",
9 | "dining",
10 | "home",
11 | "auto",
12 | "travel",
13 | "utility",
14 | "work",
15 | "small_talk",
16 | "meta",
17 | ]
18 |
19 |
20 | baselines = hu.cartesian_exp_group(
21 | {
22 | # do multiple runs to account for stochasticity in metrics
23 | "run#": list(range(10)),
24 | "dataset": [
25 | {"name": "clinc_oos", "num_labels": 151, "config": c, "oos_id": 150}
26 | for c in ["few_pure"]
27 | ], # config: small/plus/full/few_pure
28 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
29 | "exp_type": "baseline", # intrinsic/baseline
30 | "lr": 4e-5,
31 | "batch_size": 32,
32 | "epochs": 6,
33 | "warmup_ratio": 0.1,
34 | "weight_decay": 0.01,
35 | # metrics to compute
36 | "metrics": [["accuracy", "f1", "precision", "recall"]],
37 | "metric_best": "accuracy",
38 | "ngpu": 1,
39 | # 'gpt3_engine': ['ada', 'babbage', 'curie', 'davinci'],
40 | # 'gpt3_temp': 1.0,
41 | "eval_accumulation_steps": 30,
42 | }
43 | )
44 |
45 | ex2_setup = hu.cartesian_exp_group(
46 | {
47 | # do multiple runs to account for stochasticity in metrics
48 | "run#": list(range(10)),
49 | "dataset": [
50 | {
51 | "name": "clinc_oos",
52 | "num_labels": 151,
53 | "oos_id": 150,
54 | "config": "full_" + v,
55 | }
56 | for v in DOMAINS
57 | ], # config -> small/imbalanced/plus/small_aug/full
58 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
59 | "exp_type": ["gpt3"], # gpt3/upsample/baseline
60 | "lr": 5e-5,
61 | "batch_size": 64,
62 | "epochs": 10,
63 | "warmup_ratio": 0.1,
64 | "weight_decay": 0.01,
65 | # metrics to compute. if oos_id is not None,
66 | # compute inscope_accuracy and oos_recall as well
67 | "metrics": [["accuracy", "f1", "precision", "recall"]],
68 | "metric_best": "accuracy",
69 | # 'gpt3_engine': 'ada', # ada/babbage/curie/davinci
70 | "gpt3_engine": [
71 | "ada",
72 | "babbage",
73 | "curie",
74 | "davinci",
75 | "gptj",
76 | ], # ada/babbage/curie/davinci
77 | # 'gpt3_temp': 1.0, # 0.5/0.6/0.7/0.8/0.9/1.0/1.5/2.0
78 | "gpt3_temp": 1.0, # 0.5-2.0
79 | "eval_accumulation_steps": 30,
80 | }
81 | )
82 |
83 | fewshot_oracle_clinc = hu.cartesian_exp_group(
84 | {
85 | "run#": list(range(10)), # for extrinsic evaluation
86 | "dataset": {
87 | "name": "clinc_oos",
88 | "num_labels": 151,
89 | "oos_id": 150,
90 | "config": "few_pure",
91 | },
92 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
93 | "lr": 5e-5,
94 | "batch_size": 64,
95 | "epochs": 6,
96 | "warmup_ratio": 0.1,
97 | "weight_decay": 0.01,
98 | "metrics": [["accuracy", "f1", "precision", "recall"]],
99 | "metric_best": "accuracy", # accuracy/f1
100 | "exp_type": "gpt3_oracle",
101 | "gpt3_engine": [
102 | "ada",
103 | "babbage",
104 | "curie",
105 | "davinci",
106 | "gptj",
107 | ], # ada/babbage/curie/davinci/gptj
108 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
109 | "eval_accumulation_steps": 30,
110 | }
111 | )
112 |
113 | fewshot_gpt3_clinc = hu.cartesian_exp_group(
114 | {
115 | "run#": list(range(10)), # for extrinsic evaluation
116 | "dataset": {
117 | "name": "clinc_oos",
118 | "num_labels": 151,
119 | "oos_id": 150,
120 | "config": "few_pure",
121 | },
122 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
123 | "lr": 5e-5,
124 | "batch_size": 64,
125 | "epochs": 6,
126 | "warmup_ratio": 0.1,
127 | "weight_decay": 0.01,
128 | "metrics": [["accuracy", "f1", "precision", "recall"]],
129 | "metric_best": "accuracy", # accuracy/f1
130 | "exp_type": "gpt3",
131 | "gpt3_engine": [
132 | "ada",
133 | "babbage",
134 | "curie",
135 | "davinci",
136 | "gptj",
137 | ], # ada/babbage/curie/davinci/gptj
138 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
139 | "eval_accumulation_steps": 30,
140 | }
141 | )
142 |
143 | fewshot_gpt3mix_clinc = hu.cartesian_exp_group(
144 | {
145 | "run#": list(range(10)), # for extrinsic evaluation
146 | "dataset": {
147 | "name": "clinc_oos",
148 | "num_labels": 151,
149 | "oos_id": 150,
150 | "config": "few_pure",
151 | },
152 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
153 | "lr": 5e-5,
154 | "batch_size": 64,
155 | "epochs": 6,
156 | "warmup_ratio": 0.1,
157 | "weight_decay": 0.01,
158 | "metrics": [["accuracy", "f1", "precision", "recall"]],
159 | "metric_best": "accuracy", # accuracy/f1
160 | "exp_type": ["gpt3mix", "gpt3mix_oracle"],
161 | "soft_label": True,
162 | "gpt3_engine": "curie", # only generated for curie
163 | "gpt3_temp": 1.0, # only generated for 1.0
164 | "eval_accumulation_steps": 30,
165 | }
166 | )
167 |
168 | fewshot_eda_clinc = hu.cartesian_exp_group(
169 | {
170 | "run#": list(range(10)), # for extrinsic evaluation
171 | "dataset": {
172 | "name": "clinc_oos",
173 | "num_labels": 151,
174 | "oos_id": 150,
175 | "config": "few_pure",
176 | },
177 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
178 | "lr": 5e-5,
179 | "batch_size": 64,
180 | "epochs": 6,
181 | "warmup_ratio": 0.1,
182 | "weight_decay": 0.01,
183 | "metrics": [["accuracy", "f1", "precision", "recall"]],
184 | "metric_best": "accuracy", # accuracy/f1
185 | "exp_type": ["eda", "eda_oracle"], # eda/eda_oracle
186 | "eval_accumulation_steps": 30,
187 | }
188 | )
189 |
190 |
191 | fewshot_baseline_clinc = hu.cartesian_exp_group(
192 | {
193 | "run#": list(range(10)), # for extrinsic evaluation
194 | "dataset": {
195 | "name": "clinc_oos",
196 | "num_labels": 151,
197 | "oos_id": 150,
198 | "config": "few_pure",
199 | },
200 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
201 | "lr": 3e-5, # 1e-5 for k=3, 5
202 | "batch_size": 8, # 2 for k=3, 5
203 | "epochs": 6, # 20 for k=3, 5
204 | "warmup_ratio": 0.1,
205 | "weight_decay": 0.001, # 0.0001 for k=3, 5
206 | "metrics": [["accuracy", "f1", "precision", "recall"]],
207 | "metric_best": "accuracy", # accuracy/f1
208 | "exp_type": "baseline",
209 | "eval_accumulation_steps": 30,
210 | }
211 | )
212 |
213 | banking77_baselines = hu.cartesian_exp_group(
214 | {
215 | # do multiple runs to account for stochasticity in metrics
216 | "run#": list(range(10)),
217 | "dataset": {
218 | "name": "banking77",
219 | "num_labels": 77,
220 | "config": "full",
221 | "oos_id": None,
222 | }, # config: small/plus/full/few_pure
223 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
224 | "exp_type": "baseline", # intrinsic/baseline
225 | "lr": 5e-5,
226 | "batch_size": 32,
227 | "epochs": 10,
228 | "warmup_ratio": 0.1,
229 | "weight_decay": 0.01,
230 | # metrics to compute
231 | "metrics": [["accuracy", "f1", "precision", "recall"]],
232 | "metric_best": "accuracy",
233 | "eval_accumulation_steps": 30,
234 | }
235 | )
236 |
237 | fewshot_baseline_banking77 = hu.cartesian_exp_group(
238 | {
239 | "run#": list(range(10)), # for extrinsic evaluation
240 | "dataset": {
241 | "name": "banking77",
242 | "num_labels": 77,
243 | "oos_id": None,
244 | "config": "few_pure",
245 | },
246 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
247 | "lr": 3e-5,
248 | "batch_size": 8,
249 | "epochs": 20,
250 | "warmup_ratio": 0.1,
251 | "weight_decay": 0.001,
252 | "metrics": [["accuracy", "f1", "precision", "recall"]],
253 | "metric_best": "accuracy", # accuracy/f1
254 | "exp_type": "baseline",
255 | "eval_accumulation_steps": 30,
256 | }
257 | )
258 |
259 | fewshot_oracle_banking77 = hu.cartesian_exp_group(
260 | {
261 | "run#": list(range(10)), # for extrinsic evaluation
262 | "dataset": {
263 | "name": "banking77",
264 | "num_labels": 77,
265 | "oos_id": None,
266 | "config": "few_pure",
267 | },
268 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
269 | "lr": 5e-5,
270 | "batch_size": 32,
271 | "epochs": 10,
272 | "warmup_ratio": 0.1,
273 | "weight_decay": 0.01,
274 | "metrics": [["accuracy", "f1", "precision", "recall"]],
275 | "metric_best": "accuracy", # accuracy/f1
276 | "exp_type": "gpt3_oracle",
277 | "gpt3_engine": [
278 | "ada",
279 | "babbage",
280 | "curie",
281 | "davinci",
282 | "gptj",
283 | ], # ada/babbage/curie/davinci/gptj
284 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
285 | "eval_accumulation_steps": 30,
286 | }
287 | )
288 |
289 | fewshot_gpt3_banking77 = hu.cartesian_exp_group(
290 | {
291 | "run#": list(range(10)), # for extrinsic evaluation
292 | "dataset": {
293 | "name": "banking77",
294 | "num_labels": 77,
295 | "oos_id": None,
296 | "config": "few_pure",
297 | },
298 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
299 | "lr": 5e-5,
300 | "batch_size": 32,
301 | "epochs": 10,
302 | "warmup_ratio": 0.1,
303 | "weight_decay": 0.01,
304 | "metrics": [["accuracy", "f1", "precision", "recall"]],
305 | "metric_best": "accuracy", # accuracy/f1
306 | "exp_type": "gpt3",
307 | "gpt3_engine": [
308 | "ada",
309 | "babbage",
310 | "curie",
311 | "davinci",
312 | "gptj",
313 | ], # ada/babbage/curie/davinci/gptj
314 | # "gpt3_engine": "curie", # ada/babbage/curie/davinci/gptj
315 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
316 | "eval_accumulation_steps": 30,
317 | }
318 | )
319 |
320 |
321 | fewshot_eda_banking77 = hu.cartesian_exp_group(
322 | {
323 | "run#": list(range(10)), # for extrinsic evaluation
324 | "dataset": {
325 | "name": "banking77",
326 | "num_labels": 77,
327 | "oos_id": None,
328 | "config": "few_pure",
329 | },
330 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
331 | "lr": 5e-5,
332 | "batch_size": 32,
333 | "epochs": 10,
334 | "warmup_ratio": 0.1,
335 | "weight_decay": 0.01,
336 | "metrics": [["accuracy", "f1", "precision", "recall"]],
337 | "metric_best": "accuracy", # accuracy/f1
338 | "exp_type": ["eda", "eda_oracle"],
339 | "eval_accumulation_steps": 30,
340 | }
341 | )
342 |
343 | hwu64_baselines = hu.cartesian_exp_group(
344 | {
345 | # do multiple runs to account for stochasticity in metrics
346 | "run#": list(range(10)),
347 | "dataset": {
348 | "name": "hwu64",
349 | "num_labels": 64,
350 | "config": "full",
351 | "oos_id": None,
352 | }, # config: small/plus/full/few_pure
353 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
354 | "exp_type": "intrinsic", # intrinsic/baseline
355 | "lr": 5e-5,
356 | "batch_size": 32,
357 | "epochs": 6,
358 | "warmup_ratio": 0.1,
359 | "weight_decay": 0.01,
360 | # metrics to compute
361 | "metrics": [["accuracy", "f1", "precision", "recall"]],
362 | "metric_best": "accuracy",
363 | "eval_accumulation_steps": 30,
364 | }
365 | )
366 |
367 | fewshot_baseline_hwu64 = hu.cartesian_exp_group(
368 | {
369 | "run#": list(range(10)), # for extrinsic evaluation
370 | "dataset": {
371 | "name": "hwu64",
372 | "num_labels": 64,
373 | "oos_id": None,
374 | "config": "few_pure",
375 | },
376 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
377 | "lr": 3e-5,
378 | "batch_size": 8,
379 | "epochs": 20,
380 | "warmup_ratio": 0.1,
381 | "weight_decay": 0.001,
382 | "metrics": [["accuracy", "f1", "precision", "recall"]],
383 | "metric_best": "accuracy", # accuracy/f1
384 | "exp_type": "baseline",
385 | "eval_accumulation_steps": 30,
386 | }
387 | )
388 |
389 | fewshot_oracle_hwu64 = hu.cartesian_exp_group(
390 | {
391 | "run#": list(range(10)), # for extrinsic evaluation
392 | "dataset": {
393 | "name": "hwu64",
394 | "num_labels": 64,
395 | "oos_id": None,
396 | "config": "few_pure",
397 | },
398 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
399 | "lr": 3e-5,
400 | "batch_size": 32,
401 | "epochs": 10,
402 | "warmup_ratio": 0.1,
403 | "weight_decay": 0.01,
404 | "metrics": [["accuracy", "f1", "precision", "recall"]],
405 | "metric_best": "accuracy", # accuracy/f1
406 | "exp_type": "gpt3_oracle",
407 | "gpt3_engine": [
408 | "ada",
409 | "babbage",
410 | "curie",
411 | "davinci",
412 | "gptj",
413 | ], # ada/babbage/curie/davinci/gptj
414 | # "gpt3_engine": "babbage", # ada/babbage/curie/davinci/gptj
415 | "gpt3_temp": 1.0,
416 | "eval_accumulation_steps": 30,
417 | }
418 | )
419 |
420 | fewshot_gpt3_hwu64 = hu.cartesian_exp_group(
421 | {
422 | "run#": list(range(10)), # for extrinsic evaluation
423 | "dataset": {
424 | "name": "hwu64",
425 | "num_labels": 64,
426 | "oos_id": None,
427 | "config": "few_pure",
428 | },
429 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
430 | "lr": 3e-5,
431 | "batch_size": 32,
432 | "epochs": 10,
433 | "warmup_ratio": 0.1,
434 | "weight_decay": 0.01,
435 | "metrics": [["accuracy", "f1", "precision", "recall"]],
436 | "metric_best": "accuracy", # accuracy/f1
437 | "exp_type": "gpt3",
438 | "gpt3_engine": [
439 | "ada",
440 | "babbage",
441 | "curie",
442 | "davinci",
443 | "gptj",
444 | ], # ada/babbage/curie/davinci/gptj
445 | # "gpt3_engine": "babbage", # ada/babbage/curie/davinci/gptj
446 | # "gpt3_temp": [round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))],
447 | "gpt3_temp": [1.0],
448 | "eval_accumulation_steps": 30,
449 | }
450 | )
451 |
452 | fewshot_eda_hwu64 = hu.cartesian_exp_group(
453 | {
454 | "run#": list(range(10)), # for extrinsic evaluation
455 | "dataset": {
456 | "name": "hwu64",
457 | "num_labels": 64,
458 | "oos_id": None,
459 | "config": "few_pure",
460 | },
461 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
462 | "lr": 3e-5,
463 | "batch_size": 32,
464 | "epochs": 10,
465 | "warmup_ratio": 0.1,
466 | "weight_decay": 0.01,
467 | "metrics": [["accuracy", "f1", "precision", "recall"]],
468 | "metric_best": "accuracy", # accuracy/f1
469 | "exp_type": ["eda", "eda_oracle"],
470 | "eval_accumulation_steps": 30,
471 | }
472 | )
473 |
474 | # SNIPS
475 | SNIPS_DOMAINS = [
476 | "AddToPlaylist",
477 | "BookRestaurant",
478 | "GetWeather",
479 | "PlayMusic",
480 | "RateBook",
481 | "SearchCreativeWork",
482 | "SearchScreeningEvent",
483 | ]
484 |
485 | # SNIPS_DOMAINS = ["RateBook"]
486 |
487 | snips_baselines = hu.cartesian_exp_group(
488 | {
489 | "dataset": {
490 | "name": "snips_official",
491 | "num_labels": 7,
492 | "oos_id": None,
493 | "config": "full",
494 | }, # config -> small/imbalanced/plus/small_aug/full/intrinsic
495 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
496 | "exp_type": "baseline", # upsample/baseline
497 | "lr": 4e-5,
498 | "batch_size": 32,
499 | "epochs": 6,
500 | "warmup_ratio": 0.1,
501 | "weight_decay": 0.01,
502 | "metrics": [["accuracy"]],
503 | "metric_best": "accuracy",
504 | "ngpu": 1,
505 | "gpt3_temp": 1.0,
506 | "eval_accumulation_steps": 30,
507 | }
508 | )
509 |
510 | snips_ex2_setup = hu.cartesian_exp_group(
511 | {
512 | # do multiple runs to account for stochasticity in metrics
513 | "run#": list(range(10)),
514 | "dataset": [
515 | {
516 | "name": "snips_official",
517 | "num_labels": 7,
518 | "oos_id": None,
519 | "config": "full_" + v,
520 | }
521 | for v in SNIPS_DOMAINS
522 | ], # config -> small/imbalanced/plus/small_aug/full
523 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
524 | "exp_type": ["gpt3"], # gpt3/upsample/baseline
525 | "lr": 5e-5,
526 | "batch_size": 64,
527 | "epochs": 10,
528 | "warmup_ratio": 0.1,
529 | "weight_decay": 0.01,
530 | # metrics to compute. if oos_id is not None,
531 | # compute inscope_accuracy and oos_recall as well
532 | "metrics": [["accuracy", "f1", "precision", "recall"]],
533 | "metric_best": "accuracy",
534 | # 'gpt3_engine': 'davinci', # ada/babbage/curie/davinci/gptj
535 | "gpt3_engine": [
536 | "ada",
537 | "babbage",
538 | "curie",
539 | "davinci",
540 | "gptj",
541 | ], # ada/babbage/curie/davinci/gptj
542 | "gpt3_temp": 1.0, # 0.5/0.6/0.7/0.8/0.9/1.0/1.5
543 | "eval_accumulation_steps": 30,
544 | }
545 | )
546 |
547 |
548 | fewshot_baseline_snips = hu.cartesian_exp_group(
549 | {
550 | "run#": list(range(10)), # for extrinsic evaluation
551 | "dataset": {
552 | "name": "snips_official",
553 | "num_labels": 7,
554 | "oos_id": None,
555 | "config": "few_pure",
556 | },
557 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
558 | "lr": 3e-5,
559 | "batch_size": 8,
560 | "epochs": 20,
561 | "warmup_ratio": 0.1,
562 | "weight_decay": 0.001,
563 | "metrics": [["accuracy", "f1", "precision", "recall"]],
564 | "metric_best": "accuracy", # accuracy/f1
565 | "exp_type": "baseline",
566 | "eval_accumulation_steps": 30,
567 | }
568 | )
569 |
570 | fewshot_oracle_snips = hu.cartesian_exp_group(
571 | {
572 | "run#": list(range(10)), # for extrinsic evaluation
573 | "dataset": {
574 | "name": "snips_official",
575 | "num_labels": 7,
576 | "oos_id": None,
577 | "config": "few_pure",
578 | },
579 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
580 | "lr": 5e-5,
581 | "batch_size": 64,
582 | "epochs": 10,
583 | "warmup_ratio": 0.1,
584 | "weight_decay": 0.01,
585 | "metrics": [["accuracy", "f1", "precision", "recall"]],
586 | "metric_best": "accuracy", # accuracy/f1
587 | "exp_type": "gpt3_oracle",
588 | "gpt3_engine": [
589 | "ada",
590 | "babbage",
591 | "curie",
592 | "davinci",
593 | "gptj",
594 | ], # ada/babbage/curie/davinci/gptj
595 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
596 | "eval_accumulation_steps": 30,
597 | }
598 | )
599 |
600 | fewshot_gpt3_snips = hu.cartesian_exp_group(
601 | {
602 | "run#": list(range(10)), # for extrinsic evaluation
603 | "dataset": {
604 | "name": "snips_official",
605 | "num_labels": 7,
606 | "oos_id": None,
607 | "config": "few_pure",
608 | },
609 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
610 | "lr": 5e-5,
611 | "batch_size": 64,
612 | "epochs": 10,
613 | "warmup_ratio": 0.1,
614 | "weight_decay": 0.01,
615 | "metrics": [["accuracy", "f1", "precision", "recall"]],
616 | "metric_best": "accuracy", # accuracy/f1
617 | "exp_type": "gpt3",
618 | "gpt3_engine": [
619 | "ada",
620 | "babbage",
621 | "curie",
622 | "davinci",
623 | "gptj",
624 | ], # ada/babbage/curie/davinci/gptj
625 | # "gpt3_engine": "babbage", # ada/babbage/curie/davinci/gptj
626 | "gpt3_temp": 1.0, # [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
627 | "eval_accumulation_steps": 30,
628 | }
629 | )
630 |
631 | fewshot_eda_snips = hu.cartesian_exp_group(
632 | {
633 | "run#": list(range(10)), # for extrinsic evaluation
634 | "dataset": {
635 | "name": "snips_official",
636 | "num_labels": 7,
637 | "oos_id": None,
638 | "config": "few_pure",
639 | },
640 | "model": {"name": "intent_classification", "backbone": "bert-large-uncased"},
641 | "lr": 5e-5,
642 | "batch_size": 64,
643 | "epochs": 10,
644 | "warmup_ratio": 0.1,
645 | "weight_decay": 0.01,
646 | "metrics": [["accuracy", "f1", "precision", "recall"]],
647 | "metric_best": "accuracy", # accuracy/f1
648 | "exp_type": ["eda", "eda_oracle"],
649 | "eval_accumulation_steps": 30,
650 | }
651 | )
652 |
653 | # CLINC
654 | EXP_GROUPS["baselines"] = baselines
655 | EXP_GROUPS["ex2_setup"] = ex2_setup
656 | EXP_GROUPS["fewshot_baseline_clinc"] = fewshot_baseline_clinc
657 | EXP_GROUPS["fewshot_oracle_clinc"] = fewshot_oracle_clinc
658 | EXP_GROUPS["fewshot_gpt3_clinc"] = fewshot_gpt3_clinc
659 | EXP_GROUPS["fewshot_gpt3mix_clinc"] = fewshot_gpt3mix_clinc
660 | EXP_GROUPS["fewshot_eda_clinc"] = fewshot_eda_clinc
661 |
662 | # Banking77
663 | EXP_GROUPS["banking77_baselines"] = banking77_baselines
664 | EXP_GROUPS["fewshot_baseline_banking77"] = fewshot_baseline_banking77
665 | EXP_GROUPS["fewshot_gpt3_banking77"] = fewshot_gpt3_banking77
666 | EXP_GROUPS["fewshot_oracle_banking77"] = fewshot_oracle_banking77
667 | EXP_GROUPS["fewshot_eda_banking77"] = fewshot_eda_banking77
668 |
669 | # HWU64
670 | EXP_GROUPS["hwu64_baselines"] = hwu64_baselines
671 | EXP_GROUPS["fewshot_baseline_hwu64"] = fewshot_baseline_hwu64
672 | EXP_GROUPS["fewshot_gpt3_hwu64"] = fewshot_gpt3_hwu64
673 | EXP_GROUPS["fewshot_oracle_hwu64"] = fewshot_oracle_hwu64
674 | EXP_GROUPS["fewshot_eda_hwu64"] = fewshot_eda_hwu64
675 |
676 | # SNIPS
677 | EXP_GROUPS["snips_baselines"] = snips_baselines
678 | EXP_GROUPS["snips_ex2_setup"] = snips_ex2_setup
679 | EXP_GROUPS["fewshot_baseline_snips"] = fewshot_baseline_snips
680 | EXP_GROUPS["fewshot_gpt3_snips"] = fewshot_gpt3_snips
681 | EXP_GROUPS["fewshot_oracle_snips"] = fewshot_oracle_snips
682 | EXP_GROUPS["fewshot_eda_snips"] = fewshot_eda_snips
683 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ServiceNow/data-augmentation-with-llms/77c740df2aaf4bfc1b2449114613f98b37a1faae/models/__init__.py
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import (
2 | AutoModelForSequenceClassification,
3 | AutoModelForSeq2SeqLM,
4 | AutoConfig,
5 | BertModel,
6 | )
7 |
8 | import torch.nn as nn
9 | from transformers.modeling_outputs import SequenceClassifierOutput
10 |
11 |
12 | class BertModelWithCustomLossFunction(nn.Module):
13 | def __init__(self, exp_dict):
14 | super(BertModelWithCustomLossFunction, self).__init__()
15 | self.num_labels = exp_dict["dataset"]["num_labels"]
16 | self.bert = BertModel.from_pretrained(
17 | exp_dict["model"]["backbone"], num_labels=self.num_labels
18 | )
19 | self.dropout = nn.Dropout(0.1)
20 | self.classifier = nn.Linear(1024, self.num_labels)
21 |
22 | def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
23 | outputs = self.bert(
24 | input_ids=input_ids,
25 | attention_mask=attention_mask,
26 | token_type_ids=token_type_ids,
27 | )
28 |
29 | output = self.dropout(outputs.pooler_output)
30 | logits = self.classifier(output)
31 |
32 | loss = None
33 | if labels is not None:
34 | # you can define any loss function here yourself
35 | # see https://pytorch.org/docs/stable/nn.html#loss-functions for an overview
36 | loss_fct = nn.CrossEntropyLoss()
37 | # next, compute the loss based on logits + ground-truth labels
38 | loss = loss_fct(logits.view(-1, self.num_labels), labels)
39 |
40 | return SequenceClassifierOutput(
41 | loss=loss,
42 | logits=logits,
43 | hidden_states=outputs.hidden_states,
44 | attentions=outputs.attentions,
45 | )
46 |
47 |
48 | def get_backbone(exp_dict):
49 | if exp_dict["exp_type"] == "gpt3mix":
50 | backbone = BertModelWithCustomLossFunction(exp_dict)
51 | return backbone
52 |
53 | if exp_dict["model"]["backbone"] in [
54 | "distilbert-base-uncased",
55 | "bert-large-uncased",
56 | "bert-base-uncased",
57 | ]:
58 | backbone = AutoModelForSequenceClassification.from_pretrained(
59 | exp_dict["model"]["backbone"], num_labels=exp_dict["dataset"]["num_labels"]
60 | )
61 | return backbone
62 | raise ValueError(f"backbone: {exp_dict['model']['backbone']} not supported")
63 |
--------------------------------------------------------------------------------
/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset preparation step. It's mostly a one-time script."""
2 | import os
3 |
4 | os.environ["LOCAL_RANK"] = "-1"
5 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
6 | os.environ["TRANSFORMERS_CACHE"] = "/mnt/home/"
7 | import argparse
8 | from utils import main_data_utils, sample_few_shot, augment_slices
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--data_root", default="./data")
14 | parser.add_argument("--name", default="clinc_oos")
15 | parser.add_argument(
16 | "--modes", nargs="+", default=["upsample", "gptj", "gpt3", "eda"]
17 | )
18 | parser.add_argument("--top_k", default=0, type=int)
19 | parser.add_argument("--top_p", default=1.0, type=float)
20 | return parser.parse_args()
21 |
22 |
23 | if __name__ == "__main__":
24 | args = parse_args()
25 | if args.top_k < 0 or args.top_p > 1.0:
26 | print("args.top_k >= 0 | 0.0 <= args.top_p <= 1.0")
27 | import sys
28 |
29 | sys.exit()
30 | ds_config = main_data_utils.get_ds_config(args.name)
31 | print(f"Loaded dataset config for {args.name}")
32 | sample_few_shot(args.data_root, ds_config)
33 | print(f"augmenting for modes: {args.modes}")
34 | data_slices = augment_slices(
35 | args.data_root,
36 | ds_config,
37 | modes=args.modes,
38 | top_k=args.top_k,
39 | top_p=args.top_p,
40 | )
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==1.18.4
2 | GPUtil==1.4.0
3 | haven-ai @ git+https://github.com/haven-ai/haven-ai
4 | huggingface-hub==0.0.19
5 | matplotlib==3.4.3
6 | nltk
7 | numpy
8 | openai==0.15.0
9 | pandas==1.3.3
10 | scikit-image==0.18.3
11 | scikit-learn==1.0
12 | seaborn==0.11.2
13 | sentencepiece==0.1.96
14 | torch==1.9.1
15 | transformers==4.18.0
16 | wandb==0.12.4
17 | pydantic==1.8.2
--------------------------------------------------------------------------------
/runners/compile_results.py:
--------------------------------------------------------------------------------
1 | import os, json, copy, pickle
2 | import numpy as np
3 | import pandas as pd
4 | from collections import defaultdict
5 | from scipy import stats
6 |
7 | from utils.metrics import Metrics
8 |
9 | # this dir contains CLINC and SNIPS partial fewshot experiments
10 | # EXPERIMENTS_DIR = "/path/to/partial-fewshot/savedir/"
11 |
12 | # this dir contains CLINC, SNIPS, Banking77, and HWU64 full fewshot experiments
13 | EXPERIMENTS_DIR = "/path/to/full-fewshot/savedir/"
14 |
15 | # FILTERS (NOTE: any of these filters can be disabled by setting to None)
16 | datasets = None # clinc_oos, snips_official, banking77, hwu64
17 | augmentation = None # eda (in FS), ada, babbage, curie, davinci, gptj
18 | temperatures = None # 0.5-2.0
19 | skip_oracle = False # True/False
20 |
21 | # exp_type-wise significance tests are always computed w.r.t davinci models
22 | TTEST = False # True/False
23 |
24 | # print results as a latex table
25 | TO_LATEX = False
26 |
27 | # save metrics for plotting purposes
28 | GEN_FOR_PLOT = False
29 | # name of the results file (NOTE: only GPTJ for full fewshot is supported)
30 | FNAME = "results/gptj_val_fidelity_et_al.pkl"
31 |
32 | # The below vars are used by to_latex()
33 | # ALL_METRICS and ALL_DATASETS are in the order or reporting in the paper
34 | # defining the order here so that we don't need to apply any convoluted
35 | # ops like sorting to ensure consistent performance reporting.
36 | # NOTE: to_latex() auto-ignores a dataset if it's filtered out in this script
37 | ALL_DATASETS = ["clinc_oos", "hwu64", "banking77", "snips_official"]
38 | BACKBONE = "BERT" if "bert" in EXPERIMENTS_DIR else "T5"
39 | SCOPES = ["overall", "few_shot"] if "ex2" in EXPERIMENTS_DIR else ["test"]
40 | ALL_METRICS = ["IA", "OR"]
41 |
42 | # No need to touch anything else below this line.
43 | FILTER = {
44 | "datasets": datasets, # set None to fetch for all datasets
45 | "augmentation": augmentation, # set None to fetch for all augmentors
46 | "temps": temperatures, # set None to fetch for all temperatures
47 | "skip_oracle?": skip_oracle, # set False/None to fetch oracle results
48 | }
49 |
50 | pjoin = os.path.join
51 |
52 |
53 | def fmt(xs):
54 | return f"{np.mean(xs):.2f} ({np.std(xs):.2f})"
55 |
56 |
57 | def compile_results(folder_list, exp_dir):
58 | # fetch experiment type, and populate total_results accordingly
59 | exp_dict_path = pjoin(exp_dir, folder_list[0], "exp_dict.json")
60 | exp_dict = json.load(open(exp_dict_path, "r"))
61 | dataset = exp_dict["dataset"]["name"]
62 | dconfig = exp_dict["dataset"]["config"]
63 | oos_id = exp_dict["dataset"]["oos_id"]
64 | exp_type = exp_dict["exp_type"]
65 | ex2_setup = True if dconfig.startswith("full_") else False
66 | full_fewshot = True if dconfig == "few_pure" else False
67 |
68 | if dconfig != "few_pure" and not ex2_setup and exp_type == "baseline":
69 | org_baseline = True
70 | else:
71 | org_baseline = False
72 |
73 | # declare total_results template based on exp config
74 | if full_fewshot or org_baseline: # no need for overall, fewshot keys
75 | total_results = {
76 | "test": {"IA": [], "OR": [], "A": []},
77 | "val": {"IA": [], "OR": [], "A": []},
78 | }
79 | if oos_id is None:
80 | total_results = {
81 | "test": {"A": []},
82 | "val": {"A": []},
83 | }
84 | else: # ex2 setup
85 | total_results = {
86 | "few_shot": {"IA": []},
87 | "few_shot_val": {"IA": []},
88 | "overall": {"IA": [], "OR": []},
89 | "overall_val": {"IA": [], "OR": []},
90 | }
91 |
92 | # read all the json files
93 | for folder in folder_list:
94 | sub_results = json.load(open(pjoin(exp_dir, folder, "code/results.json")))
95 | key = list(sub_results.keys())[0]
96 | sub_results = sub_results[key]
97 |
98 | if dataset == "snips_official" and ex2_setup: # snips has no OR
99 | _overall = sub_results["overall"]
100 | total_results["overall"]["IA"].append(_overall["test_accuracy"])
101 | total_results["overall_val"]["IA"].append(_overall["valid_accuracy"])
102 |
103 | elif full_fewshot or org_baseline:
104 | total_results["test"]["A"].append(sub_results["test_accuracy"])
105 | total_results["val"]["A"].append(sub_results["valid_accuracy"])
106 | if not oos_id:
107 | continue
108 | total_results["test"]["IA"].append(sub_results["test_inscope_accuracy"])
109 | total_results["val"]["IA"].append(sub_results["valid_inscope_accuracy"])
110 | total_results["test"]["OR"].append(sub_results["test_oos_recall"])
111 | total_results["val"]["OR"].append(sub_results["valid_oos_recall"])
112 | continue
113 |
114 | elif ex2_setup and dataset == "clinc_oos":
115 | # not handling oos_id as ex2_setup ALWAYS has an oos_id
116 | overall = sub_results["overall"]
117 | total_results["overall_val"]["IA"].append(overall["valid_inscope_accuracy"])
118 | total_results["overall"]["IA"].append(overall["test_inscope_accuracy"])
119 | total_results["overall_val"]["OR"].append(overall["valid_oos_recall"])
120 | total_results["overall"]["OR"].append(overall["test_oos_recall"])
121 |
122 | fs = sub_results["few_shot"]
123 | total_results["few_shot"]["IA"].append(fs["test_accuracy"])
124 | total_results["few_shot_val"]["IA"].append(fs["valid_accuracy"])
125 |
126 | for k in total_results:
127 | for m in total_results[k]:
128 | results = [100 * v for v in total_results[k][m]]
129 | # the following is averging across different fewshot domains.
130 | # For EX2, it's nth run's avg. across full_banking, full_meta, etc.
131 | # For Full fewshot, it's computing the mean for one domain, i.e,
132 | # the value is going to remain the same except it won't be a list.
133 | total_results[k][m] = np.mean(results)
134 | return total_results
135 |
136 |
137 | def segregate_sub_folders(exp_dir):
138 | sub_folder_dict = {}
139 |
140 | for folder in os.listdir(exp_dir):
141 | exp_dict_path = pjoin(exp_dir, folder, "exp_dict.json")
142 | exp_dict = json.load(open(exp_dict_path))
143 | dname = exp_dict["dataset"]["name"] # aggregate on the dataset
144 |
145 | # dataset filter
146 | if FILTER["datasets"] and dname not in FILTER["datasets"]:
147 | continue
148 | exp_dict["gpt3_temp"] = 1.0 # TODO: remove
149 | if "gpt" in exp_dict["exp_type"]:
150 | engine, temp = exp_dict["gpt3_engine"], exp_dict["gpt3_temp"]
151 | # engine and temperature filter
152 | if FILTER["augmentation"] and engine not in FILTER["augmentation"]:
153 | continue
154 | if FILTER["temps"] and temp not in FILTER["temps"]:
155 | continue
156 | exp_type = f"{exp_dict['exp_type']}_{engine}_{temp}"
157 | else:
158 | # NOTE that for non-GPT experiments exp_type is the augmentation mode
159 | exp_type = exp_dict["exp_type"] # eda/eda_oracle
160 | _aug = exp_type.replace("_oracle", "") if "oracle" in exp_type else exp_type
161 | if FILTER["augmentation"] and _aug not in FILTER["augmentation"]:
162 | continue
163 |
164 | # oracle filter
165 | if "oracle" in exp_type and FILTER["skip_oracle?"]:
166 | continue
167 |
168 | if exp_type not in sub_folder_dict:
169 | sub_folder_dict[exp_type] = {}
170 | if dname not in sub_folder_dict[exp_type]:
171 | sub_folder_dict[exp_type][dname] = defaultdict(list)
172 | sub_folder_dict[exp_type][dname][exp_dict["run#"]].append(folder)
173 |
174 | # a sanity check line, prints the number of experiments per config.
175 | folders_per_exp = []
176 | for exp_type in sub_folder_dict:
177 | for dname in sub_folder_dict[exp_type]:
178 | num_runs = len(sub_folder_dict[exp_type][dname])
179 | folders_per_exp.append((exp_type, dname, num_runs))
180 |
181 | print(folders_per_exp, len(folders_per_exp))
182 | return sub_folder_dict
183 |
184 |
185 | def final_compile(sub_results_dicts):
186 | final_result = {}
187 | for s in sub_results_dicts:
188 | for scope in s.keys():
189 | if scope not in final_result:
190 | final_result[scope] = {}
191 | for metric in s[scope]:
192 | if np.isnan(s[scope][metric]):
193 | continue
194 | if metric not in final_result[scope]:
195 | final_result[scope][metric] = []
196 | final_result[scope][metric].append(s[scope][metric])
197 | return final_result
198 |
199 |
200 | def get_performance(exp_dir):
201 | """
202 | returns a results dictionary which is not aggregated by runs
203 |
204 | An example of hierarchy:
205 | gpt3_ada_1.0:
206 | clinc_oos:
207 | test:
208 | IA[90.32, 90.1...90.3]
209 | OR[40.23, 39.12...38.1]
210 | val:
211 | IA[92.32, 91.1...93.3]
212 | OR[45.23, 41.12...40.1]
213 | banking77:
214 | test:
215 | A[82.3...80.1]
216 | val:
217 | A[83.2...79.2]
218 | snips_official:
219 | .
220 | .
221 | .
222 | gpt3_babbage_1.0:
223 | .
224 | .
225 | .
226 |
227 | It's not aggregated so that other functions may use to for:
228 | - mean and std computation across multiple runs
229 | - significace testing
230 | """
231 | sub_folder_dict = segregate_sub_folders(exp_dir)
232 | performance = {}
233 | for exp_type in sorted(list(sub_folder_dict.keys())):
234 | for dname in sorted(list(sub_folder_dict[exp_type].keys())):
235 | config_results = []
236 | for config in sub_folder_dict[exp_type][dname]:
237 | folderlist = sub_folder_dict[exp_type][dname][config]
238 | config_results.append(compile_results(folderlist, exp_dir))
239 | if exp_type not in performance:
240 | performance[exp_type] = {}
241 | performance[exp_type][dname] = final_compile(config_results)
242 | return performance
243 |
244 |
245 | def to_latex(performance):
246 | """
247 | Generates latex table code for aug, aug+relabel settings
248 | """
249 | table_latex = ""
250 | # backbone mode (aug/aug.+relabel)
251 | template = "{} {} (Ours) &" # line template
252 |
253 | # num of columns to report will be same for all exp. settings
254 | _etype = list(performance.keys())[0]
255 | n_cols = 0
256 | for dname in performance[_etype]:
257 | for s in SCOPES:
258 | curr_metrics = performance[_etype][dname][s]
259 | for _m in curr_metrics:
260 | if _m == "A" and "IA" in curr_metrics:
261 | continue
262 | n_cols += 1
263 |
264 | template += " {} &" * (n_cols - 1)
265 | template += " {} \\\\\n"
266 |
267 | for etype in performance:
268 | dscores = []
269 | for dname in ALL_DATASETS:
270 | # print(dname)
271 | if dname not in performance[etype]:
272 | continue
273 | # print(dname)
274 | for s in SCOPES:
275 | # print(s)
276 | curr_metrics = list(performance[etype][dname][s])
277 | for _m in ALL_METRICS:
278 | if _m not in curr_metrics:
279 | # a dataset without IA means no OOS. in that case,
280 | # A is the same as IA.
281 | if _m == "IA":
282 | _m = "A"
283 | else:
284 | continue
285 | dscores.append(fmt(performance[etype][dname][s][_m]))
286 | # print(_m)
287 | # print("===")
288 |
289 | table_latex += template.format(
290 | BACKBONE, etype.replace("_1.0", "").replace("_", "\_"), *dscores
291 | )
292 | print(table_latex)
293 |
294 |
295 | def perform_ttest(performance):
296 | """
297 | receives a performance for datasets and performs two statistical
298 | t-tests w.r.t davinci model at 1.0 temp. for that experiment type
299 | """
300 | if performance == {}:
301 | print("Nothing to show here")
302 | return
303 |
304 | bigger_model = None
305 | for e in performance:
306 | if "davinci" in e:
307 | bigger_model = e
308 | bigger_results = performance[bigger_model]
309 |
310 | # Gather model-wise metrics
311 | for dname in bigger_results.keys():
312 | print(f"Dataset: {dname.upper()}")
313 | print("-" * 30)
314 | for s in SCOPES:
315 | for m in bigger_results[dname][s]:
316 | _bresults = bigger_results[dname][s][m]
317 | print(f"--- {s} {m} test ---")
318 | print(f"{bigger_model.upper()}: ({fmt(_bresults)})")
319 | for model, results in performance.items():
320 | if model == bigger_model:
321 | continue
322 | _sresults = results[dname][s][m]
323 | test_result = stats.ttest_ind(_bresults, _sresults)
324 | print(f" vs {model.upper()} ({fmt(_sresults)}) {test_result}")
325 | print()
326 |
327 |
328 | def display_results(performance):
329 | for etype in performance:
330 | for dname in performance[etype]:
331 | for scope in performance[etype][dname]:
332 | for metric in performance[etype][dname][scope]:
333 | results = performance[etype][dname][scope][metric]
334 | performance[etype][dname][scope][metric] = fmt(results)
335 |
336 | for etype in performance:
337 | for dname in performance[etype]:
338 | print(f"Setting: {etype} | {dname}")
339 | print("-" * 20)
340 | print(pd.DataFrame().from_dict(performance[etype][dname]))
341 | print("=" * 30)
342 | print("\n")
343 |
344 |
345 | def gen_for_plot(performance):
346 | """
347 | Will save fidelity and fs accuries for all the datasets in a file
348 | NOTE: this doesn't save metrics for partial fewshot temp. profiling nor
349 | does it support any engine other than GPTJ right now.
350 | """
351 | if FILTER["augmentation"] != ["gptj"]:
352 | raise NotImplementedError(
353 | "Metrics generation for plotting only supported for GPTJ!"
354 | )
355 | if "fs" not in EXPERIMENTS_DIR:
356 | raise NotImplementedError(
357 | "Metrics generation for plotting only supported for Full Fewshot!"
358 | )
359 |
360 | if os.path.exists(FNAME):
361 | print(f"{FNAME} already exists!! Loading...")
362 | print("Delete/Rename it to recompute fidelities.")
363 | return pickle.load(open(FNAME, "rb"))
364 | print("Compiling plotting metrics in a file...")
365 |
366 | # init df
367 | df = pd.DataFrame(columns=["temp", "ds", "val_acc_mean", "val_acc_std", "fidelity"])
368 |
369 | # compute fidelities
370 | fidelities = {ds: Metrics().compute_fidelities(ds) for ds in ALL_DATASETS}
371 | for etype, results in performance.items():
372 | # etype: gpt3_gptj_1.0
373 | _, temp = etype.rsplit("_", 1)
374 | for ds in results:
375 | acc_key = "IA" if ds == "clinc_oos" else "A"
376 | val_accs = results[ds]["val"][acc_key]
377 | val_acc_mean, val_acc_std = np.mean(val_accs), np.std(val_accs)
378 | _fid = fidelities[ds][f"gptj_{temp}"]
379 | # create a new entry in the dataframe
380 | df.loc[len(df.index)] = [float(temp), ds, val_acc_mean, val_acc_std, _fid]
381 |
382 | # will be used to plot the threshold lines in the fidelity plots
383 | thresholds = {ds: fidelities[ds]["threshold"] for ds in ALL_DATASETS}
384 | metrics = {"metrics": df, "thresholds": thresholds}
385 | print(f"Saving fidelity metrics for plotting {FNAME}")
386 | with open(FNAME, "wb") as f:
387 | pickle.dump(metrics, f)
388 | return metrics
389 |
390 |
391 | def main():
392 | """
393 | Computes mean and std of metrics obtained by get_performance
394 | """
395 | # remove the "deleted" folder if it exists
396 | if os.path.exists(pjoin(EXPERIMENTS_DIR, "deleted")):
397 | print(f"Removing {pjoin(EXPERIMENTS_DIR, 'deleted')}...")
398 | os.system(f"rm -rf {pjoin(EXPERIMENTS_DIR, 'deleted')}")
399 | print("Removed.")
400 |
401 | performance = get_performance(EXPERIMENTS_DIR)
402 |
403 | # display results (deepcopy needed as display_results permutes its input)
404 | display_results(copy.deepcopy(performance))
405 |
406 | if TTEST:
407 | # non-oracle etype
408 | print("T-Test no oracle")
409 | perform_ttest({k: v for k, v in performance.items() if "oracle" not in k})
410 |
411 | # oracle etype
412 | print("T-Test with oracle")
413 | perform_ttest({k: v for k, v in performance.items() if "oracle" in k})
414 |
415 | if TO_LATEX:
416 | to_latex(performance)
417 |
418 | if GEN_FOR_PLOT:
419 | gen_for_plot(performance)
420 |
421 |
422 | if __name__ == "__main__":
423 | main()
424 |
--------------------------------------------------------------------------------
/runners/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["WANDB_DISABLED"] = "true"
4 |
5 | from collections import defaultdict
6 | import argparse
7 | from haven import haven_wizard as hw
8 | import gc
9 | import json
10 | import GPUtil
11 | import torch
12 | from transformers import (
13 | AutoTokenizer,
14 | AutoModelForSequenceClassification,
15 | TrainingArguments,
16 | Trainer,
17 | )
18 |
19 | from utils.metrics import Metrics
20 | from utils.data_utils.data_loader import DatasetLoader
21 |
22 |
23 | torch.backends.cudnn.benchmark = True
24 |
25 | NUM_LABELS = 7
26 | DATASET_NAME = "snips_official"
27 | MODEL_DIR = (
28 | f"/mnt/colab_public/results/few_shot_nlp/model/{DATASET_NAME}/oracle_checkpoint/"
29 | )
30 |
31 | RESULTS_PATH = f"results/gpt_fidelity_{DATASET_NAME}.json"
32 |
33 |
34 | def write_results(results):
35 | with open(RESULTS_PATH, "w") as f:
36 | json.dump(results, f)
37 |
38 |
39 | def trainval(exp_dict, savedir, args):
40 | """Main."""
41 | # ==========================
42 | # load datasets
43 | # ==========================
44 | dataset_loader = DatasetLoader(args.datadir, exp_dict)
45 |
46 | # ==========================
47 | # create model and trainer
48 | # ==========================
49 | backbone = AutoModelForSequenceClassification.from_pretrained(
50 | MODEL_DIR, num_labels=NUM_LABELS
51 | )
52 |
53 | tokenizer = AutoTokenizer.from_pretrained(
54 | exp_dict["model"]["backbone"], use_fast=True
55 | )
56 |
57 | def preprocess(example):
58 | # note that setting max_length and truncation will not
59 | # have any effect for the vanilla baseline experiments
60 | results = tokenizer(example["text"], max_length=50, truncation=True)
61 | results["label"] = example["intent"]
62 | return results
63 |
64 | encoded_dataset = dataset_loader.dataset.map(preprocess, batched=True)
65 |
66 | if args.job_scheduler == "1":
67 | from configs import job_configs
68 |
69 | n_gpu = job_configs.JOB_CONFIG["resources"]["gpu"]
70 | else:
71 | n_gpu = 1
72 |
73 | args = TrainingArguments(
74 | savedir,
75 | evaluation_strategy="epoch",
76 | save_strategy="epoch",
77 | learning_rate=exp_dict["lr"],
78 | per_device_train_batch_size=exp_dict["batch_size"] // n_gpu,
79 | per_device_eval_batch_size=exp_dict["batch_size"] // n_gpu,
80 | num_train_epochs=exp_dict["epochs"],
81 | warmup_ratio=exp_dict["warmup_ratio"],
82 | weight_decay=exp_dict["weight_decay"],
83 | load_best_model_at_end=True,
84 | metric_for_best_model=exp_dict["metric_best"],
85 | # push_to_hub=True,
86 | # push_to_hub_model_id=f"{model_name}-finetuned-{task}",
87 | )
88 |
89 | if "full_validation" in encoded_dataset:
90 | # for ex2 setup experiments
91 | eval_split = "full_validation"
92 | else:
93 | # for clinc setup experiments
94 | eval_split = "validation"
95 |
96 | trainer = Trainer(
97 | backbone,
98 | args,
99 | train_dataset=encoded_dataset["train"],
100 | eval_dataset=encoded_dataset[eval_split],
101 | tokenizer=tokenizer,
102 | compute_metrics=Metrics(exp_dict).compute_metrics(),
103 | )
104 | # trainer.train()
105 | print(GPUtil.showUtilization())
106 |
107 | print(f"n_gpus: {trainer.args.n_gpu}, local rank: {trainer.args.local_rank}")
108 | print("Emptying cache...") # crucial!
109 | gc.collect()
110 | torch.cuda.empty_cache()
111 |
112 | # compute metrics
113 | trainer.args.eval_accumulation_steps = exp_dict["eval_accumulation_steps"]
114 | if "full_test" in encoded_dataset:
115 | metrics = {"overall": {}, "few_shot": {}}
116 | else:
117 | metrics = defaultdict(dict)
118 |
119 | for split in encoded_dataset:
120 | if "train" in split:
121 | continue
122 |
123 | # split test, full_test, validation, full_validation
124 | prefix = "test" if "test" in split else "valid"
125 | if "overall" in metrics:
126 | _type = "overall" if "full" in split else "few_shot"
127 | metrics[_type].update(
128 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix)
129 | )
130 | elif exp_dict["exp_type"] == "intrinsic":
131 | metrics[split].update(
132 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix)
133 | )
134 | else:
135 | metrics.update(
136 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix)
137 | )
138 | # print results
139 | print(metrics)
140 |
141 |
142 | if __name__ == "__main__":
143 | parser = argparse.ArgumentParser()
144 | parser.add_argument(
145 | "-e",
146 | "--exp_group_list",
147 | nargs="+",
148 | default="resnet",
149 | help="name of an experiment in exp_configs.py",
150 | )
151 | parser.add_argument(
152 | "-sb",
153 | "--savedir_base",
154 | default="/mnt/home/haven_output",
155 | help="folder where logs will be saved",
156 | )
157 | parser.add_argument("-nw", "--num_workers", type=int, default=4)
158 | parser.add_argument("-d", "--datadir", type=str, default="./data")
159 | parser.add_argument(
160 | "-r", "--reset", default=0, type=int, help="Overwrite previous results"
161 | )
162 | parser.add_argument("-ei", "--exp_id", default=None)
163 | parser.add_argument(
164 | "-j",
165 | "--job_scheduler",
166 | type=str,
167 | default=None,
168 | help="If 1, runs in toolkit in parallel",
169 | )
170 | parser.add_argument("-v", default="results.ipynb", help="orkestrator")
171 | parser.add_argument(
172 | "--python_binary", default="python", help="path to your python executable"
173 | )
174 |
175 | args, unknown = parser.parse_known_args()
176 | from configs import exp_configs
177 |
178 | if args.job_scheduler == "1":
179 | from configs import job_configs
180 |
181 | job_config = job_configs.JOB_CONFIG
182 | else:
183 | job_config = None
184 |
185 | file_name = os.path.basename(__file__)[:-3] # remove .py
186 | hw.run_wizard(
187 | func=trainval,
188 | exp_groups=exp_configs.EXP_GROUPS,
189 | job_config=job_config,
190 | python_binary_path=args.python_binary,
191 | python_file_path=f"-m runners.{file_name}",
192 | use_threads=True,
193 | args=args,
194 | )
195 |
--------------------------------------------------------------------------------
/runners/modules/bert.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import gc, os, GPUtil, torch
3 | from transformers import AutoTokenizer, TrainingArguments, Trainer
4 |
5 | from utils.metrics import Metrics
6 | from utils import main_data_utils as mdu
7 | from utils.data_utils.data_loader import DatasetLoader
8 | from models.backbones import get_backbone
9 |
10 |
11 | torch.backends.cudnn.benchmark = True
12 |
13 |
14 | def trainval(exp_dict, savedir, cl_args):
15 | """Main."""
16 | # ==========================
17 | # load datasets
18 | # ==========================
19 | dataset_loader = DatasetLoader(cl_args.datadir, exp_dict)
20 |
21 | # ==========================
22 | # create model and trainer
23 | # ==========================
24 | backbone = get_backbone(exp_dict)
25 |
26 | tokenizer = AutoTokenizer.from_pretrained(
27 | exp_dict["model"]["backbone"], use_fast=True
28 | )
29 |
30 | def preprocess(example):
31 | # note that setting max_length and truncation will not
32 | # have any effect for the vanilla baseline experiments
33 | results = tokenizer(example["text"], max_length=50, truncation=True)
34 | results["label"] = example["intent"]
35 | return results
36 |
37 | encoded_dataset = dataset_loader.dataset.map(preprocess, batched=True)
38 |
39 | if cl_args.job_scheduler == "1":
40 | from configs import job_configs
41 |
42 | n_gpu = job_configs.JOB_CONFIG["resources"]["gpu"]
43 | else:
44 | n_gpu = 1
45 |
46 | args = TrainingArguments(
47 | savedir,
48 | evaluation_strategy="epoch",
49 | save_strategy="epoch",
50 | learning_rate=exp_dict["lr"],
51 | per_device_train_batch_size=exp_dict["batch_size"] // n_gpu,
52 | per_device_eval_batch_size=exp_dict["batch_size"] // n_gpu,
53 | num_train_epochs=exp_dict["epochs"],
54 | warmup_ratio=exp_dict["warmup_ratio"],
55 | weight_decay=exp_dict["weight_decay"],
56 | load_best_model_at_end=True,
57 | metric_for_best_model=exp_dict["metric_best"],
58 | save_total_limit=1,
59 | # push_to_hub=True,
60 | # push_to_hub_model_id=f"{model_name}-finetuned-{task}",
61 | )
62 |
63 | if "full_validation" in encoded_dataset:
64 | # for ex2 setup experiments
65 | eval_split = "full_validation"
66 | else:
67 | # for clinc setup experiments
68 | eval_split = "validation"
69 |
70 | trainer = Trainer(
71 | backbone,
72 | args,
73 | train_dataset=encoded_dataset["train"],
74 | eval_dataset=encoded_dataset[eval_split],
75 | tokenizer=tokenizer,
76 | compute_metrics=Metrics(exp_dict).compute_metrics(),
77 | )
78 | trainer.train()
79 | print(GPUtil.showUtilization())
80 |
81 | print(f"n_gpus: {trainer.args.n_gpu}, local rank: {trainer.args.local_rank}")
82 | print("Emptying cache...") # crucial!
83 | gc.collect()
84 | torch.cuda.empty_cache()
85 |
86 | # compute metrics
87 | trainer.args.eval_accumulation_steps = exp_dict["eval_accumulation_steps"]
88 | if "full_test" in encoded_dataset:
89 | metrics = {"overall": {}, "few_shot": {}}
90 | else:
91 | metrics = defaultdict(dict)
92 |
93 | for split in encoded_dataset:
94 | if "train" in split:
95 | continue
96 |
97 | # split test, full_test, validation, full_validation
98 | prefix = "test" if "test" in split else "valid"
99 | if "overall" in metrics:
100 | _type = "overall" if "full" in split else "few_shot"
101 | metrics[_type].update(
102 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix)
103 | )
104 | elif exp_dict["exp_type"] == "intrinsic":
105 | metrics[split].update(
106 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix)
107 | )
108 | else:
109 | metrics.update(
110 | trainer.evaluate(encoded_dataset[split], metric_key_prefix=prefix)
111 | )
112 | # write results
113 | results_path = os.path.join(savedir, "code/results.json")
114 | # this will happen if not using scheduler
115 | if not os.path.exists(results_path):
116 | os.makedirs(os.path.dirname(results_path))
117 | mdu.write_json({exp_dict["dataset"]["config"]: metrics}, results_path)
118 |
119 | if not cl_args.retain_checkpoints:
120 | # delete HF checkpoints to save space on toolkit
121 | print(f"DELETING! {os.path.join(savedir, 'checkpoint-*')}")
122 | os.system(f"rm -rf {os.path.join(savedir, 'checkpoint-*')}")
123 |
--------------------------------------------------------------------------------
/runners/oracle_relabel.py:
--------------------------------------------------------------------------------
1 | from configs import exp_configs
2 | import os, gc, argparse, pickle, torch, numpy as np
3 | from utils import main_data_utils
4 |
5 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
6 | from datasets import DatasetDict, Dataset
7 |
8 | from haven import haven_wizard as hw
9 |
10 | torch.backends.cudnn.benchmark = True
11 |
12 | pjoin = os.path.join
13 | write_pickle = lambda obj, path: pickle.dump(obj, open(path, "wb"))
14 | read_pickle = lambda path: pickle.load(open(path, "rb"))
15 |
16 |
17 | def oracle_correction(exp_dict, savedir, args):
18 | """Main."""
19 | # ==========================
20 | # load full dataset (containing all generated examples)
21 | # ==========================
22 | ds_name = exp_dict["dataset"]["name"]
23 | ds_config = main_data_utils.get_ds_config(ds_name)
24 | DOMAINS = list(ds_config.domain_to_intent.keys())
25 | if "oos" in DOMAINS:
26 | print(f"ignoring OOS domain from {ds_name.upper()}")
27 | DOMAINS.pop(DOMAINS.index("oos")) # remove oos
28 |
29 | print(f"{ds_name} domains:")
30 | print(DOMAINS)
31 | engines = ["ada", "babbage", "curie", "davinci", "gptj", "eda", "val", "test"]
32 | temp = [1.0]
33 |
34 | base_data_path = pjoin(args.datadir, ds_name, "full", "dataset.pkl")
35 | exp_data_path = pjoin(args.datadir, ds_name, "full", "data_full_suite.pkl")
36 |
37 | tokenizer = AutoTokenizer.from_pretrained(
38 | exp_dict["model"]["backbone"], use_fast=True
39 | )
40 |
41 | generated_dataset = DatasetDict(read_pickle(exp_data_path))
42 |
43 | # remove the domains and make a HF Dataset
44 | for e in engines:
45 | # no need to remove domains, will fetch from the base dataset
46 | if e in ["val", "test"]:
47 | continue
48 | elif e == "gptj":
49 | temp = [round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))]
50 | else:
51 | temp = [1.0]
52 | for t in temp:
53 | _lines = []
54 | _intents = []
55 | for d in DOMAINS:
56 | attr = e if e in ["eda", "bt"] else f"{e}_{t}"
57 | _lines.extend(generated_dataset[d]["F"][attr]["text"])
58 | _intents.extend(generated_dataset[d]["F"][attr]["intent"])
59 | generated_dataset[attr] = Dataset.from_dict(
60 | {"text": _lines, "intent": _intents}
61 | )
62 |
63 | base_dataset = read_pickle(base_data_path)
64 | # Add validation samples (for computing thresholds in fidelity plots)
65 | generated_dataset["val"] = Dataset.from_dict(base_dataset["val"])
66 | generated_dataset["test"] = Dataset.from_dict(base_dataset["test"])
67 |
68 | # ==========================
69 | # create model
70 | # ==========================
71 | print(f"Oracle path: {args.modeldir}")
72 | oracle = AutoModelForSequenceClassification.from_pretrained(
73 | args.modeldir, num_labels=exp_dict["dataset"]["num_labels"]
74 | )
75 | oracle.cuda()
76 | oracle.eval()
77 |
78 | # Init AL dataset
79 | al_path = pjoin(args.datadir, ds_name, "full", "al_dataset.pkl")
80 | if os.path.exists(al_path):
81 | print(f"Loading existing data from {al_path}")
82 | al_ds = read_pickle(al_path)["generated"]
83 | else:
84 | print(f"Initializing al_dataset.pkl for {ds_name.upper()}")
85 | al_ds = {}
86 |
87 | with torch.no_grad():
88 | for e in engines:
89 | if e == "gptj":
90 | temp = [
91 | round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))
92 | ]
93 | else:
94 | temp = [1.0]
95 | for t in temp:
96 | attr = e if e in ["eda", "bt", "val", "test"] else f"{e}_{t}"
97 | if attr in al_ds:
98 | print(f"{attr} already exists in {ds_name}'s AL dataset")
99 | continue
100 | print(f"relabeling for {attr}")
101 | al_ds[attr] = {}
102 | al_ds[attr]["text"] = []
103 | al_ds[attr]["label"] = []
104 |
105 | encodings = tokenizer(
106 | generated_dataset[attr]["text"],
107 | max_length=50,
108 | padding=True,
109 | truncation=True,
110 | return_tensors="pt",
111 | )
112 | total_num = len(encodings["input_ids"])
113 | batch_size = 100
114 |
115 | print("total_num", total_num, "and batch_size", batch_size)
116 | lbls = []
117 | for i in range(0, total_num, batch_size):
118 | outputs = oracle(
119 | encodings["input_ids"][i : i + 100].cuda(),
120 | attention_mask=encodings["attention_mask"][i : i + 100].cuda(),
121 | )
122 | probs = torch.softmax(outputs.logits, dim=1)
123 | lbls.extend(torch.argmax(probs, dim=1).cpu().tolist())
124 |
125 | al_ds[attr]["text"] = generated_dataset[attr]["text"]
126 | al_ds[attr]["intent"] = lbls
127 | al_ds[attr]["old_intent"] = generated_dataset[attr]["intent"]
128 |
129 | al_dataset = dict(
130 | train=base_dataset["train"],
131 | val=base_dataset["val"],
132 | test=base_dataset["test"],
133 | generated=al_ds,
134 | )
135 |
136 | with open(al_path, "wb") as f:
137 | pickle.dump(al_dataset, f)
138 |
139 | print("Emptying cache...") # crucial!
140 | gc.collect()
141 | torch.cuda.empty_cache()
142 |
143 |
144 | if __name__ == "__main__":
145 | parser = argparse.ArgumentParser()
146 | parser.add_argument(
147 | "-e",
148 | "--exp_group_list",
149 | nargs="+",
150 | default="resnet",
151 | help="name of an experiment in exp_configs.py",
152 | )
153 | parser.add_argument(
154 | "-sb",
155 | "--savedir_base",
156 | default="/mnt/home/haven_output",
157 | help="folder where logs will be saved",
158 | )
159 | parser.add_argument("-nw", "--num_workers", type=int, default=4)
160 | parser.add_argument("-d", "--datadir", type=str, default="./data")
161 | parser.add_argument("-md", "--modeldir", type=str, default=None)
162 | parser.add_argument("-ei", "--exp_id", default=None)
163 | parser.add_argument(
164 | "-j",
165 | "--job_scheduler",
166 | type=str,
167 | default=None,
168 | help="If 1, runs in toolkit in parallel",
169 | )
170 | parser.add_argument(
171 | "--python_binary", default="python", help="path to your python executable"
172 | )
173 |
174 | args, unknown = parser.parse_known_args()
175 | if args.job_scheduler == "1":
176 | from configs import job_configs
177 |
178 | job_config = job_configs.JOB_CONFIG
179 | else:
180 | job_config = None
181 | file_name = os.path.basename(__file__)[:-3] # remove .py
182 | hw.run_wizard(
183 | func=oracle_correction,
184 | exp_groups=exp_configs.EXP_GROUPS,
185 | job_config=job_config,
186 | python_binary_path=args.python_binary,
187 | python_file_path=f"-m runners.{file_name}",
188 | use_threads=True,
189 | args=args,
190 | )
191 |
--------------------------------------------------------------------------------
/runners/train.py:
--------------------------------------------------------------------------------
1 | from multiprocessing.sharedctypes import Value
2 | from configs import exp_configs
3 | import os, argparse, torch, wandb
4 | from haven import haven_wizard as hw
5 |
6 | from runners.modules import bert
7 |
8 | torch.backends.cudnn.benchmark = True
9 |
10 |
11 | def init_wandb(exp_dict):
12 | exp_name = f"{exp_dict['exp_type']}_oosID={exp_dict['dataset']['oos_id']}_"
13 | exp_name += f'{exp_dict["dataset"]["name"]}_{exp_dict["dataset"]["config"]}_'
14 | exp_name += f'{exp_dict["lr"]}_{exp_dict["epochs"]}_{exp_dict["batch_size"]}'
15 | wandb.init(project="few-shot-nlp", name=exp_name)
16 |
17 |
18 | def trainval(exp_dict, savedir, args):
19 | """Main."""
20 | # ==========================
21 | # init wandb
22 | # ==========================
23 | if not args.disable_wandb:
24 | init_wandb(exp_dict)
25 | else:
26 | os.environ["WANDB_DISABLED"] = "true"
27 |
28 | # ==========================
29 | # Load appropriate trainer
30 | # ==========================
31 | if "bert" in exp_dict["model"]["backbone"]:
32 | return bert.trainval(exp_dict, savedir, args)
33 | return ValueError("backend not recognized")
34 |
35 |
36 | if __name__ == "__main__":
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument(
39 | "-e",
40 | "--exp_group_list",
41 | nargs="+",
42 | default="resnet",
43 | help="name of an experiment in exp_configs.py",
44 | )
45 | parser.add_argument(
46 | "-sb",
47 | "--savedir_base",
48 | default="/mnt/home/haven_output",
49 | help="folder where logs will be saved",
50 | )
51 | parser.add_argument("-nw", "--num_workers", type=int, default=4)
52 | parser.add_argument("-d", "--datadir", type=str, default="./data")
53 | parser.add_argument("-md", "--modeldir", type=str, default=None)
54 | parser.add_argument(
55 | "-r", "--reset", default=0, type=int, help="Overwrite previous results"
56 | )
57 | parser.add_argument("-ei", "--exp_id", default=None)
58 | parser.add_argument(
59 | "-j",
60 | "--job_scheduler",
61 | type=str,
62 | default=None,
63 | help="If 1, runs in toolkit in parallel",
64 | )
65 | parser.add_argument("-v", default="results.ipynb", help="orkestrator")
66 | parser.add_argument(
67 | "--python_binary", default="python", help="path to your python executable"
68 | )
69 | parser.add_argument("--disable_wandb", default=1, type=int)
70 | parser.add_argument("--retain_checkpoints", action="store_true")
71 |
72 | args, unknown = parser.parse_known_args()
73 | if args.job_scheduler == "1":
74 | from configs import job_configs
75 |
76 | job_config = job_configs.JOB_CONFIG
77 | else:
78 | job_config = None
79 |
80 | file_name = os.path.basename(__file__)[:-3] # remove .py
81 | hw.run_wizard(
82 | func=trainval,
83 | exp_groups=exp_configs.EXP_GROUPS,
84 | job_config=job_config,
85 | python_binary_path=args.python_binary,
86 | python_file_path=f"-m runners.{file_name}",
87 | use_threads=True,
88 | args=args,
89 | results_fname="notebooks/fsn.ipynb",
90 | )
91 |
--------------------------------------------------------------------------------
/scripts/3way_human_eval.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from utils import main_data_utils as mdu
3 | import argparse, numpy as np
4 | from openpyxl import Workbook
5 | from openpyxl.styles import Font
6 |
7 |
8 | def create_content(ds, id2name, args, run_code):
9 | samples = []
10 | for (intent_id, sentence) in zip(ds[run_code]["intent"], ds[run_code]["text"]):
11 | if id2name[str(intent_id)] not in args.intent_triplet:
12 | continue
13 | samples.append((sentence, id2name[str(intent_id)]))
14 |
15 | np.random.shuffle(samples)
16 | # start creating an excel sheet
17 | wb = Workbook()
18 | wb.remove(wb.active)
19 |
20 | # create a fresh worksheet with meaningful name
21 | ws = wb.create_sheet("Human eval.")
22 | ws.append(["Sentence", "Your Prediction"])
23 | labels = ""
24 | max_sent_length = 0
25 | for sent, intent in samples:
26 | if len(sent) > max_sent_length:
27 | max_sent_length = len(sent)
28 | ws.append([sent])
29 | labels += intent + "\n"
30 |
31 | # increase font size of sentences
32 | for i in range(1, len(samples) + 2):
33 | ws[f"A{i}"].font = Font(size=14)
34 |
35 | # max width of column A
36 | ws.column_dimensions["A"].width = max_sent_length
37 | ws["A1"].font = Font(bold=True, size=14)
38 | ws["B1"].font = Font(bold=True, size=14)
39 | wb.active = 0
40 | wb.save(f"spreadsheets/{args.dataset_name}/{run_code}_human_eval.xlsx")
41 | mdu.write_file(labels, f"spreadsheets/{args.dataset_name}/{run_code}_labels.txt")
42 |
43 |
44 | def evaluate(args, run_code):
45 | labels_path = f"spreadsheets/{args.dataset_name}/{run_code}_labels.txt"
46 | labels = mdu.read_file(labels_path).splitlines()
47 | if args.dataset_name == "hwu64":
48 | predcode2name = {
49 | "1": "music_likeness",
50 | "2": "music_settings",
51 | "3": "play_music",
52 | }
53 | elif args.dataset_name == "banking77":
54 | predcode2name = {
55 | "1": "topping_up_by_card",
56 | "2": "top_up_failed",
57 | "3": "pending_top_up",
58 | }
59 | preds_path = f"spreadsheets/{args.dataset_name}/sahu_{run_code}_preds.txt"
60 | preds = [predcode2name[p] for p in mdu.read_file(preds_path).splitlines()]
61 | assert len(preds) == len(labels)
62 | class_wise_preds = defaultdict(list)
63 | for _pred, _label in zip(preds, labels):
64 | class_wise_preds[_label].append(_pred)
65 | print(
66 | "Overall Val Acc of sahu =",
67 | f"{np.mean([1 if p == l else 0 for p, l in zip(preds, labels)])*100:.2f}",
68 | )
69 |
70 | for intent, preds in class_wise_preds.items():
71 | print(f"sahu Acc on {intent}: {preds.count(intent)/len(preds)*100:.2f}")
72 |
73 |
74 | def main():
75 | # read base data
76 | args = parse_args()
77 | if args.eval:
78 | evaluate(args, "val")
79 | ds = mdu.read_pickle(f"data/{args.dataset_name}/full/dataset.pkl")
80 | id2name = mdu.read_json(f"data/{args.dataset_name}/id2name.json")
81 |
82 | create_content(ds, id2name, args, "val") # 3-way val set
83 | create_content(ds, id2name, args, "test") # 3-way test set
84 |
85 |
86 | def parse_args():
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument("-d", "--dataset-name", default="hwu64")
89 | # parser.add_argument("-d", "--dataset-name", default="banking77")
90 | parser.add_argument(
91 | "-it",
92 | "--intent-triplet",
93 | nargs="+",
94 | default=["music_likeness", "play_music", "music_settings"],
95 | # default=["topping_up_by_card", "top_up_failed", "pending_top_up"],
96 | )
97 | parser.add_argument("--eval", action="store_true")
98 | return parser.parse_args()
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ServiceNow/data-augmentation-with-llms/77c740df2aaf4bfc1b2449114613f98b37a1faae/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/compute_ttr.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import nltk
3 |
4 |
5 | def ttr(sentences, n):
6 | """
7 | Computes Type-Token Ratio for a given
8 | set of sentences
9 | Params:
10 | ======
11 | sentences (list): list of sentences in the corpus
12 | n (list): list of n values to consider when computing ngrams
13 | """
14 | ngrams = [[] for _ in n]
15 | for s in sentences:
16 | tokens = nltk.tokenize.word_tokenize(s.lower())
17 | for i in range(len(n)):
18 | ngrams[i].extend(list(nltk.ngrams(tokens, n[i])))
19 | return [round(len(set(v)) / len(v), 2) for v in ngrams]
20 |
21 |
22 | def main():
23 | for dataset in ["clinc_oos", "snips_official", "hwu64", "banking77"]:
24 | dpath = f"./data/{dataset}/full/data_full_suite.pkl"
25 | data = pickle.load(open(dpath, "rb"))
26 | domains = list(data.keys())
27 | print(f"Diversity metrics for {dataset}")
28 | for e in ["eda", "ada", "babbage", "curie", "davinci"]:
29 | for t in [1.0]:
30 | sentences = []
31 | for d in domains:
32 | # skip OOS for now
33 | if d == "oos":
34 | continue
35 | attr = e if e == "eda" else f"{e}_{t}"
36 | sentences.extend(data[d]["F"][attr]["text"])
37 | print(f"{e}_{t}: {ttr(sentences, n=[1, 2, 3, 4])}")
38 |
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/scripts/gpt3_analysis.py:
--------------------------------------------------------------------------------
1 | import argparse, numpy as np
2 | from collections import defaultdict
3 | from tqdm import tqdm
4 | from utils import data_utils as du
5 | from utils import main_data_utils as mdu
6 |
7 |
8 | def compute_acc(preds, labels):
9 | return f"{np.mean([1 if p == l else 0 for p, l in zip(preds, labels)])*100:.2f}"
10 |
11 |
12 | def get_seed_tuples(ds, seed_intent, ds_config, name2id):
13 | """
14 | returns seed sentences for a given intent_id in (full suite) dataset object
15 | ds as a list of tuples with (intent_id, sentence).
16 | seed_intent is a string
17 | """
18 | # identify the domain for the seed intent
19 | intent_id = int(name2id[seed_intent])
20 | seed_domain = None
21 | for _domain, _intents in ds_config.domain_to_intent.items():
22 | if seed_intent in _intents:
23 | seed_domain = _domain
24 | break
25 |
26 | # fetch the seed sentences for seed intent
27 | _start = ds[seed_domain]["F"]["train"]["intent"].index(intent_id)
28 | num_examples = ds[seed_domain]["F"]["train"]["intent"].count(intent_id)
29 | sents = ds[seed_domain]["F"]["train"]["text"][_start : _start + num_examples]
30 | return [(seed_intent, s) for s in sents]
31 |
32 |
33 | def filter_via_gpt(sentences, seed_sentences, args):
34 | """
35 | `sentences` is a list of tuples (old intent, new intent, sentence)
36 | """
37 | # construct a prompt
38 | ltemplate = "sentence: {} ; category:{}" # line template
39 | prompt = f"Each example in the following list contains a sentence that belongs to a category. A category is one of the following: {', '.join(args.intent_triplet)}"
40 | prompt += "\n\n"
41 | prompt += "\n".join([ltemplate.format(s, " " + i) for (i, s) in seed_sentences])
42 | prompt += "\n"
43 |
44 | print(f"FILTERING generations using {args.gpt_engine.upper()}...")
45 | retained_sentences = []
46 | # NOTE: ignoring the new intent (mid element of the tuple)
47 | for (old_intent, new_intent, sent) in tqdm(sentences):
48 | # query gpt
49 | input_prompt = prompt + ltemplate.format(sent, "")
50 | responses = du.augment_slices.openai_complete(
51 | prompt=input_prompt,
52 | n=10,
53 | engine=args.gpt_engine,
54 | temp=1.0,
55 | top_p=1.0,
56 | )
57 | responses = [r.text.strip() for r in responses]
58 | # print(input_prompt)
59 | # print(responses)
60 | # breakpoint()
61 | # NOTE: not handling ties. if there is a tie, it just means that
62 | # the sentence is not a good enough one, and we shouldn't include it.
63 | response = max(args.intent_triplet, key=responses.count)
64 | if response == old_intent:
65 | retained_sentences.append((old_intent, new_intent, sent))
66 |
67 | # num retained, num input
68 | n_r, n_i = len(retained_sentences), len(sentences)
69 | # percentage retained
70 | ret_per = f"{(n_r/n_i)*100:.2f}%"
71 | print(f"{args.gpt_engine.upper()} retained {n_r}/{n_i} sentences ({ret_per}).")
72 | return retained_sentences
73 |
74 |
75 | def run_gpt_eval(eval_sentences, seed_sentences, args):
76 | """
77 | `eval_sentences` is a list of tuples (gt label in the dataset, text)
78 | `seed_sentences` is a list of tupels (seed intent, text)
79 | """
80 | # construct a prompt
81 | ltemplate = "sentence: {} ; category:{}"
82 | prompt = f"Each example in the following list contains a sentence that belongs to a category. A category is one of the following: {', '.join(args.intent_triplet)}"
83 | prompt += "\n\n"
84 | prompt += "\n".join([ltemplate.format(s, " " + i) for (i, s) in seed_sentences])
85 | prompt += "\n"
86 |
87 | preds, labels = [], []
88 | class_wise_pred_labels = defaultdict(list)
89 | print("Running evaluation...")
90 | for (intent, sent) in tqdm(eval_sentences):
91 | # query gpt
92 | input_prompt = prompt + ltemplate.format(sent, "")
93 | responses = du.augment_slices.openai_complete(
94 | prompt=input_prompt, n=10, engine=args.gpt_engine, temp=1.0, top_p=1.0
95 | )
96 | responses = [r.text.strip() for r in responses]
97 | # NOTE: not handling ties here.
98 | response = max(args.intent_triplet, key=responses.count)
99 | # for overall preds and labels
100 | preds.append(response)
101 | labels.append(intent)
102 |
103 | # for class-wise performance
104 | class_wise_pred_labels[intent].append(response)
105 |
106 | # Evaluation
107 | print(f"{args.gpt_engine.upper()} performance on {len(eval_sentences)} examples:")
108 | print(f"Overall accuracy = {compute_acc(preds, labels)}")
109 | print(f"Class-wise accuracies:")
110 | for intent, preds in class_wise_pred_labels.items():
111 | _acc = preds.count(intent) / len(preds)
112 | print(f"Acc. for intent {intent}: {_acc*100:.2f}")
113 |
114 |
115 | def compute_fidelity(sentence_tuples):
116 | """
117 | `sentence_tuples` is a list of tuples (old intent, new intent, sentence)
118 | and old intent and new intent are intent names (not ids)
119 | """
120 | overall_3way_fidelity = []
121 | class_wise_fidelity = defaultdict(list)
122 | for (old_intent, new_intent, _) in sentence_tuples:
123 | class_wise_fidelity[old_intent].append(new_intent)
124 | overall_3way_fidelity.append(1 if new_intent == old_intent else 0)
125 |
126 | print(f"Overall 3-way fidelity = {np.mean(overall_3way_fidelity)*100:.2f}")
127 | for intent, preds in class_wise_fidelity.items():
128 | _fid = preds.count(intent) / len(preds)
129 | print(f"Fidelity for {intent}: {_fid*100:.2f}")
130 |
131 |
132 | def oracle_eval(eval_sentences, id2name, name2id, args, size):
133 | """
134 | This will be used to evaluate:
135 | - small oracle (a.k.a 10-shot baseline)
136 | - bigger oracle (a.k.a the oracle)
137 | """
138 |
139 | intent_ids_of_interest = [name2id[i] for i in args.intent_triplet]
140 | intent_ids_of_interest.sort()
141 |
142 | # segregate sentences and labels
143 | labels, sentences = [], []
144 | for l, s in eval_sentences:
145 | labels.append(l)
146 | sentences.append(s)
147 |
148 | if args.dataset_name not in ["hwu64", "banking77"]:
149 | raise NotImplementedError(f"Dataset {args.dataset_name} not supported")
150 |
151 | print("Running N-shot baseline evaluation on reduced val+test set...")
152 | model_dir = f"/mnt/colab_public/results/few_shot_nlp/model/{args.dataset_name}/"
153 | model_dir += "10_shot_baseline/" if size == "small" else "oracle_checkpoint/"
154 |
155 | if args.dataset_name == "hwu64":
156 | num_labels = 64
157 | else:
158 | num_labels = 77
159 |
160 | import torch
161 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
162 |
163 | tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased", use_fast=True)
164 | model = AutoModelForSequenceClassification.from_pretrained(
165 | model_dir, num_labels=num_labels
166 | )
167 | device = "cuda" if torch.cuda.is_available() else "cpu"
168 | model.to(device)
169 | model.eval()
170 |
171 | with torch.no_grad():
172 | encodings = tokenizer(
173 | sentences,
174 | max_length=50,
175 | padding=True,
176 | truncation=True,
177 | return_tensors="pt",
178 | )
179 | input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"]
180 | logits = model(input_ids.to(device), attention_mask.to(device)).logits
181 | # only consider the 3 intents of interest for prediction
182 | _temp_iids = torch.tensor(intent_ids_of_interest).to(device)
183 | logits = logits.index_select(index=_temp_iids, dim=1)
184 | preds = torch.argmax(logits, dim=1).cpu().tolist()
185 | # remap pred ids to dataset intent ids
186 | preds = [id2name[str(intent_ids_of_interest[p])] for p in preds]
187 |
188 | class_wise_preds = defaultdict(list)
189 | for p, l in zip(preds, labels):
190 | class_wise_preds[l].append(p)
191 |
192 | # Evaluation
193 | print(f"N-shot baseline performance on {len(eval_sentences)} examples:")
194 | print(f"Overall accuracy = {compute_acc(preds, labels)}")
195 | print(f"Class-wise accuracies:")
196 | for intent, preds in class_wise_preds.items():
197 | _acc = preds.count(intent) / len(preds)
198 | print(f"Acc. for intent {intent}: {_acc*100:.2f}")
199 |
200 |
201 | def main():
202 | """
203 | generated sentences
204 | filtered generated sentences
205 |
206 | val+test sentences for the dataset
207 | """
208 | args = parse_args()
209 | print(args)
210 | ds_config = mdu.get_ds_config(args.dataset_name)
211 | name2id = mdu.read_json(f"data/{args.dataset_name}/name2id.json")
212 | id2name = mdu.read_json(f"data/{args.dataset_name}/id2name.json")
213 |
214 | ############# START: Fetch all required set of sentences ################
215 |
216 | # fetch seed sentences for all the intents in the triplet using the
217 | # dataset prepared for partial fewshot experiment
218 | ds = mdu.read_pickle(f"data/{args.dataset_name}/full/data_full_suite.pkl")
219 | all_seed_sentences = []
220 | for intent in args.intent_triplet:
221 | seed_sents = get_seed_tuples(ds, intent, ds_config, name2id)
222 | all_seed_sentences.extend(seed_sents)
223 | # shuffle shuffle
224 | np.random.shuffle(all_seed_sentences)
225 |
226 | ds = mdu.read_pickle(f"data/{args.dataset_name}/full/al_dataset.pkl")
227 | # Fetch generated sentences as a list of tuples
228 | # each tuple --> (old intent, new intent, sentence)
229 | all_gen_sentences = []
230 | relabled_pool = ds["generated"][f"{args.gpt_engine}_1.0"]
231 | for idx, intent in enumerate(relabled_pool["old_intent"]):
232 | if id2name[str(intent)] not in args.intent_triplet:
233 | continue
234 | old_intent = id2name[str(intent)]
235 | new_intent = id2name[str(relabled_pool["intent"][idx])]
236 | sentence = relabled_pool["text"][idx]
237 | all_gen_sentences.append((old_intent, new_intent, sentence))
238 |
239 | # fetch val+test sentences (all accuracies are computed on these sentences)
240 | # NOTE: since it's val and test, old_intent is the ground truth
241 | eval_sentences = []
242 | for part in ["val", "test"]:
243 | sent_pool = ds["generated"][part]
244 | for (text, label) in zip(sent_pool["text"], sent_pool["old_intent"]):
245 | if id2name[str(label)] not in args.intent_triplet:
246 | continue
247 | eval_sentences.append((id2name[str(label)], text))
248 |
249 | # filtered examples using GPT as the classifier
250 | retained_generations = filter_via_gpt(all_gen_sentences, all_seed_sentences, args)
251 |
252 | ############## END: Fetched all required set of sentencs #################
253 |
254 | ############# COMPUTE DIFFERENT ##################
255 | print(f"\nFIDELITY for ALL {args.gpt_engine.upper()} generations")
256 | compute_fidelity(all_gen_sentences)
257 | print(f"\nFIDELITY after FILTERING {args.gpt_engine.upper()} rejections")
258 | compute_fidelity(retained_generations)
259 |
260 | print(f"\n{args.gpt_engine.upper()}'s 3-way classification performance")
261 | run_gpt_eval(eval_sentences, all_seed_sentences, args)
262 | print(f"\n10-SHOT-BASELINE's 3-way classification performance")
263 | oracle_eval(eval_sentences, id2name, name2id, args, "small")
264 | print(f"\nORACLE's 3-way classification performance")
265 | oracle_eval(eval_sentences, id2name, name2id, args, "big")
266 |
267 |
268 | def parse_args():
269 | parser = argparse.ArgumentParser()
270 | # parser.add_argument("-dn", "--dataset-name", default="hwu64")
271 | parser.add_argument("-dn", "--dataset-name", default="banking77")
272 | parser.add_argument("-e", "--gpt_engine", default="davinci")
273 | parser.add_argument(
274 | "-it",
275 | "--intent-triplet",
276 | nargs="+",
277 | # default=["music_likeness", "play_music", "music_settings"],
278 | default=["topping_up_by_card", "top_up_failed", "pending_top_up"],
279 | )
280 | return parser.parse_args()
281 |
282 |
283 | if __name__ == "__main__":
284 | main()
285 |
--------------------------------------------------------------------------------
/scripts/make_spreadsheet.py:
--------------------------------------------------------------------------------
1 | import os, json, pickle
2 | from openpyxl import Workbook
3 | from openpyxl.styles import Font
4 |
5 | DATASET_NAME = "banking77"
6 |
7 |
8 | def load_data():
9 | data_dir = f"./data/{DATASET_NAME}"
10 | opj = os.path.join
11 | ds_full_suite = pickle.load(open(opj(data_dir, "full/data_full_suite.pkl"), "rb"))
12 | generated_samples = pickle.load(open(opj(data_dir, "full/al_dataset.pkl"), "rb"))[
13 | "generated"
14 | ]
15 | id2name = json.load(open(opj(data_dir, "id2name.json")))
16 | return ds_full_suite, generated_samples, id2name
17 |
18 |
19 | def massage_data(ds_full_suite, generated_samples, id2name):
20 | workbooks = {}
21 | for engine in generated_samples:
22 | workbooks[engine] = gather_workbook_data(
23 | ds_full_suite, generated_samples[engine], id2name
24 | )
25 | return workbooks
26 |
27 |
28 | def gather_workbook_data(ds_full_suite, generated_samples, id2name):
29 | workbook_data = {}
30 | for domain in ds_full_suite:
31 | # generate prompt column
32 | prompt_data = ds_full_suite[domain]["F"]["train"]
33 | for text, intent_id in zip(prompt_data["text"], prompt_data["intent"]):
34 | intent_name = id2name[str(intent_id)].replace("?", "")
35 | sheet_name = f"{domain}<>{intent_name}"
36 | if sheet_name not in workbook_data:
37 | workbook_data[sheet_name] = {
38 | "prompt": [],
39 | "generated": [],
40 | "oracle_prediction": [],
41 | }
42 | workbook_data[sheet_name]["prompt"].append(text)
43 |
44 | # add generated data, and oracle prediction data
45 | for text, oracle_intent_id, org_intent_id in zip(
46 | generated_samples["text"],
47 | generated_samples["intent"],
48 | generated_samples["old_intent"],
49 | ):
50 | oracle_pred = id2name[str(oracle_intent_id)].replace("?", "")
51 | org_intent_name = id2name[str(org_intent_id)].replace("?", "")
52 | sheet_name = f"{domain}<>{org_intent_name}"
53 | if sheet_name not in workbook_data:
54 | # print(f"sheet {sheet_name} doesn't exist")
55 | continue
56 | workbook_data[sheet_name]["generated"].append(text)
57 | workbook_data[sheet_name]["oracle_prediction"].append(oracle_pred)
58 | return workbook_data
59 |
60 |
61 | def create_excel_sheet(name, data):
62 | wb = Workbook()
63 | wb.remove(wb.active) # remove the empty "Sheet"
64 | # create different sheets
65 | for sheet_name in data:
66 | org_intent = sheet_name.split("<>", 1)[1]
67 | ws = wb.create_sheet(sheet_name)
68 | prompts = data[sheet_name]["prompt"]
69 | generated = data[sheet_name]["generated"]
70 | oracle_predictions = data[sheet_name]["oracle_prediction"]
71 |
72 | ############# compute some quantities for formatting ##############
73 | # max width of column A
74 | max_sent_length = max(map(len, prompts + generated))
75 | # max width of column B
76 | max_pred_length = max(map(len, oracle_predictions))
77 | total_faithful_samples = oracle_predictions.count(org_intent)
78 | ############# compute end #################
79 |
80 | # add the first column
81 | ws.append(["Sentences", "Oracle Predictions"])
82 | # add the sentences column
83 | for irow in range(len(prompts + generated)):
84 | if irow < len(prompts):
85 | ws.append([prompts[irow]])
86 | else:
87 | new_irow = irow - len(prompts)
88 | ws.append([generated[new_irow], oracle_predictions[new_irow]])
89 |
90 | # some analysis
91 | ws["C1"] = "Total faithful samples"
92 | ws["C2"] = f"{total_faithful_samples}/{len(generated)}"
93 | ws["C3"] = f"{total_faithful_samples/len(generated)*100:.2f}%"
94 |
95 | # adjust column widths
96 | ws.column_dimensions["A"].width = max_sent_length
97 | ws.column_dimensions["B"].width = max_pred_length
98 | ws.column_dimensions["C"].width = len("Total faithful samples")
99 |
100 | # increase font size
101 | n_rows = len(prompts + generated)
102 | for col, n_rows in [("A", n_rows), ("B", n_rows), ("C", 3)]:
103 | for i in range(1, n_rows + 2):
104 | ws[f"{col}{i}"].font = Font(size=14)
105 |
106 | # bold the first row
107 | ws["A1"].font = Font(bold=True, size=14)
108 | ws["B1"].font = Font(bold=True, size=14)
109 | ws["C1"].font = Font(bold=True, size=14)
110 |
111 | # delete this useless sheet
112 | # sort sheets based on fidelity (ws['C3'] is the fidelity)
113 | wb._sheets.sort(key=lambda ws: float(ws["C3"].value[:-1]))
114 | wb.active = 0
115 |
116 | save_folder = f"spreadsheets/{DATASET_NAME}"
117 | if not os.path.exists(save_folder):
118 | os.mkdir(save_folder)
119 | wb.save(os.path.join(save_folder, f"{name}.xlsx"))
120 |
121 |
122 | if __name__ == "__main__":
123 | workbooks = massage_data(*load_data())
124 | for engine_temp, data in workbooks.items():
125 | if "_" in engine_temp and engine_temp.split("_")[1] != "1.0":
126 | continue
127 | create_excel_sheet(engine_temp, data)
128 |
--------------------------------------------------------------------------------
/scripts/openai_sandbox.py:
--------------------------------------------------------------------------------
1 | from utils.data_utils.augment_slices import openai_complete
2 | import argparse
3 |
4 |
5 | def main(args):
6 | if args.examples is None:
7 | raise ValueError("No seed examples provided.")
8 | if args.intent_name is None:
9 | raise ValueError("Please provide the name of a seed intent")
10 | if args.gpt_engine is None:
11 | print("No engine provided. Using ada...")
12 | ENGINE = "ada"
13 | else:
14 | ENGINE = args.gpt_engine
15 |
16 | intent = args.intent_name
17 | lines = args.examples
18 | k = len(args.examples)
19 |
20 | print("----Default method----")
21 | prompt = f"The following sentences belong to the same category {intent}:\n"
22 | prompt += "\n".join([f"Example {i+1}: {l}" for i, l in enumerate(lines)])
23 | prompt += "\n"
24 | prompt += f"Example {k+1}:"
25 | print(prompt)
26 | from pprint import pprint
27 |
28 | pprint(
29 | [
30 | r.text.strip()
31 | for r in openai_complete(
32 | prompt=prompt, n=20, engine=ENGINE, temp=1.0, top_p=1.0
33 | )
34 | ]
35 | )
36 |
37 |
38 | if __name__ == "__main__":
39 | parser = argparse.ArgumentParser()
40 |
41 | parser.add_argument(
42 | "-e",
43 | "--examples",
44 | nargs="+",
45 | default=None,
46 | help="Seed examples for prompting GPT",
47 | )
48 | parser.add_argument(
49 | "-i", "--intent_name", default=None, type=str, help="Name of the seed intent"
50 | )
51 |
52 | parser.add_argument(
53 | "-g",
54 | "--gpt_engine",
55 | default=None,
56 | help="GPT engine to use for augmentation (ada/babbage/curie/davinci)",
57 | )
58 | args, unknown = parser.parse_known_args()
59 |
60 | main(args)
61 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_utils.sample_few_shot import sample_few_shot
2 | from .data_utils.augment_slices import augment_slices
3 | from .data_utils import main as main_data_utils
--------------------------------------------------------------------------------
/utils/data_utils/augment_slices.py:
--------------------------------------------------------------------------------
1 | """
2 | This script uses openai to augment the dataset with
3 | samples for different few-shot domains
4 | """
5 | import os, gc, torch, openai, pickle, json
6 | import numpy as np
7 | from . import eda_utils
8 | from collections import Counter
9 |
10 | pjoin = os.path.join
11 |
12 |
13 | class GPTJChoice:
14 | def __init__(self, text):
15 | self.text = text
16 |
17 |
18 | def load_dataset_slices(data_root, data_name):
19 | with open(pjoin(data_root, data_name, "full", "data_full_suite.pkl"), "rb") as f:
20 | return pickle.load(f)
21 |
22 |
23 | def openai_complete(
24 | prompt,
25 | n,
26 | engine,
27 | temp,
28 | top_p,
29 | max_tokens=32,
30 | stop="\n",
31 | frequency_penalty=0,
32 | logprobs=None,
33 | ):
34 | completion = openai.Completion.create(
35 | engine=engine,
36 | prompt=prompt,
37 | max_tokens=max_tokens,
38 | n=n,
39 | stop=stop,
40 | temperature=temp,
41 | top_p=1 if not top_p else top_p,
42 | frequency_penalty=frequency_penalty,
43 | logprobs=logprobs,
44 | )
45 | return completion.choices
46 | # return [c.text.strip() for c in completion.choices]
47 |
48 |
49 | def gptj_complete(prompt, n, temp, model, tokenizer, top_k, top_p):
50 | """
51 | Parameters:
52 | ===========
53 | prompt: Str
54 | Text to be fed as prompt
55 | n: Int
56 | Number of sentences to be generated
57 | temp: Float
58 | Sampling temperature for GPTJ
59 | model: GPTJ model instance
60 | GPTJ model loaded for inference
61 | tokenizer: GPTJ tokenizer instance
62 | GPTJ tokenizer loaded (from Huggingface currenlty)
63 | top_k: Anyof False, Int
64 | top k tokens to consider when sampling
65 | top_p: Float
66 | p value for top-p sampling (nucleus sampling)
67 | """
68 | # k is the line where the predicted/generated sample resides
69 | # compensate for the last (incomplete) line "Example {num_seed+1}:"
70 | k = len(prompt.splitlines()) - 1
71 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids
72 | input_ids = input_ids.to("cuda")
73 | stop_token = tokenizer.encode("\n")[0]
74 | sentences = []
75 | while len(sentences) != n:
76 | # generate multiple sentences at a time
77 | # NOTE: .generate already sets torch.no_grad()
78 | gen_tokens = model.generate(
79 | input_ids,
80 | do_sample=True,
81 | max_length=2 * input_ids.shape[1],
82 | temperature=temp,
83 | eos_token_id=stop_token,
84 | # 30 is the max we can go on a 32G GPU
85 | # ^^that's a lie
86 | num_return_sequences=min(n, 30),
87 | # to suppress open-end generation warning
88 | pad_token_id=stop_token,
89 | top_k=0 if not top_k else top_k,
90 | top_p=1 if not top_p else top_p,
91 | )
92 | generations = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
93 | del gen_tokens
94 | # remove the first k lines as they belong to the prompt
95 | for i in range(min(n, 30)):
96 | # intentionally using that space after :
97 | # the model should be predicting that
98 | s = generations[i].splitlines()[k:][0][len(f"Example {k-1}: ") :].strip()
99 | # don't add if empty or if already generated n sentences
100 | if s and len(sentences) < n:
101 | sentences.append(s)
102 | del s
103 | del generations
104 | del input_ids
105 | del model
106 | gc.collect()
107 | torch.cuda.empty_cache()
108 | return [GPTJChoice(s) for s in sentences]
109 |
110 |
111 | def upsample_domain(prompt, n):
112 | lines = prompt.strip().splitlines()
113 | upsampled = lines * (n // len(lines))
114 | upsampled.extend(lines[: (n % len(lines))])
115 | return upsampled
116 |
117 |
118 | def eda_domain(prompt, n):
119 | lines = prompt.strip().splitlines()
120 | k = len(lines)
121 | augmented = []
122 | for line in lines:
123 | if not line:
124 | continue
125 | # augment for a line
126 | # NOTE: num_aug in the EDA paper is #new lines per training sample
127 | # using alpha = 0.05 as per recommendation in the paper
128 | generated = eda_utils.eda(
129 | sentence=line,
130 | alpha_sr=0.05,
131 | alpha_ri=0.05,
132 | alpha_rs=0.05,
133 | p_rd=0.05,
134 | num_aug=n // k,
135 | )
136 | augmented.extend(generated)
137 | return augmented
138 |
139 |
140 | def regenerate(input_prompt, n_empty, engine, temp, top_p):
141 | new_lines = []
142 | while n_empty > 0:
143 | print(f"Saw {n_empty} empty line(s). GPT3ing again...")
144 | curr_lines = openai_complete(
145 | prompt=input_prompt,
146 | n=n_empty,
147 | engine=engine,
148 | temp=temp,
149 | top_p=top_p,
150 | )
151 | curr_lines = [r.text.strip() for r in curr_lines]
152 | n_empty = curr_lines.count("")
153 | new_lines.extend([t for t in curr_lines if t])
154 | if n_empty == 0:
155 | return new_lines
156 |
157 |
158 | def augment_domain(
159 | dataset_slices,
160 | val_domain,
161 | data_save_path,
162 | id2name,
163 | num_ex=10,
164 | n_max=128,
165 | engine=None,
166 | temp=None,
167 | model=None,
168 | tokenizer=None,
169 | top_k=False,
170 | top_p=False,
171 | mode="upsample",
172 | mt_dict=None,
173 | ):
174 | """
175 | Augments a given domain in dataset_slices AND updates the pickle file
176 | """
177 | if len(dataset_slices[val_domain]["M"]["train"]["intent"]) == 0:
178 | # no many-domain available for this dataset
179 | data_path = os.path.join(os.path.dirname(data_save_path), "dataset.pkl")
180 | with open(data_path, "rb") as f:
181 | dataset = pickle.load(f)
182 | num_synthetic = int(
183 | np.median(list(Counter(dataset["train"]["intent"]).values()))
184 | )
185 | else:
186 | # when many-domain is available
187 | counts = Counter(dataset_slices[val_domain]["M"]["train"]["intent"])
188 | num_synthetic = int(np.median(list(counts.values())))
189 |
190 | f_train_lines = dataset_slices[val_domain]["F"]["train"]["text"]
191 | f_train_labels = dataset_slices[val_domain]["F"]["train"]["intent"]
192 |
193 | f_gen_lines, f_gen_labels = [], []
194 | for idx in range(0, len(f_train_lines), num_ex):
195 | prompt_lines = f_train_lines[idx : idx + num_ex]
196 |
197 | # simple prompt format (prepend 'Example i: ')
198 | input_prompt = "\n".join(
199 | [f"Example {i+1}: {t}" for i, t in enumerate(prompt_lines)]
200 | )
201 | input_prompt += f"\nExample {num_ex+1}:"
202 |
203 | # prompting with label addition as well
204 | seed_intent = id2name[f"{f_train_labels[idx : idx + num_ex][0]}"]
205 | print(f"Seed intent: {seed_intent}")
206 | input_prompt = (
207 | f"The following sentences belong to the same category '{seed_intent}':\n"
208 | + input_prompt
209 | )
210 |
211 | if mode == "upsample":
212 | print("Upsampling...")
213 | generated_lines = upsample_domain(input_prompt, num_synthetic)
214 | elif mode == "eda":
215 | print("EDAing...")
216 | generated_lines = eda_domain(input_prompt, num_synthetic)
217 | elif mode == "gptj":
218 | engine = "gptj"
219 | print("GPTJing...")
220 | generated_lines = gptj_complete(
221 | prompt=input_prompt,
222 | n=num_synthetic,
223 | temp=temp,
224 | model=model,
225 | tokenizer=tokenizer,
226 | top_k=top_k,
227 | top_p=top_p,
228 | )
229 | generated_lines = [r.text.strip() for r in generated_lines]
230 | else:
231 | print("GPT3ing...")
232 | if num_synthetic <= n_max:
233 | generated_lines = openai_complete(
234 | prompt=input_prompt,
235 | n=num_synthetic,
236 | engine=engine,
237 | temp=temp,
238 | top_p=top_p,
239 | )
240 | generated_lines = [r.text.strip() for r in generated_lines]
241 | else:
242 | generated_lines = []
243 | for _ in range(num_synthetic // n_max):
244 | _c = openai_complete(
245 | prompt=input_prompt,
246 | n=n_max,
247 | engine=engine,
248 | temp=temp,
249 | top_p=top_p,
250 | )
251 | generated_lines.extend([r.text.strip() for r in _c])
252 | # rest of the lines
253 | _c = openai_complete(
254 | prompt=input_prompt,
255 | n=num_synthetic % n_max,
256 | engine=engine,
257 | temp=temp,
258 | top_p=top_p,
259 | )
260 | generated_lines.extend([r.text.strip() for r in _c])
261 |
262 | # sometimes there are empty strings generated by GPT3, try again
263 | n_empty = generated_lines.count("")
264 | if n_empty > 0:
265 | generated_lines = [t for t in generated_lines if t]
266 | if n_empty <= n_max:
267 | generated_lines.extend(
268 | regenerate(
269 | input_prompt,
270 | n_empty,
271 | engine,
272 | temp,
273 | top_p,
274 | )
275 | )
276 | else:
277 | for _ in range(n_empty // n_max):
278 | generated_lines.extend(
279 | regenerate(
280 | input_prompt,
281 | n_max,
282 | engine,
283 | temp,
284 | top_p,
285 | )
286 | )
287 | # rest of the lines
288 | generated_lines.extend(
289 | regenerate(
290 | input_prompt,
291 | n_empty % n_max,
292 | engine,
293 | temp,
294 | top_p,
295 | )
296 | )
297 |
298 | assert len(generated_lines) == num_synthetic
299 |
300 | f_gen_lines.extend(generated_lines)
301 | # using len(generated_lines) to make sure #lines == #labels
302 | # as, for imbalanced datasets, there can be slightly more sentences
303 | # generated by EDA as it augments per sentence.
304 | f_gen_labels.extend([f_train_labels[idx]] * len(generated_lines))
305 |
306 | attr_name = mode if engine is None else f"{engine}_{temp}"
307 |
308 | dataset_slices[val_domain]["F"][attr_name] = {
309 | "text": f_gen_lines,
310 | "intent": f_gen_labels,
311 | }
312 | write_pickle(data_save_path, dataset_slices)
313 |
314 |
315 | def write_pickle(path, data):
316 | with open(path, "wb") as f:
317 | pickle.dump(data, f)
318 |
319 |
320 | def upsample_loop(dataset_slices, domains, data_save_path, id2name):
321 | # Upsample loop
322 | for val_domain in domains:
323 | if "upsample" in dataset_slices[val_domain]["F"].keys():
324 | print(f"upsample for {val_domain} already exists")
325 | continue
326 | print(f"Augmenting for domain: {val_domain}")
327 | augment_domain(
328 | dataset_slices=dataset_slices,
329 | val_domain=val_domain,
330 | data_save_path=data_save_path,
331 | id2name=id2name,
332 | )
333 |
334 |
335 | def eda_loop(dataset_slices, domains, data_save_path, id2name):
336 | """
337 | Easy data augmenatation baseline by Wei and Zhou (EMNLP, 2019)
338 | """
339 | # EDA loop:
340 | for val_domain in domains:
341 | if "eda" in dataset_slices[val_domain]["F"].keys():
342 | print(f"eda for {val_domain} already exists")
343 | continue
344 | print(f"Augmenting for domain: {val_domain}")
345 | augment_domain(
346 | dataset_slices=dataset_slices,
347 | val_domain=val_domain,
348 | data_save_path=data_save_path,
349 | id2name=id2name,
350 | mode="eda",
351 | )
352 |
353 |
354 | def load_gptj():
355 | """returns GPTJ and its tokenizer"""
356 | from transformers import GPTJForCausalLM, AutoTokenizer
357 |
358 | print("Loading GPT-J...")
359 | model = GPTJForCausalLM.from_pretrained(
360 | "EleutherAI/gpt-j-6B",
361 | revision="float16",
362 | torch_dtype=torch.float16,
363 | low_cpu_mem_usage=True,
364 | )
365 | model = model.to("cuda")
366 | model.eval()
367 | print("Loaded GPT-J.")
368 |
369 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
370 | return model, tokenizer
371 |
372 |
373 | def gptj_loop(
374 | dataset_slices,
375 | domains,
376 | ds_config,
377 | data_save_path,
378 | id2name,
379 | top_k,
380 | top_p,
381 | ):
382 | # GPTJ loop
383 | model, tokenizer = None, None
384 | # for temp in [1.0]:
385 | # round the temperatures to avoid floating points as 1.200003..
386 | for temp in [round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))]:
387 | print(f"Engine: GPT-J | Temp: {temp}")
388 | for val_domain in domains:
389 | # comment the following two lines to *update* existing lines
390 | if f"gptj_{temp}" in dataset_slices[val_domain]["F"].keys():
391 | print(f"gptj_{temp} for {val_domain} already exists")
392 | continue
393 |
394 | if model is None and tokenizer is None:
395 | model, tokenizer = load_gptj()
396 |
397 | print(f"Augmenting for domain: {val_domain}")
398 | augment_domain(
399 | dataset_slices=dataset_slices,
400 | val_domain=val_domain,
401 | data_save_path=data_save_path,
402 | id2name=id2name,
403 | num_ex=ds_config.num_examples,
404 | engine="gptj",
405 | temp=temp,
406 | model=model,
407 | tokenizer=tokenizer,
408 | top_k=top_k,
409 | top_p=top_p,
410 | mode="gptj",
411 | )
412 |
413 |
414 | def gpt3_loop(
415 | dataset_slices,
416 | domains,
417 | ds_config,
418 | data_save_path,
419 | id2name,
420 | top_p,
421 | ):
422 | # GPT3 loop
423 | for engine in ["davinci", "curie", "babbage", "ada"]:
424 | for temp in [1.0]:
425 | # round the temperatures to avoid floating points as 1.200003..
426 | # for temp in [round(a, 1) for a in np.linspace(0.5, 2, int((2.1 - 0.5) / 0.1))]:
427 | print(f"Engine: {engine} | Temp: {temp}")
428 | for val_domain in domains:
429 | # comment the following two lines to *update* existing lines
430 | if f"{engine}_{temp}" in dataset_slices[val_domain]["F"].keys():
431 | print(f"{engine}_{temp} for {val_domain} already exists")
432 | continue
433 | print(f"Augmenting for domain: {val_domain}")
434 | augment_domain(
435 | dataset_slices=dataset_slices,
436 | val_domain=val_domain,
437 | data_save_path=data_save_path,
438 | id2name=id2name,
439 | num_ex=ds_config.num_examples,
440 | n_max=ds_config.gpt3_batch_size,
441 | engine=engine,
442 | temp=temp,
443 | top_p=top_p,
444 | mode="gpt3",
445 | )
446 | # davinci quickly reaches the token/min limit, so we must sleep
447 | if engine == "davinci":
448 | print("sleeping, for openai won't let me GPT3 no more...")
449 | import time
450 |
451 | time.sleep(60)
452 |
453 |
454 | def augment_slices(
455 | data_root,
456 | ds_config,
457 | modes=["upsample", "gptj", "gpt3", "eda"],
458 | top_k=False,
459 | top_p=False,
460 | ):
461 | dataset_slices = load_dataset_slices(data_root, ds_config.data_name)
462 | DOMAINS = ds_config.domain_to_intent.keys()
463 |
464 | data_save_path = pjoin(
465 | data_root,
466 | ds_config.data_name,
467 | "full",
468 | "data_full_suite.pkl",
469 | )
470 | id2name = json.load(open(pjoin(data_root, ds_config.data_name, "id2name.json")))
471 |
472 | for mode in modes:
473 | if mode == "upsample":
474 | upsample_loop(dataset_slices, DOMAINS, data_save_path, id2name)
475 | elif mode == "gptj":
476 | gptj_loop(
477 | dataset_slices,
478 | DOMAINS,
479 | ds_config,
480 | data_save_path,
481 | id2name,
482 | top_k=top_k,
483 | top_p=top_p,
484 | )
485 | elif mode == "gpt3":
486 | if top_k:
487 | print("NOTE: ignoring top_k for gpt3 as openai doesn't support it yet")
488 | gpt3_loop(
489 | dataset_slices,
490 | DOMAINS,
491 | ds_config,
492 | data_save_path,
493 | id2name,
494 | top_p=top_p,
495 | )
496 | elif mode == "eda":
497 | eda_loop(dataset_slices, DOMAINS, data_save_path, id2name)
498 | return dataset_slices
499 |
--------------------------------------------------------------------------------
/utils/data_utils/banking77_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | All banking77 specific utilities are implemented here.
3 | """
4 | from datasets import load_dataset
5 | from sklearn.model_selection import train_test_split
6 | from haven import haven_utils as hu
7 | import os
8 |
9 | INTENTS = [
10 | "activate_my_card",
11 | "age_limit",
12 | "apple_pay_or_google_pay",
13 | "atm_support",
14 | "automatic_top_up",
15 | "balance_not_updated_after_bank_transfer",
16 | "balance_not_updated_after_cheque_or_cash_deposit",
17 | "beneficiary_not_allowed",
18 | "cancel_transfer",
19 | "card_about_to_expire",
20 | "card_acceptance",
21 | "card_arrival",
22 | "card_delivery_estimate",
23 | "card_linking",
24 | "card_not_working",
25 | "card_payment_fee_charged",
26 | "card_payment_not_recognised",
27 | "card_payment_wrong_exchange_rate",
28 | "card_swallowed",
29 | "cash_withdrawal_charge",
30 | "cash_withdrawal_not_recognised",
31 | "change_pin",
32 | "compromised_card",
33 | "contactless_not_working",
34 | "country_support",
35 | "declined_card_payment",
36 | "declined_cash_withdrawal",
37 | "declined_transfer",
38 | "direct_debit_payment_not_recognised",
39 | "disposable_card_limits",
40 | "edit_personal_details",
41 | "exchange_charge",
42 | "exchange_rate",
43 | "exchange_via_app",
44 | "extra_charge_on_statement",
45 | "failed_transfer",
46 | "fiat_currency_support",
47 | "get_disposable_virtual_card",
48 | "get_physical_card",
49 | "getting_spare_card",
50 | "getting_virtual_card",
51 | "lost_or_stolen_card",
52 | "lost_or_stolen_phone",
53 | "order_physical_card",
54 | "passcode_forgotten",
55 | "pending_card_payment",
56 | "pending_cash_withdrawal",
57 | "pending_top_up",
58 | "pending_transfer",
59 | "pin_blocked",
60 | "receiving_money",
61 | "Refund_not_showing_up",
62 | "request_refund",
63 | "reverted_card_payment?",
64 | "supported_cards_and_currencies",
65 | "terminate_account",
66 | "top_up_by_bank_transfer_charge",
67 | "top_up_by_card_charge",
68 | "top_up_by_cash_or_cheque",
69 | "top_up_failed",
70 | "top_up_limits",
71 | "top_up_reverted",
72 | "topping_up_by_card",
73 | "transaction_charged_twice",
74 | "transfer_fee_charged",
75 | "transfer_into_account",
76 | "transfer_not_received_by_recipient",
77 | "transfer_timing",
78 | "unable_to_verify_identity",
79 | "verify_my_identity",
80 | "verify_source_of_funds",
81 | "verify_top_up",
82 | "virtual_card_not_working",
83 | "visa_or_mastercard",
84 | "why_verify_identity",
85 | "wrong_amount_of_cash_received",
86 | "wrong_exchange_rate_for_cash_withdrawal",
87 | ]
88 |
89 |
90 | class Banking77:
91 | def __init__(self, name):
92 | self.data_name = name
93 | self.full_path = f"./data/{name}/full/dataset.pkl"
94 | self.num_examples = 10
95 |
96 | # 1. save dataset.pkl
97 | if not os.path.exists(self.full_path):
98 | self.dataset = load_dataset(name).rename_column("label", "intent")
99 | # get the splits
100 | train, validation = train_test_split(
101 | self.dataset["train"],
102 | train_size=0.9,
103 | stratify=self.dataset["train"]["intent"],
104 | )
105 |
106 | # get data dict
107 | data_dict = {
108 | "train": train,
109 | "val": validation,
110 | "test": self.dataset["test"].to_dict(),
111 | }
112 | hu.save_pkl(self.full_path, data_dict)
113 | else:
114 | data_dict = hu.load_pkl(self.full_path)
115 |
116 | # 2. Group by Intent
117 | name2id = {k: i for i, k in enumerate(INTENTS)}
118 | hu.save_json(f"./data/{name}/name2id.json", name2id)
119 | id2name = {i: k for i, k in enumerate(INTENTS)}
120 | hu.save_json(f"./data/{name}/id2name.json", id2name)
121 |
122 | # intents = list(name2id.keys())
123 |
124 | self.dataset_by_intent = {}
125 | for split in ["train", "val", "test"]:
126 | self.dataset_by_intent[split] = {}
127 | text_intent_dict = data_dict[split]
128 | text_list, intent_list = (
129 | text_intent_dict["text"],
130 | map(lambda x: id2name[x], text_intent_dict["intent"]),
131 | )
132 |
133 | # get texts from intent
134 | intent2texts = {}
135 | for t, i in zip(text_list, intent_list):
136 | if i not in intent2texts:
137 | intent2texts[i] = []
138 | intent2texts[i] += [t]
139 | self.dataset_by_intent[split] = intent2texts
140 |
141 | self.domain_to_intent = {name: INTENTS}
142 | self.gpt3_batch_size: int = 128
143 |
144 | def parse_and_load(self):
145 | return self.dataset_by_intent
146 |
--------------------------------------------------------------------------------
/utils/data_utils/clinc_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | All CLINC specific utilities must be implemented here.
3 | It defines a DS_CONFIG for the dataset.
4 | The main function that will be used by any other submodule in the repo is
5 | `parse_and_load_clinc`. In a nutshell, it prepares a `dataset.pkl` file for
6 | CLINC if not already prepared AND returns the dataset grouped by intent names.
7 | Refer to the function's documentation to know more details.
8 | """
9 |
10 | import pickle
11 | import os
12 | import json
13 | import collections
14 | from typing import Dict
15 | from pydantic import BaseModel
16 |
17 |
18 | class DS_CONFIG(BaseModel):
19 | data_name: str = "clinc_oos"
20 | full_path: str = "./data/clinc_oos/full/dataset.pkl"
21 | num_examples: int = 10
22 | # See https://github.com/clinc/oos-eval/blob/master/supplementary.pdf.
23 | domain_to_intent: Dict = {
24 | "banking": [
25 | "transfer",
26 | "transactions",
27 | "balance",
28 | "freeze_account",
29 | "pay_bill",
30 | "bill_balance",
31 | "bill_due",
32 | "interest_rate",
33 | "routing",
34 | "min_payment",
35 | "order_checks",
36 | "pin_change",
37 | "report_fraud",
38 | "account_blocked",
39 | "spending_history",
40 | ],
41 | "credit_card": [
42 | "credit_score",
43 | "report_lost_card",
44 | "credit_limit",
45 | "rewards_balance",
46 | "new_card",
47 | "application_status",
48 | "card_declined",
49 | "international_fees",
50 | "apr",
51 | "redeem_rewards",
52 | "credit_limit_change",
53 | "damaged_card",
54 | "replacement_card_duration",
55 | "improve_credit_score",
56 | "expiration_date",
57 | ],
58 | "dining": [
59 | "recipe",
60 | "restaurant_reviews",
61 | "calories",
62 | "nutrition_info",
63 | "restaurant_suggestion",
64 | "ingredients_list",
65 | "ingredient_substitution",
66 | "cook_time",
67 | "food_last",
68 | "meal_suggestion",
69 | "restaurant_reservation",
70 | "confirm_reservation",
71 | "how_busy",
72 | "cancel_reservation",
73 | "accept_reservations",
74 | ],
75 | "home": [
76 | "shopping_list",
77 | "shopping_list_update",
78 | "next_song",
79 | "play_music",
80 | "update_playlist",
81 | "todo_list",
82 | "todo_list_update",
83 | "calendar",
84 | "calendar_update",
85 | "what_song",
86 | "order",
87 | "order_status",
88 | "reminder",
89 | "reminder_update",
90 | "smart_home",
91 | ],
92 | "auto": [
93 | "traffic",
94 | "directions",
95 | "gas",
96 | "gas_type",
97 | "distance",
98 | "current_location",
99 | "mpg",
100 | "oil_change_when",
101 | "oil_change_how",
102 | "jump_start",
103 | "uber",
104 | "schedule_maintenance",
105 | "last_maintenance",
106 | "tire_pressure",
107 | "tire_change",
108 | ],
109 | "travel": [
110 | "book_flight",
111 | "book_hotel",
112 | "car_rental",
113 | "travel_suggestion",
114 | "travel_alert",
115 | "travel_notification",
116 | "carry_on",
117 | "timezone",
118 | "vaccines",
119 | "translate",
120 | "flight_status",
121 | "international_visa",
122 | "lost_luggage",
123 | "plug_type",
124 | "exchange_rate",
125 | ],
126 | "utility": [
127 | "time",
128 | "alarm",
129 | "share_location",
130 | "find_phone",
131 | "weather",
132 | "text",
133 | "spelling",
134 | "make_call",
135 | "timer",
136 | "date",
137 | "calculator",
138 | "measurement_conversion",
139 | "flip_coin",
140 | "roll_dice",
141 | "definition",
142 | ],
143 | "work": [
144 | "direct_deposit",
145 | "pto_request",
146 | "taxes",
147 | "payday",
148 | "w2",
149 | "pto_balance",
150 | "pto_request_status",
151 | "next_holiday",
152 | "insurance",
153 | "insurance_change",
154 | "schedule_meeting",
155 | "pto_used",
156 | "meeting_schedule",
157 | "rollover_401k",
158 | "income",
159 | ],
160 | "small_talk": [
161 | "greeting",
162 | "goodbye",
163 | "tell_joke",
164 | "where_are_you_from",
165 | "how_old_are_you",
166 | "what_is_your_name",
167 | "who_made_you",
168 | "thank_you",
169 | "what_can_i_ask_you",
170 | "what_are_your_hobbies",
171 | "do_you_have_pets",
172 | "are_you_a_bot",
173 | "meaning_of_life",
174 | "who_do_you_work_for",
175 | "fun_fact",
176 | ],
177 | "meta": [
178 | "change_ai_name",
179 | "change_user_name",
180 | "cancel",
181 | "user_name",
182 | "reset_settings",
183 | "whisper_mode",
184 | "repeat",
185 | "no",
186 | "yes",
187 | "maybe",
188 | "change_language",
189 | "change_accent",
190 | "change_volume",
191 | "change_speed",
192 | "sync_device",
193 | ],
194 | "oos": ["oos"],
195 | }
196 | gpt3_batch_size: int = 128
197 |
198 |
199 | DOMAIN_TO_INTENT = DS_CONFIG().domain_to_intent
200 | CLINC_FULL_PATH = "./data/clinc_oos/full/data_full.json"
201 |
202 |
203 | def make_label_maps():
204 | """
205 | returns a mapping of intent name to intent id and vice versa
206 | """
207 | intent_names = []
208 | for intent_list in DOMAIN_TO_INTENT.values():
209 | if intent_list == ["oos"]:
210 | continue
211 | intent_names.extend(intent_list)
212 |
213 | name2id, id2name = {}, {}
214 | for idx, name in enumerate(sorted(set(intent_names))):
215 | name2id[name] = idx
216 | id2name[idx] = name
217 |
218 | # explicitly assign oos an id of 150
219 | name2id["oos"] = 150
220 | id2name[150] = "oos"
221 |
222 | # storing in JSON for readability later on
223 | # NOTE: keys in id2name.json will be str and not int!
224 | with open("./data/clinc_oos/name2id.json", "w") as f:
225 | json.dump(name2id, f)
226 | with open("./data/clinc_oos/id2name.json", "w") as f:
227 | json.dump(id2name, f)
228 | return name2id, id2name
229 |
230 |
231 | def download_clinc_full():
232 | download_dir = "./data/clinc_oos/full/"
233 | # create dir if doesn't exist
234 | if not os.path.exists(download_dir):
235 | os.makedirs(download_dir)
236 | clinc_data_url = (
237 | "https://raw.githubusercontent.com/clinc/oos-eval/master/data/data_full.json"
238 | )
239 | os.system(f"wget {clinc_data_url} -P {download_dir}")
240 | return json.load(open(os.path.join(download_dir, "data_full.json")))
241 |
242 |
243 | def load_or_download_clinc():
244 | try:
245 | full_data = json.load(open(CLINC_FULL_PATH))
246 | except FileNotFoundError:
247 | full_data = download_clinc_full()
248 | return full_data
249 |
250 |
251 | def get_label_maps():
252 | try:
253 | intentname2id = json.load(open("./data/clinc_oos/name2id.json"))
254 | intentid2name = json.load(open("./data/clinc_oos/id2name.json"))
255 | except FileNotFoundError:
256 | intentname2id, intentid2name = make_label_maps()
257 | return intentname2id, intentid2name
258 |
259 |
260 | def prepare_clinc():
261 | full_data = load_or_download_clinc()
262 | intentname2id, intentid2name = get_label_maps()
263 |
264 | data_dict = {
265 | "train": {"text": [], "intent": []},
266 | "val": {"text": [], "intent": []},
267 | "test": {"text": [], "intent": []},
268 | }
269 |
270 | for key, data in full_data.items():
271 | for sample in data:
272 | line, label = sample
273 | label = intentname2id[label]
274 | # maybe there's a better way to prepare this...
275 | if "train" in key:
276 | data_dict["train"]["text"].append(line)
277 | data_dict["train"]["intent"].append(label)
278 | elif "val" in key:
279 | data_dict["val"]["text"].append(line)
280 | data_dict["val"]["intent"].append(label)
281 | else:
282 | data_dict["test"]["text"].append(line)
283 | data_dict["test"]["intent"].append(label)
284 |
285 | print("Data details:")
286 | print(
287 | f'Train #lines: {len(data_dict["train"]["text"])} #labels: {len(data_dict["train"]["intent"])}'
288 | )
289 | print(
290 | f'Val #lines: {len(data_dict["val"]["text"])} #labels: {len(data_dict["val"]["intent"])}'
291 | )
292 | print(
293 | f'Test #lines: {len(data_dict["test"]["text"])} #labels: {len(data_dict["test"]["intent"])}'
294 | )
295 |
296 | with open("./data/clinc_oos/full/dataset.pkl", "wb") as f:
297 | pickle.dump(data_dict, f)
298 | print("Base dataset.pkl prepared for CLINC!")
299 |
300 |
301 | def load_clinc():
302 | data = collections.defaultdict(lambda: collections.defaultdict(list))
303 | with open(CLINC_FULL_PATH, "r") as data_file:
304 | for split_name, split_data in json.load(data_file).items():
305 | for query, intent in split_data:
306 | data[split_name][intent].append(query)
307 | return data
308 |
309 |
310 | def parse_and_load_clinc():
311 | """
312 | This functions has two primary roles:
313 | 1) create `dataset.pkl` for CLINC if it doesn't exist already
314 | 2) return the CLINC dataset grouped by intent names
315 |
316 | Secondary role:
317 | This function also creates name2id.json and id2name.json files for the
318 | dataset in the respective dataset folder (e.g. ./data/clinc_oos/ for
319 | CLINC. Note that the name of the dataset folder matches
320 | DS_CONFIG.data_name). If the dataset contains an OOS class, make it the
321 | last class for that dataset (to easen execution of non-OOS experiments)
322 |
323 | name2id.json is a Dict[Str: Int] and id2name.json is a Dict[Str: Str]
324 |
325 | parse_and_load_clinc() is the only function that will interact with the
326 | outside world, and any dataset specific utility required to accomplish the
327 | described roles above (like downloading, parsing, etc.) may be implemented
328 | in this file.
329 |
330 | NOTE that we store intent ids in dataset.pkl, but the grouped dataset
331 | returned by the function stores intent names!
332 |
333 | `dataset.pkl` contains a Dict object with the following schema:
334 |
335 | {
336 | 'train': {'text': listofStr, 'intent': listofInt},
337 | 'val': {'text': listofStr, 'intent': listofInt},
338 | 'test': {'text': listofStr, 'intent': listofInt}
339 | }
340 |
341 | Format of the returned dataset grouped by intent names:
342 |
343 | collections.defaultdict(None, {
344 | 'train': {
345 | Str: listofStr,
346 | Str: listofStr,
347 | .
348 | .
349 | .
350 | Str: listofStr # n_intents
351 | },
352 |
353 | 'val': {
354 | Str: listofStr,
355 | Str: listofStr,
356 | .
357 | .
358 | .
359 | Str: listofStr # n_intents
360 | },
361 |
362 | 'test': {
363 | Str: listofStr,
364 | Str: listofStr,
365 | .
366 | .
367 | .
368 | Str: listofStr # n_intents
369 | }
370 | })
371 |
372 | parse_and_load_clinc: None -> collections.defaultdict
373 | """
374 | if not os.path.exists(os.path.join("./data/clinc_oos/full/dataset.pkl")):
375 | prepare_clinc()
376 | return load_clinc()
377 |
--------------------------------------------------------------------------------
/utils/data_utils/data_loader.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from utils import main_data_utils as mdu
3 | from datasets import load_dataset, Dataset, DatasetDict, Dataset
4 | import os, warnings, regex, numpy as np, collections, math, time
5 |
6 | from utils.data_utils.augment_slices import openai_complete
7 |
8 |
9 | def _normalize(probs):
10 | """
11 | Given label probs Dict[str: list]
12 | first, normalizes probablities of tokens predicted multiple times
13 | then, normalizes across all the predicted labels
14 | """
15 | # NOTE: not using assertion because davinci and curie give differet probs
16 | # for the same prediction sometimes
17 | # for k, v in probs.items():
18 | # prob should be the same for all the multiple predictions of the label
19 | # assert len(set(v)) == 1
20 | # probs[k] = v[0]
21 | probs = {k: np.mean(v) for k, v in probs.items()}
22 | return {k: v / sum(probs.values()) for k, v in probs.items()}
23 |
24 |
25 | def gpt3mix_complete(prompt, n, labels_list, exp_dict, name2id):
26 | """
27 | Given a seed_text and its corresponding seed_intent (name, not id),
28 |
29 | 1. generate x_dash (n augmentations per seed_text)
30 | 2. generate y_dash (soft label using name2id)
31 | """
32 | pattern = regex.compile(rf"(?r)Sentence: (.*)\(intent: (.*)\)")
33 | # gpt prompting to generate x_dashes
34 | completions = openai_complete(
35 | engine=exp_dict["gpt3_engine"],
36 | prompt=prompt,
37 | temp=exp_dict["gpt3_temp"],
38 | top_p=1.0,
39 | n=n,
40 | stop="\n",
41 | max_tokens=50,
42 | frequency_penalty=0.02,
43 | )
44 |
45 | augmentations = {"text": [], "intent": []}
46 | for c in completions:
47 | match = pattern.search("Sentence:" + c.text)
48 | if match is None: # invalid prediction
49 | continue
50 | _txt = match.group(1).strip().lower()
51 | if not _txt:
52 | continue
53 |
54 | # prompt GPT3 again to create soft label
55 | label_completions = openai_complete(
56 | engine=exp_dict["gpt3_engine"],
57 | prompt=prompt + f" {_txt}. (intent:",
58 | temp=exp_dict["gpt3_temp"],
59 | n=100,
60 | top_p=1.0,
61 | max_tokens=20,
62 | stop="\n",
63 | logprobs=1,
64 | )
65 |
66 | # construct probabilities for all the predicted labels
67 | label_probs = collections.defaultdict(list)
68 | for _lc in label_completions:
69 | _log_probs = _lc.logprobs
70 | _match = pattern.search(f"Sentence: {_txt}. (intent:" + _lc.text)
71 | if _match is None: # incomplete prediction
72 | continue
73 |
74 | _pred = _match.group(2).strip().lower()
75 | if _pred not in labels_list: # invalid label
76 | continue
77 |
78 | # NOTE: we are looking at token_logprobs v/s top_logprobs used
79 | # by the GPT3Mix paper because we are sampling to compute
80 | # p(y_dash| x_dash) as opposed to looking at logprobs of top 100
81 | # most likely tokens. default value limits us to just 5 now.
82 | _curr_log_prob = 0
83 | for t, p in zip(_log_probs["tokens"], _log_probs["token_logprobs"]):
84 | # if the code reaches here, ) is guaranteed to be present
85 | # as regex check earlier would trigger a `continue` otherwise
86 | if t == ")":
87 | label_probs[_pred].append(math.exp(_curr_log_prob))
88 | break
89 |
90 | # add logprobs (multiply probs) for sub words of _pred as
91 | # class names are not single tokens
92 | _curr_log_prob += p
93 |
94 | # normalize label_probs
95 | label_probs = _normalize(label_probs)
96 | # create soft label
97 | soft_label = [0] * exp_dict["dataset"]["num_labels"]
98 | for k, v in label_probs.items():
99 | soft_label[name2id[k]] = v
100 |
101 | augmentations["text"].append(_txt)
102 | augmentations["intent"].append(soft_label)
103 | return augmentations
104 |
105 |
106 | def generate_for_gpt3mix(base_ds, ex2_ds, exp_dict, interim_save_path):
107 | num_labels = exp_dict["dataset"]["num_labels"]
108 | ds_name = exp_dict["dataset"]["name"]
109 | id2name = mdu.read_json(f"data/{ds_name}/id2name.json")
110 | name2id = mdu.read_json(f"data/{ds_name}/name2id.json")
111 | ds_config = mdu.get_ds_config(ds_name)
112 | k = ds_config.num_examples
113 |
114 | labels_list = list(name2id.keys())
115 | if "oos" in labels_list:
116 | labels_list.remove("oos")
117 |
118 | train_lines, train_labels = [], []
119 | if os.path.exists(interim_save_path):
120 | interim_copy = mdu.read_pickle(interim_save_path)
121 | else:
122 | interim_copy = {}
123 |
124 | for domain in ex2_ds:
125 | if domain in interim_copy:
126 | print(f"Domain: {domain} already GPT3Mix augmented. Moving on...")
127 | continue
128 |
129 | print(f"Augmenting domain: {domain}")
130 | texts = ex2_ds[domain]["F"]["train"]["text"]
131 | hard_labels = ex2_ds[domain]["F"]["train"]["intent"]
132 | _lines, _labels = [], []
133 | # NOTE: this loop will never be executed for oos--both lists will be []
134 | for text, intent in tqdm(zip(texts, hard_labels), total=len(texts)):
135 | # add gold example to training set
136 | one_hot = [0.0] * num_labels
137 | one_hot[intent] = 1.0
138 | # for interim copy
139 | _lines.append(text)
140 | _labels.append(one_hot)
141 |
142 | # construct prompt header
143 | prompt = "Each item in the following list contains a sentence and the respective intent."
144 | label_enum_str = [f"'{l.lower()}'" for l in labels_list]
145 | prompt += f" Intent is one of {', or '.join(label_enum_str)}"
146 | prompt += ".\n"
147 | prompt += f"Sentence: {text}. (intent: {id2name[str(intent)]})\n"
148 |
149 | # remove current intent from candidates to sample from
150 | _lbl_list = [l for l in labels_list if l != id2name[str(intent)]]
151 | # sample k-1 random intents from the label_set (k=9)
152 | other_lbls = np.random.choice(_lbl_list, k - 1, replace=False)
153 | # fetch a sample for each of these new intents and add to the prompt
154 | for lbl in other_lbls:
155 | # find the domain of lbl
156 | _domain, _domain_found = None, False
157 | for _d, _i_l in ds_config.domain_to_intent.items():
158 | if not _domain_found and lbl in _i_l:
159 | _domain_found = True
160 | _domain = _d
161 |
162 | gt_txts = ex2_ds[_domain]["F"]["train"]["text"]
163 | gt_lbls = ex2_ds[_domain]["F"]["train"]["intent"]
164 | _start = gt_lbls.index(name2id[lbl])
165 | # select a random sentence for lbl
166 | _text = np.random.choice(gt_txts[_start : _start + k], 1)[0]
167 | # add the _text, lbl pair to prompt
168 | prompt += f"Sentence: {_text}. (intent: {lbl})\n"
169 | prompt += "Sentence:"
170 |
171 | # generated examples with soft labels
172 | augs = gpt3mix_complete(prompt, 10, labels_list, exp_dict, name2id)
173 | _lines.extend(augs["text"])
174 | _labels.extend(augs["intent"])
175 |
176 | train_lines.extend(_lines)
177 | train_labels.extend(_labels)
178 |
179 | # save an interim copy now
180 | interim_copy[domain] = {"text": _lines, "intent": _labels}
181 | mdu.write_pickle(interim_copy, interim_save_path)
182 |
183 | print("Sleeping...for a minute")
184 | time.sleep(60)
185 |
186 | # Add OOS samples
187 | oos_texts, oos_labels = extract_oos(base_ds["train"], exp_dict["dataset"]["oos_id"])
188 | for text, intent in tqdm(zip(oos_texts, oos_labels), total=len(oos_texts)):
189 | # add gold example to training set
190 | one_hot = [0.0] * num_labels
191 | one_hot[intent] = 1.0
192 | train_lines.append(text)
193 | train_labels.append(one_hot)
194 |
195 | # delete interim copy
196 | del interim_copy
197 | return {"text": train_lines, "intent": train_labels}
198 |
199 |
200 | def prepare_for_seq2seq(dataset, id2name_path):
201 | """
202 | dataset: Dict[str]:
203 | """
204 | id2name = mdu.read_json(id2name_path)
205 | return {
206 | "text": [t + " " for t in dataset["text"]],
207 | # intents are class ids here, not names
208 | "intent": [id2name[str(i)] + " " for i in dataset["intent"]],
209 | }
210 |
211 |
212 | def filter_oos(data_dict, oos_id, soft_label=False):
213 | """Removes oos samples from the data dict"""
214 | lines, labels = data_dict["text"], data_dict["intent"]
215 | # some datasets (like SNIPS) don't have an OOS class
216 | if oos_id is None:
217 | return lines, labels
218 | _lines, _labels = [], []
219 | for idx, intent_id in enumerate(labels):
220 | if soft_label and np.array(intent_id).argmax(-1) == oos_id:
221 | continue
222 | if not soft_label and intent_id == oos_id:
223 | continue
224 | _lines.append(lines[idx])
225 | _labels.append(labels[idx])
226 | # print(len(_lines), len(_labels))
227 | return _lines, _labels
228 |
229 |
230 | def extract_oos(data_dict, oos_id):
231 | """Extract the OOS samples from the data dict. It is the
232 | opposite of filter_oos"""
233 | lines, labels = data_dict["text"], data_dict["intent"]
234 | # some datasets (like SNIPS) don't have an OOS class
235 | _lines, _labels = [], []
236 | for idx, intent_id in enumerate(labels):
237 | if intent_id != oos_id:
238 | continue
239 | _lines.append(lines[idx])
240 | _labels.append(labels[idx])
241 | return _lines, _labels
242 |
243 |
244 | class DatasetLoader:
245 | """
246 | Available datasets:
247 | - Clinc original: We can define whether to get the `full` version or the `small` version.
248 | - Pure Fewshot Clinc:
249 | baseline: Contains 10 example per class (except the OOS) which is randomly sampled from the original full clinc.
250 |
251 | """
252 |
253 | def __init__(self, data_root, exp_dict):
254 | dataset_config = exp_dict["dataset"]["config"]
255 |
256 | var_path = "full" if dataset_config.startswith("f") else "small"
257 | ds_name = exp_dict["dataset"]["name"]
258 | basic_data_path = os.path.join(data_root, ds_name, var_path, "dataset.pkl")
259 | ex2_data_path = os.path.join(
260 | data_root, ds_name, var_path, "data_full_suite.pkl"
261 | )
262 |
263 | if dataset_config == "few_pure":
264 | base_ds = mdu.read_pickle(basic_data_path)
265 | data_set = mdu.read_pickle(ex2_data_path)
266 | oos_id = exp_dict["dataset"]["oos_id"]
267 | train_lines, train_labels = [], []
268 | if exp_dict["exp_type"] == "baseline":
269 | print("Loading dataset for full few-shot baseline")
270 | for domain in data_set:
271 | train_lines.extend(data_set[domain]["F"]["train"]["text"])
272 | train_labels.extend(data_set[domain]["F"]["train"]["intent"])
273 | elif exp_dict["exp_type"] in ["eda"]:
274 | exp_type = exp_dict["exp_type"]
275 | print(f"Loading dataset for full few-shot {exp_type.upper()}")
276 | # lump in EDA examples with all few-shot samples
277 | for domain in data_set:
278 | train_lines.extend(
279 | data_set[domain]["F"]["train"]["text"]
280 | + data_set[domain]["F"][exp_type]["text"]
281 | )
282 | train_labels.extend(
283 | data_set[domain]["F"]["train"]["intent"]
284 | + data_set[domain]["F"][exp_type]["intent"]
285 | )
286 | elif exp_dict["exp_type"] in ["gpt3", "eda"]:
287 | print(f"Loading dataset for full few-shot {exp_dict['exp_type']}")
288 | # set correct attribute to fetch from the dataset
289 | if exp_dict["exp_type"] == "gpt3":
290 | engine, temp = exp_dict["gpt3_engine"], exp_dict["gpt3_temp"]
291 | attr = f"{engine}_{temp}"
292 | else: # eda
293 | attr = exp_dict["exp_type"]
294 |
295 | # lump in the fetched examples with all few-shot samples
296 | for domain in data_set:
297 | train_lines.extend(
298 | data_set[domain]["F"]["train"]["text"]
299 | + data_set[domain]["F"][attr]["text"]
300 | )
301 | train_labels.extend(
302 | data_set[domain]["F"]["train"]["intent"]
303 | + data_set[domain]["F"][attr]["intent"]
304 | )
305 | elif exp_dict["exp_type"] in [
306 | "gpt3_oracle",
307 | "eda_oracle",
308 | "gpt3mix_oracle",
309 | ]:
310 | # the few shot sentences are taken from the ex2 setup data
311 | # and the relabeled samples are taken from al_dataset.pkl
312 | print(f"Loading dataset for full few-shot {exp_dict['exp_type']}")
313 | # Use relabeled dataset as base
314 | al_path = os.path.join(data_root, ds_name, "full", "al_dataset.pkl")
315 | al_ds = mdu.read_pickle(al_path)
316 |
317 | # set correct attribute to fetch from the dataset
318 | if exp_dict["exp_type"] == "gpt3_oracle":
319 | engine, temp = exp_dict["gpt3_engine"], exp_dict["gpt3_temp"]
320 | attr = f"{engine}_{temp}"
321 | elif exp_dict["exp_type"] == "gpt3mix_oracle":
322 | attr = f"gpt3mix_{exp_dict['gpt3_engine']}"
323 | else: # eda_oracle
324 | attr = exp_dict["exp_type"].split("_")[0] # just eda
325 |
326 | for domain in data_set:
327 | train_lines.extend(data_set[domain]["F"]["train"]["text"])
328 | train_labels.extend(data_set[domain]["F"]["train"]["intent"])
329 | train_lines.extend(al_ds["generated"][attr]["text"])
330 | train_labels.extend(al_ds["generated"][attr]["intent"])
331 |
332 | elif exp_dict["exp_type"] == "gpt3mix":
333 | print("Loading labelled pool for full few-shot gpt3mix")
334 | engine = exp_dict["gpt3_engine"]
335 | gpt3mix_path = f"data/{ds_name}/full/gpt3mix_{engine}.pkl"
336 |
337 | # augs will also contain the seed samples
338 | if os.path.exists(gpt3mix_path): # load from existing pkl
339 | print(f"Loading existing GPT3Mix data for {engine.upper()}")
340 | augs = mdu.read_pickle(gpt3mix_path)
341 | else: # otherwise, generate gpt3mix pickle
342 | print(f"Generating GPT3Mix data with {engine.upper()}")
343 | interim_save_path = gpt3mix_path[:-4] + "_interim.pkl"
344 | augs = generate_for_gpt3mix(
345 | base_ds, data_set, exp_dict, interim_save_path
346 | )
347 | # save complete augmented data
348 | mdu.write_pickle(augs, gpt3mix_path)
349 |
350 | train_lines, train_labels = augs["text"], augs["intent"]
351 |
352 | val_lines, val_labels = base_ds["val"]["text"], base_ds["val"]["intent"]
353 | test_lines, test_labels = (
354 | base_ds["test"]["text"],
355 | base_ds["test"]["intent"],
356 | )
357 |
358 | # add oos samples to train set (gpt3mix setting already adds)
359 | if oos_id is not None and exp_dict["exp_type"] != "gpt3mix":
360 | # add oos samples to the dataset
361 | oos_lines, oos_labels = extract_oos(base_ds["train"], oos_id)
362 | train_lines.extend(oos_lines)
363 | train_labels.extend(oos_labels)
364 |
365 | # remove oos samples appropriately
366 | if oos_id is None:
367 | name2id_path = os.path.join(data_root, ds_name, "name2id.json")
368 | temp_oos_id = mdu.read_json(name2id_path).get("oos", None)
369 | if exp_dict["exp_type"] == "gpt3mix":
370 | train_set = {"text": train_lines, "intent": train_labels}
371 | # remove oos samples add to train set here by default
372 | train_lines, train_labels = filter_oos(
373 | train_set, oos_id, soft_label=True
374 | )
375 | # remove oos samples from the val set added by default
376 | val_lines, val_labels = filter_oos(base_ds["val"], temp_oos_id)
377 | test_lines, test_labels = filter_oos(base_ds["test"], temp_oos_id)
378 |
379 | print(len(train_lines), len(train_labels))
380 | self.dataset = DatasetDict(
381 | train=Dataset.from_dict({"text": train_lines, "intent": train_labels}),
382 | validation=Dataset.from_dict({"text": val_lines, "intent": val_labels}),
383 | test=Dataset.from_dict({"text": test_lines, "intent": test_labels}),
384 | )
385 | elif dataset_config == "full":
386 | # read the original FULL version of the dataset
387 | data_set = mdu.read_pickle(basic_data_path)
388 |
389 | if exp_dict["exp_type"] == "intrinsic":
390 | print("Loading utils for intrinsic evaluation")
391 |
392 | oos_id = exp_dict["dataset"]["oos_id"]
393 | train_lines, train_labels = filter_oos(data_set["train"], oos_id)
394 | val_lines, val_labels = filter_oos(data_set["val"], oos_id)
395 | test_lines, test_labels = filter_oos(data_set["test"], oos_id)
396 |
397 | self.dataset = DatasetDict(
398 | train=Dataset.from_dict(
399 | {"text": train_lines, "intent": train_labels}
400 | ),
401 | validation=Dataset.from_dict(
402 | {"text": val_lines, "intent": val_labels}
403 | ),
404 | test=Dataset.from_dict({"text": test_lines, "intent": test_labels}),
405 | )
406 | # add different set of generated lines as test set
407 | augmented_data = mdu.mdu.read_pickle(ex2_data_path)
408 | domains = list(augmented_data.keys())
409 | for e in ["ada", "babbage", "curie", "davinci", "gptj"]:
410 | # for t in np.linspace(0.5, 2, int((2.1-.5)/.1)):
411 | for t in [1.0]:
412 | _lines, _intents = [], []
413 | for d in domains:
414 | if d == "oos":
415 | continue
416 | _lines.extend(augmented_data[d]["F"][f"{e}_{t}"]["text"])
417 | _intents.extend(
418 | augmented_data[d]["F"][f"{e}_{t}"]["intent"]
419 | )
420 | self.dataset[f"{e}_{t}"] = Dataset.from_dict(
421 | {"text": _lines, "intent": _intents}
422 | )
423 | elif exp_dict["exp_type"] == "baseline":
424 | print("Loading utils for baseline version")
425 | self.dataset = DatasetDict(
426 | train=Dataset.from_dict(data_set["train"]),
427 | validation=Dataset.from_dict(data_set["val"]),
428 | test=Dataset.from_dict(data_set["test"]),
429 | )
430 |
431 | elif dataset_config.startswith("full_"):
432 | print(f"Loading utils for {dataset_config}")
433 | # read the augmented version of the dataset
434 | data_set = mdu.read_pickle(ex2_data_path)
435 | # the few-shot domain
436 | val_domain = dataset_config.split("_", 1)[1]
437 | # train set = D_{M, train} + D_{F, train}
438 | train_lines = (
439 | data_set[val_domain]["M"]["train"]["text"]
440 | + data_set[val_domain]["F"]["train"]["text"]
441 | )
442 |
443 | train_labels = (
444 | data_set[val_domain]["M"]["train"]["intent"]
445 | + data_set[val_domain]["F"]["train"]["intent"]
446 | )
447 |
448 | if exp_dict["exp_type"] == "upsample":
449 | train_lines.extend(data_set[val_domain]["F"]["upsample"]["text"])
450 | train_labels.extend(data_set[val_domain]["F"]["upsample"]["intent"])
451 | elif exp_dict["exp_type"] == "gpt3":
452 | engine = exp_dict["gpt3_engine"]
453 | temp = exp_dict["gpt3_temp"]
454 | train_lines.extend(
455 | data_set[val_domain]["F"][f"{engine}_{temp}"]["text"]
456 | )
457 | train_labels.extend(
458 | data_set[val_domain]["F"][f"{engine}_{temp}"]["intent"]
459 | )
460 |
461 | full_val_lines = (
462 | data_set[val_domain]["M"]["val"]["text"]
463 | + data_set[val_domain]["F"]["val"]["text"]
464 | )
465 |
466 | full_val_labels = (
467 | data_set[val_domain]["M"]["val"]["intent"]
468 | + data_set[val_domain]["F"]["val"]["intent"]
469 | )
470 |
471 | full_test_lines = (
472 | data_set[val_domain]["M"]["test"]["text"]
473 | + data_set[val_domain]["F"]["test"]["text"]
474 | )
475 |
476 | full_test_labels = (
477 | data_set[val_domain]["M"]["test"]["intent"]
478 | + data_set[val_domain]["F"]["test"]["intent"]
479 | )
480 |
481 | # add oos samples to the dataset for oos-aware classifiers
482 | if exp_dict["dataset"]["oos_id"] is not None:
483 | print("adding OOS samples to the dataset")
484 | base_ds = mdu.mdu.read_pickle(basic_data_path)
485 | oos_id = exp_dict["dataset"]["oos_id"]
486 |
487 | # augment training set
488 | oos_train_lines, oos_train_labels = extract_oos(
489 | base_ds["train"], oos_id
490 | )
491 | train_lines.extend(oos_train_lines)
492 | train_labels.extend(oos_train_labels)
493 |
494 | # augment validation set
495 | oos_val_lines, oos_val_labels = extract_oos(base_ds["val"], oos_id)
496 | full_val_lines.extend(oos_val_lines)
497 | full_val_labels.extend(oos_val_labels)
498 |
499 | # augment test set
500 | oos_test_lines, oos_test_labels = extract_oos(base_ds["test"], oos_id)
501 | full_test_lines.extend(oos_test_lines)
502 | full_test_labels.extend(oos_test_labels)
503 |
504 | self.dataset = DatasetDict(
505 | train=Dataset.from_dict({"text": train_lines, "intent": train_labels}),
506 | validation=Dataset.from_dict(data_set[val_domain]["F"]["val"]),
507 | test=Dataset.from_dict(data_set[val_domain]["F"]["test"]),
508 | full_test=Dataset.from_dict(
509 | {"text": full_test_lines, "intent": full_test_labels}
510 | ),
511 | full_validation=Dataset.from_dict(
512 | {"text": full_val_lines, "intent": full_val_labels}
513 | ),
514 | )
515 | else:
516 | warnings.warn("At the moment we can only load clinc_oos")
517 | self.dataset = load_dataset(ds_name, dataset_config, cache_dir=data_root)
518 |
519 | def get_split(self, split):
520 | return self.dataset[split]
521 |
--------------------------------------------------------------------------------
/utils/data_utils/eda_utils.py:
--------------------------------------------------------------------------------
1 | # Easy data augmentation techniques for text classification
2 | # Jason Wei and Kai Zou from
3 | # https://github.com/jasonwei20/eda_nlp/blob/04ab29c5b18d2d72f9fa5b304322aaf4793acea0/code/eda.py
4 |
5 | import random
6 | from random import shuffle
7 |
8 | random.seed(1)
9 |
10 | # stop words list
11 | stop_words = [
12 | "i",
13 | "me",
14 | "my",
15 | "myself",
16 | "we",
17 | "our",
18 | "ours",
19 | "ourselves",
20 | "you",
21 | "your",
22 | "yours",
23 | "yourself",
24 | "yourselves",
25 | "he",
26 | "him",
27 | "his",
28 | "himself",
29 | "she",
30 | "her",
31 | "hers",
32 | "herself",
33 | "it",
34 | "its",
35 | "itself",
36 | "they",
37 | "them",
38 | "their",
39 | "theirs",
40 | "themselves",
41 | "what",
42 | "which",
43 | "who",
44 | "whom",
45 | "this",
46 | "that",
47 | "these",
48 | "those",
49 | "am",
50 | "is",
51 | "are",
52 | "was",
53 | "were",
54 | "be",
55 | "been",
56 | "being",
57 | "have",
58 | "has",
59 | "had",
60 | "having",
61 | "do",
62 | "does",
63 | "did",
64 | "doing",
65 | "a",
66 | "an",
67 | "the",
68 | "and",
69 | "but",
70 | "if",
71 | "or",
72 | "because",
73 | "as",
74 | "until",
75 | "while",
76 | "of",
77 | "at",
78 | "by",
79 | "for",
80 | "with",
81 | "about",
82 | "against",
83 | "between",
84 | "into",
85 | "through",
86 | "during",
87 | "before",
88 | "after",
89 | "above",
90 | "below",
91 | "to",
92 | "from",
93 | "up",
94 | "down",
95 | "in",
96 | "out",
97 | "on",
98 | "off",
99 | "over",
100 | "under",
101 | "again",
102 | "further",
103 | "then",
104 | "once",
105 | "here",
106 | "there",
107 | "when",
108 | "where",
109 | "why",
110 | "how",
111 | "all",
112 | "any",
113 | "both",
114 | "each",
115 | "few",
116 | "more",
117 | "most",
118 | "other",
119 | "some",
120 | "such",
121 | "no",
122 | "nor",
123 | "not",
124 | "only",
125 | "own",
126 | "same",
127 | "so",
128 | "than",
129 | "too",
130 | "very",
131 | "s",
132 | "t",
133 | "can",
134 | "will",
135 | "just",
136 | "don",
137 | "should",
138 | "now",
139 | "",
140 | ]
141 |
142 | # cleaning up text
143 | import re
144 |
145 |
146 | def get_only_chars(line):
147 |
148 | clean_line = ""
149 |
150 | line = line.replace("’", "")
151 | line = line.replace("'", "")
152 | line = line.replace("-", " ") # replace hyphens with spaces
153 | line = line.replace("\t", " ")
154 | line = line.replace("\n", " ")
155 | line = line.lower()
156 |
157 | for char in line:
158 | if char in "qwertyuiopasdfghjklzxcvbnm ":
159 | clean_line += char
160 | else:
161 | clean_line += " "
162 |
163 | clean_line = re.sub(" +", " ", clean_line) # delete extra spaces
164 | if clean_line[0] == " ":
165 | clean_line = clean_line[1:]
166 | return clean_line
167 |
168 |
169 | ########################################################################
170 | # Synonym replacement
171 | # Replace n words in the sentence with synonyms from wordnet
172 | ########################################################################
173 |
174 |
175 | # import nltk
176 | # nltk.download('wordnet')
177 | from nltk.corpus import wordnet
178 |
179 |
180 | def synonym_replacement(words, n):
181 | new_words = words.copy()
182 | random_word_list = list(set([word for word in words if word not in stop_words]))
183 | random.shuffle(random_word_list)
184 | num_replaced = 0
185 | for random_word in random_word_list:
186 | synonyms = get_synonyms(random_word)
187 | if len(synonyms) >= 1:
188 | synonym = random.choice(list(synonyms))
189 | new_words = [synonym if word == random_word else word for word in new_words]
190 | # print("replaced", random_word, "with", synonym)
191 | num_replaced += 1
192 | if num_replaced >= n: # only replace up to n words
193 | break
194 |
195 | # this is stupid but we need it, trust me
196 | sentence = " ".join(new_words)
197 | new_words = sentence.split(" ")
198 |
199 | return new_words
200 |
201 |
202 | def get_synonyms(word):
203 | synonyms = set()
204 | for syn in wordnet.synsets(word):
205 | for l in syn.lemmas():
206 | synonym = l.name().replace("_", " ").replace("-", " ").lower()
207 | synonym = "".join(
208 | [char for char in synonym if char in " qwertyuiopasdfghjklzxcvbnm"]
209 | )
210 | synonyms.add(synonym)
211 | if word in synonyms:
212 | synonyms.remove(word)
213 | return list(synonyms)
214 |
215 |
216 | ########################################################################
217 | # Random deletion
218 | # Randomly delete words from the sentence with probability p
219 | ########################################################################
220 |
221 |
222 | def random_deletion(words, p):
223 |
224 | # obviously, if there's only one word, don't delete it
225 | if len(words) == 1:
226 | return words
227 |
228 | # randomly delete words with probability p
229 | new_words = []
230 | for word in words:
231 | r = random.uniform(0, 1)
232 | if r > p:
233 | new_words.append(word)
234 |
235 | # if you end up deleting all words, just return a random word
236 | if len(new_words) == 0:
237 | rand_int = random.randint(0, len(words) - 1)
238 | return [words[rand_int]]
239 |
240 | return new_words
241 |
242 |
243 | ########################################################################
244 | # Random swap
245 | # Randomly swap two words in the sentence n times
246 | ########################################################################
247 |
248 |
249 | def random_swap(words, n):
250 | new_words = words.copy()
251 | for _ in range(n):
252 | new_words = swap_word(new_words)
253 | return new_words
254 |
255 |
256 | def swap_word(new_words):
257 | random_idx_1 = random.randint(0, len(new_words) - 1)
258 | random_idx_2 = random_idx_1
259 | counter = 0
260 | while random_idx_2 == random_idx_1:
261 | random_idx_2 = random.randint(0, len(new_words) - 1)
262 | counter += 1
263 | if counter > 3:
264 | return new_words
265 | new_words[random_idx_1], new_words[random_idx_2] = (
266 | new_words[random_idx_2],
267 | new_words[random_idx_1],
268 | )
269 | return new_words
270 |
271 |
272 | ########################################################################
273 | # Random insertion
274 | # Randomly insert n words into the sentence
275 | ########################################################################
276 |
277 |
278 | def random_insertion(words, n):
279 | new_words = words.copy()
280 | for _ in range(n):
281 | add_word(new_words)
282 | return new_words
283 |
284 |
285 | def add_word(new_words):
286 | synonyms = []
287 | counter = 0
288 | while len(synonyms) < 1:
289 | random_word = new_words[random.randint(0, len(new_words) - 1)]
290 | synonyms = get_synonyms(random_word)
291 | counter += 1
292 | if counter >= 10:
293 | return
294 | random_synonym = synonyms[0]
295 | random_idx = random.randint(0, len(new_words) - 1)
296 | new_words.insert(random_idx, random_synonym)
297 |
298 |
299 | ########################################################################
300 | # main data augmentation function
301 | ########################################################################
302 |
303 |
304 | def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):
305 | sentence = get_only_chars(sentence)
306 | words = sentence.split(" ")
307 | words = [word for word in words if word != ""]
308 | num_words = len(words)
309 |
310 | augmented_sentences = []
311 | num_new_per_technique = int(num_aug / 4) + 1
312 |
313 | # sr
314 | if alpha_sr > 0:
315 | n_sr = max(1, int(alpha_sr * num_words))
316 | for _ in range(num_new_per_technique):
317 | a_words = synonym_replacement(words, n_sr)
318 | augmented_sentences.append(" ".join(a_words))
319 |
320 | # ri
321 | if alpha_ri > 0:
322 | n_ri = max(1, int(alpha_ri * num_words))
323 | for _ in range(num_new_per_technique):
324 | a_words = random_insertion(words, n_ri)
325 | augmented_sentences.append(" ".join(a_words))
326 |
327 | # rs
328 | if alpha_rs > 0:
329 | n_rs = max(1, int(alpha_rs * num_words))
330 | for _ in range(num_new_per_technique):
331 | a_words = random_swap(words, n_rs)
332 | augmented_sentences.append(" ".join(a_words))
333 |
334 | # rd
335 | if p_rd > 0:
336 | for _ in range(num_new_per_technique):
337 | a_words = random_deletion(words, p_rd)
338 | augmented_sentences.append(" ".join(a_words))
339 |
340 | augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences]
341 | shuffle(augmented_sentences)
342 |
343 | # trim so that we have the desired number of augmented sentences
344 | if num_aug >= 1:
345 | augmented_sentences = augmented_sentences[:num_aug]
346 | else:
347 | keep_prob = num_aug / len(augmented_sentences)
348 | augmented_sentences = [
349 | s for s in augmented_sentences if random.uniform(0, 1) < keep_prob
350 | ]
351 |
352 | # DO NOT append the original sentence (we'll fetch it from another part)
353 | # during data loading
354 | # augmented_sentences.append(sentence)
355 |
356 | return augmented_sentences
357 |
--------------------------------------------------------------------------------
/utils/data_utils/hwu64_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | All HWU64 specific utilities are implemented here.
3 | """
4 | from haven import haven_utils as hu
5 | import os
6 | import pandas as pd
7 |
8 | INTENTS = [
9 | "alarm_query",
10 | "alarm_remove",
11 | "alarm_set",
12 | "audio_volume_down",
13 | "audio_volume_mute",
14 | "audio_volume_up",
15 | "calendar_query",
16 | "calendar_remove",
17 | "calendar_set",
18 | "cooking_recipe",
19 | "datetime_convert",
20 | "datetime_query",
21 | "email_addcontact",
22 | "email_query",
23 | "email_querycontact",
24 | "email_sendemail",
25 | "general_affirm",
26 | "general_commandstop",
27 | "general_confirm",
28 | "general_dontcare",
29 | "general_explain",
30 | "general_joke",
31 | "general_negate",
32 | "general_praise",
33 | "general_quirky",
34 | "general_repeat",
35 | "iot_cleaning",
36 | "iot_coffee",
37 | "iot_hue_lightchange",
38 | "iot_hue_lightdim",
39 | "iot_hue_lightoff",
40 | "iot_hue_lighton",
41 | "iot_hue_lightup",
42 | "iot_wemo_off",
43 | "iot_wemo_on",
44 | "lists_createoradd",
45 | "lists_query",
46 | "lists_remove",
47 | "music_likeness",
48 | "music_query",
49 | "music_settings",
50 | "news_query",
51 | "play_audiobook",
52 | "play_game",
53 | "play_music",
54 | "play_podcasts",
55 | "play_radio",
56 | "qa_currency",
57 | "qa_definition",
58 | "qa_factoid",
59 | "qa_maths",
60 | "qa_stock",
61 | "recommendation_events",
62 | "recommendation_locations",
63 | "recommendation_movies",
64 | "social_post",
65 | "social_query",
66 | "takeaway_order",
67 | "takeaway_query",
68 | "transport_query",
69 | "transport_taxi",
70 | "transport_ticket",
71 | "transport_traffic",
72 | "weather_query",
73 | ]
74 |
75 |
76 | class Hwu64:
77 | def __init__(self, name):
78 | path_base = "./data/dialoglue/data_utils/dialoglue/hwu"
79 |
80 | self.data_name = name
81 | self.full_path = f"./data/{name}/full/dataset.pkl"
82 | self.num_examples = 10
83 |
84 | # 1. Get Intent Mappings
85 | name2id = {k: i for i, k in enumerate(INTENTS)}
86 | hu.save_json(f"./data/{name}/name2id.json", name2id)
87 | id2name = {i: k for i, k in enumerate(INTENTS)}
88 | hu.save_json(f"./data/{name}/id2name.json", id2name)
89 |
90 | # 2. save dataset.pkl
91 | if not os.path.exists(self.full_path):
92 | # get data dict
93 | data_dict = {
94 | "train": get_data_dict(path_base, "train", name2id),
95 | "val": get_data_dict(path_base, "val", name2id),
96 | "test": get_data_dict(path_base, "test", name2id),
97 | }
98 | hu.save_pkl(self.full_path, data_dict)
99 | else:
100 | data_dict = hu.load_pkl(self.full_path)
101 |
102 | # intents = list(name2id.keys())
103 | # Groupp by intent
104 | self.dataset_by_intent = {}
105 | for split in ["train", "val", "test"]:
106 | self.dataset_by_intent[split] = {}
107 | text_intent_dict = data_dict[split]
108 | text_list, intent_list = (
109 | text_intent_dict["text"],
110 | map(lambda x: id2name[x], text_intent_dict["intent"]),
111 | )
112 |
113 | # get texts from intent
114 | intent2texts = {}
115 | for t, i in zip(text_list, intent_list):
116 | if i not in intent2texts:
117 | intent2texts[i] = []
118 | intent2texts[i] += [t]
119 | self.dataset_by_intent[split] = intent2texts
120 |
121 | self.domain_to_intent = self.generate_domain_to_intent_map()
122 | self.gpt3_batch_size: int = 128
123 |
124 | def parse_and_load(self):
125 | return self.dataset_by_intent
126 |
127 | def generate_domain_to_intent_map(self):
128 | mapping = {}
129 | for intent in INTENTS:
130 | # the intents are structured as {domain}_{intent}
131 | # E.g. the domain is alarm and intent is query for alarm_query
132 | domain = intent.split("_", 1)[0]
133 | if domain not in mapping:
134 | mapping[domain] = []
135 | mapping[domain].append(intent)
136 | return mapping
137 |
138 |
139 | def get_data_dict(path_base, split, name2id):
140 | tmp_dict = pd.read_csv(os.path.join(path_base, f"{split}.csv")).to_dict()
141 |
142 | data_dict = {}
143 | data_dict["text"] = list(tmp_dict["text"].values())
144 | data_dict["intent"] = [int(name2id[c]) for c in tmp_dict["category"].values()]
145 | return data_dict
146 |
--------------------------------------------------------------------------------
/utils/data_utils/main.py:
--------------------------------------------------------------------------------
1 | import os, json, hashlib, pickle
2 | from utils.data_utils import clinc_utils, snips_utils, banking77_utils, hwu64_utils
3 |
4 | # some handy aliases
5 | pjoin = os.path.join
6 |
7 | write_json = lambda obj, path: json.dump(obj, open(path, "w"))
8 | read_json = lambda path: json.load(open(path, "r"))
9 |
10 | write_pickle = lambda obj, path: pickle.dump(obj, open(path, "wb"))
11 | read_pickle = lambda path: pickle.load(open(path, "rb"))
12 |
13 | write_file = lambda content, path: open(path, "w").write(content)
14 | read_file = lambda path: open(path, "r").read()
15 |
16 |
17 | def get_hash(x):
18 | return int(hashlib.sha1(x.encode("utf-8")).hexdigest(), 16)
19 |
20 |
21 | def get_label_maps(data_root, data_name):
22 | with open(pjoin(data_root, data_name, "name2id.json"), "r") as f:
23 | name2id = json.load(f)
24 | with open(pjoin(data_root, data_name, "id2name.json"), "r") as f:
25 | id2name = json.load(f)
26 | return name2id, id2name
27 |
28 |
29 | def get_ds_config(ds_name):
30 | # This code is used for "prepare_dataset.py"
31 | if ds_name == "clinc_oos":
32 | return clinc_utils.DS_CONFIG()
33 | elif ds_name == "snips_official":
34 | return snips_utils.DS_CONFIG()
35 | elif ds_name == "banking77":
36 | return banking77_utils.Banking77(ds_name)
37 | elif ds_name == "hwu64":
38 | return hwu64_utils.Hwu64(ds_name)
39 | else:
40 | raise NotImplementedError(f"Dataset {ds_name} not supported")
41 |
42 |
43 | def truncate_data(data, num_examples):
44 | """Used to simulate the few-shot setting for validation intents."""
45 | return sorted(data, key=get_hash)[:num_examples]
46 |
47 |
48 | def parse_and_load(dataset_name):
49 | # This code is used for "prepare_dataset.py"
50 | if dataset_name == "clinc_oos":
51 | return clinc_utils.parse_and_load_clinc()
52 | elif dataset_name == "snips_official":
53 | return snips_utils.parse_and_load_snips()
54 | elif dataset_name == "banking77":
55 | ds = banking77_utils.Banking77(dataset_name)
56 | return ds.parse_and_load()
57 | elif dataset_name == "hwu64":
58 | ds = hwu64_utils.Hwu64(dataset_name)
59 | return ds.parse_and_load()
60 | else:
61 | raise Exception(f"Dataset {dataset_name} not supported")
62 |
--------------------------------------------------------------------------------
/utils/data_utils/sample_few_shot.py:
--------------------------------------------------------------------------------
1 | """
2 | This script samples K(default=10) exemplers per intent class in the few-shot
3 | slice using the heuristic used in the EX2 paper:
4 | https://arxiv.org/pdf/2102.01335.pdf
5 |
6 | Since we will do cross-validation, every domain in the SNIPS dataset will be
7 | treated as a few-shot slice once.
8 | See https://github.com/clinc/oos-eval/blob/master/supplementary.pdf.
9 | """
10 | import os
11 | import pickle
12 |
13 | from utils.data_utils.main import get_label_maps, truncate_data, parse_and_load
14 |
15 | pjoin = os.path.join
16 |
17 |
18 | def add_fs_slice(data_splits, val_domain, data_root, ds_config):
19 | DOMAIN_TO_INTENT = ds_config.domain_to_intent
20 | few_shot_intents = set(DOMAIN_TO_INTENT[val_domain])
21 |
22 | # load data grouped by intent, will also prepare the basic dataset.pkl
23 | data = parse_and_load(ds_config.data_name)
24 | name2id, id2name = get_label_maps(data_root, ds_config.data_name)
25 |
26 | # NOTE: `data` has oos_train, oos_val, and oos_test keys as well.
27 | # ignoring them here. since we don't perform a ex2-setup run for the OOS
28 | # domain, OOS samples are added when required in data_loader.py
29 | train_data = data["train"]
30 | val_data = data["val"]
31 | test_data = data["test"]
32 |
33 | # construct D_M and D_F (with respective, train, val, test sets)
34 | # MANY SHOT INTENTS
35 | # -----------------
36 | m_train_lines, m_train_intents = [], []
37 | for intent, lines in train_data.items():
38 | if intent not in few_shot_intents:
39 | m_train_lines.extend(lines)
40 | m_train_intents.extend([name2id[intent]] * len(lines))
41 |
42 | m_val_lines, m_val_intents = [], []
43 | for intent, lines in val_data.items():
44 | if intent not in few_shot_intents:
45 | m_val_lines.extend(lines)
46 | m_val_intents.extend([name2id[intent]] * len(lines))
47 |
48 | m_test_lines, m_test_intents = [], []
49 | for intent, lines in test_data.items():
50 | if intent not in few_shot_intents:
51 | m_test_lines.extend(lines)
52 | m_test_intents.extend([name2id[intent]] * len(lines))
53 |
54 | # FEW SHOT INTENTS
55 | # ----------------
56 | f_train_lines, f_train_intents = [], []
57 | for intent, lines in train_data.items():
58 | if intent in few_shot_intents:
59 | f_train_lines.extend(truncate_data(lines, ds_config.num_examples))
60 | f_train_intents.extend([name2id[intent]] * ds_config.num_examples)
61 |
62 | f_val_lines, f_val_intents = [], []
63 | for intent, lines in val_data.items():
64 | if intent in few_shot_intents:
65 | f_val_lines.extend(lines)
66 | f_val_intents.extend([name2id[intent]] * len(lines))
67 |
68 | f_test_lines, f_test_intents = [], []
69 | for intent, lines in test_data.items():
70 | if intent in few_shot_intents:
71 | f_test_lines.extend(lines)
72 | f_test_intents.extend([name2id[intent]] * len(lines))
73 |
74 | data_splits[val_domain] = {}
75 |
76 | # add the many-shot split
77 | data_splits[val_domain]["M"] = {
78 | "train": {
79 | "text": m_train_lines,
80 | "intent": m_train_intents,
81 | },
82 | "val": {
83 | "text": m_val_lines,
84 | "intent": m_val_intents,
85 | },
86 | "test": {
87 | "text": m_test_lines,
88 | "intent": m_test_intents,
89 | },
90 | }
91 | # add the few-shot split
92 | data_splits[val_domain]["F"] = {
93 | "train": {
94 | "text": f_train_lines,
95 | "intent": f_train_intents,
96 | },
97 | "val": {
98 | "text": f_val_lines,
99 | "intent": f_val_intents,
100 | },
101 | "test": {
102 | "text": f_test_lines,
103 | "intent": f_test_intents,
104 | },
105 | }
106 |
107 |
108 | def sample_few_shot(data_root, ds_config):
109 | save_path = pjoin(data_root, ds_config.data_name, "full", "data_full_suite.pkl")
110 | if os.path.exists(save_path):
111 | print(f"Sample few shot has already been called for {ds_config.data_name}")
112 | return
113 | data_splits = {}
114 | DOMAIN_TO_INTENT = ds_config.domain_to_intent.keys()
115 |
116 | for val_domain in DOMAIN_TO_INTENT:
117 | # in-place modifies data_splits
118 | add_fs_slice(data_splits, val_domain, data_root, ds_config)
119 |
120 | # note that this pkl file will be updated
121 | with open(save_path, "wb") as f:
122 | pickle.dump(data_splits, f)
123 | return data_splits
124 |
--------------------------------------------------------------------------------
/utils/data_utils/snips_utils.py:
--------------------------------------------------------------------------------
1 | # prepare SNIPS dataset for intent classification
2 | import os
3 | import json
4 | import pickle
5 | import collections
6 | from typing import Dict
7 | from pydantic import BaseModel
8 |
9 |
10 | class DS_CONFIG(BaseModel):
11 | data_name: str = "snips_official"
12 | full_path: str = "./data/snips_official/full/dataset.pkl"
13 | num_examples: int = 10
14 | domain_to_intent: Dict = {
15 | "AddToPlaylist": ["AddToPlaylist"],
16 | "BookRestaurant": ["BookRestaurant"],
17 | "GetWeather": ["GetWeather"],
18 | "PlayMusic": ["PlayMusic"],
19 | "RateBook": ["RateBook"],
20 | "SearchCreativeWork": ["SearchCreativeWork"],
21 | "SearchScreeningEvent": ["SearchScreeningEvent"],
22 | }
23 | gpt3_batch_size: int = 128
24 |
25 |
26 | SNIPS_DIR = './data/snips_official/'
27 |
28 |
29 | def make_label_maps():
30 | with open(os.path.join(SNIPS_DIR, 'full', 'train', 'label')) as f:
31 | labels = sorted(set([x.strip() for x in f.readlines()]))
32 | intent_name2id = {label: idx for idx, label in enumerate(labels)}
33 | intent_id2name = {idx: label for idx, label in enumerate(labels)}
34 | with open(os.path.join(SNIPS_DIR, 'name2id.json'), 'w') as f:
35 | json.dump(intent_name2id, f)
36 | with open(os.path.join(SNIPS_DIR, 'id2name.json'), 'w') as f:
37 | json.dump(intent_id2name, f)
38 |
39 |
40 | def download_snips_full():
41 | # NOTE: the official SNIPS repo uses valid, train, and test as split names
42 | for split in ['train', 'valid', 'test']:
43 | download_snips_files(split)
44 |
45 | def download_snips_files(split):
46 | print(f'Downloading SNIPS {split} files')
47 | snips_base_url = f'https://raw.githubusercontent.com/MiuLab/SlotGated-SLU/master/data/snips/{split}/'
48 | download_dir = os.path.join(SNIPS_DIR, 'full', split)
49 | if not os.path.exists(download_dir):
50 | os.makedirs(download_dir)
51 | os.system(f'wget {snips_base_url + "seq.in"} -P {download_dir}')
52 | os.system(f'wget {snips_base_url + "label"} -P {download_dir}')
53 |
54 |
55 | def parse_snips_split(dataset, split, name2id_map):
56 | with open(os.path.join(SNIPS_DIR, 'full', split, 'seq.in')) as f:
57 | sentences = [x.strip() for x in f.readlines()]
58 | with open(os.path.join(SNIPS_DIR, 'full', split, 'label')) as f:
59 | labels = [name2id_map[x.strip()] for x in f.readlines()]
60 |
61 | # just trying to be consistent with naming splits across diff datasets
62 | dataset['val' if split == 'valid' else split] = {
63 | 'text': sentences,
64 | 'intent': labels
65 | }
66 |
67 |
68 | def prepare_snips():
69 | download_snips_full()
70 | make_label_maps()
71 |
72 | dataset = {}
73 | name2id_map = json.load(open(os.path.join(SNIPS_DIR, 'name2id.json')))
74 | for split in ['train', 'valid', 'test']:
75 | parse_snips_split(dataset, split, name2id_map)
76 |
77 | with open(os.path.join(SNIPS_DIR, 'full', 'dataset.pkl'), 'wb') as f:
78 | pickle.dump(dataset, f)
79 | print('Base dataset.pkl prepared for SNIPS')
80 |
81 |
82 | def check_snips():
83 | with open(os.path.join(SNIPS_DIR, 'full', 'dataset.pkl'), 'rb') as f:
84 | dataset = pickle.load(f)
85 | print(len(dataset['train']['text']))
86 | print(len(dataset['val']['text']))
87 | print(len(dataset['test']['text']))
88 |
89 |
90 | def load_snips():
91 | print('Loading SNIPS dataset')
92 | id2name = json.load(open(os.path.join(SNIPS_DIR, 'id2name.json')))
93 | data_full_path = DS_CONFIG().full_path
94 | data = collections.defaultdict(lambda: collections.defaultdict(list))
95 | with open(data_full_path, 'rb') as data_file:
96 | for split_name, split_data in pickle.load(data_file).items():
97 | for query, intent in zip(split_data['text'], split_data['intent']):
98 | intent = id2name[str(intent)]
99 | data[split_name][intent].append(query)
100 | return data
101 |
102 |
103 | def parse_and_load_snips():
104 | if not os.path.exists(os.path.join(SNIPS_DIR, 'full', 'dataset.pkl')):
105 | prepare_snips()
106 | check_snips()
107 | return load_snips()
108 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import json, re, os, numpy as np
2 |
3 | from datasets import load_metric
4 | from utils import main_data_utils as mdu
5 | from sklearn.metrics import confusion_matrix
6 |
7 |
8 | class Metrics:
9 | def __init__(self, exp_dict=None, tokenizer=None):
10 | self.exp_dict = exp_dict
11 | self.tokenizer = tokenizer
12 |
13 | def compute_metrics(self):
14 | """
15 | Will choose the appropriate metric computer based on the config
16 | """
17 | if "bert" in self.exp_dict["model"]["backbone"]:
18 | return self.compute_metrics_bert
19 | raise ValueError(f"Incompatible backbone {self.exp_dict['model']['backbone']}.")
20 |
21 | def compute_metrics_bert(self, eval_pred):
22 | predictions, labels = eval_pred
23 | predictions = np.argmax(predictions, axis=1)
24 | metrics = {}
25 | # sort, so accuracy is evaluated first, always
26 | for metric in sorted(self.exp_dict["metrics"]):
27 | _metric = eval(f"self.{metric}")
28 | metrics.update(_metric(predictions, labels))
29 | return metrics
30 |
31 | def accuracy(self, predictions, labels):
32 | accuracies = {}
33 | acc = load_metric("accuracy")
34 | accuracies.update(acc.compute(predictions=predictions, references=labels))
35 | oos_id = self.exp_dict["dataset"]["oos_id"]
36 | if oos_id is not None:
37 | # compute in_scope accuracy as well
38 | inscope_preds, inscope_labels = [], []
39 | for idx in range(len(labels)):
40 | if labels[idx] == oos_id:
41 | continue
42 | inscope_labels.append(labels[idx])
43 | inscope_preds.append(predictions[idx])
44 | self.inscope_preds, self.inscope_labels = inscope_preds, inscope_labels
45 | accuracies["inscope_accuracy"] = acc.compute(
46 | predictions=inscope_preds, references=inscope_labels
47 | )["accuracy"]
48 | return accuracies
49 |
50 | def f1(self, predictions, labels):
51 | f1s = {}
52 | f1 = load_metric("f1")
53 | f1s.update(
54 | f1.compute(predictions=predictions, references=labels, average="macro")
55 | )
56 | if self.exp_dict["dataset"]["oos_id"] is not None:
57 | f1s["inscope_f1"] = f1.compute(
58 | predictions=self.inscope_preds,
59 | references=self.inscope_labels,
60 | average="macro",
61 | )["f1"]
62 | return f1s
63 |
64 | def precision(self, predictions, labels):
65 | precision = load_metric("precision")
66 | return precision.compute(
67 | predictions=predictions, references=labels, average="macro"
68 | )
69 |
70 | def recall(self, predictions, labels):
71 | recalls = {}
72 | recall = load_metric("recall")
73 | recalls.update(
74 | recall.compute(predictions=predictions, references=labels, average="macro")
75 | )
76 | oos_id = self.exp_dict["dataset"]["oos_id"]
77 | if oos_id is not None and oos_id in labels:
78 | # compute OOS recall
79 | outscope_preds = []
80 | for idx in range(len(labels)):
81 | if labels[idx] == oos_id:
82 | outscope_preds.append(1 if predictions[idx] == oos_id else -1)
83 | recalls["oos_recall"] = outscope_preds.count(1) / len(outscope_preds)
84 | return recalls
85 |
86 | def confusion_matrix(self, predictions, references):
87 | return {
88 | "confusion_matrix": confusion_matrix(
89 | y_true=references,
90 | y_pred=predictions,
91 | labels=list(range(self.exp_dict["dataset"]["num_labels"])),
92 | )
93 | }
94 |
95 | def compute_fidelities(self, ds):
96 | """
97 | Returns and saves fidelities for all engines for given a barebone exp_dict
98 | containing dataset name, num_labels, few_pure setting
99 | its number of classes
100 | Example schema of the returned result
101 | eda: Float
102 | ada_1.0: Float
103 | .
104 | .
105 | .
106 | gptj_2.0: Float
107 | NOTE: al_dataset.pkl must've been generated already for the dataset
108 | """
109 | al_ds_path = mdu.pjoin("data", ds, "full", "al_dataset.pkl")
110 | if not os.path.exists(al_ds_path):
111 | print("Oracle relabelling hasn't been done on the generated samples yet")
112 | print("Go run runners.oracle_relabel first")
113 | return
114 | generated_ds = mdu.read_pickle(al_ds_path)["generated"]
115 | # path to save this dataset's fidelity results
116 | results_path = f"results/{ds}_fidelity.json"
117 | oos_id = mdu.read_json(f"data/{ds}/name2id.json").get("oos")
118 | if os.path.exists(results_path):
119 | print(f"Loading {results_path} that already exists!")
120 | print(f"Delete/Rename it to compute fidelity for {ds} again")
121 | return mdu.read_json(results_path)
122 |
123 | print(f"Computing fidelity numbers for {ds}")
124 | results = {}
125 | for engine, samples in generated_ds.items():
126 | if engine == "val":
127 | engine = "threshold"
128 | _a = [
129 | 1 if old == new else 0
130 | for old, new in zip(samples["old_intent"], samples["intent"])
131 | ]
132 | results[engine] = np.mean(_a)
133 |
134 | print(f"Saving fidelity numbers for {ds}")
135 | mdu.write_json(results, results_path)
136 | return results
137 |
--------------------------------------------------------------------------------