├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY-LICENSES ├── data ├── README.md ├── anchor_spans.py ├── convert_spans.py ├── convert_to_allennlp.py ├── irl-annotations.jsonl ├── requirements-process-data.txt ├── requirements-train-model.txt ├── run_process_data.sh ├── run_train_model.sh └── splits.json ├── downstream ├── README.md ├── encoders │ ├── __init__.py │ ├── bert_encoder.py │ └── sbert_encoder.py ├── gen_labelname.py ├── protaugment │ ├── paraphrase │ │ ├── modeling.py │ │ └── utils │ │ │ └── data.py │ └── utils │ │ ├── __init__.py │ │ ├── data.py │ │ ├── few_shot.py │ │ ├── math.py │ │ └── python.py ├── run_eval.py ├── run_protaugment.py ├── run_protonet.py ├── scripts │ ├── run_eval_after_ft.sh │ ├── run_eval_before_ft.sh │ ├── run_protaugment.sh │ └── run_protonet.sh └── utils │ ├── __init__.py │ ├── dataloader.py │ └── train_utils.py ├── pretraining ├── README.md ├── iae │ ├── __init__.py │ ├── contrastive_learning.py │ ├── data_loader.py │ ├── drophead.py │ └── iae_model.py ├── preprocess │ ├── create_pretrain_dataset.py │ ├── irl │ │ ├── __init__.py │ │ ├── data.py │ │ ├── intent_role_labelers.py │ │ ├── irl_tagger.py │ │ └── utils.py │ ├── load_data.py │ └── utils.py ├── run_eval.py ├── run_pretrain.py ├── scripts │ ├── create_dataset.sh │ ├── run_eval.sh │ └── run_pretrain.sh └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | output 3 | models 4 | !downstream/protaugment/models -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04 2 | 3 | ARG PYTHON=python3.8 4 | 5 | SHELL ["/bin/bash", "-c"] 6 | 7 | ENV TZ=Asia/Kolkata \ 8 | DEBIAN_FRONTEND=noninteractive 9 | 10 | RUN apt-get update && apt-get install -y --no-install-recommends \ 11 | build-essential \ 12 | cmake \ 13 | git \ 14 | curl \ 15 | ca-certificates \ 16 | libjpeg-dev \ 17 | libatlas-base-dev \ 18 | libcurl4-openssl-dev \ 19 | libgomp1 \ 20 | libopencv-dev \ 21 | openssh-client \ 22 | openssh-server \ 23 | wget \ 24 | vim \ 25 | libpng-dev && \ 26 | rm -rf /var/lib/apt/lists/* 27 | 28 | RUN apt-get update && apt-get install -y --no-install-recommends \ 29 | libreadline-gplv2-dev \ 30 | libncursesw5-dev \ 31 | libssl-dev \ 32 | libsqlite3-dev \ 33 | tk-dev \ 34 | libgdbm-dev \ 35 | libc6-dev \ 36 | libbz2-dev 37 | 38 | ## Install python3.8 39 | RUN apt-get update 40 | RUN apt-get --yes install software-properties-common 41 | RUN add-apt-repository ppa:deadsnakes/ppa 42 | RUN apt-get update 43 | RUN apt-get --yes install ${PYTHON} 44 | RUN apt-get --yes install ${PYTHON}-distutils 45 | RUN apt-get --yes install ${PYTHON}-dev 46 | RUN ln -s $(which ${PYTHON}) /usr/local/bin/python 47 | 48 | ## Install pip for python3.8 49 | RUN wget https://bootstrap.pypa.io/get-pip.py 50 | RUN python get-pip.py 51 | RUN ln -s /usr/local/bin/pip3 /usr/bin/pip 52 | RUN pip --no-cache-dir install --upgrade pip setuptools 53 | 54 | RUN apt-get install -y jq 55 | 56 | ENV PYTHONDONTWRITEBYTECODE=1 \ 57 | PYTHONUNBUFFERED=1 \ 58 | LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda-11.1/lib64:/usr/local/cuda-11.1/extras/CUPTI/lib64:/usr/local/cuda-11.0/lib:/usr/lib64/openmpi/lib/:/usr/local/lib:/usr/lib:/usr/local/mpi/lib:/lib/:" \ 59 | PYTHONIOENCODING=UTF-8 \ 60 | LANG=C.UTF-8 \ 61 | LC_ALL=C.UTF-8 62 | 63 | RUN pip install --no-cache --upgrade \ 64 | transformers==4.18.0 \ 65 | numpy==1.23.0 \ 66 | torch==1.11.0 \ 67 | PyYAML==6.0 \ 68 | regex==2022.6.2 \ 69 | tqdm==4.64.0 \ 70 | tensorboardX==2.5.1 \ 71 | sentencepiece==0.1.96 \ 72 | nltk==3.7 \ 73 | pytorch-metric-learning==1.5.0 \ 74 | sentence-transformers==2.2.2 \ 75 | sacrebleu==2.1.0 \ 76 | allennlp==2.9.3 \ 77 | cached-path==1.1.2 \ 78 | https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.2.0/en_core_web_md-3.2.0.tar.gz#egg=en_core_web_md 79 | 80 | WORKDIR / 81 | 82 | SHELL ["/bin/bash", "-c"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🥧 Pre-Training Intent-Aware Encoders 2 | 3 | This repository is used for **P**re-training **I**ntent-Aware **E**ncoders (PIE) and evaluating on four intent classification datasets 4 | ([BANKING77](https://arxiv.org/abs/2003.04807), [HWU64](https://arxiv.org/abs/1903.05566), 5 | [Liu54](https://arxiv.org/abs/1903.05566), and [CLINC150](https://aclanthology.org/D19-1131/)). 6 | 7 | ## Environment setup 8 | 9 | ### Option 1: Docker 10 | ``` 11 | image_name=pie 12 | code_path=/path/to/intent-aware-encoder 13 | docker build -t $image_name . 14 | nvidia-docker run -it -v ${code_path}:/code $image_name 15 | cd code 16 | ``` 17 | 18 | ### Option 2: Conda 19 | ``` 20 | conda create -n pie python=3.8 21 | conda activate pie 22 | pip install -r requirements.txt 23 | python -m spacy download en_core_web_md 24 | ``` 25 | 26 | ## Pre-training 27 | See the readme in the `pretraining` directory. 28 | ``` 29 | cd pretraining 30 | ``` 31 | 32 | ## Fine-tuning and Evaluation 33 | See the readme in the `downstream` directory. 34 | ``` 35 | cd downstream 36 | ``` 37 | 38 | ## Acknowledgement 39 | Parts of the code are modified from [mirror-bert](https://github.com/cambridgeltl/mirror-bert), [IDML](https://github.com/microsoft/KC/tree/main/papers/IDML), and [ProtAugment](https://github.com/tdopierre/ProtAugment). We appreciate the authors for open sourcing their projects. 40 | 41 | ## Security 42 | 43 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 44 | 45 | ## License 46 | 47 | This project is licensed under the Apache-2.0 License. 48 | 49 | Please cite the following paper if using the code or data from this project in your work: 50 | ```bibtex 51 | @misc{sung2023pretraining, 52 | title={Pre-training Intent-Aware Encoders for Zero- and Few-Shot Intent Classification}, 53 | author={Mujeen Sung and James Gung and Elman Mansimov and Nikolaos Pappas and Raphael Shu and Salvatore Romeo and Yi Zhang and Vittorio Castelli}, 54 | year={2023}, 55 | eprint={2305.14827}, 56 | archivePrefix={arXiv}, 57 | primaryClass={cs.CL} 58 | } 59 | ``` -------------------------------------------------------------------------------- /THIRD-PARTY-LICENSES: -------------------------------------------------------------------------------- 1 | This Intent Aware Encoder project includes the following third-party software/licensing: 2 | 3 | ** Mirror-BERT- https://github.com/cambridgeltl/mirror-bert 4 | Copyright (c) 2021 Cambridge Language Technology Lab 5 | ** IDML - https://github.com/microsoft/KC/tree/main/papers/IDML 6 | Copyright (c) Microsoft Corporation. 7 | ** DropHead - https://github.com/Kirill-Kravtsov/drophead-pytorch 8 | Copyright (c) 2020 Kirill Kravtsov 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | ---------------- 17 | 18 | ** paraphrastic-representations-at-scale - https://github.com/jwieting/paraphrastic-representations-at-scale 19 | Copyright (c) 2020, John Wieting 20 | All rights reserved. 21 | 22 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 23 | 24 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 25 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 26 | Neither the name of the author nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 27 | 28 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | ---------------- 31 | 32 | ** ProtAugment - https://github.com/tdopierre/ProtAugment 33 | Copyright (c) 2021, Thomas Dopierre 34 | 35 | Apache License 36 | Version 2.0, January 2004 37 | http://www.apache.org/licenses/ 38 | 39 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 40 | 41 | 1. Definitions. 42 | 43 | "License" shall mean the terms and conditions for use, reproduction, 44 | and distribution as defined by Sections 1 through 9 of this document. 45 | 46 | "Licensor" shall mean the copyright owner or entity authorized by 47 | the copyright owner that is granting the License. 48 | 49 | "Legal Entity" shall mean the union of the acting entity and all 50 | other entities that control, are controlled by, or are under common 51 | control with that entity. For the purposes of this definition, 52 | "control" means (i) the power, direct or indirect, to cause the 53 | direction or management of such entity, whether by contract or 54 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 55 | outstanding shares, or (iii) beneficial ownership of such entity. 56 | 57 | "You" (or "Your") shall mean an individual or Legal Entity 58 | exercising permissions granted by this License. 59 | 60 | "Source" form shall mean the preferred form for making modifications, 61 | including but not limited to software source code, documentation 62 | source, and configuration files. 63 | 64 | "Object" form shall mean any form resulting from mechanical 65 | transformation or translation of a Source form, including but 66 | not limited to compiled object code, generated documentation, 67 | and conversions to other media types. 68 | 69 | "Work" shall mean the work of authorship, whether in Source or 70 | Object form, made available under the License, as indicated by a 71 | copyright notice that is included in or attached to the work 72 | (an example is provided in the Appendix below). 73 | 74 | "Derivative Works" shall mean any work, whether in Source or Object 75 | form, that is based on (or derived from) the Work and for which the 76 | editorial revisions, annotations, elaborations, or other modifications 77 | represent, as a whole, an original work of authorship. For the purposes 78 | of this License, Derivative Works shall not include works that remain 79 | separable from, or merely link (or bind by name) to the interfaces of, 80 | the Work and Derivative Works thereof. 81 | 82 | "Contribution" shall mean any work of authorship, including 83 | the original version of the Work and any modifications or additions 84 | to that Work or Derivative Works thereof, that is intentionally 85 | submitted to Licensor for inclusion in the Work by the copyright owner 86 | or by an individual or Legal Entity authorized to submit on behalf of 87 | the copyright owner. For the purposes of this definition, "submitted" 88 | means any form of electronic, verbal, or written communication sent 89 | to the Licensor or its representatives, including but not limited to 90 | communication on electronic mailing lists, source code control systems, 91 | and issue tracking systems that are managed by, or on behalf of, the 92 | Licensor for the purpose of discussing and improving the Work, but 93 | excluding communication that is conspicuously marked or otherwise 94 | designated in writing by the copyright owner as "Not a Contribution." 95 | 96 | "Contributor" shall mean Licensor and any individual or Legal Entity 97 | on behalf of whom a Contribution has been received by Licensor and 98 | subsequently incorporated within the Work. 99 | 100 | 2. Grant of Copyright License. Subject to the terms and conditions of 101 | this License, each Contributor hereby grants to You a perpetual, 102 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 103 | copyright license to reproduce, prepare Derivative Works of, 104 | publicly display, publicly perform, sublicense, and distribute the 105 | Work and such Derivative Works in Source or Object form. 106 | 107 | 3. Grant of Patent License. Subject to the terms and conditions of 108 | this License, each Contributor hereby grants to You a perpetual, 109 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 110 | (except as stated in this section) patent license to make, have made, 111 | use, offer to sell, sell, import, and otherwise transfer the Work, 112 | where such license applies only to those patent claims licensable 113 | by such Contributor that are necessarily infringed by their 114 | Contribution(s) alone or by combination of their Contribution(s) 115 | with the Work to which such Contribution(s) was submitted. If You 116 | institute patent litigation against any entity (including a 117 | cross-claim or counterclaim in a lawsuit) alleging that the Work 118 | or a Contribution incorporated within the Work constitutes direct 119 | or contributory patent infringement, then any patent licenses 120 | granted to You under this License for that Work shall terminate 121 | as of the date such litigation is filed. 122 | 123 | 4. Redistribution. You may reproduce and distribute copies of the 124 | Work or Derivative Works thereof in any medium, with or without 125 | modifications, and in Source or Object form, provided that You 126 | meet the following conditions: 127 | 128 | (a) You must give any other recipients of the Work or 129 | Derivative Works a copy of this License; and 130 | 131 | (b) You must cause any modified files to carry prominent notices 132 | stating that You changed the files; and 133 | 134 | (c) You must retain, in the Source form of any Derivative Works 135 | that You distribute, all copyright, patent, trademark, and 136 | attribution notices from the Source form of the Work, 137 | excluding those notices that do not pertain to any part of 138 | the Derivative Works; and 139 | 140 | (d) If the Work includes a "NOTICE" text file as part of its 141 | distribution, then any Derivative Works that You distribute must 142 | include a readable copy of the attribution notices contained 143 | within such NOTICE file, excluding those notices that do not 144 | pertain to any part of the Derivative Works, in at least one 145 | of the following places: within a NOTICE text file distributed 146 | as part of the Derivative Works; within the Source form or 147 | documentation, if provided along with the Derivative Works; or, 148 | within a display generated by the Derivative Works, if and 149 | wherever such third-party notices normally appear. The contents 150 | of the NOTICE file are for informational purposes only and 151 | do not modify the License. You may add Your own attribution 152 | notices within Derivative Works that You distribute, alongside 153 | or as an addendum to the NOTICE text from the Work, provided 154 | that such additional attribution notices cannot be construed 155 | as modifying the License. 156 | 157 | You may add Your own copyright statement to Your modifications and 158 | may provide additional or different license terms and conditions 159 | for use, reproduction, or distribution of Your modifications, or 160 | for any such Derivative Works as a whole, provided Your use, 161 | reproduction, and distribution of the Work otherwise complies with 162 | the conditions stated in this License. 163 | 164 | 5. Submission of Contributions. Unless You explicitly state otherwise, 165 | any Contribution intentionally submitted for inclusion in the Work 166 | by You to the Licensor shall be under the terms and conditions of 167 | this License, without any additional terms or conditions. 168 | Notwithstanding the above, nothing herein shall supersede or modify 169 | the terms of any separate license agreement you may have executed 170 | with Licensor regarding such Contributions. 171 | 172 | 6. Trademarks. This License does not grant permission to use the trade 173 | names, trademarks, service marks, or product names of the Licensor, 174 | except as required for reasonable and customary use in describing the 175 | origin of the Work and reproducing the content of the NOTICE file. 176 | 177 | 7. Disclaimer of Warranty. Unless required by applicable law or 178 | agreed to in writing, Licensor provides the Work (and each 179 | Contributor provides its Contributions) on an "AS IS" BASIS, 180 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 181 | implied, including, without limitation, any warranties or conditions 182 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 183 | PARTICULAR PURPOSE. You are solely responsible for determining the 184 | appropriateness of using or redistributing the Work and assume any 185 | risks associated with Your exercise of permissions under this License. 186 | 187 | 8. Limitation of Liability. In no event and under no legal theory, 188 | whether in tort (including negligence), contract, or otherwise, 189 | unless required by applicable law (such as deliberate and grossly 190 | negligent acts) or agreed to in writing, shall any Contributor be 191 | liable to You for damages, including any direct, indirect, special, 192 | incidental, or consequential damages of any character arising as a 193 | result of this License or out of the use or inability to use the 194 | Work (including but not limited to damages for loss of goodwill, 195 | work stoppage, computer failure or malfunction, or any and all 196 | other commercial damages or losses), even if such Contributor 197 | has been advised of the possibility of such damages. 198 | 199 | 9. Accepting Warranty or Additional Liability. While redistributing 200 | the Work or Derivative Works thereof, You may choose to offer, 201 | and charge a fee for, acceptance of support, warranty, indemnity, 202 | or other liability obligations and/or rights consistent with this 203 | License. However, in accepting such obligations, You may act only 204 | on Your own behalf and on Your sole responsibility, not on behalf 205 | of any other Contributor, and only if You agree to indemnify, 206 | defend, and hold each Contributor harmless for any liability 207 | incurred by, or claims asserted against, such Contributor by reason 208 | of your accepting any such warranty or additional liability. 209 | 210 | END OF TERMS AND CONDITIONS 211 | 212 | * For Apache Commons IO see also this required NOTICE: 213 | Apache Commons IO 214 | Copyright 2002-2014 The Apache Software Foundation 215 | 216 | This product includes software developed at 217 | The Apache Software Foundation (http://www.apache.org/). 218 | 219 | * For Apache Commons Lang see also this required NOTICE: 220 | Apache Commons Lang 221 | Copyright 2001-2013 The Apache Software Foundation 222 | 223 | This product includes software developed at 224 | The Apache Software Foundation (http://www.apache.org/). 225 | 226 | This product includes software from the Spring Framework, 227 | under the Apache License 2.0 (see: StringUtils.containsWhitespace()) -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Intent Role Label Annotations 2 | This directory contains annotations used for training an intent role labeling (IRL) model. 3 | 4 | * Annotation span offsets are provided in the `irl-annotations.jsonl` file 5 | * The train/dev/test split used in the paper is provided in `splits.json` 6 | 7 | ## Prerequisites 8 | * Python ≥ 3.9 9 | 10 | ## Data Processing 11 | To generate annotations with corresponding text from the 12 | [SGD](https://github.com/google-research-datasets/dstc8-schema-guided-dialogue) 13 | dataset, tokenize and create a train/dev/test split, run the script [run_process_data.sh](run_process_data.sh): 14 | ```bash 15 | cd data 16 | bash run_process_data.sh 17 | ``` 18 | 19 | This script does the following: 20 | * runs `anchor_spans.py` to get corresponding text for `irl-annotations.jsonl`. 21 | * runs `convert_spans.py` to tokenize, apply IOB labeling, and prepare train/dev/test split 22 | 23 | ## Model Training 24 | To train an intent role labeling model on the resulting annotations, run the [run_train_model.sh](run_train_model.sh): 25 | ```bash 26 | cd data 27 | bash run_train_model.sh 28 | ``` -------------------------------------------------------------------------------- /data/anchor_spans.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to anchor labeled spans in dialogues from the SGD dataset 3 | (https://github.com/google-research-datasets/dstc8-schema-guided-dialogue). 4 | 5 | Usage: 6 | python anchor_spans.py --sgd-dir path/to/dstc8-schema-guided-dialogue --offsets-file path/to/annotations.jsonl 7 | 8 | (Will write anchored annotations at path/to/annotations.anchored.jsonl) 9 | """ 10 | import argparse 11 | import json 12 | from dataclasses import dataclass, asdict, replace 13 | from pathlib import Path 14 | from typing import Dict, List 15 | 16 | from tqdm import tqdm 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--sgd-dir', 20 | help='Path to SGD repository root directory from' 21 | ' https://github.com/google-research-datasets/dstc8-schema-guided-dialogue', required=True) 22 | parser.add_argument('--offsets-file', 23 | help='Path to file defining labeled spans with offsets in SGD conversations', required=True) 24 | 25 | 26 | @dataclass 27 | class LabeledSpan: 28 | label: str 29 | start: int 30 | exclusive_end: int 31 | text: str = None 32 | 33 | @staticmethod 34 | def from_dict(json_dict: Dict): 35 | return LabeledSpan( 36 | label=json_dict['label'], 37 | start=json_dict['start'], 38 | exclusive_end=json_dict['exclusive_end'], 39 | text=json_dict.get('text'), 40 | ) 41 | 42 | 43 | @dataclass 44 | class LabeledUtterance: 45 | spans: List[LabeledSpan] 46 | dialogue: str 47 | turn: int 48 | text: str = None 49 | 50 | @staticmethod 51 | def from_dict(json_dict: Dict): 52 | return LabeledUtterance( 53 | [LabeledSpan.from_dict(span) for span in json_dict['spans']], 54 | json_dict.get('dialogue'), 55 | json_dict.get('turn'), 56 | json_dict.get('text') 57 | ) 58 | 59 | @staticmethod 60 | def read_lines(path: Path): 61 | with path.open() as lines: 62 | lines = [line.strip() for line in lines if line.strip()] 63 | result = [] 64 | for line in lines: 65 | result.append(LabeledUtterance.from_dict(json.loads(line))) 66 | return result 67 | 68 | 69 | def write_utterances(path: Path, utterances: List[LabeledUtterance]): 70 | with path.open(mode='w') as out: 71 | for utterance in utterances: 72 | out.write(json.dumps(asdict(utterance)) + '\n') 73 | 74 | 75 | def anchor_spans_from_sgd(sgd_dir: Path, offsets_file: Path) -> List[LabeledUtterance]: 76 | """ 77 | :param sgd_dir: Path to SGD repository root directory 78 | :param offsets_file: Path to file defining labeled spans with offsets in SGD conversations 79 | :return: utterances with spans anchored in text from SGD 80 | """ 81 | result = [] 82 | uid_to_utterance = { 83 | (utt.dialogue, utt.turn): utt for utt in LabeledUtterance.read_lines(offsets_file) 84 | } 85 | for dialogues_path in tqdm(sorted(sgd_dir.glob('train/dialogues*.json'))): 86 | dialogues = json.loads(dialogues_path.read_text(encoding='utf-8')) 87 | for dialogue in dialogues: 88 | for i, turn in enumerate(dialogue['turns']): 89 | text = turn['utterance'] 90 | uid = (dialogue["dialogue_id"], i) 91 | if uid not in uid_to_utterance: 92 | continue 93 | utterance = uid_to_utterance[uid] 94 | anchored_utterance = replace( 95 | utterance, 96 | spans=[replace(span, text=text[span.start:span.exclusive_end]) for span in utterance.spans], 97 | text=text 98 | ) 99 | result.append(anchored_utterance) 100 | return result 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parser.parse_args() 105 | anchored = anchor_spans_from_sgd(Path(args.sgd_dir), Path(args.offsets_file)) 106 | write_utterances(Path(args.offsets_file).with_suffix('.anchored.jsonl'), anchored) 107 | 108 | -------------------------------------------------------------------------------- /data/convert_spans.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert JSONL anchored/labeled spans to training format (IOB) and splits. 3 | 4 | Usage: 5 | ``` 6 | python3 -m venv irl-data 7 | source irl-data/bin/activate 8 | pip install -r requirements-process-data.txt 9 | python convert_spans.py irl-annotations.anchored.jsonl splits.json 10 | ``` 11 | """ 12 | import argparse 13 | import json 14 | from dataclasses import dataclass 15 | from pathlib import Path 16 | from typing import Dict, List, Tuple, Optional 17 | 18 | from spacy.tokens import Doc 19 | from tqdm import tqdm 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('spans', type=Path) 23 | parser.add_argument('splits', type=Path) 24 | 25 | 26 | @dataclass(frozen=True) 27 | class LabeledSpan: 28 | label: str 29 | start: int 30 | exclusive_end: int 31 | text: Optional[str] = None 32 | 33 | @classmethod 34 | def from_dict(cls, json_dict: Dict) -> 'LabeledSpan': 35 | return LabeledSpan(**json_dict) 36 | 37 | 38 | @dataclass(frozen=True) 39 | class LabeledUtterance: 40 | text: str 41 | spans: List[LabeledSpan] 42 | uid: str = '' 43 | 44 | @classmethod 45 | def from_dict(cls, json_dict: Dict) -> 'LabeledUtterance': 46 | return LabeledUtterance( 47 | text=json_dict['text'], 48 | spans=[LabeledSpan.from_dict(span) for span in json_dict['spans']], 49 | uid=f"{json_dict['dialogue']}.{json_dict['turn']}" 50 | ) 51 | 52 | @staticmethod 53 | def read_lines(path: Path) -> List['LabeledUtterance']: 54 | result = [] 55 | with path.open() as lines: 56 | for line in lines: 57 | result.append(LabeledUtterance.from_dict(json.loads(line))) 58 | return result 59 | 60 | 61 | def _init_parser(name="en_core_web_md"): 62 | import spacy 63 | from spacy.attrs import ORTH 64 | from spacy import Language 65 | 66 | nlp = spacy.load(name) 67 | 68 | @Language.component("set_custom_boundaries") 69 | def set_custom_boundaries(_doc): 70 | for _token in _doc[:-1]: 71 | if _token.text == "\n": 72 | _doc[_token.i + 1].is_sent_start = True 73 | else: 74 | _doc[_token.i + 1].is_sent_start = False 75 | return _doc 76 | 77 | # special cases for tokenizer to avoid sentence splitting issues 78 | special_cases = [ 79 | ("ride?How", [{ORTH: "ride"}, {ORTH: "?"}, {ORTH: "How"}]), 80 | ("travel?I", [{ORTH: "travel"}, {ORTH: "?"}, {ORTH: "I"}]), 81 | ("cab?it", [{ORTH: "cab"}, {ORTH: "?"}, {ORTH: "it"}]), 82 | ("events?2", [{ORTH: "events"}, {ORTH: "?"}, {ORTH: "2"}]), 83 | ("Great!Buy", [{ORTH: "Great"}, {ORTH: "!"}, {ORTH: "Buy"}]), 84 | ("there/", [{ORTH: "there"}, {ORTH: "/"}]), 85 | ("calendar/", [{ORTH: "calendar"}, {ORTH: "/"}]), 86 | ("then-", [{ORTH: "then"}, {ORTH: "-"}]), 87 | ("Stovall'S.", [{ORTH: "Stovall"}, {ORTH: "'S"}, {ORTH: "."}]), 88 | ] 89 | for letter in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ': 90 | special_cases.append((f'{letter}..', [{ORTH: f'{letter}.'}, {ORTH: '.'}])) 91 | 92 | for key, val in special_cases: 93 | nlp.tokenizer.add_special_case(key, val) 94 | nlp.add_pipe("set_custom_boundaries", first=True) 95 | return nlp 96 | 97 | 98 | def convert_to_labels_and_tokens( 99 | spans: List[LabeledSpan], 100 | parse: Doc, 101 | ) -> Tuple[List[str], List[str]]: 102 | ann_map = {} # resolved spans by word index 103 | for span in spans: 104 | character_span = parse.char_span(span.start, span.exclusive_end) 105 | for word in character_span: 106 | ann_map[word.i] = span 107 | 108 | tokens = [] 109 | labels = [] 110 | prev_ann_start = -1 111 | for tok in parse: 112 | label, ann_start = 'O', -1 113 | if tok.i in ann_map: 114 | span = ann_map[tok.i] 115 | ann_start = span.start 116 | label = f'{"B" if ann_start != prev_ann_start else "I"}-{span.label}' 117 | labels.append(label) 118 | tokens.append(tok.text.strip()) 119 | prev_ann_start = ann_start 120 | return labels, tokens 121 | 122 | 123 | def _convert_utterance(labeled_utterance: LabeledUtterance, parse: Doc) -> Optional[Dict]: 124 | spans = labeled_utterance.spans 125 | spans = sorted(spans, key=lambda x: (x.start, x.exclusive_end)) 126 | labels, tokens = convert_to_labels_and_tokens(spans, parse) 127 | return dict(tokens=tokens, labels=labels, uid=labeled_utterance.uid) 128 | 129 | 130 | def _tokenize_and_apply_iob_labels(utterances: List[LabeledUtterance]) -> Dict[str, Dict]: 131 | # tokenize 132 | nlp = _init_parser() 133 | parses = tqdm(nlp.pipe([utt.text for utt in utterances], batch_size=64), total=len(utterances)) 134 | # convert to IOB 135 | tokenized_utterances = {} 136 | for parse, utterance in zip(parses, utterances): 137 | converted = _convert_utterance(utterance, parse) 138 | tokenized_utterances[converted['uid']] = converted 139 | return tokenized_utterances 140 | 141 | 142 | def main(in_pth: Path, splits_pth: Path, out_pth: Path): 143 | # read anchored utterances 144 | utterances = LabeledUtterance.read_lines(in_pth) 145 | 146 | # tokenize utterances and convert spans to IOB labels (e.g. B-Query, I-Query, O) 147 | tokenized_utterances = _tokenize_and_apply_iob_labels(utterances) 148 | 149 | # split into train/dev/test using splits from paper 150 | split_to_uids = json.loads(splits_pth.read_text('utf-8')) 151 | splits = {} 152 | for split, uids in split_to_uids.items(): 153 | print(f'{split}: {len(uids)}') 154 | splits[split] = [tokenized_utterances[uid] for uid in uids] 155 | 156 | # write splits in JSONL format, with a line per sentence 157 | out_pth.mkdir(exist_ok=True, parents=True) 158 | for split, utterances in splits.items(): 159 | with (out_pth / f'{split}.json').open(mode='w') as out: 160 | for conv in utterances: 161 | out.write(json.dumps(conv) + '\n') 162 | 163 | 164 | if __name__ == '__main__': 165 | args = parser.parse_args() 166 | main(args.spans, args.splits, args.spans.parent) 167 | -------------------------------------------------------------------------------- /data/convert_to_allennlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts from a transformers IRL model to an AllenNLP predictor (sequence tagger). 3 | """ 4 | import argparse 5 | import json 6 | from pathlib import Path 7 | 8 | import torch 9 | from allennlp.common import Params 10 | from allennlp.data.vocabulary import Vocabulary 11 | from allennlp.models import archive_model 12 | from allennlp.modules.seq2seq_encoders import PassThroughEncoder 13 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder 14 | from allennlp.modules.token_embedders import PretrainedTransformerMismatchedEmbedder 15 | from allennlp.nn import InitializerApplicator 16 | from allennlp.nn.initializers import PretrainedModelInitializer 17 | 18 | from irl.intent_role_labelers import IntentRoleLabeler 19 | from irl.irl_tagger import TurnTagger 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('model_path', type=Path) 23 | 24 | 25 | def _vocabulary(model_name: Path): 26 | from allennlp.common import cached_transformers 27 | tokenizer = cached_transformers.get_tokenizer(str(model_name)) 28 | vocab = Vocabulary(non_padded_namespaces=['tokenizer', 'labels'], oov_token='') 29 | # add tokenizer vocabulary 30 | vocab.add_transformer_vocab(tokenizer, 'tokenizer') 31 | # add label vocabulary 32 | id2label = json.loads((model_name / 'config.json').read_text(encoding='utf-8'))['id2label'] 33 | for k, v in id2label.items(): 34 | vocab.add_token_to_namespace(v, "labels") 35 | return vocab 36 | 37 | 38 | def _model(model_name: Path, vocab: Vocabulary): 39 | text_field_embedder = BasicTextFieldEmbedder(token_embedders={ 40 | "bert": PretrainedTransformerMismatchedEmbedder(str(model_name)), 41 | }) 42 | model = TurnTagger( 43 | vocab=vocab, 44 | text_field_embedder=text_field_embedder, 45 | encoder=PassThroughEncoder(768), 46 | initializer=InitializerApplicator([( 47 | # map to parameter names expected by AllenNLP 48 | 'tag_projection_layer.*', PretrainedModelInitializer( 49 | weights_file_path=str(model_name / 'pytorch_model.bin'), 50 | parameter_name_overrides={ 51 | 'tag_projection_layer._module.bias': 'classifier.bias', 52 | 'tag_projection_layer._module.weight': 'classifier.weight', 53 | })) 54 | ]) 55 | ) 56 | return model 57 | 58 | 59 | def _config(): 60 | # hard-coded configuration for AllenNLP inference 61 | return { 62 | "dataset_reader": { 63 | "type": "tagger_dataset_reader", 64 | "tagger_preprocessor": { 65 | "tokenizer": { 66 | "type": "spacy", 67 | "pos_tags": True 68 | } 69 | }, 70 | "token_indexers": { 71 | "bert": { 72 | "type": "pretrained_transformer_mismatched", 73 | "model_name": "roberta-base", 74 | "namespace": "tokens" 75 | } 76 | } 77 | }, 78 | "model": { 79 | "type": "turn_tagger", 80 | "encoder": { 81 | "type": "pass_through", 82 | "input_dim": 768 83 | }, 84 | "text_field_embedder": { 85 | "token_embedders": { 86 | "bert": { 87 | "type": "pretrained_transformer_mismatched", 88 | "model_name": "roberta-base" 89 | } 90 | } 91 | } 92 | }} 93 | 94 | 95 | def main(): 96 | args = parser.parse_args() 97 | model_path = args.model_path 98 | archive_path = model_path.parent / 'archive' 99 | archive_path.mkdir(parents=True, exist_ok=True) 100 | # prepare the vocabulary 101 | vocab = _vocabulary(model_path) 102 | # prepare the model 103 | model = _model(model_path, vocab) 104 | # add weights 105 | torch.save(model.state_dict(), archive_path / 'weights.th') 106 | # meta file indicating allenNLP version number 107 | (archive_path / 'meta.json').write_text(json.dumps({'version': '2.9.3'}), encoding='utf-8') 108 | # add vocabulary 109 | vocab.save_to_files(str(archive_path / 'vocabulary')) 110 | # add config 111 | (archive_path / 'config.json').write_text(json.dumps(_config()), encoding='utf-8') 112 | # write archive to model.tar.gz 113 | archive_model(str(archive_path), 'weights.th') 114 | 115 | # prepare the dataset reader / tokenizer 116 | predictor = IntentRoleLabeler.from_params(Params({ 117 | 'type': 'tagger_based_intent_role_labeler', 118 | 'model_path': archive_path / 'model.tar.gz', 119 | 'cuda_device': -1, 120 | })) 121 | 122 | # sanity check 123 | text_sample = "I'm looking to purchase movie tickets." 124 | print(predictor.label_batch([text_sample])) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /data/requirements-process-data.txt: -------------------------------------------------------------------------------- 1 | spacy==3.2.6 2 | https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.2.0/en_core_web_md-3.2.0.tar.gz#egg=en_core_web_md 3 | tqdm 4 | -------------------------------------------------------------------------------- /data/requirements-train-model.txt: -------------------------------------------------------------------------------- 1 | # to serialize results 2 | allennlp==2.9.3 3 | # workaround for allennlp issue 4 | cached-path==1.1.2 5 | # for training model 6 | torch==1.11.0 7 | transformers==4.18 8 | seqeval 9 | datasets 10 | # just for testing serialized predictor 11 | spacy==3.2.6 12 | https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.2.0/en_core_web_md-3.2.0.tar.gz#egg=en_core_web_md -------------------------------------------------------------------------------- /data/run_process_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euxo pipefail 4 | cd "$(dirname "$0")" || exit 5 | 6 | if [ ! -d "irl-data" ]; then 7 | python3 -m venv irl-data 8 | fi 9 | 10 | # Install dependencies 11 | source irl-data/bin/activate 12 | pip install -r requirements-process-data.txt 13 | 14 | # Add text from SGD to IRL annotations 15 | if [ ! -d "dstc8-schema-guided-dialogue" ]; then 16 | git clone https://github.com/google-research-datasets/dstc8-schema-guided-dialogue.git 17 | fi 18 | 19 | python3 anchor_spans.py --sgd-dir dstc8-schema-guided-dialogue --offsets-file irl-annotations.jsonl 20 | 21 | # Tokenize text and convert spans to IOB labels corresponding to each token, split into train/dev/tet 22 | python3 convert_spans.py irl-annotations.anchored.jsonl splits.json 23 | -------------------------------------------------------------------------------- /data/run_train_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_NAME='irl-model-sgd-08-16-2022.tar.gz' 4 | 5 | set -euxo pipefail 6 | cd "$(dirname "$0")" || exit 7 | 8 | if [ ! -d "irl-model-train" ]; then 9 | python3 -m venv irl-model-train 10 | fi 11 | 12 | # Install dependencies 13 | source irl-model-train/bin/activate 14 | pip install -r requirements-train-model.txt 15 | 16 | # download training script 17 | if [ ! -f "run_ner.py" ]; then 18 | wget https://raw.githubusercontent.com/huggingface/transformers/v4.18.0/examples/pytorch/token-classification/run_ner.py 19 | fi 20 | 21 | # i/o 22 | train_file='train.json' 23 | validation_file='dev.json' 24 | test_file='test.json' 25 | output_dir='results' 26 | label_column_name='labels' 27 | 28 | # parameters 29 | base_model='roberta-base' 30 | max_seq_length=256 31 | train_batch_size=16 32 | eval_batch_size=32 33 | lr=2e-5 34 | weight_decay=0.01 35 | lr_scheduler_type='linear' 36 | num_train_epochs=8 37 | warmup_ratio=0.06 38 | seed=1 39 | 40 | # run training 41 | WANDB_DISABLED="true" python run_ner.py \ 42 | --output_dir $output_dir \ 43 | --train_file $train_file \ 44 | --validation_file $validation_file \ 45 | --test_file $test_file \ 46 | --model_name_or_path $base_model \ 47 | --max_seq_length $max_seq_length \ 48 | --evaluation_strategy epoch \ 49 | --per_device_train_batch_size $train_batch_size \ 50 | --per_device_eval_batch_size $eval_batch_size \ 51 | --learning_rate $lr \ 52 | --weight_decay $weight_decay \ 53 | --lr_scheduler_type $lr_scheduler_type \ 54 | --num_train_epochs $num_train_epochs \ 55 | --warmup_ratio $warmup_ratio \ 56 | --seed $seed \ 57 | --label_column_name $label_column_name \ 58 | --do_train \ 59 | --do_eval \ 60 | --do_predict 61 | 62 | # package model 63 | PYTHONPATH="$(pwd)/../pretraining/preprocess" python convert_to_allennlp.py results 64 | mkdir -p ../models/irl_model/ 65 | mv archive/model.tar.gz ../models/irl_model/$MODEL_NAME 66 | -------------------------------------------------------------------------------- /downstream/README.md: -------------------------------------------------------------------------------- 1 | # Downstream 2 | 3 | This directory is about how to evaluate the pre-trained model on each of four intent classification datasets (BANKING77, HWU64, Liu, OOS), and how to fine-tune the model using ProtoNet or ProtAugment. 4 | 5 | ## Download Datasets & Get Label Names 6 | For convenience, you can download the intent dataset splits from [this repo](https://github.com/tdopierre/ProtAugment/tree/main/data) and put the downloaded data into `data` folder. 7 | Before evaluating or fine-tuning model, label name files should be generated to use label names as support example. 8 | 9 | ``` 10 | DATA_DIR=../data/downstream 11 | python gen_labelname.py --data-dir ${DATA_DIR} 12 | ``` 13 | 14 | ## Eval the Model Before Fine-tuning 15 | To evaluate the pre-trained model on the intent classification dataset, run the following script. 16 | ``` 17 | bash ./scripts/run_eval_before_ft.sh 18 | ``` 19 | 20 | ## ProtoNet 21 | To fine-tune the pre-trained model using ProtoNet, run the following script. 22 | ``` 23 | bash ./scripts/run_protonet.sh 24 | ``` 25 | 26 | ## ProtAugment 27 | To fine-tune the pre-trained model using ProtAugment, run the following script. 28 | ``` 29 | bash ./scripts/run_protaugment.sh 30 | ``` 31 | 32 | ## Eval the Model After Fine-tuning 33 | 34 | To evaluate the fine-tuned model, run the following script. 35 | ``` 36 | bash ./scripts/run_eval_after_ft.sh 37 | ``` -------------------------------------------------------------------------------- /downstream/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. Licensed under the MIT license. 2 | 3 | import sys, os 4 | sys.path.append(os.path.dirname(__file__) + os.sep + '../') 5 | -------------------------------------------------------------------------------- /downstream/encoders/bert_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | 3 | from typing import List 4 | 5 | import torch.nn as nn 6 | import logging 7 | import warnings 8 | import torch 9 | from transformers import AutoModel, AutoTokenizer 10 | 11 | logging.basicConfig() 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.DEBUG) 14 | 15 | warnings.simplefilter('ignore') 16 | 17 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 18 | 19 | 20 | class BERTEncoder(nn.Module): 21 | def __init__(self, config_name_or_path): 22 | super(BERTEncoder, self).__init__() 23 | logger.info(f"Loading Encoder @ {config_name_or_path}") 24 | self.tokenizer = AutoTokenizer.from_pretrained(config_name_or_path) 25 | self.bert = AutoModel.from_pretrained(config_name_or_path).to(device) 26 | logger.info(f"Encoder loaded.") 27 | self.warmed: bool = False 28 | # transformer_models/OOS/fine-tuned 29 | def embed_sentences(self, sentences: List[str]): 30 | if self.warmed: 31 | padding = True 32 | else: 33 | padding = "max_length" 34 | self.warmed = True 35 | batch = self.tokenizer.batch_encode_plus( 36 | sentences, 37 | return_tensors="pt", 38 | max_length=64, 39 | truncation=True, 40 | padding=padding 41 | ) 42 | batch = {k: v.to(device) for k, v in batch.items()} 43 | 44 | fw = self.bert.forward(**batch) 45 | return fw.pooler_output 46 | 47 | 48 | def test(): 49 | encoder = BERTEncoder("bert-base-cased") 50 | sentences = ["this is one", "why not another"] 51 | encoder.embed_sentences(sentences) 52 | -------------------------------------------------------------------------------- /downstream/encoders/sbert_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. Licensed under the MIT license. 2 | 3 | from typing import List 4 | 5 | import torch.nn as nn 6 | import logging 7 | import warnings 8 | import torch 9 | from sentence_transformers import SentenceTransformer 10 | 11 | logging.basicConfig() 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.DEBUG) 14 | 15 | warnings.simplefilter('ignore') 16 | 17 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 18 | 19 | 20 | class SentEncoder(nn.Module): 21 | def __init__(self, config_name_or_path): 22 | super(SentEncoder, self).__init__() 23 | logger.info(f"Loading Encoder @ {config_name_or_path}") 24 | self.bert = SentenceTransformer(config_name_or_path).to(device) 25 | logger.info(f"Encoder loaded.") 26 | self.emb_dim = 768 27 | 28 | def embed_sentences(self, sentences: List[str]): 29 | features = self.bert.tokenize(sentences) 30 | for key in features: 31 | if isinstance(features[key], torch.Tensor): 32 | features[key] = features[key].to(device) 33 | 34 | out_features = self.bert.forward(features) 35 | embeddings = out_features['sentence_embedding'] 36 | return embeddings 37 | 38 | def forward(self, sentences: List[str]): 39 | return self.embed_sentences(sentences) 40 | -------------------------------------------------------------------------------- /downstream/gen_labelname.py: -------------------------------------------------------------------------------- 1 | # Original Copyright (c) Microsoft Corporation. Licensed under the MIT license. 2 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | import os, json 5 | import argparse 6 | 7 | def generate_labels(input_fn, output_fn): 8 | all_labels = [] 9 | with open(input_fn, mode="r", encoding="utf-8") as fp: 10 | for line in fp: 11 | label = json.loads(line)["label"] 12 | if label not in all_labels: 13 | all_labels.append(label) 14 | with open(output_fn, mode="w", encoding="utf-8") as fp: 15 | fp.writelines([x + "\n" for x in all_labels]) 16 | return 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--data-dir", type=str, default=None) 21 | 22 | args = parser.parse_args() 23 | 24 | for dataset in ['BANKING77', 'HWU64', 'Liu', 'OOS']: 25 | input_path = os.path.join(args.data_dir, dataset, 'full.jsonl') 26 | output_path = os.path.join(args.data_dir, dataset, 'labels.txt') 27 | generate_labels(input_path, output_path) -------------------------------------------------------------------------------- /downstream/protaugment/paraphrase/modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | 3 | import numpy as np 4 | import random 5 | import torch 6 | import logging 7 | from typing import List, Dict, Callable, Union 8 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 9 | 10 | default_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 11 | 12 | logging.basicConfig() 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.DEBUG) 15 | 16 | 17 | class ParaphraseModel: 18 | def __init__(self, device=None): 19 | self.device = device if device else default_device 20 | 21 | def paraphrase(self, src_texts: List[str], **kwargs): 22 | raise NotImplementedError 23 | 24 | 25 | class BaseParaphraseModel(ParaphraseModel): 26 | def __init__( 27 | self, 28 | model_name_or_path: str, 29 | tok_name_or_path: str = None, 30 | num_return_sequences: int = 1, 31 | num_beams: int = None, 32 | device=None 33 | ): 34 | super().__init__(device=device) 35 | self.device = device if device else default_device 36 | self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path).to(self.device) 37 | self.tok = AutoTokenizer.from_pretrained(tok_name_or_path if tok_name_or_path else model_name_or_path) 38 | self.num_return_sequences = num_return_sequences 39 | self.num_beams = num_beams if num_beams else self.num_return_sequences 40 | assert self.num_beams >= self.num_return_sequences 41 | 42 | def paraphrase(self, src_texts: List[str], **kwargs): 43 | batch = self.tok.prepare_seq2seq_batch( 44 | src_texts=src_texts, 45 | max_length=128, 46 | return_tensors="pt", 47 | ) 48 | batch = {k: v.to(self.device) for k, v in batch.items()} 49 | preds = self.model.generate(**batch, max_length=512, num_beams=self.num_beams, num_return_sequences=self.num_return_sequences) 50 | tgt_texts = self.tok.batch_decode(preds.detach().cpu(), skip_special_tokens=True) 51 | return [tgt_texts[i:i + self.num_return_sequences] for i in range(0, len(src_texts) * self.num_return_sequences, self.num_return_sequences)] 52 | 53 | 54 | class DropChances: 55 | def __init__(self, auc: float): 56 | self.auc = auc 57 | 58 | def flat_drop_chance(self, current_value: Union[int, float], max_value: Union[int, float]): 59 | return self.auc 60 | 61 | def down_decrease_drop_chance(self, current_value: Union[int, float], max_value: Union[int, float]): 62 | return (self.auc / .5) * (.25 + .5 * (1 - current_value / max_value)) 63 | 64 | def fast_decrease_drop_chance(self, current_value: Union[int, float], max_value: Union[int, float]): 65 | return (self.auc / .5) * (0 + 1 * (1 - current_value / max_value)) 66 | 67 | def up_drop_chance(self, current_value: Union[int, float], max_value: Union[int, float]): 68 | return (self.auc / .5) * (.25 + .5 * (current_value / max_value)) 69 | 70 | def get_drop_fn(self, drop_fn_string: str) -> Callable: 71 | assert drop_fn_string in ("flat", "slow", "fast", "up", "down") 72 | if drop_fn_string in ("slow",): 73 | logger.warning(f"drop_chance_speed `{drop_fn_string}` will be deprecated. Valid values are:\n" 74 | f" `flat` : same drop chance for every token in the sentence;\n" 75 | f" `down` : more weight on the first tokens (linear decrease);\n" 76 | f" `up` : more weight on the last tokens (linear increase);\n" 77 | f" `fast` : linear decrease, higher slope") 78 | if drop_fn_string == "flat": 79 | return self.flat_drop_chance 80 | elif drop_fn_string in ("slow", "down"): 81 | return self.down_decrease_drop_chance 82 | elif drop_fn_string == "fast": 83 | return self.fast_decrease_drop_chance 84 | elif drop_fn_string == "up": 85 | return self.up_drop_chance 86 | else: 87 | raise NotImplementedError 88 | 89 | 90 | class ForbidStrategies: 91 | def __init__(self, special_ids: List[int]): 92 | self.special_ids = special_ids 93 | 94 | def unigram_dropping_strategy(self, input_ids: torch.Tensor, drop_chance_fn: Callable): 95 | bad_words_ids = list() 96 | for row in input_ids.tolist(): 97 | row = [item for item in row if item not in self.special_ids] 98 | for item_ix, item in enumerate(row): 99 | drop_chance = drop_chance_fn(item_ix, len(row)) 100 | if random.random() < drop_chance: 101 | bad_words_ids.append(item) 102 | 103 | # Reshape to correct format 104 | bad_words_ids = [[item] for item in bad_words_ids] 105 | return bad_words_ids 106 | 107 | def bigram_dropping_strategy(self, input_ids: torch.Tensor): 108 | bad_words_ids = list() 109 | for row in input_ids.tolist(): 110 | row = [item for item in row if item not in self.special_ids] 111 | for i in range(0, len(row) - 1): 112 | bad_words_ids.append(row[i:i + 2]) 113 | return bad_words_ids 114 | 115 | 116 | class BaseParaphraseBatchPreparer: 117 | def __init__(self, tokenizer, device=None): 118 | self.tokenizer = tokenizer 119 | self.device = device if device else default_device 120 | 121 | def prepare_batch(self, src_texts: List[str]): 122 | batch = self.tokenizer.prepare_seq2seq_batch(src_texts=src_texts, return_tensors="pt", max_length=512) 123 | batch = {k: v.to(self.device) for k, v in batch.items()} 124 | self.pimp_batch(batch) 125 | return batch 126 | 127 | def pimp_batch(self, batch: Dict[str, torch.Tensor], **kwargs): 128 | # This must be implemented elsewhere! 129 | return 130 | 131 | 132 | class UnigramRandomDropParaphraseBatchPreparer(BaseParaphraseBatchPreparer): 133 | 134 | def __init__(self, tokenizer, auc: float = None, drop_chance_speed: str = None, device=None): 135 | super().__init__(tokenizer=tokenizer, device=device) 136 | 137 | # Args checking 138 | self.auc = auc 139 | assert 0 <= self.auc <= 1 140 | self.drop_chance_speed = drop_chance_speed 141 | assert self.drop_chance_speed in ("flat", "slow", "fast", "up", "down") 142 | 143 | def pimp_batch(self, batch: Dict[str, torch.Tensor], **kwargs): 144 | bad_words_ids = ForbidStrategies( 145 | special_ids=self.tokenizer.all_special_ids 146 | ).unigram_dropping_strategy( 147 | batch["input_ids"], 148 | drop_chance_fn=DropChances(auc=self.auc).get_drop_fn(self.drop_chance_speed) 149 | ) 150 | if len(bad_words_ids): 151 | batch["bad_words_ids"] = bad_words_ids 152 | 153 | 154 | class BigramDropParaphraseBatchPreparer(BaseParaphraseBatchPreparer): 155 | def __init__(self, tokenizer, device=None): 156 | super().__init__(tokenizer=tokenizer, device=device) 157 | 158 | def pimp_batch(self, batch: Dict[str, torch.Tensor], **kwargs): 159 | bad_words_ids = ForbidStrategies(special_ids=self.tokenizer.all_special_ids).bigram_dropping_strategy(batch["input_ids"]) 160 | if len(bad_words_ids): 161 | batch["bad_words_ids"] = bad_words_ids 162 | 163 | 164 | def tune_batch_random_drop(batch: Dict[str, torch.Tensor], drop_prob: float = 1): 165 | input_ids = batch["input_ids"] 166 | attention_mask = batch["attention_mask"] 167 | ids_filtered = input_ids * attention_mask * (input_ids != 2) 168 | ids_filtered = (torch.rand_like(ids_filtered.to(float)) > 1 - drop_prob) * ids_filtered 169 | ids_filtered = ids_filtered[ids_filtered != 0] 170 | bad_words_ids = ids_filtered.view(-1).unique().view(-1, 1).tolist() 171 | if len(bad_words_ids): 172 | batch["bad_words_ids"] = bad_words_ids 173 | 174 | 175 | def bleu_score(src: str, dst: str): 176 | from sacrebleu import sentence_bleu 177 | return sentence_bleu(dst, [src]).score 178 | 179 | 180 | def filter_generated_texts_with_clustering(texts: List[str], n_return_sequences: int): 181 | assert len(texts) >= n_return_sequences 182 | embeddings = use_embedder.embed_many(texts) 183 | 184 | # KMeans (this is too slow) 185 | # from sklearn.cluster import KMeans, AgglomerativeClustering 186 | # clustering_algo = KMeans(n_clusters=self.num_beam_groups, max_iter=10) 187 | 188 | # Agglomerative Clustering 189 | from sklearn.cluster import AgglomerativeClustering 190 | clustering_algo = AgglomerativeClustering(n_clusters=n_return_sequences, affinity='euclidean', linkage='ward') 191 | labels = clustering_algo.fit_predict(embeddings) 192 | 193 | # Organise labels & data into clusters 194 | cluster = dict() 195 | for txt_ix, (txt, label) in enumerate(zip(texts, labels)): 196 | cluster.setdefault(label, []).append((txt_ix, txt)) 197 | 198 | # Write to output 199 | output = list() 200 | from sklearn.metrics.pairwise import pairwise_distances 201 | for label, txts in cluster.items(): 202 | # In a cluster, select the sentence closest to the center 203 | distances = pairwise_distances([ 204 | embeddings[txt_ix] for txt_ix, _ in txts 205 | ], [embeddings[labels == label].mean(0)]) 206 | output.append(txts[distances.flatten().argmin()][1]) 207 | # batch_output.append(txts[0]) 208 | return output 209 | 210 | 211 | def filter_generated_texts_with_distance_metric(texts: List[List[str]], src: str, distance_metric_fn: Callable[[str, str], float], lower_is_better: bool = True): 212 | scores = [ 213 | [distance_metric_fn(src, text) for text in group] 214 | for group in texts 215 | ] 216 | 217 | if lower_is_better: 218 | ranking_fn = np.argmin 219 | else: 220 | ranking_fn = np.argmax 221 | return [ 222 | group[ranking_fn(scores_)] 223 | for group, scores_ in zip(texts, scores) 224 | ] 225 | 226 | 227 | class DBSParaphraseModel(ParaphraseModel): 228 | def __init__( 229 | self, 230 | model_name_or_path: str, 231 | tok_name_or_path: str = None, 232 | beam_group_size: int = 4, 233 | num_beams: int = 20, 234 | diversity_penalty: float = 1.0, 235 | filtering_strategy: str = None, 236 | paraphrase_batch_preparer: BaseParaphraseBatchPreparer = None, 237 | device=None 238 | ): 239 | super().__init__(device=device) 240 | self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path).to(self.device) 241 | self.tokenizer = AutoTokenizer.from_pretrained(tok_name_or_path if tok_name_or_path else model_name_or_path) 242 | self.num_return_sequences = self.num_beams = num_beams 243 | self.beam_group_size = beam_group_size 244 | self.num_beam_groups = self.num_beams // self.beam_group_size 245 | assert self.num_beams % self.beam_group_size == 0 246 | self.filtering_strategy = filtering_strategy 247 | self.diversity_penalty = diversity_penalty 248 | if paraphrase_batch_preparer is None: 249 | paraphrase_batch_preparer = BaseParaphraseBatchPreparer(tokenizer=self.tokenizer) 250 | self.paraphrase_batch_preparer = paraphrase_batch_preparer 251 | 252 | def paraphrase(self, src_texts: List[str], **kwargs): 253 | batch = self.paraphrase_batch_preparer.prepare_batch(src_texts=src_texts) 254 | max_length = batch["input_ids"].shape[1] 255 | with torch.no_grad(): 256 | preds = self.model.generate( 257 | **batch, 258 | max_length=max_length, 259 | num_beams=self.num_beams, 260 | num_beam_groups=self.beam_group_size, 261 | diversity_penalty=self.diversity_penalty, 262 | num_return_sequences=self.num_return_sequences 263 | ) 264 | 265 | tgt_texts = self.tokenizer.batch_decode(preds.detach().cpu(), skip_special_tokens=True) 266 | 267 | batches = [tgt_texts[i:i + self.num_return_sequences] for i in range(0, len(src_texts) * self.num_return_sequences, self.num_return_sequences)] 268 | 269 | output = list() 270 | 271 | for src, batch in zip(src_texts, batches): 272 | if self.filtering_strategy == "clustering": 273 | filtered = filter_generated_texts_with_clustering(batch, self.num_beam_groups) 274 | elif self.filtering_strategy == "bleu": 275 | filtered = filter_generated_texts_with_distance_metric( 276 | texts=[batch[i:i + self.beam_group_size] for i in range(0, len(batch), self.beam_group_size)], 277 | src=src, 278 | distance_metric_fn=bleu_score, 279 | lower_is_better=True 280 | ) 281 | else: 282 | raise ValueError 283 | output.append(filtered) 284 | 285 | return output 286 | 287 | 288 | class EDAParaphraseModel(ParaphraseModel): 289 | def __init__( 290 | self, 291 | num_paraphrases: int = 5 292 | ): 293 | logger.info(f"Instancing EDAParaphraseModel(num_paraphrases={num_paraphrases}) to generate paraphrases") 294 | self.num_paraphrases = num_paraphrases 295 | super().__init__(device=None) 296 | 297 | def paraphrase(self, src_texts: List[str], **kwargs): 298 | from .eda import eda 299 | 300 | output = list() 301 | for text in src_texts: 302 | paraphrases = eda(text, num_aug=self.num_paraphrases) 303 | output.append(paraphrases) 304 | return output 305 | -------------------------------------------------------------------------------- /downstream/protaugment/paraphrase/utils/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | 3 | import numpy as np 4 | import collections 5 | from typing import List, Dict 6 | 7 | import random 8 | 9 | from protaugment.paraphrase.modeling import ParaphraseModel 10 | from protaugment.utils.data import get_jsonl_data, get_txt_data 11 | import torch 12 | import logging 13 | 14 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 15 | 16 | logging.basicConfig() 17 | logger = logging.getLogger(__name__) 18 | logger.setLevel(logging.DEBUG) 19 | 20 | 21 | class FewShotDataset: 22 | def __init__( 23 | self, 24 | data_path: str, 25 | n_classes: int, 26 | n_support: int, 27 | n_query: int, 28 | labels_path: str = None 29 | ): 30 | self.data_path = data_path 31 | self.labels_path = labels_path 32 | self.n_classes = n_classes 33 | self.n_support = n_support 34 | self.n_query = n_query 35 | self.data: Dict[str, List[Dict]] = None 36 | self.counter: Dict[str, int] = None 37 | self.load_file(data_path, labels_path) 38 | 39 | def load_file(self, data_path: str, labels_path: str = None): 40 | data = get_jsonl_data(data_path) 41 | if labels_path: 42 | labels = get_txt_data(labels_path) 43 | else: 44 | labels = sorted(set([item["label"] for item in data])) 45 | 46 | labels_dict = collections.defaultdict(list) 47 | for item in data: 48 | if item["label"] in labels: 49 | labels_dict[item['label']].append(item) 50 | labels_dict = dict(labels_dict) 51 | 52 | for key, val in labels_dict.items(): 53 | random.shuffle(val) 54 | self.data = labels_dict 55 | self.counter = {key: 0 for key, _ in self.data.items()} 56 | 57 | def get_episode(self) -> Dict: 58 | episode = dict() 59 | if self.n_classes: 60 | assert self.n_classes <= len(self.data.keys()) 61 | rand_keys = np.random.choice(list(self.data.keys()), self.n_classes, replace=False) 62 | 63 | # Ensure enough data are query-able 64 | assert min([len(val) for val in self.data.values()]) >= self.n_support + self.n_query 65 | 66 | # Shuffle data 67 | for key in rand_keys: 68 | random.shuffle(self.data[key]) 69 | 70 | if self.n_support: 71 | episode["xs"] = [[self.data[k][i] for i in range(self.n_support)] for k in rand_keys] 72 | if self.n_query: 73 | episode["xq"] = [[self.data[k][self.n_support + i] for i in range(self.n_query)] for k in rand_keys] 74 | return episode, rand_keys 75 | 76 | def __len__(self): 77 | return sum([len(label_data) for label, label_data in self.data.items()]) 78 | 79 | 80 | class FewShotPPDataset(FewShotDataset): 81 | def __init__( 82 | self, 83 | data_path: str, 84 | n_classes: int, 85 | n_support: int, 86 | n_query: int, 87 | n_unlabeled: int, 88 | labels_path: str): 89 | super().__init__(data_path=data_path, n_classes=n_classes, n_support=n_support, n_query=n_query, labels_path=labels_path) 90 | self.n_unlabeled = n_unlabeled 91 | 92 | def get_episode(self) -> Dict: 93 | episode, classes = super().get_episode() 94 | if self.n_classes: 95 | assert self.n_classes <= len(self.data.keys()) 96 | rand_keys = np.random.choice(list(self.data.keys()), self.n_classes, replace=False) 97 | 98 | assert set(classes) == set(rand_keys) 99 | 100 | # Ensure enough data are query-able 101 | assert all(len(self.data[key]) >= self.n_support + self.n_query + self.n_unlabeled for key in rand_keys) 102 | 103 | # Shuffle data 104 | for key in rand_keys: 105 | random.shuffle(self.data[key]) 106 | 107 | if self.n_support: 108 | episode["xs"] = [[self.data[k][i] for i in range(self.n_support)] for k in rand_keys] 109 | if self.n_query: 110 | episode["xq"] = [[self.data[k][self.n_support + i] for i in range(self.n_query)] for k in rand_keys] 111 | 112 | if self.n_unlabeled: 113 | episode['xu'] = [item for k in rand_keys for item in self.data[k][self.n_support + self.n_query:self.n_support + self.n_query + self.n_unlabeled]] 114 | 115 | return episode, classes 116 | 117 | 118 | class FewShotSSLFileDataset(FewShotDataset): 119 | def __init__( 120 | self, 121 | data_path: str, 122 | n_classes: int, 123 | n_support: int, 124 | n_query: int, 125 | n_unlabeled: int, 126 | unlabeled_file_path: str, 127 | labels_path: str): 128 | super().__init__(data_path=data_path, n_classes=n_classes, n_support=n_support, n_query=n_query, labels_path=labels_path) 129 | self.n_unlabeled = n_unlabeled 130 | logger.debug(f"Using augmented data @ {unlabeled_file_path}") 131 | self.unlabeled_data = get_jsonl_data(unlabeled_file_path) 132 | logger.debug(f"Dataset has {len(self.unlabeled_data)} unlabeled samples") 133 | 134 | def get_episode(self) -> Dict: 135 | # Get episode from regular few-shot 136 | episode, classes = super().get_episode() 137 | 138 | # Get random augmentations in the file 139 | unlabeled = np.random.choice(self.unlabeled_data, self.n_unlabeled).tolist() 140 | 141 | episode["x_augment"] = [ 142 | { 143 | "src_text": u["src_text"], 144 | "tgt_texts": u["tgt_texts"] 145 | } 146 | for u in unlabeled 147 | ] 148 | 149 | return episode, classes 150 | 151 | 152 | class FewShotSSLParaphraseDataset(FewShotDataset): 153 | n_unlabeled: int 154 | unlabeled_data: List[str] 155 | paraphrase_model: ParaphraseModel 156 | 157 | def __init__( 158 | self, 159 | data_path: str, 160 | n_classes: int, 161 | n_support: int, 162 | n_query: int, 163 | n_unlabeled: int, 164 | unlabeled_file_path: str, 165 | paraphrase_model: ParaphraseModel, 166 | labels_path: str): 167 | super().__init__(data_path=data_path, n_classes=n_classes, n_support=n_support, n_query=n_query, labels_path=labels_path) 168 | self.n_unlabeled = n_unlabeled 169 | self.unlabeled_data = get_txt_data(unlabeled_file_path) 170 | self.paraphrase_model = paraphrase_model 171 | 172 | def get_episode(self, **kwargs) -> Dict: 173 | episode, classes = super().get_episode() 174 | 175 | # Get random augmentations in the file 176 | # unlabeled = np.random.choice(self.unlabeled_data, self.n_unlabeled).tolist() 177 | unlabeled = np.random.choice(self.unlabeled_data, len(classes)).tolist() 178 | tgt_texts = self.paraphrase_model.paraphrase(unlabeled, **kwargs) 179 | episode["x_augment"] = [ 180 | { 181 | "src_text": src, 182 | "tgt_texts": tgts 183 | } 184 | for src, tgts in zip(unlabeled, tgt_texts) 185 | ] 186 | 187 | return episode, classes 188 | -------------------------------------------------------------------------------- /downstream/protaugment/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | -------------------------------------------------------------------------------- /downstream/protaugment/utils/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | 3 | import numpy as np 4 | import random 5 | import collections 6 | import os 7 | import json 8 | from typing import List, Dict 9 | 10 | 11 | def get_jsonl_data(jsonl_path: str): 12 | assert jsonl_path.endswith(".jsonl") 13 | out = list() 14 | with open(jsonl_path, 'r', encoding="utf-8") as file: 15 | for line in file: 16 | j = json.loads(line.strip()) 17 | out.append(j) 18 | return out 19 | 20 | 21 | def write_jsonl_data(jsonl_data: List[Dict], jsonl_path: str, force=False): 22 | if os.path.exists(jsonl_path) and not force: 23 | raise FileExistsError 24 | with open(jsonl_path, 'w') as file: 25 | for line in jsonl_data: 26 | file.write(json.dumps(line, ensure_ascii=False) + '\n') 27 | 28 | 29 | def get_txt_data(txt_path: str): 30 | assert txt_path.endswith(".txt") 31 | with open(txt_path, "r") as file: 32 | return [line.strip() for line in file.readlines()] 33 | 34 | 35 | def write_txt_data(data: List[str], path: str, force: bool = False): 36 | if os.path.exists(path) and not force: 37 | raise FileExistsError 38 | with open(path, "w") as file: 39 | for line in data: 40 | file.write(line + "\n") 41 | 42 | 43 | def get_tsv_data(tsv_path: str, label: str = None): 44 | out = list() 45 | with open(tsv_path, "r") as file: 46 | for line in file: 47 | line = line.strip().split('\t') 48 | if not label: 49 | label = tsv_path.split('/')[-1] 50 | 51 | out.append({ 52 | "sentence": line[0], 53 | "label": label + str(line[1]) 54 | }) 55 | return out 56 | 57 | 58 | def raw_data_to_dict(data, shuffle=True): 59 | labels_dict = collections.defaultdict(list) 60 | for item in data: 61 | labels_dict[item['label']].append(item) 62 | labels_dict = dict(labels_dict) 63 | if shuffle: 64 | for key, val in labels_dict.items(): 65 | random.shuffle(val) 66 | return labels_dict 67 | 68 | 69 | class UnlabeledDataLoader: 70 | def __init__(self, file_path: str): 71 | self.file_path = file_path 72 | self.raw_data = get_jsonl_data(self.file_path) 73 | self.data_dict = raw_data_to_dict(self.raw_data, shuffle=True) 74 | 75 | def create_episode(self, n_augment: int = 0): 76 | episode = dict() 77 | augmentations = list() 78 | if n_augment: 79 | already_done = list() 80 | for i in range(n_augment): 81 | # Draw a random label 82 | key = random.choice(list(self.data_dict.keys())) 83 | # Draw a random data index 84 | ix = random.choice(range(len(self.data_dict[key]))) 85 | # If already used, re-sample 86 | while (key, ix) in already_done: 87 | key = random.choice(list(self.data_dict.keys())) 88 | ix = random.choice(range(len(self.data_dict[key]))) 89 | already_done.append((key, ix)) 90 | if "augmentations" not in self.data_dict[key][ix]: 91 | raise KeyError(f"Input data {self.data_dict[key][ix]} does not contain any augmentations / is not properly formatted.") 92 | augmentations.append((self.data_dict[key][ix])) 93 | 94 | episode["x_augment"] = augmentations 95 | 96 | return episode 97 | 98 | 99 | class FewShotDataLoader: 100 | def __init__(self, file_path, unlabeled_file_path: str = None): 101 | self.raw_data = get_jsonl_data(file_path) 102 | self.data_dict = raw_data_to_dict(self.raw_data, shuffle=True) 103 | self.unlabeled_file_path = unlabeled_file_path 104 | if self.unlabeled_file_path: 105 | self.unlabeled_data_loader = UnlabeledDataLoader(file_path=self.unlabeled_file_path) 106 | 107 | def create_episode(self, n_support: int = 0, n_classes: int = 0, n_query: int = 0, n_unlabeled: int = 0, n_augment: int = 0): 108 | episode = dict() 109 | if n_classes: 110 | n_classes = min(n_classes, len(self.data_dict.keys())) 111 | rand_keys = np.random.choice(list(self.data_dict.keys()), n_classes, replace=False) 112 | 113 | assert min([len(val) for val in self.data_dict.values()]) >= n_support + n_query + n_unlabeled 114 | 115 | for key, val in self.data_dict.items(): 116 | random.shuffle(val) 117 | 118 | if n_support: 119 | episode["xs"] = [[self.data_dict[k][i] for i in range(n_support)] for k in rand_keys] 120 | if n_query: 121 | episode["xq"] = [[self.data_dict[k][n_support + i] for i in range(n_query)] for k in rand_keys] 122 | 123 | if n_unlabeled: 124 | episode['xu'] = [item for k in rand_keys for item in self.data_dict[k][n_support + n_query:n_support + n_query + n_unlabeled]] 125 | 126 | if n_augment: 127 | episode = dict(**episode, **self.unlabeled_data_loader.create_episode(n_augment=n_augment)) 128 | 129 | return episode, rand_keys 130 | -------------------------------------------------------------------------------- /downstream/protaugment/utils/few_shot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | 3 | import collections 4 | 5 | import numpy as np 6 | import random 7 | from typing import List 8 | from protaugment.utils.data import get_tsv_data 9 | import torch 10 | 11 | 12 | def random_sample_cls(sentences: List[str], labels: List[str], n_support: int, n_query: int, label: str): 13 | """ 14 | Randomly samples Ns examples as support set and Nq as Query set 15 | """ 16 | data = [sentences[i] for i, lab in enumerate(labels) if lab == label] 17 | perm = torch.randperm(len(data)) 18 | idx = perm[:n_support] 19 | support = [data[i] for i in idx] 20 | idx = perm[n_support: n_support + n_query] 21 | query = [data[i] for i in idx] 22 | 23 | return support, query 24 | 25 | 26 | def create_episode(data_dict, n_support, n_classes, n_query, n_unlabeled=0, n_augment=0): 27 | n_classes = min(n_classes, len(data_dict.keys())) 28 | rand_keys = np.random.choice(list(data_dict.keys()), n_classes, replace=False) 29 | 30 | assert min([len(val) for val in data_dict.values()]) >= n_support + n_query + n_unlabeled 31 | 32 | for key, val in data_dict.items(): 33 | random.shuffle(val) 34 | 35 | episode = { 36 | "xs": [ 37 | [data_dict[k][i] for i in range(n_support)] for k in rand_keys 38 | ], 39 | "xq": [ 40 | [data_dict[k][n_support + i] for i in range(n_query)] for k in rand_keys 41 | ] 42 | } 43 | 44 | if n_unlabeled: 45 | episode['xu'] = [ 46 | item for k in rand_keys for item in data_dict[k][n_support + n_query:n_support + n_query + 10] 47 | ] 48 | 49 | if n_augment: 50 | augmentations = list() 51 | already_done = list() 52 | for i in range(n_augment): 53 | # Draw a random label 54 | key = random.choice(list(data_dict.keys())) 55 | # Draw a random data index 56 | ix = random.choice(range(len(data_dict[key]))) 57 | # If already used, re-sample 58 | while (key, ix) in already_done: 59 | key = random.choice(list(data_dict.keys())) 60 | ix = random.choice(range(len(data_dict[key]))) 61 | already_done.append((key, ix)) 62 | if "augmentations" not in data_dict[key][ix]: 63 | raise KeyError(f"Input data {data_dict[key][ix]} does not contain any augmentations / is not properly formatted.") 64 | augmentations.append(( 65 | data_dict[key][ix]["sentence"], 66 | [item["text"] for item in data_dict[key][ix]["augmentations"]] 67 | )) 68 | episode["x_augment"] = augmentations 69 | 70 | return episode 71 | 72 | 73 | def create_ARSC_train_episode(prefix: str = "data/ARSC-Yu/raw", n_support: int = 5, n_query: int = 5, n_unlabeled=0): 74 | labels = sorted( 75 | set([line.strip() for line in open(f"{prefix}/workspace.filtered.list", "r").readlines()]) 76 | - set([line.strip() for line in open(f"{prefix}/workspace.target.list", "r").readlines()])) 77 | 78 | # Pick a random label 79 | label = random.choice(labels) 80 | 81 | # Pick a random binary task (2, 4, 5) 82 | binary_task = random.choice([2, 4, 5]) 83 | 84 | # Fix: this label/binary task sucks 85 | while label == "office_products" and binary_task == 2: 86 | # Pick a random label 87 | label = random.choice(labels) 88 | 89 | # Pick a random binary task (2, 4, 5) 90 | binary_task = random.choice([2, 4, 5]) 91 | 92 | data = ( 93 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.train", label=label) + 94 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.dev", label=label) + 95 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.test", label=label) 96 | ) 97 | 98 | random.shuffle(data) 99 | task = collections.defaultdict(list) 100 | for d in data: 101 | task[d['label']].append(d['sentence']) 102 | task = dict(task) 103 | 104 | assert min([len(val) for val in task.values()]) >= n_support + n_query + n_unlabeled, \ 105 | f"Label {label}_{binary_task}: min samples is {min([len(val) for val in task.values()])} while K+Q+U={n_support + n_query + n_unlabeled}" 106 | 107 | for key, val in task.items(): 108 | random.shuffle(val) 109 | 110 | episode = { 111 | "xs": [ 112 | [task[k][i] for i in range(n_support)] for k in task.keys() 113 | ], 114 | "xq": [ 115 | [task[k][n_support + i] for i in range(n_query)] for k in task.keys() 116 | ] 117 | } 118 | 119 | if n_unlabeled: 120 | episode['xu'] = [ 121 | item for k in task.keys() for item in task[k][n_support + n_query:n_support + n_query + n_unlabeled] 122 | ] 123 | return episode 124 | 125 | 126 | def create_ARSC_test_episode(prefix: str = "data/ARSC-Yu/raw", n_query: int = 5, n_unlabeled=0, set_type: str = "test"): 127 | assert set_type in ("test", "dev") 128 | labels = [line.strip() for line in open(f"{prefix}/workspace.target.list", "r").readlines()] 129 | 130 | # Pick a random label 131 | label = random.choice(labels) 132 | 133 | # Pick a random binary task (2, 4, 5) 134 | binary_task = random.choice([2, 4, 5]) 135 | 136 | support_data = get_tsv_data(f"{prefix}/{label}.t{binary_task}.train", label=label) 137 | assert len(support_data) == 10 # 2 * 5 shots 138 | support_dict = collections.defaultdict(list) 139 | for d in support_data: 140 | support_dict[d['label']].append(d['sentence']) 141 | 142 | query_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.{set_type}", label=label) 143 | query_dict = collections.defaultdict(list) 144 | for d in query_data: 145 | query_dict[d['label']].append(d['sentence']) 146 | 147 | assert min([len(val) for val in query_dict.values()]) >= n_query + n_unlabeled 148 | 149 | for key, val in query_dict.items(): 150 | random.shuffle(val) 151 | 152 | episode = { 153 | "xs": [ 154 | [sentence for sentence in support_dict[k]] for k in sorted(query_dict.keys()) 155 | ], 156 | "xq": [ 157 | [query_dict[k][i] for i in range(n_query)] for k in sorted(query_dict.keys()) 158 | ] 159 | } 160 | 161 | if n_unlabeled: 162 | episode['xu'] = [ 163 | item for k in sorted(query_dict.keys()) for item in query_dict[k][n_query:n_query + n_unlabeled] 164 | ] 165 | return episode 166 | 167 | 168 | def create_ARSC_train_baseline_episode(): 169 | labels = sorted( 170 | set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.filtered.list", "r").readlines()]) 171 | - set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.target.list", "r").readlines()])) 172 | 173 | # Pick a random label 174 | label = random.choice(labels) 175 | 176 | # Pick a random binary task (2, 4, 5) 177 | binary_task = random.choice([2, 4, 5]) 178 | 179 | data = ( 180 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.train", label=label) + 181 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.dev", label=label) + 182 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.test", label=label) 183 | ) 184 | 185 | random.shuffle(data) 186 | task = collections.defaultdict(list) 187 | for d in data: 188 | task[d['label']].append(d['sentence']) 189 | task = dict(task) 190 | 191 | for key, val in task.items(): 192 | random.shuffle(val) 193 | 194 | episode = { 195 | "xs": [ 196 | list(task[k]) for k in task.keys() 197 | ] 198 | } 199 | 200 | return episode 201 | 202 | 203 | def get_ARSC_test_tasks(): 204 | labels = sorted(set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.target.list", "r").readlines()])) 205 | 206 | tasks = list() 207 | for label in labels: 208 | for binary_task in (2, 4, 5): 209 | train_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.train", label=label) 210 | valid_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.dev", label=label) 211 | test_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.test", label=label) 212 | tasks.append({ 213 | "xs": [ 214 | [d['sentence'] for d in train_data if d['label'] == f"{label}-1"], 215 | [d['sentence'] for d in train_data if d['label'] == f"{label}1"], 216 | ], 217 | "x_valid": [ 218 | [d['sentence'] for d in valid_data if d['label'] == f"{label}-1"], 219 | [d['sentence'] for d in valid_data if d['label'] == f"{label}1"], 220 | ], 221 | "x_test": [ 222 | [d['sentence'] for d in test_data if d['label'] == f"{label}-1"], 223 | [d['sentence'] for d in test_data if d['label'] == f"{label}1"], 224 | ], 225 | }) 226 | 227 | assert all([len(task['xs'][0]) == len(task['xs'][1]) for task in tasks]) 228 | return tasks 229 | -------------------------------------------------------------------------------- /downstream/protaugment/utils/math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | 3 | import torch 4 | 5 | def euclidean_dist(x, y): 6 | # x: N x D 7 | # y: M x D 8 | n = x.size(0) 9 | m = y.size(0) 10 | d = x.size(1) 11 | assert d == y.size(1) 12 | 13 | x = x.unsqueeze(1).expand(n, m, d) 14 | y = y.unsqueeze(0).expand(n, m, d) 15 | 16 | return torch.pow(x - y, 2).sum(2) 17 | 18 | 19 | def cosine_similarity(x, y): 20 | x = (x / x.norm(dim=1).view(-1, 1)) 21 | y = (y / y.norm(dim=1).view(-1, 1)) 22 | 23 | return x @ y.T 24 | -------------------------------------------------------------------------------- /downstream/protaugment/utils/python.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | 3 | import datetime 4 | import random 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def now(): 10 | return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f") 11 | 12 | 13 | def set_seeds(seed: int) -> None: 14 | """ 15 | set random seeds 16 | :param seed: int 17 | :return: None 18 | """ 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | -------------------------------------------------------------------------------- /downstream/run_eval.py: -------------------------------------------------------------------------------- 1 | # Original Copyright (c) Microsoft Corporation. Licensed under the MIT license. 2 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | import sys, os 5 | 6 | import json 7 | import argparse 8 | from encoders.bert_encoder import BERTEncoder 9 | from encoders.sbert_encoder import SentEncoder 10 | from utils.dataloader import FewShotDataLoader 11 | from utils.train_utils import set_seeds, euclidean_dist, cosine_similarity 12 | import collections 13 | import os 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | from torch.nn import CrossEntropyLoss 19 | import warnings 20 | import logging 21 | 22 | logging.basicConfig() 23 | logger = logging.getLogger(__name__) 24 | logger.setLevel(logging.DEBUG) 25 | 26 | warnings.simplefilter('ignore') 27 | 28 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 29 | 30 | def load_label_file(label_fn): 31 | class2name_mp = {} 32 | replace_pair = [("lightdim", "light dim"), ("lightchange", "light change"), ("lightup", "light up"), 33 | ("commandstop", "command stop"), ("lighton", "light on"), ("dontcare", "don't care"), 34 | ("lightoff", "light off"), ("querycontact", "query contact"), ("addcontact", "add contact"), 35 | ("sendemail", "send email"), ("createoradd", "create or add"), ("qa", "what")] 36 | 37 | with open(label_fn, mode="r", encoding="utf-8") as fp: 38 | for line in fp: 39 | line = line.strip("\n") 40 | tmp = line.replace("_", " ").replace("/", " ") 41 | for x, y in replace_pair: 42 | tmp = tmp.replace(x, y) 43 | class2name_mp[line] = tmp 44 | return class2name_mp 45 | 46 | class ProtoNet(nn.Module): 47 | def __init__(self, encoder, metric="euclidean", label_fn=None, zero_shot=False): 48 | super(ProtoNet, self).__init__() 49 | self.encoder = encoder 50 | self.metric = metric 51 | assert self.metric in ('euclidean', 'cosine') 52 | self.class2name_mp = load_label_file(label_fn) if label_fn else None 53 | # self.role_labeler = IntentRoleLabeler.from_params(Params.from_file('config/tagger-intent-role-labeler.jsonnet')) 54 | self.zero_shot = zero_shot 55 | 56 | def eval_with_log(self, sample, classes, pooling): 57 | xs = sample['xs'] # support 58 | xq = sample['xq'] # query 59 | supports = [item["sentence"] for xs_ in xs for item in xs_] 60 | queries = [item["sentence"] for xq_ in xq for item in xq_] 61 | gold_labels = [classes[i] for i in range(len(xq)) for j in xq[i]] 62 | sp_gold_labels = [classes[i] for i in range(len(xs)) for j in xs[i]] 63 | loss, loss_dict = self.loss(sample, classes, pooling) 64 | logs = {"support": [], "query": [], "accuracy": loss_dict["metrics"]["acc"]} 65 | for i in range(len(supports)): 66 | logs["support"].append({"words": supports[i], "gold": sp_gold_labels[i]}) 67 | for i in range(len(queries)): 68 | logs["query"].append({"words": queries[i], "gold": gold_labels[i], "pred": classes[loss_dict["pred"][i]]}) 69 | return logs 70 | 71 | def loss(self, sample, classes=None, pooling='avg'): 72 | xs = sample['xs'] # support 73 | xq = sample['xq'] # query 74 | 75 | n_class = len(xs) 76 | assert len(xq) == n_class 77 | n_support = len(xs[0]) 78 | n_query = len(xq[0]) 79 | 80 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long() 81 | target_inds = Variable(target_inds, requires_grad=False).to(device) 82 | 83 | # When not using augmentations 84 | supports = [item["sentence"] for xs_ in xs for item in xs_] 85 | queries = [item["sentence"] for xq_ in xq for item in xq_] 86 | 87 | # Encode 88 | x = supports + queries 89 | z = self.encoder.embed_sentences(x) 90 | z_dim = z.size(-1) 91 | 92 | # Dispatch 93 | z_support = z[:len(supports)].view(n_class, n_support, z_dim) 94 | z_query = z[len(supports):len(supports) + len(queries)] 95 | 96 | if self.class2name_mp: 97 | class_names = [self.class2name_mp[classes[i]] for i in range(len(xs))] 98 | z_class = self.encoder.embed_sentences(class_names).view(n_class, 1, z_dim) 99 | z_support = torch.cat([z_support, z_class], dim=1) 100 | 101 | # TODO! need for refactoring 102 | if self.zero_shot: 103 | class_names = [self.class2name_mp[classes[i]] for i in range(len(xs))] 104 | z_class = self.encoder.embed_sentences(class_names).view(n_class, 1, z_dim) 105 | z_support = z_class 106 | 107 | if pooling == 'avg': 108 | z_support = z_support.mean(dim=[1]) 109 | else: 110 | assert pooling == 'nn' 111 | z_support = z_support.view(-1, z_dim) 112 | if self.metric == "euclidean": 113 | supervised_dists = euclidean_dist(z_query, z_support) 114 | elif self.metric == "cosine": 115 | supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5 116 | else: 117 | raise NotImplementedError 118 | 119 | if pooling == 'nn': 120 | supervised_dists = supervised_dists.view(len(queries), n_class, -1) 121 | # print(supervised_dists.shape) 122 | supervised_dists = torch.min(supervised_dists, 2)[0] 123 | # Supervised loss 124 | supervised_loss = CrossEntropyLoss()(-supervised_dists, target_inds.reshape(-1)) 125 | _, y_hat_supervised = (-supervised_dists).max(1) 126 | acc_val_supervised = torch.eq(y_hat_supervised, target_inds.reshape(-1)).float().mean() 127 | 128 | return supervised_loss, { 129 | "metrics": { 130 | "acc": acc_val_supervised.item(), 131 | "loss": supervised_loss.item(), 132 | }, 133 | "dists": supervised_dists, 134 | "target": target_inds.cpu().tolist(), 135 | "pred": y_hat_supervised.cpu().tolist() 136 | } 137 | 138 | def test_step(self, 139 | data_loader: FewShotDataLoader, 140 | n_support: int, 141 | n_query: int, 142 | n_classes: int, 143 | n_episodes: int = 1000): 144 | metrics = collections.defaultdict(list) 145 | 146 | self.eval() 147 | for i in range(n_episodes): 148 | episode, classes = data_loader.create_episode( 149 | n_support=n_support, 150 | n_query=n_query, 151 | n_classes=n_classes, 152 | ) 153 | with torch.no_grad(): 154 | loss, loss_dict = self.loss(episode) 155 | 156 | for k, v in loss_dict["metrics"].items(): 157 | metrics[k].append(v) 158 | 159 | return { 160 | key: np.mean(value) for key, value in metrics.items() 161 | } 162 | 163 | def test_proto( 164 | encoder: str, 165 | model_name_or_path: str, 166 | n_support: int, 167 | n_query: int, 168 | n_classes: int, 169 | test_path: str = None, 170 | output_path: str = './output', 171 | n_test_episodes: int = 600, 172 | metric: str = "euclidean", 173 | label_fn: str = None, 174 | zero_shot:bool = False, 175 | load_ckpt: bool = False, 176 | ckpt_path: str = None, 177 | pooling: str = 'avg' 178 | ): 179 | 180 | logs = [] 181 | # Load model 182 | if encoder == "bert": 183 | bert = BERTEncoder(model_name_or_path).to(device) 184 | elif encoder == "sentbert": 185 | bert = SentEncoder(model_name_or_path).to(device) 186 | else: 187 | raise ValueError("encoder name unk") 188 | 189 | protonet = ProtoNet(encoder=bert, 190 | metric=metric, 191 | label_fn=label_fn, 192 | zero_shot=zero_shot) 193 | if load_ckpt: 194 | protonet.load_state_dict(torch.load(ckpt_path)) 195 | 196 | if not os.path.exists(output_path): 197 | os.makedirs(output_path) 198 | 199 | # Load data 200 | valid_data_loader = FewShotDataLoader(test_path) 201 | logger.info(f"valid labels: {valid_data_loader.data_dict.keys()}") 202 | 203 | protonet.eval() 204 | for i in range(n_test_episodes): 205 | episode, classes = valid_data_loader.create_episode( 206 | n_support=n_support, 207 | n_query=n_query, 208 | n_classes=n_classes, 209 | ) 210 | with torch.no_grad(): 211 | logs.append(protonet.eval_with_log(episode, classes, pooling)) 212 | acc = [] 213 | with open(os.path.join(output_path, "logs.json"), mode="w", encoding="utf-8") as fp: 214 | for line in logs: 215 | fp.write(json.dumps(line) + "\n") 216 | fp.flush() 217 | acc.append(line["accuracy"]) 218 | avg_acc = np.mean(acc) 219 | with open(os.path.join(output_path, 'metrics.json'), "w") as file: 220 | json.dump({"accuracy": avg_acc}, file, ensure_ascii=False) 221 | return avg_acc 222 | 223 | def str2bool(arg): 224 | if arg.lower() == "true": 225 | return True 226 | return False 227 | 228 | def add_args(): 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data") 231 | parser.add_argument("--output-path", type=str, default=None, required=True) 232 | 233 | parser.add_argument("--encoder", type=str, default="bert") 234 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use") 235 | parser.add_argument("--load_ckpt", type=str2bool, default=False) 236 | parser.add_argument("--ckpt_path", default=None) 237 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 238 | 239 | # Few-Shot related stuff 240 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 241 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class") 242 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 243 | parser.add_argument("--n-test-episodes", type=int, default=600, help="Number of episodes during evaluation (valid, test)") 244 | 245 | #currently it is only for test... 246 | parser.add_argument("--label_fn", type=str, default=None) 247 | parser.add_argument("--zero_shot", action='store_true') 248 | parser.add_argument("--cv", default='01,02,03,04,05') 249 | parser.add_argument("--draft", action='store_true') 250 | 251 | # Metric to use in proto distance calculation 252 | parser.add_argument("--metric", type=str, default="euclidean", help="Metric to use", choices=("euclidean", "cosine")) 253 | parser.add_argument("--pooling", type=str, default="avg", help="Metric to use", choices=("avg", "nn")) 254 | 255 | 256 | args = parser.parse_args() 257 | 258 | if args.load_ckpt and args.ckpt_path == None: 259 | raise NotImplementedError 260 | 261 | return args 262 | 263 | def read_results(output_dir): 264 | for model_name in ["paraphrase-distilroberta-base-v2-euc-wotrain-label", 265 | "nli-roberta-base-v2-euc-wotrain-label", 266 | "simcse-nli-euc-wotrain-label", 267 | "declutr-base-euc-wotrain-label", 268 | "sp-paraphrase-cos-wotrain-label" 269 | ]: 270 | for K in [1, 5]: 271 | acc_list = [] 272 | for split in [1, 2, 3, 4, 5]: 273 | res_fn = os.path.join(output_dir, f"0{split}/proto-5way{K}shot-{model_name}/metrics.json") 274 | with open(res_fn, mode="r", encoding="utf-8") as fp: 275 | acc = json.load(fp)["accuracy"] 276 | acc_list.append(acc) 277 | mu = np.mean(acc_list) * 100 278 | std = np.std(acc_list) * 100 279 | print("model {}, {} shot : {} +- {}".format(model_name, K, mu, std)) 280 | return 281 | 282 | def main(args): 283 | test_path = args.test_path 284 | ckpt_path = args.ckpt_path 285 | output_path = args.output_path 286 | 287 | if args.n_support == 0: 288 | args.n_support = 1 289 | args.zero_shot = True 290 | 291 | accs = [] 292 | if args.draft: 293 | args.n_test_episodes = 1 294 | 295 | splits = args.cv.split(",") 296 | for split in splits: 297 | set_seeds(args.seed) 298 | if ckpt_path is not None: 299 | args.ckpt_path = ckpt_path + f"/{split}/encoder.pkl" 300 | args.test_path = test_path + f"/{split}/test.jsonl" 301 | args.output_path = output_path + f"/{split}" 302 | 303 | acc = test_proto( 304 | encoder=args.encoder, 305 | model_name_or_path=args.model_name_or_path, 306 | n_support=args.n_support, 307 | n_query=args.n_query, 308 | n_classes=args.n_classes, 309 | n_test_episodes=args.n_test_episodes, 310 | output_path=args.output_path, 311 | metric=args.metric, 312 | test_path=args.test_path, 313 | label_fn=args.label_fn, 314 | zero_shot=args.zero_shot, 315 | load_ckpt=args.load_ckpt, 316 | ckpt_path=args.ckpt_path, 317 | pooling=args.pooling, 318 | ) 319 | #Save config 320 | with open(os.path.join(args.output_path, "config.json"), "w") as file: 321 | json.dump(vars(args), file, ensure_ascii=False) 322 | 323 | accs.append(acc) 324 | print(acc) 325 | 326 | acc = np.mean(accs) * 100 327 | std = np.std(accs) * 100 328 | 329 | print(acc, std) 330 | 331 | if __name__ == '__main__': 332 | args = add_args() 333 | 334 | main(args) 335 | -------------------------------------------------------------------------------- /downstream/run_protaugment.py: -------------------------------------------------------------------------------- 1 | # Original Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | import logging 5 | 6 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 7 | 8 | import json 9 | import argparse 10 | 11 | from encoders.bert_encoder import BERTEncoder 12 | from encoders.sbert_encoder import SentEncoder 13 | 14 | from protaugment.paraphrase.utils.data import FewShotDataset, FewShotSSLFileDataset 15 | from protaugment.utils.python import now, set_seeds 16 | import collections 17 | import os 18 | from typing import Callable, Union 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | from torch.autograd import Variable 23 | import warnings 24 | import logging 25 | import copy 26 | from protaugment.utils.math import euclidean_dist, cosine_similarity 27 | logger = logging.getLogger(__name__) 28 | logger.setLevel(logging.DEBUG) 29 | 30 | warnings.simplefilter('ignore') 31 | 32 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 33 | 34 | def load_label_file(label_fn): 35 | class2name_mp = {} 36 | replace_pair = [("lightdim", "light dim"), ("lightchange", "light change"), ("lightup", "light up"), 37 | ("commandstop", "command stop"), ("lighton", "light on"), ("dontcare", "don't care"), 38 | ("lightoff", "light off"), ("querycontact", "query contact"), ("addcontact", "add contact"), 39 | ("sendemail", "send email"), ("createoradd", "create or add"), ("qa", "what")] 40 | 41 | with open(label_fn, mode="r", encoding="utf-8") as fp: 42 | for line in fp: 43 | line = line.strip("\n") 44 | tmp = line.replace("_", " ").replace("/", " ") 45 | for x, y in replace_pair: 46 | tmp = tmp.replace(x, y) 47 | class2name_mp[line] = tmp 48 | return class2name_mp 49 | 50 | class ProtAugmentNet(nn.Module): 51 | def __init__(self, encoder, metric="euclidean", label_fn=None, zero_shot=False): 52 | super(ProtAugmentNet, self).__init__() 53 | 54 | self.encoder = encoder 55 | self.metric = metric 56 | self.class2name_mp = load_label_file(label_fn) if label_fn else None 57 | self.zero_shot = zero_shot 58 | assert self.metric in ('euclidean', 'cosine') 59 | 60 | def loss(self, sample, supervised_loss_share: float = 0, classes=None): 61 | """ 62 | :param supervised_loss_share: share of supervised loss in total loss 63 | :param sample: { 64 | "xs": [ 65 | [support_A_1, support_A_2, ...], 66 | [support_B_1, support_B_2, ...], 67 | [support_C_1, support_C_2, ...], 68 | ... 69 | ], 70 | "xq": [ 71 | [query_A_1, query_A_2, ...], 72 | [query_B_1, query_B_2, ...], 73 | [query_C_1, query_C_2, ...], 74 | ... 75 | ], 76 | "x_augment":[ 77 | { 78 | "src_text": str, 79 | "tgt_texts: List[str] 80 | }, . 81 | ] 82 | } 83 | :return: 84 | """ 85 | xs = sample['xs'] # support 86 | xq = sample['xq'] # query 87 | 88 | n_class = len(xs) 89 | assert len(xq) == n_class 90 | n_support = len(xs[0]) 91 | n_query = len(xq[0]) 92 | 93 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long() 94 | target_inds = Variable(target_inds, requires_grad=False).to(device) 95 | 96 | # x_augment is not always present in `sample` 97 | # Indeed, at evaluation / test time, the network is judged on a regular meta-learning episode (i.e. only samples and query points) 98 | has_augment = "x_augment" in sample 99 | 100 | if has_augment: 101 | augmentations = sample["x_augment"] 102 | 103 | # expection: n_augmentations_samples == n_classes 104 | # actual: n_augmentations_samples == n_unlabeled 105 | n_augmentations_samples = len(sample["x_augment"]) 106 | n_augmentations_per_sample = [len(item['tgt_texts']) for item in augmentations] 107 | assert len(set(n_augmentations_per_sample)) == 1 108 | n_augmentations_per_sample = n_augmentations_per_sample[0] 109 | 110 | supports = [item["sentence"] for xs_ in xs for item in xs_] 111 | queries = [item["sentence"] for xq_ in xq for item in xq_] 112 | augmentations_supports = [[item2 for item2 in item["tgt_texts"]] for item in sample["x_augment"]] 113 | augmentation_queries = [item["src_text"] for item in sample["x_augment"]] 114 | 115 | # Encode 116 | x = supports + queries + [item2 for item1 in augmentations_supports for item2 in item1] + augmentation_queries 117 | z = self.encoder.embed_sentences(x) 118 | z_dim = z.size(-1) 119 | 120 | # Dispatch 121 | z_support = z[:len(supports)].view(n_class, n_support, z_dim) 122 | z_query = z[len(supports):len(supports) + len(queries)] 123 | z_aug_support = (z[len(supports) + len(queries):len(supports) + len(queries) + n_augmentations_per_sample * n_augmentations_samples] 124 | .view(n_augmentations_samples, n_augmentations_per_sample, z_dim).mean(dim=[1])) 125 | z_aug_query = z[-len(augmentation_queries):] 126 | else: 127 | # When not using augmentations 128 | supports = [item["sentence"] for xs_ in xs for item in xs_] 129 | queries = [item["sentence"] for xq_ in xq for item in xq_] 130 | 131 | # Encode 132 | x = supports + queries 133 | z = self.encoder.embed_sentences(x) 134 | z_dim = z.size(-1) 135 | 136 | # Dispatch 137 | z_support = z[:len(supports)].view(n_class, n_support, z_dim) 138 | z_query = z[len(supports):len(supports) + len(queries)] 139 | 140 | if self.class2name_mp: 141 | class_names = [self.class2name_mp[classes[i]] for i in range(len(xs))] 142 | z_class = self.encoder.embed_sentences(class_names).view(n_class, 1, z_dim) 143 | if self.zero_shot: 144 | z_support = z_class 145 | else: 146 | z_support = torch.cat([z_support, z_class], dim=1) 147 | 148 | # avg pooling 149 | z_support = z_support.mean(dim=[1]) 150 | if self.metric == "euclidean": 151 | supervised_dists = euclidean_dist(z_query, z_support) 152 | if has_augment: 153 | unsupervised_dists = euclidean_dist(z_aug_query, z_aug_support) 154 | elif self.metric == "cosine": 155 | supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5 156 | if has_augment: 157 | unsupervised_dists = (-cosine_similarity(z_aug_query, z_aug_support) + 1) * 5 158 | else: 159 | raise NotImplementedError 160 | 161 | from torch.nn import CrossEntropyLoss 162 | supervised_loss = CrossEntropyLoss()(-supervised_dists, target_inds.reshape(-1)) 163 | _, y_hat_supervised = (-supervised_dists).max(1) 164 | acc_val_supervised = torch.eq(y_hat_supervised, target_inds.reshape(-1)).float().mean() 165 | 166 | if has_augment: 167 | # Unsupervised loss 168 | unsupervised_target_inds = torch.range(0, n_augmentations_samples - 1).to(device).long() 169 | unsupervised_loss = CrossEntropyLoss()(-unsupervised_dists, unsupervised_target_inds) 170 | _, y_hat_unsupervised = (-unsupervised_dists).max(1) 171 | acc_val_unsupervised = torch.eq(y_hat_unsupervised, unsupervised_target_inds.reshape(-1)).float().mean() 172 | 173 | # Final loss 174 | assert 0 <= supervised_loss_share <= 1 175 | final_loss = (supervised_loss_share) * supervised_loss + (1 - supervised_loss_share) * unsupervised_loss 176 | 177 | return final_loss, { 178 | "metrics": { 179 | "supervised_acc": acc_val_supervised.item(), 180 | "unsupervised_acc": acc_val_unsupervised.item(), 181 | "supervised_loss": supervised_loss.item(), 182 | "unsupervised_loss": unsupervised_loss.item(), 183 | "supervised_loss_share": supervised_loss_share, 184 | "final_loss": final_loss.item(), 185 | }, 186 | "supervised_dists": supervised_dists, 187 | "unsupervised_dists": unsupervised_dists, 188 | "target": target_inds 189 | } 190 | 191 | return supervised_loss, { 192 | "metrics": { 193 | "acc": acc_val_supervised.item(), 194 | "loss": supervised_loss.item(), 195 | }, 196 | "dists": supervised_dists, 197 | "target": target_inds 198 | } 199 | 200 | def train_step(self, optimizer, episode, supervised_loss_share: float, classes): 201 | self.train() 202 | optimizer.zero_grad() 203 | torch.cuda.empty_cache() 204 | loss, loss_dict = self.loss(episode, supervised_loss_share=supervised_loss_share, classes=classes) 205 | loss.backward() 206 | optimizer.step() 207 | 208 | return loss, loss_dict 209 | 210 | def test_step(self, dataset: FewShotDataset, n_episodes: int = 1000): 211 | metrics = collections.defaultdict(list) 212 | 213 | self.eval() 214 | for i in range(n_episodes): 215 | episode, classes = dataset.get_episode() 216 | 217 | with torch.no_grad(): 218 | loss, loss_dict = self.loss(episode, supervised_loss_share=1, classes=classes) 219 | 220 | for k, v in loss_dict["metrics"].items(): 221 | metrics[k].append(v) 222 | 223 | return { 224 | key: np.mean(value) for key, value in metrics.items() 225 | } 226 | 227 | 228 | def run_protaugment( 229 | # Compulsory! 230 | data_path: str, 231 | train_labels_path: str, 232 | model_name_or_path: str, 233 | 234 | # Few-shot Stuff 235 | n_support: int, 236 | n_query: int, 237 | n_classes: int, 238 | metric: str = "euclidean", 239 | 240 | # Optional path to augmented data 241 | unlabeled_path: str = None, 242 | 243 | # Path training data ONLY (optional) 244 | train_path: str = None, 245 | 246 | # Validation & test 247 | valid_labels_path: str = None, 248 | test_labels_path: str = None, 249 | evaluate_every: int = 100, 250 | n_test_episodes: int = 1000, 251 | 252 | # Logging & Saving 253 | output_path: str = f'runs/{now()}', 254 | log_every: int = 10, 255 | lr: float = 2e-5, 256 | 257 | # Training stuff 258 | max_iter: int = 10000, 259 | early_stop: int = None, 260 | 261 | # Augmentation & paraphrase 262 | n_unlabeled: int = 5, 263 | paraphrase_model_name_or_path: str = None, 264 | paraphrase_tokenizer_name_or_path: str = None, 265 | paraphrase_num_beams: int = None, 266 | paraphrase_beam_group_size: int = None, 267 | paraphrase_diversity_penalty: float = None, 268 | paraphrase_filtering_strategy: str = None, 269 | paraphrase_drop_strategy: str = None, 270 | paraphrase_drop_chance_speed: str = None, 271 | paraphrase_drop_chance_auc: float = None, 272 | supervised_loss_share_fn: Callable[[int, int], float] = lambda x, y: 1 - (x / y), 273 | 274 | paraphrase_generation_method: str = None, 275 | 276 | augmentation_data_path: str = None, 277 | 278 | encoder: str = 'bert', 279 | draft: bool = False, 280 | zero_shot: bool = False, 281 | label_fn: str = None, 282 | ): 283 | # -------------------- 284 | # Creating Log Writers 285 | # -------------------- 286 | os.makedirs(output_path,exist_ok=True) 287 | log_dict = dict(train=list()) 288 | 289 | # ---------- 290 | # Load model 291 | # ---------- 292 | if encoder == 'bert': 293 | bert = BERTEncoder(model_name_or_path).to(device) 294 | elif encoder == 'sentbert': 295 | bert = SentEncoder(model_name_or_path).to(device) 296 | else: 297 | raise NotImplementedError 298 | 299 | protonet: ProtAugmentNet = ProtAugmentNet(encoder=bert, metric=metric, label_fn=label_fn, zero_shot=zero_shot) 300 | optimizer = torch.optim.Adam(protonet.parameters(), lr=lr) 301 | 302 | # ------------------ 303 | # Load Train Dataset 304 | # ------------------ 305 | if augmentation_data_path: 306 | # If an augmentation data path is provided, uses those pre-generated augmentations 307 | train_dataset = FewShotSSLFileDataset( 308 | data_path=train_path if train_path else data_path, 309 | labels_path=train_labels_path, 310 | n_classes=n_classes, 311 | n_support=n_support, 312 | n_query=n_query, 313 | n_unlabeled=n_unlabeled, 314 | unlabeled_file_path=augmentation_data_path, 315 | ) 316 | 317 | else: 318 | raise NotImplementedError 319 | logger.info(f"Train dataset has {len(train_dataset)} items") 320 | 321 | # --------- 322 | # Load data 323 | # --------- 324 | logger.info(f"train labels: {train_dataset.data.keys()}") 325 | valid_dataset: FewShotDataset = None 326 | if valid_labels_path: 327 | log_dict["valid"] = list() 328 | valid_dataset = FewShotDataset(data_path=data_path, labels_path=valid_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query) 329 | logger.info(f"valid labels: {valid_dataset.data.keys()}") 330 | assert len(set(valid_dataset.data.keys()) & set(train_dataset.data.keys())) == 0 331 | 332 | test_dataset: FewShotDataset = None 333 | if test_labels_path: 334 | log_dict["test"] = list() 335 | test_dataset = FewShotDataset(data_path=data_path, labels_path=test_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query) 336 | logger.info(f"test labels: {test_dataset.data.keys()}") 337 | assert len(set(test_dataset.data.keys()) & set(train_dataset.data.keys())) == 0 338 | 339 | train_metrics = collections.defaultdict(list) 340 | n_eval_since_last_best = 0 341 | best_valid_acc = 0.0 342 | best_protonet = copy.deepcopy(protonet) 343 | 344 | for step in range(max_iter): 345 | episode, classes = train_dataset.get_episode() 346 | 347 | supervised_loss_share = supervised_loss_share_fn(step, max_iter) 348 | loss, loss_dict = protonet.train_step(optimizer=optimizer, episode=episode, supervised_loss_share=supervised_loss_share, classes=classes) 349 | 350 | for key, value in loss_dict["metrics"].items(): 351 | train_metrics[key].append(value) 352 | 353 | # Logging 354 | if (step + 1) % log_every == 0: 355 | # TODO! logging 356 | # for key, value in train_metrics.items(): 357 | 358 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()])) 359 | train_metrics = collections.defaultdict(list) 360 | 361 | if valid_labels_path or test_labels_path: 362 | if (step + 1) % evaluate_every == 0: 363 | for labels_path, set_type, set_dataset in zip( 364 | [valid_labels_path, test_labels_path], 365 | ["valid", "test"], 366 | [valid_dataset, test_dataset] 367 | ): 368 | if set_dataset: 369 | 370 | set_results = protonet.test_step( 371 | dataset=set_dataset, 372 | n_episodes=n_test_episodes 373 | ) 374 | 375 | # TODO! logging 376 | # for key, val in set_results.items(): 377 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()])) 378 | if set_type == "valid": 379 | if set_results["acc"] > best_valid_acc: 380 | best_valid_acc = set_results["acc"] 381 | best_protonet = copy.deepcopy(protonet) 382 | n_eval_since_last_best = 0 383 | logger.info(f"Better eval results!") 384 | # TODO! logging 385 | else: 386 | n_eval_since_last_best += 1 387 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})") 388 | 389 | if draft: 390 | break 391 | 392 | if early_stop and n_eval_since_last_best >= early_stop: 393 | logger.warning(f"Early-stopping.") 394 | break 395 | 396 | # save model in encoder.pkl 397 | with open(os.path.join(output_path,'encoder.pkl'), 'wb') as handle: 398 | torch.save(best_protonet.state_dict(), handle) 399 | print(os.path.join(output_path,'encoder.pkl')) 400 | 401 | def main(args): 402 | logger.debug(f"Received args: {json.dumps(args.__dict__, sort_keys=True, ensure_ascii=False, indent=1)}") 403 | 404 | # # Check if data path(s) exist 405 | # for arg in [args.data_path, args.train_labels_path, args.valid_labels_path, args.test_labels_path]: 406 | # if arg and not os.path.exists(arg): 407 | # raise FileNotFoundError(f"Data @ {arg} not found.") 408 | 409 | # Create supervised_loss_share_fn 410 | def get_supervised_loss_share_fn(supervised_loss_share_power: Union[int, float]) -> Callable[[int, int], float]: 411 | def _supervised_loss_share_fn(current_step: int, max_steps: int) -> float: 412 | assert current_step <= max_steps 413 | return 1 - (current_step / max_steps) ** supervised_loss_share_power 414 | 415 | return _supervised_loss_share_fn 416 | 417 | supervised_loss_share_fn = get_supervised_loss_share_fn(args.supervised_loss_share_power) 418 | 419 | if args.n_support == 0: 420 | args.n_support = 1 421 | args.zero_shot = True 422 | 423 | orig_args = copy.deepcopy(args) 424 | cvs = args.cv.split(",") 425 | # for cv in ['01', '02', '03', '04', '05']: 426 | for cv in cvs: 427 | args.train_labels_path = os.path.join(orig_args.train_labels_path,cv,'labels.train.txt') 428 | args.valid_labels_path = os.path.join(orig_args.valid_labels_path,cv,'labels.valid.txt') 429 | args.test_labels_path = os.path.join(orig_args.test_labels_path,cv,'labels.test.txt') 430 | args.output_path = os.path.join(orig_args.output_path, cv) 431 | 432 | set_seeds(args.seed) 433 | # Run 434 | run_protaugment( 435 | data_path=args.data_path, 436 | train_labels_path=args.train_labels_path, 437 | train_path=args.train_path, 438 | model_name_or_path=args.model_name_or_path, 439 | n_support=args.n_support, 440 | n_query=args.n_query, 441 | n_classes=args.n_classes, 442 | metric=args.metric, 443 | 444 | valid_labels_path=args.valid_labels_path, 445 | test_labels_path=args.test_labels_path, 446 | evaluate_every=args.evaluate_every, 447 | n_test_episodes=args.n_test_episodes, 448 | 449 | output_path=args.output_path, 450 | log_every=args.log_every, 451 | max_iter=args.max_iter, 452 | early_stop=args.early_stop, 453 | 454 | unlabeled_path=args.unlabeled_path, 455 | n_unlabeled=args.n_unlabeled, 456 | 457 | # Paraphrase generation model 458 | paraphrase_model_name_or_path=args.paraphrase_model_name_or_path, 459 | paraphrase_tokenizer_name_or_path=args.paraphrase_tokenizer_name_or_path, 460 | paraphrase_num_beams=args.paraphrase_num_beams, 461 | paraphrase_beam_group_size=args.paraphrase_beam_group_size, 462 | paraphrase_filtering_strategy=args.paraphrase_filtering_strategy, 463 | paraphrase_drop_strategy=args.paraphrase_drop_strategy, 464 | paraphrase_drop_chance_speed=args.paraphrase_drop_chance_speed, 465 | paraphrase_drop_chance_auc=args.paraphrase_drop_chance_auc, 466 | supervised_loss_share_fn=supervised_loss_share_fn, 467 | 468 | # Other paraphrase generation method 469 | paraphrase_generation_method=args.paraphrase_generation_method, 470 | 471 | # Or just path to augmented data 472 | augmentation_data_path=args.augmentation_data_path, 473 | 474 | encoder=args.encoder, 475 | draft=args.draft, 476 | lr=args.lr, 477 | zero_shot=args.zero_shot, 478 | label_fn=args.label_fn 479 | ) 480 | 481 | if __name__ == '__main__': 482 | parser = argparse.ArgumentParser() 483 | parser.add_argument("--data-path", type=str, required=True, help="Path to the full data") 484 | parser.add_argument("--train-labels-path", type=str, required=True, help="Path to train labels. This file contains unique names of labels (i.e. one row per label)") 485 | parser.add_argument("--train-path", type=str, help="Path to training data (if provided, picks training data from this path instead of --data-path") 486 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Language Model PROTAUGMENT initializes from") 487 | 488 | # Few-Shot related stuff 489 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 490 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class") 491 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 492 | parser.add_argument("--metric", type=str, default="euclidean", help="Distance function to use", choices=("euclidean", "cosine")) 493 | 494 | # Validation & test 495 | parser.add_argument("--valid-labels-path", type=str, required=True, help="Path to valid labels. This file contains unique names of labels (i.e. one row per label)") 496 | parser.add_argument("--test-labels-path", type=str, required=True, help="Path to test labels. This file contains unique names of labels (i.e. one row per label)") 497 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)") 498 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)") 499 | 500 | # Logging & Saving 501 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}') 502 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging") 503 | 504 | # Training stuff 505 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes") 506 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled") 507 | 508 | # Augmentation & Paraphrase 509 | parser.add_argument("--unlabeled-path", type=str, help="Path to raw data (one sentence per line), to generate paraphrases from.") 510 | parser.add_argument("--n-unlabeled", type=int, help="Number of rows to draw from `--unlabeled-path` at each episode", default=5) 511 | 512 | # If you are using a paraphrase generation model 513 | parser.add_argument("--paraphrase-model-name-or-path", type=str, help="Name or path to the paraphrase model") 514 | parser.add_argument("--paraphrase-tokenizer-name-or-path", type=str, help="Name or path to the paraphrase model's tokenizer") 515 | parser.add_argument("--paraphrase-num-beams", type=int, help="Total number of beams in the Beam Search algorithm") 516 | parser.add_argument("--paraphrase-beam-group-size", type=int, help="Size of each group of beams") 517 | parser.add_argument("--paraphrase-diversity-penalty", type=float, help="Diversity penalty (float) to use in Diverse Beam Search") 518 | parser.add_argument("--paraphrase-filtering-strategy", type=str, choices=["bleu", "clustering"], help="Filtering strategy to apply to a group of generated paraphrases to choose the one to pick. `bleu` takes the sentence which has the highest bleu_score w/r to the original sentence.") 519 | parser.add_argument("--paraphrase-drop-strategy", type=str, choices=["bigram", "unigram"], help="Drop strategy to use to contraint the paraphrase generation. If not set, no words are forbidden.") 520 | parser.add_argument("--paraphrase-drop-chance-speed", type=str, choices=["flat", "down", "up"], help="Curve of drop probability depending on token position in the sentence") 521 | parser.add_argument("--paraphrase-drop-chance-auc", type=float, help="Area of the drop chance probability w/r to the position in the sentence. When --paraphrase-drop-chance-speed=flat (same chance for all tokens to be forbidden no matter the position in the sentence), this parameter equals to p_{mask}") 522 | 523 | # If you want to use another augmentation technique, e.g. EDA (https://github.com/jasonwei20/eda_nlp/) 524 | parser.add_argument("--paraphrase-generation-method", type=str, choices=["eda"]) 525 | 526 | # Augmentation file path (optional, but if provided it will be used) 527 | parser.add_argument("--augmentation-data-path", type=str, help="Path to a .jsonl file containing augmentations. Refer to `back-translation.jsonl` for an example") 528 | 529 | # Seed 530 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 531 | 532 | # Supervised loss share 533 | parser.add_argument("--supervised-loss-share-power", default=1.0, type=float, help="supervised_loss_share = 1 - (x/y) ** ") 534 | 535 | parser.add_argument("--encoder", type=str, default="bert", help="Metric to use", choices=("bert", "sentbert")) 536 | parser.add_argument("--draft", action="store_true") 537 | parser.add_argument("--lr", type=float, default=2e-5) 538 | parser.add_argument("--zero-shot", action='store_true') 539 | parser.add_argument("--label_fn", type=str) 540 | parser.add_argument("--cv", type=str, default='01,02,03,04,05') 541 | 542 | args = parser.parse_args() 543 | 544 | main(args) 545 | -------------------------------------------------------------------------------- /downstream/scripts/run_eval_after_ft.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=../data/downstream 2 | OUTPUT_DIR=../output/eval 3 | MODEL_PATH=../models/iae_model 4 | CKPT_DIR=../output/finetuned_iae_model 5 | dataset=OOS 6 | n_classes=50 # N-way (50-way) (e.g. 5) 7 | n_support=1 # K-shot (1-shot) 8 | ckpt_n_classes=5 # assume that ckpt was created in 5-way 9 | CKPT_NAME=protaugment_${dataset}_${ckpt_n_classes}_${n_support} 10 | OUTPUT_NAME=iae_protaugment_${dataset}_${n_classes}_${n_support} 11 | python run_eval.py \ 12 | --test-path ${DATA_DIR}/${dataset}/few_shot \ 13 | --output-path ${OUTPUT_DIR}/${dataset}/${OUTPUT_NAME} \ 14 | --encoder sentbert \ 15 | --model-name-or-path ${MODEL_PATH} \ 16 | --load_ckpt True \ 17 | --ckpt_path ${CKPT_DIR}/protaugment/${CKPT_NAME} \ 18 | --n-test-episodes 6 --n-support ${n_support} --n-classes ${n_classes} --n-query 5 --metric euclidean --pooling avg \ 19 | --label_fn ${DATA_DIR}/${dataset}/labels.txt -------------------------------------------------------------------------------- /downstream/scripts/run_eval_before_ft.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=../data/downstream 2 | OUTPUT_DIR=../output/eval 3 | MODEL_PATH=../models/iae_model 4 | dataset=OOS 5 | n_classes=50 # N-way (50-way) 6 | n_support=1 # K-shot (1-shot) 7 | OUTPUT_NAME=iae_${dataset}_${n_classes}_${n_support} 8 | python run_eval.py \ 9 | --test-path ${DATA_DIR}/${dataset}/few_shot \ 10 | --output-path ${OUTPUT_DIR}/${dataset}/${OUTPUT_NAME} \ 11 | --encoder sentbert \ 12 | --model-name-or-path ${MODEL_PATH} \ 13 | --load_ckpt False \ 14 | --n-test-episodes 600 --n-support ${n_support} --n-classes ${n_classes} --n-query 5 --metric euclidean --pooling avg \ 15 | --label_fn ${DATA_DIR}/${dataset}/labels.txt -------------------------------------------------------------------------------- /downstream/scripts/run_protaugment.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=../data/downstream 2 | OUTPUT_DIR=../output/finetuned_iae_model 3 | MODEL_PATH=../models/iae_model 4 | dataset=OOS 5 | n_classes=5 # N-way 6 | n_support=1 # K-shot 7 | MODEL_NAME=protaugment_${dataset}_${n_classes}_${n_support} 8 | python run_protaugment.py \ 9 | --data-path ${DATA_DIR}/${dataset}/full.jsonl \ 10 | --train-labels-path ${DATA_DIR}/${dataset}/few_shot \ 11 | --valid-labels-path ${DATA_DIR}/${dataset}/few_shot \ 12 | --test-labels-path ${DATA_DIR}/${dataset}/few_shot \ 13 | --unlabeled-path ${DATA_DIR}/${dataset}/raw.txt \ 14 | --n-support ${n_support} \ 15 | --n-query 5 \ 16 | --n-classes ${n_classes} \ 17 | --evaluate-every 100 \ 18 | --n-test-episodes 600 \ 19 | --max-iter 10 \ 20 | --early-stop 20 \ 21 | --log-every 10 \ 22 | --seed 42 \ 23 | --n-unlabeled 5 \ 24 | --augmentation-data-path ${DATA_DIR}/${dataset}/paraphrases/DBS-unigram-flat-1.0/paraphrases.jsonl \ 25 | --metric euclidean \ 26 | --lr 1e-6 \ 27 | --encoder sentbert \ 28 | --supervised-loss-share-power 1 \ 29 | --model-name-or-path ${MODEL_PATH} \ 30 | --output-path ${OUTPUT_DIR}/protaugment/${MODEL_NAME} \ 31 | --label_fn ${DATA_DIR}/${dataset}/labels.txt -------------------------------------------------------------------------------- /downstream/scripts/run_protonet.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=../data/downstream 2 | OUTPUT_DIR=../output/finetuned_iae_model 3 | MODEL_PATH=../models/iae_model 4 | dataset=OOS 5 | n_classes=5 # N-way 6 | n_support=1 # K-shot 7 | MODEL_NAME=iae_${dataset}_${n_classes}_${n_support} 8 | python run_protonet.py \ 9 | --train-path ${DATA_DIR}/${dataset}/few_shot \ 10 | --valid-path ${DATA_DIR}/${dataset}/few_shot \ 11 | --test-path ${DATA_DIR}/${dataset}/few_shot \ 12 | --model-name-or-path ${MODEL_PATH} \ 13 | --n-support ${n_support} \ 14 | --n-query 5 \ 15 | --n-classes ${n_classes} \ 16 | --n-augment 0 \ 17 | --lr 1e-6 \ 18 | --encoder sentbert \ 19 | --evaluate-every 100 \ 20 | --n-test-episodes 600 \ 21 | --max-iter 10000 \ 22 | --early-stop 20 \ 23 | --log-every 10 \ 24 | --seed 42 \ 25 | --metric euclidean \ 26 | --output-path ${OUTPUT_DIR}/protonet/${MODEL_NAME} \ 27 | --label_fn ${DATA_DIR}/${dataset}/labels.txt -------------------------------------------------------------------------------- /downstream/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/intent-aware-encoder/e4dd281a7310659adeeb807c1f3ef8140a27013c/downstream/utils/__init__.py -------------------------------------------------------------------------------- /downstream/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Thomas Dopierre. Licensed under Apache License 2.0 2 | 3 | import numpy as np 4 | import random 5 | import collections 6 | import os 7 | import json 8 | from typing import List, Dict 9 | 10 | def get_jsonl_data(jsonl_path: str): 11 | assert jsonl_path.endswith(".jsonl") 12 | out = list() 13 | with open(jsonl_path, 'r', encoding="utf-8") as file: 14 | for line in file: 15 | j = json.loads(line.strip()) 16 | out.append(j) 17 | return out 18 | 19 | def write_jsonl_data(jsonl_data: List[Dict], jsonl_path: str, force=False): 20 | if os.path.exists(jsonl_path) and not force: 21 | raise FileExistsError 22 | with open(jsonl_path, 'w') as file: 23 | for line in jsonl_data: 24 | file.write(json.dumps(line, ensure_ascii=False) + '\n') 25 | 26 | def raw_data_to_dict(data, shuffle=True): 27 | labels_dict = collections.defaultdict(list) 28 | for item in data: 29 | labels_dict[item['label']].append(item) 30 | labels_dict = dict(labels_dict) 31 | if shuffle: 32 | for key, val in labels_dict.items(): 33 | random.shuffle(val) 34 | print(list(labels_dict.keys())) 35 | return labels_dict 36 | 37 | class FewShotDataLoader: 38 | def __init__(self, file_path): 39 | self.raw_data = get_jsonl_data(file_path) 40 | self.data_dict = raw_data_to_dict(self.raw_data, shuffle=True) 41 | 42 | def create_episode(self, n_support: int = 0, n_classes: int = 0, n_query: int = 0): 43 | episode = dict() 44 | if n_classes: 45 | n_classes = min(n_classes, len(self.data_dict.keys())) 46 | rand_keys = np.random.choice(list(self.data_dict.keys()), n_classes, replace=False) 47 | 48 | while min([len(self.data_dict[k]) for k in rand_keys]) < n_support + n_query: 49 | rand_keys = np.random.choice(list(self.data_dict.keys()), n_classes, replace=False) 50 | # assert min([len(val) for val in self.data_dict.values()]) >= n_support + n_query + n_unlabeled 51 | 52 | for key, val in self.data_dict.items(): 53 | random.shuffle(val) 54 | 55 | if n_support: 56 | episode["xs"] = [[self.data_dict[k][i] for i in range(n_support)] for k in rand_keys] 57 | if n_query: 58 | episode["xq"] = [[self.data_dict[k][n_support + i] for i in range(n_query)] for k in rand_keys] 59 | 60 | return episode, rand_keys 61 | -------------------------------------------------------------------------------- /downstream/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. Licensed under the MIT license. 2 | 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | def set_seeds(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | torch.backends.cudnn.deterministic = True 15 | return 16 | 17 | def euclidean_dist(x, y): 18 | x = x.unsqueeze(1) 19 | y = y.unsqueeze(0) 20 | return torch.pow(x - y, 2).sum(2) 21 | 22 | 23 | def cosine_similarity(x, y): 24 | x = F.normalize(x, dim=-1) 25 | y = F.normalize(y, dim=-1) 26 | sim = torch.matmul(x, y.transpose(1, 0)) 27 | return sim 28 | -------------------------------------------------------------------------------- /pretraining/README.md: -------------------------------------------------------------------------------- 1 | # Pre-training IAE Model 2 | 3 | This directory is about how to preprocess pre-training datasets, how to pre-train an encoder using the pre-training datasets, and how to validate the trained model. 4 | 5 | ## Prerequisites 6 | To pre-process data for pre-training, you will first need to prepare an intent role labeling model. 7 | Instructions for this can be found in [README.md](/data/README.md). 8 | 9 | ## Pre-training Dataset Creation 10 | 11 | To create pre-training dataset, run the following script. 12 | 13 | ``` 14 | bash ./scripts/create_dataset.sh 15 | ``` 16 | 17 | ## Pretraining 18 | 19 | To pre-train an encoder, run the following script. 20 | ``` 21 | bash ./scripts/run_pretrain.sh 22 | ``` 23 | 24 | ## Evaluation 25 | 26 | To evaluate the pre-trained encoder, run the following script. 27 | ``` 28 | bash ./scripts/run_eval.sh 29 | ``` -------------------------------------------------------------------------------- /pretraining/iae/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import ContrastiveLearningDataset 2 | from .contrastive_learning import ContrastiveLearningPairwise 3 | from .iae_model import IAEModel 4 | -------------------------------------------------------------------------------- /pretraining/iae/contrastive_learning.py: -------------------------------------------------------------------------------- 1 | # Original Copyright (c) 2021 Cambridge Language Technology Lab. Licensed under the MIT License. 2 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | from torch.cuda.amp import autocast 8 | from pytorch_metric_learning import losses 9 | from transformers import ( 10 | AdamW, 11 | ) 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | class ContrastiveLearningPairwise(nn.Module): 16 | def __init__(self, encoder, learning_rate, weight_decay, use_cuda=True, \ 17 | agg_mode="cls", infoNCE_tau="0.04"): 18 | 19 | LOGGER.info(f"ContrastiveLearningPairwise! learning_rate={learning_rate} weight_decay={weight_decay} " \ 20 | f"agg_mode={agg_mode} infoNCE_tau={infoNCE_tau}") 21 | super(ContrastiveLearningPairwise, self).__init__() 22 | self.encoder = encoder 23 | self.learning_rate = learning_rate 24 | self.weight_decay = weight_decay 25 | self.use_cuda = use_cuda 26 | self.agg_mode = agg_mode 27 | self.optimizer = AdamW([{'params': self.encoder.parameters()},], 28 | lr=self.learning_rate, weight_decay=self.weight_decay 29 | ) 30 | self.infoNCE_tau = infoNCE_tau # sentence & phrase: 0.04, word: 0.2 # The MoCo paper uses 0.07, while SimCLR uses 0.5. 31 | self.loss = losses.NTXentLoss(temperature=self.infoNCE_tau) 32 | 33 | print ("loss:", self.loss) 34 | 35 | @autocast() 36 | def forward(self, query_toks1, query_toks2, labels=None): 37 | outputs1 = self.encoder(**query_toks1, return_dict=True, output_hidden_states=False) 38 | outputs2 = self.encoder(**query_toks2, return_dict=True, output_hidden_states=False) 39 | last_hidden_state1 = outputs1.last_hidden_state 40 | last_hidden_state2 = outputs2.last_hidden_state 41 | 42 | if self.agg_mode=="cls": 43 | query_embed1 = last_hidden_state1[:,0] 44 | query_embed2 = last_hidden_state2[:,0] 45 | elif self.agg_mode == "mean": # include padded tokens 46 | query_embed1 = last_hidden_state1.mean(1) 47 | query_embed2 = last_hidden_state2.mean(1) 48 | elif self.agg_mode == "mean_std": 49 | query_embed1 = (last_hidden_state1 * query_toks1['attention_mask'].unsqueeze(-1)).sum(1) / query_toks1['attention_mask'].sum(-1).unsqueeze(-1) 50 | query_embed2 = (last_hidden_state2 * query_toks2['attention_mask'].unsqueeze(-1)).sum(1) / query_toks2['attention_mask'].sum(-1).unsqueeze(-1) 51 | else: 52 | raise NotImplementedError() 53 | 54 | query_embed = torch.cat([query_embed1, query_embed2], dim=0) 55 | 56 | if labels is None: 57 | labels = torch.arange(query_embed1.size(0)) 58 | labels = torch.cat([labels, labels], dim=0) 59 | 60 | return self.loss(query_embed, labels) -------------------------------------------------------------------------------- /pretraining/iae/data_loader.py: -------------------------------------------------------------------------------- 1 | # Original Copyright (c) 2021 Cambridge Language Technology Lab. Licensed under the MIT License. 2 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | import numpy as np 5 | import random 6 | from torch.utils.data import Dataset 7 | import logging 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | def erase_and_mask(s, tokenizer, mask_len=5): 12 | """ 13 | Randomly replace a span in input s with "[MASK]". 14 | """ 15 | if len(s) <= mask_len: return s 16 | if len(s) < 30: return s # if too short, no span masking 17 | ind = np.random.randint(len(s)-mask_len) 18 | left, right = s.split(s[ind:ind+mask_len], 1) 19 | return " ".join([left, tokenizer.mask_token, right]) 20 | 21 | # 2022-09-09: Amazon modification 22 | class ContrastiveLearningDataset(Dataset): 23 | def __init__(self, path, tokenizer, random_span_mask=0, pairwise=False, triplewise=False, masking_strat='opt1', draft=False): 24 | with open(path, 'r') as f: 25 | lines = f.readlines() 26 | self.sent_pairs = [] 27 | self.pairwise = pairwise 28 | self.intent2label = {} 29 | self.utterance2label = {} 30 | 31 | for line in lines: 32 | line = line.rstrip("\n") 33 | try: 34 | utterance, intent, irl = line.split("||") 35 | except: 36 | continue 37 | 38 | if intent not in self.intent2label: 39 | self.intent2label[intent] = len(self.intent2label) 40 | self.utterance2label[utterance] = self.intent2label[intent] 41 | self.sent_pairs.append((utterance, intent, irl)) 42 | 43 | if draft: 44 | self.sent_pairs = self.sent_pairs[:1000] 45 | 46 | self.tokenizer = tokenizer 47 | self.random_span_mask = random_span_mask 48 | self.masking_strat = masking_strat 49 | 50 | self.intent2utterances = {} 51 | for utterance, intent, irl in self.sent_pairs: 52 | if intent not in self.intent2utterances: 53 | self.intent2utterances[intent] = [] 54 | 55 | self.intent2utterances[intent].append(utterance) 56 | 57 | def __getitem__(self, idx): 58 | # batch_x1: input utterance 59 | # batch_x2: gold intent 60 | # batch_x3: gold utterance 61 | # batch_x4: pseudo intent 62 | 63 | utterance = self.sent_pairs[idx][0] 64 | gold_intent = self.sent_pairs[idx][1] 65 | pseudo_intent = self.sent_pairs[idx][2] 66 | 67 | # gold_utterances: utternaces with the same gold intent as the input utterance 68 | gold_utterances = [u for u in self.intent2utterances[gold_intent] if u != utterance] 69 | 70 | if len(gold_utterances)>0: 71 | gold_utterance = random.sample(gold_utterances,k = 1)[0] 72 | else: # random masking 73 | gold_utterance = erase_and_mask(utterance, self.tokenizer, mask_len=int(self.random_span_mask)) 74 | 75 | if idx < 5: 76 | print(f"{idx},input utterance=",utterance) 77 | print(f"{idx},gold intent=",gold_intent) 78 | print(f"{idx},gold utterance=",gold_utterance) 79 | print(f"{idx},pseudo intent=",pseudo_intent) 80 | 81 | return utterance, gold_intent, gold_utterance, pseudo_intent 82 | 83 | def __len__(self): 84 | assert (len(self.sent_pairs) !=0) 85 | return len(self.sent_pairs) 86 | # End of Amazon modification -------------------------------------------------------------------------------- /pretraining/iae/drophead.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Kirill Kravtsov. Licensed under the MIT License. 2 | 3 | import torch 4 | from transformers import BertModel, RobertaModel, XLMRobertaModel 5 | 6 | 7 | VALID_CLS = (BertModel, RobertaModel, XLMRobertaModel) 8 | 9 | 10 | def _drophead_hook(module, input, output): 11 | """ 12 | Pytorch forward hook for transformers.modeling_bert.BertSelfAttention layer 13 | """ 14 | if (not module.training) or (module.p_drophead==0): 15 | return output 16 | 17 | orig_shape = output[0].shape 18 | dist = torch.distributions.Bernoulli(torch.tensor([1-module.p_drophead])) 19 | mask = dist.sample((orig_shape[0], module.num_attention_heads)) 20 | mask = mask.to(output[0].device).unsqueeze(-1) 21 | count_ones = mask.sum(dim=1).unsqueeze(-1) # calc num of active heads 22 | 23 | self_att_out = module.transpose_for_scores(output[0]) 24 | self_att_out = self_att_out * mask * module.num_attention_heads / count_ones 25 | self_att_out = self_att_out.permute(0, 2, 1, 3).view(*orig_shape) 26 | return (self_att_out,) + output[1:] 27 | 28 | 29 | def valid_type(obj): 30 | return isinstance(obj, VALID_CLS) 31 | 32 | 33 | def get_base_model(model): 34 | """ 35 | Check model type. If correct then return the model itself. 36 | If not correct then try to find in attributes and return correct type 37 | attribute if found 38 | """ 39 | if not valid_type(model): 40 | attrs = [name for name in dir(model) if valid_type(getattr(model, name))] 41 | if len(attrs) == 0: 42 | raise ValueError("Please provide valid model") 43 | model = getattr(model, attrs[0]) 44 | return model 45 | 46 | 47 | def set_drophead(model, p=0.1): 48 | """ 49 | Adds drophead to model. Works inplace. 50 | Args: 51 | model: an instance of transformers.BertModel / transformers.RobertaModel / 52 | transformers.XLMRobertaModel or downstream model (e.g. transformers.BertForSequenceClassification) 53 | or any custom downstream model 54 | p: drophead probability 55 | """ 56 | if (p < 0) or (p > 1): 57 | raise ValueError("Wrong p argument") 58 | 59 | model = get_base_model(model) 60 | 61 | for bert_layer in model.encoder.layer: 62 | if not hasattr(bert_layer.attention.self, "p_drophead"): 63 | bert_layer.attention.self.register_forward_hook(_drophead_hook) 64 | bert_layer.attention.self.p_drophead = p 65 | -------------------------------------------------------------------------------- /pretraining/iae/iae_model.py: -------------------------------------------------------------------------------- 1 | # Original Copyright (c) 2021 Cambridge Language Technology Lab. Licensed under the MIT License. 2 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | import logging 5 | import torch 6 | from tqdm import tqdm 7 | from transformers import ( 8 | AutoTokenizer, 9 | AutoModel, 10 | ) 11 | 12 | from .contrastive_learning import * 13 | 14 | LOGGER = logging.getLogger() 15 | 16 | 17 | class IAEModel(object): 18 | """ 19 | Wrapper class for IAEModel 20 | """ 21 | 22 | def __init__(self): 23 | self.tokenizer = None 24 | self.encoder = None 25 | 26 | def get_encoder(self): 27 | assert (self.encoder is not None) 28 | 29 | return self.encoder 30 | 31 | def get_tokenizer(self): 32 | assert (self.tokenizer is not None) 33 | 34 | return self.tokenizer 35 | 36 | def save_model(self, path, context=False): 37 | # save bert model, bert config 38 | self.encoder.save_pretrained(path) 39 | 40 | # save bert vocab 41 | self.tokenizer.save_pretrained(path) 42 | 43 | def load_model(self, path, max_length=50, lowercase=True, 44 | use_cuda=True, return_model=False): 45 | 46 | self.tokenizer = AutoTokenizer.from_pretrained(path, 47 | use_fast=True, do_lower_case=lowercase) 48 | self.encoder = AutoModel.from_pretrained(path) 49 | if use_cuda: 50 | self.encoder = self.encoder.cuda() 51 | if not return_model: 52 | return 53 | return self.encoder, self.tokenizer 54 | 55 | def encode(self, sentences, max_length=50, agg_mode="cls"): 56 | sent_toks = self.tokenizer.batch_encode_plus( 57 | list(sentences), 58 | max_length=max_length, 59 | padding="max_length", 60 | truncation=True, 61 | add_special_tokens=True, 62 | return_tensors="pt") 63 | sent_toks_cuda = {} 64 | for k,v in sent_toks.items(): 65 | sent_toks_cuda[k] = v.cuda() 66 | with torch.no_grad(): 67 | outputs = self.encoder(**sent_toks_cuda, return_dict=True, output_hidden_states=False) 68 | last_hidden_state = outputs.last_hidden_state 69 | 70 | if agg_mode=="cls": 71 | query_embed = last_hidden_state[:,0] 72 | elif agg_mode == "mean": # including padded tokens 73 | # query_embed = last_hidden_state.mean(1) 74 | query_embed = last_hidden_state * sent_toks_cuda['attention_mask'].unsqueeze(-1) 75 | query_embed = query_embed.sum(1) / sent_toks_cuda['attention_mask'].sum(1).unsqueeze(-1) 76 | elif agg_mode == "mean_std": 77 | query_embed = (last_hidden_state * query_toks['attention_mask'].unsqueeze(-1)).sum(1) / query_toks['attention_mask'].sum(-1).unsqueeze(-1) 78 | else: 79 | raise NotImplementedError() 80 | return query_embed 81 | 82 | def get_embeddings(self, sentences, batch_size=1024, max_length=50, agg_mode="cls"): 83 | """ 84 | Compute embeddings from a list of sentence. 85 | """ 86 | embedding_table = [] 87 | with torch.no_grad(): 88 | for start in tqdm(range(0, len(sentences), batch_size)): 89 | end = min(start + batch_size, len(sentences)) 90 | batch = sentences[start:end] 91 | batch_embedding = self.encode(batch, max_length=max_length, agg_mode=agg_mode) 92 | batch_embedding = batch_embedding.cpu() 93 | embedding_table.append(batch_embedding) 94 | embedding_table = torch.cat(embedding_table, dim=0) 95 | return embedding_table 96 | 97 | 98 | -------------------------------------------------------------------------------- /pretraining/preprocess/create_pretrain_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from allennlp.common import Params 4 | import random 5 | from utils import construct_irl_text 6 | from irl.intent_role_labelers import IntentRoleLabeler 7 | import argparse 8 | from dataclasses import asdict 9 | from load_data import load_top, load_dstc11t2, load_sgd, load_multiwoz 10 | 11 | def process_irl(utterances, intents, role_labeler): 12 | irls = role_labeler.label_batch(utterances) 13 | 14 | output = [] 15 | new_utterances = {} 16 | new_intents = {} 17 | new_irl_texts = {} 18 | for utterance, intent, irl in zip(utterances, intents, irls): 19 | irl = asdict(irl) # irl to dict 20 | if utterance.lower() in new_utterances: # deduplicate 21 | continue 22 | 23 | irl_text, has_irl = construct_irl_text(irl) 24 | if has_irl: 25 | output.append('||'.join([utterance, intent, irl_text])) 26 | new_utterances[utterance.lower()] = "" 27 | new_intents[intent.lower()] = "" 28 | new_irl_texts[irl_text.lower()] = "" 29 | 30 | print("len(new_utterances):", len(new_utterances)) 31 | print("len(new_intents):", len(new_intents)) 32 | print("len(new_irl_texts):", len(new_irl_texts)) 33 | 34 | return output 35 | 36 | def save_data(output, path): 37 | # stat 38 | utterances = {} 39 | intents ={} 40 | irl_texts ={} 41 | for o in output: 42 | utterance, intent, irl = o.split("||") 43 | utterances[utterance.lower()] = "" 44 | intents[intent.lower()] = "" 45 | irl_texts[irl.lower()] = "" 46 | 47 | print("total len(utterances):", len(utterances)) 48 | print("total len(intents):", len(intents)) 49 | print("total len(irl_texts):", len(irl_texts)) 50 | 51 | random.shuffle(output) 52 | with open(path, 'w') as f: 53 | for line in output: 54 | f.write(line + "\n") 55 | print(path) 56 | 57 | def main(args): 58 | # load irl model 59 | role_labeler = IntentRoleLabeler.from_params(Params({ 60 | 'type': 'tagger_based_intent_role_labeler', 61 | 'model_path': args.irl_model_path, 62 | 'cuda_device': 0 63 | })) 64 | 65 | # preprocess train dataset 66 | train_data = [] 67 | utterances, intents = load_top(1000, top1_data_dir=args.top1_dir, top2_data_dir=args.top2_dir) 68 | train_data += process_irl(utterances, intents, role_labeler) 69 | utterances, intents = load_dstc11t2(1000, data_dir=args.dstc11t2_dir) 70 | train_data += process_irl(utterances, intents, role_labeler) 71 | utterances, intents = load_sgd(100, single_sent=True, data_dir=args.sgd_dir) 72 | train_data += process_irl(utterances, intents, role_labeler) 73 | train_data = list(set(train_data)) # deduplicate 74 | 75 | # preprocess validation dataset 76 | val_data = [] 77 | utterances, intents = load_multiwoz(100, data_dir=args.multiwoz_dir) 78 | val_data += process_irl(utterances, intents, role_labeler) 79 | 80 | # save 81 | os.makedirs(args.output_dir, exist_ok = True) 82 | train_path = os.path.join(args.output_dir, 'train.txt') 83 | val_path = os.path.join(args.output_dir, 'val.txt') 84 | save_data(train_data, train_path) 85 | save_data(val_data, val_path) 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('--irl_model_path', type=str, required=True) 90 | parser.add_argument('--top1_dir', type=str, required=True) 91 | parser.add_argument('--top2_dir', type=str, required=True) 92 | parser.add_argument('--dstc11t2_dir', type=str, required=True) 93 | parser.add_argument('--sgd_dir', type=str, required=True) 94 | parser.add_argument('--multiwoz_dir', type=str, required=True) 95 | parser.add_argument('--output_dir', type=str, required=True) 96 | 97 | args = parser.parse_args() 98 | 99 | print(args) 100 | main(args) 101 | 102 | -------------------------------------------------------------------------------- /pretraining/preprocess/irl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/intent-aware-encoder/e4dd281a7310659adeeb807c1f3ef8140a27013c/pretraining/preprocess/irl/__init__.py -------------------------------------------------------------------------------- /pretraining/preprocess/irl/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import List, Dict 5 | 6 | 7 | @dataclass(frozen=True, eq=True) 8 | class LabeledSpan: 9 | label: str 10 | start: int 11 | exclusive_end: int 12 | text: str 13 | 14 | @staticmethod 15 | def from_dict(json_dict: Dict): 16 | return LabeledSpan( 17 | json_dict['label'], 18 | json_dict['start'], 19 | json_dict['exclusive_end'], 20 | json_dict['text'] 21 | ) 22 | 23 | 24 | @dataclass(frozen=True, eq=True) 25 | class Frame: 26 | spans: List[LabeledSpan] 27 | 28 | @staticmethod 29 | def from_dict(json_dict: Dict): 30 | return Frame([LabeledSpan.from_dict(span) for span in json_dict['spans']]) 31 | 32 | 33 | @dataclass 34 | class LabeledUtterance: 35 | text: str 36 | frames: List[Frame] 37 | 38 | @staticmethod 39 | def from_dict(json_dict: Dict): 40 | return LabeledUtterance(json_dict['text'], [Frame.from_dict(frame) for frame in json_dict['frames']]) 41 | 42 | @staticmethod 43 | def read_lines(path: Path): 44 | with path.open() as lines: 45 | lines = [line.strip() for line in lines if line.strip()] 46 | result = [] 47 | for line in lines: 48 | result.append(LabeledUtterance.from_dict(json.loads(line))) 49 | return result -------------------------------------------------------------------------------- /pretraining/preprocess/irl/intent_role_labelers.py: -------------------------------------------------------------------------------- 1 | """ 2 | I want to find and reserve a room. 3 | I want to reserve a flight and a hotel. 4 | """ 5 | from collections import defaultdict 6 | from dataclasses import replace 7 | from typing import List, Dict, Tuple 8 | 9 | from allennlp.common import Registrable 10 | from spacy.tokens import Span, Doc 11 | 12 | from irl.data import LabeledUtterance, Frame 13 | from irl.irl_tagger import load_irl_model 14 | 15 | 16 | class IntentRoleLabeler(Registrable): 17 | def label(self, utterance: str) -> LabeledUtterance: 18 | """ 19 | Predict intent role labels for a single input utterance. 20 | :param utterance: input utterance 21 | :return: predicted role labels 22 | """ 23 | return self.label_batch([utterance])[0] 24 | 25 | def label_batch(self, utterances: List[str]) -> List[LabeledUtterance]: 26 | """ 27 | Predict intent role labels for a batch of input utterances. 28 | :param utterances: batch of input utterances 29 | :return: predicted role labels for each input utterance 30 | """ 31 | raise NotImplementedError 32 | 33 | 34 | @IntentRoleLabeler.register('tagger_based_intent_role_labeler') 35 | class TaggerBasedIntentRoleLabeler(IntentRoleLabeler): 36 | def __init__( 37 | self, 38 | model_path: str, 39 | spacy_model: str = 'en_core_web_md', 40 | cuda_device: int = 0 41 | ) -> None: 42 | super().__init__() 43 | self._predictor = load_irl_model(model_path, cuda_device=cuda_device) 44 | from spacy import load 45 | self._nlp = load(spacy_model) 46 | 47 | @staticmethod 48 | def _convert_to_labeled_utterance( 49 | predictions: List[LabeledUtterance], 50 | doc: Doc 51 | ) -> LabeledUtterance: 52 | frames = [] 53 | for prediction, sent in zip(predictions, doc.sents): 54 | for prop in prediction.frames: 55 | spans = [ 56 | replace( 57 | span, 58 | start=span.start + sent.start_char, 59 | exclusive_end=span.exclusive_end + sent.start_char 60 | ) for span in prop.spans] 61 | if not spans: 62 | continue 63 | frames.append(Frame(spans)) 64 | return LabeledUtterance(doc.text, frames) 65 | 66 | @staticmethod 67 | def convert_to_labeled_utterances( 68 | predictions: List[LabeledUtterance], 69 | inputs: List[Span], 70 | sent_idx_to_utterance_idx: Dict[int, int], 71 | parsed: List[Doc], 72 | ): 73 | predictions_by_utterance = defaultdict(list) 74 | for i, (prediction, parse) in enumerate(zip(predictions, inputs)): 75 | utterance_idx = sent_idx_to_utterance_idx[i] 76 | predictions_by_utterance[utterance_idx].append(prediction) 77 | 78 | result = [ 79 | TaggerBasedIntentRoleLabeler._convert_to_labeled_utterance(predictions, parsed[idx]) 80 | for idx, predictions in predictions_by_utterance.items() 81 | ] 82 | return result 83 | 84 | def label_batch(self, utterances: List[str]) -> List[LabeledUtterance]: 85 | # split into sentences and parse 86 | sent_idx_to_utterance_idx, inputs, parsed = parse_utterances(utterances, self._nlp) 87 | 88 | sentences = [sent.text for sent in inputs] 89 | predictions = [] 90 | for utterance, spans in zip(sentences, self._predictor.tag_batch(sentences)): 91 | predictions.append( 92 | LabeledUtterance( 93 | utterance, 94 | [Frame(spans)] 95 | ) 96 | ) 97 | 98 | # regroup predictions by utterance 99 | result = self.convert_to_labeled_utterances( 100 | predictions, 101 | inputs, 102 | sent_idx_to_utterance_idx, 103 | parsed, 104 | ) 105 | return result 106 | 107 | 108 | def parse_utterances( 109 | utterances: List[str], nlp, batch_size: int = 64 110 | ) -> Tuple[Dict[int, int], List[Span], List[Doc]]: 111 | sent_idx_to_utterance_idx = {} 112 | inputs = [] 113 | parsed = list(nlp.pipe(utterances, batch_size=batch_size, n_process=1, disable=["ner"])) 114 | for i, parse in enumerate(parsed): 115 | for _ in parse.sents: 116 | sent_idx_to_utterance_idx[len(sent_idx_to_utterance_idx)] = i 117 | inputs.extend(parse.sents) 118 | return sent_idx_to_utterance_idx, inputs, parsed 119 | -------------------------------------------------------------------------------- /pretraining/preprocess/irl/irl_tagger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predictors applied to dialogs/dialog turns, providing facades over AllenNLP models for inference. 3 | 4 | Essentially a lot of boilerplate code to load AllenNLP predictor wrapper. 5 | """ 6 | import logging 7 | from typing import Dict, Iterable, List, Tuple, Any 8 | 9 | import torch 10 | from allennlp.common import JsonDict, Registrable 11 | from allennlp.common.util import lazy_groups_of 12 | from allennlp.data import DatasetReader, Instance, Vocabulary, TextFieldTensors, Token, Tokenizer, TokenIndexer 13 | from allennlp.data.fields import SequenceLabelField, TextField, MetadataField 14 | from allennlp.data.token_indexers import SingleIdTokenIndexer 15 | from allennlp.models import Model, SimpleTagger, load_archive 16 | from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder 17 | from allennlp.modules.conditional_random_field import allowed_transitions, ConditionalRandomField 18 | from allennlp.predictors import Predictor 19 | 20 | from irl.data import LabeledSpan 21 | from irl.utils import labels_to_spans 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class TaggerConstants: 27 | TAGS = 'tags' 28 | TOKENS = 'tokens' 29 | WORDS = 'words' 30 | TEXT = 'text' 31 | BEGIN_ = 'B-' 32 | IN_ = 'I-' 33 | OUT = 'O' 34 | 35 | 36 | @Predictor.register('tagger_predictor') 37 | class TaggerPredictor(Predictor): 38 | """ 39 | Sequence tagger of `DialogTurn`s, used as a facade over AllenNLP models. Provides conversion 40 | logic of IOB tags to spans. 41 | """ 42 | 43 | def __init__(self, 44 | model: Model, 45 | dataset_reader: DatasetReader, 46 | frozen: bool = True, 47 | batch_size: int = 128) -> None: 48 | """ 49 | Initialize a predictor used for inference on dialog turns. 50 | Args: 51 | model: AllenNLP model 52 | dataset_reader: dataset reader used for feature extraction 53 | frozen: whether model is frozen or not 54 | batch_size: batch size for inference 55 | """ 56 | super().__init__(model, dataset_reader, frozen) 57 | self._batch_size = batch_size 58 | 59 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 60 | dialog_turn = json_dict[TaggerConstants.TEXT] 61 | return self._dataset_reader.text_to_instance(dialog_turn, []) 62 | 63 | def predict_instance(self, instance: Instance) -> JsonDict: 64 | prediction = super().predict_instance(instance) 65 | return self._update_prediction(instance, prediction) 66 | 67 | def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]: 68 | results = [] 69 | for model_input_json, prediction in zip(instances, 70 | super().predict_batch_instance(instances)): 71 | results.append(self._update_prediction(model_input_json, prediction)) 72 | return results 73 | 74 | def tag_batch(self, dialog_turns: Iterable[str]) -> List[List[LabeledSpan]]: 75 | instances = self._batch_json_to_instances( 76 | [{TaggerConstants.TEXT: dialog_turn} for dialog_turn in dialog_turns] 77 | ) 78 | 79 | results = [] 80 | for batch_instance in lazy_groups_of(instances, self._batch_size): 81 | for model_input_json, prediction in zip(batch_instance, 82 | self.predict_batch_instance( 83 | batch_instance)): 84 | results.append(prediction) 85 | 86 | results = [self.prediction_to_spans(pred) for pred in results] 87 | return results 88 | 89 | @staticmethod 90 | def _update_prediction(instance: Instance, prediction: JsonDict) -> JsonDict: 91 | tags = prediction[TaggerConstants.TAGS] 92 | # noinspection PyUnresolvedReferences 93 | tokens = instance[TaggerConstants.TOKENS].tokens 94 | # noinspection PyUnresolvedReferences 95 | return { 96 | TaggerConstants.TOKENS: tokens, 97 | TaggerConstants.TAGS: tags[:len(tokens)], 98 | TaggerConstants.TEXT: instance.fields[TaggerConstants.TEXT].metadata 99 | } 100 | 101 | @staticmethod 102 | def prediction_to_spans(prediction: JsonDict) -> List[LabeledSpan]: 103 | """ 104 | Convert tagger predictions to `LabeledSpan` which consists of start and end 105 | offsets in original text. 106 | Args: 107 | prediction: predict output 108 | 109 | Returns: list of labeled spans 110 | 111 | """ 112 | tokens = prediction[TaggerConstants.TOKENS] 113 | tags = prediction[TaggerConstants.TAGS] 114 | labeled_spans = [] 115 | for name, start, end in labels_to_spans(tags): 116 | start_idx = tokens[start].idx 117 | end_idx = tokens[end - 1].idx_end 118 | text = prediction[TaggerConstants.TEXT][start_idx:end_idx] 119 | labeled_spans.append(LabeledSpan(name, start_idx, end_idx, text=text)) 120 | return labeled_spans 121 | 122 | 123 | @Model.register('turn_tagger') 124 | class TurnTagger(SimpleTagger): 125 | """ 126 | Override `SimpleTagger` mostly to add **kwargs to the forward method, allowing us 127 | to include the original UID for inputs in output without causing an assertion error. 128 | """ 129 | 130 | def __init__(self, 131 | vocab: Vocabulary, 132 | text_field_embedder: TextFieldEmbedder, 133 | encoder: Seq2SeqEncoder, 134 | label_encoding='BIO', 135 | viterbi_decoding=False, 136 | **kwargs) -> None: 137 | super().__init__(vocab, 138 | text_field_embedder, 139 | encoder, 140 | calculate_span_f1=True, 141 | label_encoding=label_encoding, 142 | **kwargs) 143 | if viterbi_decoding: 144 | logger.info(f'Initializing tagger with {label_encoding}' 145 | f'-constrained Viterbi decoding enabled') 146 | self.viterbi_decoding = viterbi_decoding 147 | constraints = allowed_transitions( 148 | constraint_type=label_encoding, 149 | labels=self.vocab.get_index_to_token_vocabulary(self.label_namespace) 150 | ) 151 | self._crf = ConditionalRandomField( 152 | num_tags=self.vocab.get_vocab_size(self.label_namespace), 153 | constraints=constraints, 154 | ) 155 | self._crf.transitions.requires_grad_(False).fill_(0) 156 | 157 | def forward( 158 | self, 159 | tokens: TextFieldTensors, 160 | tags: torch.LongTensor = None, 161 | **kwargs 162 | ) -> Dict[str, torch.Tensor]: 163 | output_dict = super().forward(tokens, 164 | tags, 165 | kwargs.get(TaggerConstants.WORDS)) 166 | return {**output_dict, **kwargs} 167 | 168 | def _decode( 169 | self, output_dict: Dict[str, Any] 170 | ) -> Dict[str, Any]: 171 | """ 172 | Perform IOB-constrained Viterbi decoding. 173 | """ 174 | viterbi_tags = self._crf.viterbi_tags( 175 | logits=output_dict['logits'], 176 | ) 177 | all_tags = [] 178 | for indices, score in viterbi_tags: 179 | tags = [ 180 | self.vocab.get_token_from_index(x, namespace=self.label_namespace) 181 | for x in indices 182 | ] 183 | all_tags.append(tags) 184 | output_dict[TaggerConstants.TAGS] = all_tags 185 | return output_dict 186 | 187 | def make_output_human_readable( 188 | self, 189 | output_dict: Dict[str, torch.Tensor] 190 | ) -> Dict[str, torch.Tensor]: 191 | readable_output = super().make_output_human_readable(output_dict) 192 | 193 | return { 194 | TaggerConstants.TAGS: readable_output[TaggerConstants.TAGS] 195 | } 196 | 197 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 198 | return super().get_metrics(reset) 199 | 200 | 201 | class TaggerPreprocessor(Registrable): 202 | default_implementation = 'transformer_tagger_preprocessor' 203 | 204 | def text_to_tokens_and_tags( 205 | self, 206 | text: str, 207 | spans: List[LabeledSpan] 208 | ) -> Tuple[List[Token], List[str]]: 209 | raise NotImplementedError 210 | 211 | 212 | @TaggerPreprocessor.register('transformer_tagger_preprocessor') 213 | class TransformerTaggerPreprocessor(TaggerPreprocessor): 214 | """ 215 | Tagger pre-processor that extracts labels and corresponding tokens for Transformer-based models. 216 | """ 217 | 218 | def __init__(self, tokenizer: Tokenizer) -> None: 219 | self.tokenizer = tokenizer 220 | 221 | def text_to_tokens_and_tags( 222 | self, 223 | text: str, 224 | spans: List[LabeledSpan] 225 | ) -> Tuple[List[Token], List[str]]: 226 | text = text.lower() 227 | utterance_tokens = self.tokenizer.tokenize(text=text) 228 | 229 | annotation_map = {} 230 | for span in spans: 231 | for i in range(span.start, span.exclusive_end): 232 | annotation_map[i] = span 233 | prev_annotation_start = -1 234 | labels = [] 235 | for tok in utterance_tokens: 236 | label, annotation_start = TaggerConstants.OUT, -1 237 | if tok.idx in annotation_map: 238 | span = annotation_map[tok.idx] 239 | annotation_start = span.start 240 | tag = (TaggerConstants.BEGIN_ if annotation_start != prev_annotation_start 241 | else TaggerConstants.IN_) 242 | label = f'{tag}{span.label}' 243 | # truncate tokens above the max length 244 | prev_annotation_start = annotation_start 245 | labels.append(label) 246 | 247 | return utterance_tokens, labels 248 | 249 | 250 | @DatasetReader.register('tagger_dataset_reader') 251 | class TaggerDatasetReader(DatasetReader): 252 | """ 253 | Dataset reader that extracts token-level labels from each turn in a `Dialog`. 254 | """ 255 | 256 | def __init__(self, 257 | tagger_preprocessor: TransformerTaggerPreprocessor, 258 | token_indexers: Dict[str, TokenIndexer] = None, 259 | **kwargs) -> None: 260 | super().__init__(**kwargs) 261 | self._tagger_preprocessor = tagger_preprocessor 262 | self._token_indexers = token_indexers or {TaggerConstants.TOKENS: SingleIdTokenIndexer()} 263 | 264 | def text_to_instance(self, turn: str, labels: List[LabeledSpan]): 265 | tokens, labels = self._tagger_preprocessor.text_to_tokens_and_tags(turn, labels) 266 | tokens_field = TextField(tokens=tokens, token_indexers=self._token_indexers) 267 | fields = { 268 | TaggerConstants.WORDS: MetadataField( 269 | { 270 | TaggerConstants.WORDS: [x.text for x in tokens] 271 | } 272 | ), 273 | TaggerConstants.TOKENS: tokens_field, 274 | TaggerConstants.TAGS: SequenceLabelField(labels, tokens_field), 275 | TaggerConstants.TEXT: MetadataField(turn) 276 | } 277 | 278 | return Instance(fields) 279 | 280 | def _read(self, file_path) -> Iterable[Instance]: 281 | with open(file_path) as lines: 282 | for line in lines: 283 | yield self.text_to_instance(line, []) 284 | 285 | 286 | def load_irl_model(model_path: str, override: Dict = None, cuda_device: int = 0) -> TaggerPredictor: 287 | """Load IRL model as TaggerPredictor""" 288 | if not override: 289 | override = {} 290 | archive = load_archive(model_path, cuda_device=cuda_device, overrides=override) 291 | return TaggerPredictor.from_archive(archive, predictor_name='tagger_predictor') 292 | -------------------------------------------------------------------------------- /pretraining/preprocess/irl/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Iterable 2 | 3 | 4 | class BioesConfig: 5 | BEGIN = "B" 6 | BEGIN_ = "B-" 7 | END = "E" 8 | END_ = "E-" 9 | SINGLE = "S" 10 | SINGLE_ = "S-" 11 | IN = "I" 12 | IN_ = "I-" 13 | OUT = "O" 14 | DELIM = "-" 15 | 16 | 17 | def labels_to_spans(labeling: Iterable[str]) -> List[Tuple[str, int, int]]: 18 | """ 19 | Given an IOB/BESIO chunking produce a list of labeled spans--triples of (label, start index, 20 | end index exclusive). 21 | >>> labels_to_spans(['O', 'B-PER', 'I-PER', 'O', 'B-ORG']) 22 | [('PER', 1, 3), ('ORG', 4, 5)] 23 | 24 | :param labeling: list of IOB/BESIO labels 25 | :return: list of spans 26 | """ 27 | 28 | def _start_of_chunk(curr): 29 | curr_tag, _ = _get_val_and_tag(curr) 30 | return curr_tag in {BioesConfig.SINGLE, BioesConfig.BEGIN} 31 | 32 | def _end_of_chunk(curr): 33 | curr_tag, _ = _get_val_and_tag(curr) 34 | return curr_tag in {BioesConfig.END, BioesConfig.SINGLE} 35 | 36 | besio = chunk(labeling, besio=True) 37 | 38 | result = [] 39 | curr_label, start = None, None 40 | for index, label in enumerate(besio): 41 | if _start_of_chunk(label): 42 | if curr_label: 43 | result.append((curr_label, start, index)) 44 | curr_label, start = _get_val_and_tag(label)[1], index 45 | if _end_of_chunk(label): 46 | result.append((curr_label, start, index + 1)) 47 | curr_label = None 48 | if curr_label: 49 | result.append((curr_label, start, len(besio))) 50 | 51 | return result 52 | 53 | 54 | def chunk(labeling: Iterable[str], besio=False) -> List[str]: 55 | """ 56 | Convert an IO/BIO/BESIO-formatted sequence of labels to BIO, BESIO, or CoNLL-2005 formatted. 57 | :param labeling: original labels 58 | :param besio: (optional) convert to BESIO format, `False` by default 59 | :return: converted labels 60 | """ 61 | result = [] 62 | prev_type = None 63 | curr = [] 64 | for label in labeling: 65 | if label == BioesConfig.OUT: 66 | state, chunk_type = BioesConfig.OUT, '' 67 | else: 68 | split_index = label.index(BioesConfig.DELIM) 69 | state, chunk_type = label[:split_index], label[split_index + 1:] 70 | if state == BioesConfig.IN and chunk_type != prev_type: # new chunk of different type 71 | state = BioesConfig.BEGIN 72 | if state in [BioesConfig.BEGIN, BioesConfig.OUT] and curr: # end of chunk 73 | result += _to_besio(curr) if besio else curr 74 | curr = [] 75 | if state == BioesConfig.OUT: 76 | result.append(state) 77 | else: 78 | curr.append(state + BioesConfig.DELIM + chunk_type) 79 | prev_type = chunk_type 80 | if curr: 81 | result += _to_besio(curr) if besio else curr 82 | return result 83 | 84 | 85 | def _to_besio(iob_labeling): 86 | if len(iob_labeling) == 1: 87 | return [BioesConfig.SINGLE + iob_labeling[0][1:]] 88 | return iob_labeling[:-1] + [BioesConfig.END + iob_labeling[-1][1:]] 89 | 90 | 91 | def _get_val_and_tag(label): 92 | if not label: 93 | return '', '' 94 | if label == BioesConfig.OUT: 95 | return label, '' 96 | return label.split(BioesConfig.DELIM, 1) 97 | -------------------------------------------------------------------------------- /pretraining/preprocess/load_data.py: -------------------------------------------------------------------------------- 1 | from utils import camel_terms, filter_text, process_text, sample_min_num 2 | from glob import glob 3 | import csv 4 | import json 5 | from tqdm import tqdm 6 | 7 | def load_top(min_num = 1000, single_sent=False, top1_data_dir="", top2_data_dir=""): 8 | print("load top") 9 | intent_correction = { 10 | 'gettimer': 'get timer' 11 | } 12 | 13 | paths = glob(f'{top1_data_dir}/*.tsv') 14 | paths2 = glob(f'{top2_data_dir}/*.tsv') 15 | paths += paths2 16 | intent2utterances = {} 17 | utterances = {} # to detect duplication 18 | for path in paths: 19 | with open(path) as f: 20 | reader = csv.reader(f, delimiter='\t', quotechar='\"') 21 | next(reader, None) 22 | for row in reader: 23 | _, utterance, tag = row 24 | utterance = utterance.strip() 25 | if not filter_text(utterance, single_sent): 26 | continue 27 | utterance = process_text(utterance) 28 | 29 | if tag.count('IN:') != 1: 30 | continue 31 | 32 | for t in tag.split(): 33 | if 'IN:' in t: 34 | intent = t[t.index("IN:")+3:].replace("_"," ").lower() 35 | break 36 | if 'unsupported' in intent or 'unintelligible' in intent: 37 | continue 38 | if intent in intent_correction: 39 | intent = intent_correction[intent] 40 | if intent not in intent2utterances: 41 | intent2utterances[intent] = [] 42 | 43 | # deduplicate 44 | if utterance in utterances: 45 | continue 46 | utterances[utterance] = "" 47 | 48 | intent2utterances[intent].append(utterance) 49 | return sample_min_num(intent2utterances, min_num) 50 | 51 | def load_dstc11t2(min_num = 1000, single_sent=False, data_dir=""): 52 | print("load dstc11t2") 53 | intent_correction = { 54 | 'getbankstatement': 'get bank statement', 55 | 'getloaninfo': 'get loan info', 56 | 'netincome': 'net income' 57 | } 58 | 59 | intent2utterances = {} 60 | utterances = {} # to detect duplication 61 | path = f'{data_dir}/dstc11t2-intent-data.tsv' 62 | with open(path) as f: 63 | reader = csv.reader(f, delimiter='\t', quotechar='\"') 64 | next(reader, None) 65 | for row in reader: 66 | source, intent, utterance = row 67 | utterance = utterance.strip() 68 | if not filter_text(utterance, single_sent): 69 | continue 70 | utterance = process_text(utterance) 71 | 72 | intent = ' '.join(camel_terms(intent)).lower() 73 | if 'faq' in intent: 74 | continue 75 | 76 | if intent in intent_correction: 77 | intent = intent_correction[intent] 78 | 79 | if intent not in intent2utterances: 80 | intent2utterances[intent] = [] 81 | 82 | # deduplicate 83 | if utterance in utterances: 84 | continue 85 | utterances[utterance] = "" 86 | 87 | intent2utterances[intent].append(utterance) 88 | 89 | return sample_min_num(intent2utterances, min_num) 90 | 91 | 92 | def load_sgd(min_num=100, single_sent=False, data_dir=""): 93 | intent2utterances = {} 94 | intent_names = {} 95 | utterances = {} # to detect duplication 96 | 97 | for datatype in ['train', 'dev', 'test']: 98 | paths = glob(f'{data_dir}/{datatype}/dialogues_*.json') 99 | for path in tqdm(paths): 100 | with open(path) as f: 101 | data = json.load(f) 102 | for dialogue in data: 103 | turns = dialogue['turns'] 104 | for turn in turns[:1]: # constraint: first turn 105 | utterance = turn['utterance'].strip() 106 | if not filter_text(utterance, single_sent): 107 | continue 108 | utterance = process_text(utterance) 109 | 110 | frames = turn['frames'] 111 | for frame in frames: 112 | actions = frame['actions'] 113 | for action in actions: 114 | slot = action['slot'] 115 | canonical_values = action['canonical_values'] 116 | if slot == 'intent': 117 | for canonical_value in canonical_values: 118 | if canonical_value in intent_names: 119 | intent = intent_names[canonical_value] 120 | else: 121 | intent = ' '.join(camel_terms(canonical_value)).lower() 122 | intent_names[canonical_value] = intent 123 | 124 | if intent not in intent2utterances: 125 | intent2utterances[intent] = [] 126 | 127 | # deduplicate 128 | if utterance in utterances: 129 | continue 130 | utterances[utterance] = "" 131 | 132 | intent2utterances[intent].append(utterance) 133 | return sample_min_num(intent2utterances, min_num) 134 | 135 | def load_multiwoz(min_num=100, data_dir=""): 136 | intent2utterances = {} 137 | utterances = {} # to detect duplication 138 | 139 | for datatype in ['train', 'dev', 'test']: 140 | paths = glob(f'{data_dir}/{datatype}/dialogues_*.json') 141 | for path in tqdm(paths): 142 | with open(path) as f: 143 | data = json.load(f) 144 | 145 | for dialogue in data: 146 | turns = dialogue['turns'] 147 | for turn in turns[:1]: # constraint: first turn 148 | utterance = turn['utterance'].strip() 149 | if not filter_text(utterance): 150 | continue 151 | utterance = process_text(utterance) 152 | 153 | frames = turn['frames'] 154 | for frame in frames: 155 | if 'state' not in frame: 156 | continue 157 | if not frame['slots']: 158 | continue 159 | intent = frame['state']['active_intent'] 160 | if intent == 'NONE': 161 | continue 162 | 163 | intent = intent.replace("_", " ") 164 | 165 | # deduplicate 166 | if utterance in utterances: 167 | continue 168 | utterances[utterance] = "" 169 | 170 | if intent not in intent2utterances: 171 | intent2utterances[intent] = [] 172 | intent2utterances[intent].append(utterance) 173 | return sample_min_num(intent2utterances, min_num) 174 | -------------------------------------------------------------------------------- /pretraining/preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from nltk.tokenize import sent_tokenize 3 | import html 4 | import random 5 | import nltk 6 | nltk.download('stopwords') 7 | from nltk.corpus import stopwords 8 | stops = set(stopwords.words('english')) 9 | 10 | emoji_pattern = re.compile("[" 11 | u"\U0001F600-\U0001F64F" # emoticons 12 | u"\U0001F300-\U0001F5FF" # symbols & pictographs 13 | u"\U0001F680-\U0001F6FF" # transport & map symbols 14 | u"\U0001F1E0-\U0001F1FF" # flags (iOS) 15 | "]+", flags=re.UNICODE) 16 | 17 | def camel_terms(value): 18 | return re.findall('[A-Z][a-z]+|[0-9A-Z]+(?=[A-Z][a-z])|[0-9A-Z]{2,}|[a-z0-9]{2,}|[a-zA-Z0-9]', value) 19 | 20 | def construct_irl_text(irl, const_action=False): 21 | # TODO: add ablation study logic (different IRL subsets) 22 | 23 | org_utterance = irl['text'] 24 | frames = irl['frames'] 25 | if len(frames) == 0: 26 | return org_utterance, False 27 | 28 | valid_spans = [] 29 | for frame in frames: 30 | spans = frame['spans'] 31 | # Heuristic 0: Set STOPWORDS 32 | for span in spans: 33 | text = span['text'] 34 | if text in stops: 35 | span['label'] = 'Stopword' 36 | 37 | # IRL Labels: Action, Argument, Request, Query, Slot, Problem 38 | labels = set([span['label'] for span in spans]) 39 | labels -= set(['Slot']) # to filter 40 | if len(labels) == 0: 41 | continue 42 | 43 | valid_spans += [{'text': span['text'], 'start': span['start']} for span in spans] 44 | 45 | if len(valid_spans) == 0: 46 | return org_utterance, False 47 | 48 | valid_spans = sorted(valid_spans, key = lambda x: x['start']) 49 | irl_text = ' '.join([i['text'] for i in valid_spans]) 50 | return irl_text, True 51 | 52 | def filter_text(text, single_sent=False): 53 | # start filtering criteria 54 | if len(text.split()) < 3: # text is shorter than 3 words 55 | return False # (also filters out [removed] or [deleted] comments) 56 | if len(text.split()) > 20: # text is paragraph-length or longer. probably not intentful utterance. 57 | return False 58 | if "http://" in text or "https://" in text or ".com " in text: # contains URL. these utterances tend to be messy/noisy 59 | return False 60 | if single_sent and len(sent_tokenize(text)) >1: 61 | return False 62 | return True 63 | 64 | def process_text(text): 65 | text = emoji_pattern.sub(r'', text) # remove emoji 66 | text = text.replace("\n", ' ') 67 | text = text.replace("|", ' ') 68 | text = html.unescape(text) 69 | 70 | return text 71 | 72 | def sample_min_num(intent2utterances, min_num): 73 | new_intent2utterances = {} 74 | for intent,utterances in intent2utterances.items(): 75 | utterances = list(set(utterances)) 76 | random.shuffle(utterances) 77 | new_intent2utterances[intent] = utterances[:min_num] 78 | intent2utterances = new_intent2utterances 79 | 80 | final_intents = [] 81 | final_utterances = [] 82 | for intent, utterances in intent2utterances.items(): 83 | for utt in utterances: 84 | final_utterances.append(utt) 85 | final_intents.append(intent) 86 | 87 | intent2counts = {intent:len(utterances) for intent,utterances in intent2utterances.items()} 88 | intent2counts = dict(sorted(intent2counts.items(), key=lambda item: item[1], reverse=True)) 89 | 90 | # print(intent2counts) 91 | print("len(intents):",len(intent2counts)) 92 | print("len(utterances):",len(final_utterances)) 93 | 94 | assert len(final_utterances) == len(final_intents) 95 | return final_utterances, final_intents 96 | -------------------------------------------------------------------------------- /pretraining/run_eval.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer, util 2 | import argparse 3 | import random 4 | random.seed(0) 5 | import logging 6 | import torch 7 | from utils import init_logging 8 | 9 | LOGGER = logging.getLogger() 10 | 11 | def load_data(test_path, draft=False): 12 | utterances = [] 13 | labels = [] 14 | 15 | with open(test_path) as f: 16 | intent2count = {} 17 | for row in f: 18 | utterance, label, _ = row.strip().split("||") 19 | 20 | if label not in intent2count: 21 | intent2count[label] = 0 22 | intent2count[label] += 1 23 | 24 | if draft and intent2count[label] > 3: 25 | continue 26 | 27 | utterances.append(utterance) 28 | labels.append(label) 29 | 30 | return utterances, labels 31 | 32 | def find_top1_intent_idxs(model, utterances, intent_embeddings, distance_metric='cosine'): 33 | utterance_embeddings = model.encode(utterances) 34 | # calc distance/score 35 | if distance_metric == 'cosine': 36 | consine_scores = util.cos_sim(utterance_embeddings, intent_embeddings) 37 | top1_intent_idxs = torch.argmax(consine_scores, dim=1) 38 | else: # euclidean 39 | raise NotImplementedError 40 | 41 | return top1_intent_idxs 42 | 43 | def run_eval(test_path, model_name_or_path="", distance_metric='cosine', draft=False, verbose=False): 44 | utterances, labels = load_data(test_path, draft) 45 | unique_intents = list(set(labels)) 46 | LOGGER.info(f"the number of unique intents)={len(unique_intents)}") 47 | 48 | model = SentenceTransformer(model_name_or_path, device='cuda') 49 | 50 | with torch.no_grad(): 51 | intent_embeddings = model.encode(unique_intents) 52 | hit = 0 53 | total = 0 54 | batch_size = 100 55 | for i in range(0, len(utterances), batch_size): 56 | b_utterances = utterances[i:i+batch_size] 57 | b_labels = labels[i:i+batch_size] 58 | 59 | top1_intent_idxs = find_top1_intent_idxs(model, b_utterances, intent_embeddings, distance_metric) 60 | 61 | for label, top1_intent_idx in zip(b_labels, top1_intent_idxs): 62 | top1_intent = unique_intents[top1_intent_idx] 63 | if label == top1_intent: 64 | hit+=1 65 | total +=1 66 | 67 | acc = hit/total*100 68 | if verbose: 69 | LOGGER.info(f"acc={acc}") 70 | return acc 71 | 72 | def main(args): 73 | run_eval( 74 | test_path=args.test_path, 75 | model_name_or_path=args.model_name_or_path, 76 | distance_metric=args.distance_metric, 77 | draft=args.draft, 78 | verbose=args.verbose 79 | ) 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--test_path', required=True) 84 | parser.add_argument('--model_name_or_path', required=True) 85 | parser.add_argument('--distance_metric', default='cosine', choices=['cosine', 'euclidean']) 86 | parser.add_argument('--draft', action='store_true') 87 | parser.add_argument('--verbose', action='store_true') 88 | args = parser.parse_args() 89 | 90 | init_logging(LOGGER) 91 | print(args) 92 | main(args) 93 | -------------------------------------------------------------------------------- /pretraining/run_pretrain.py: -------------------------------------------------------------------------------- 1 | # Original Copyright (c) 2021 Cambridge Language Technology Lab. Licensed under the MIT License. 2 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | 4 | #!/usr/bin/env python 5 | import argparse 6 | import copy 7 | import logging 8 | import os 9 | import time 10 | import shutil 11 | import torch 12 | from run_eval import run_eval 13 | from torch.cuda.amp import GradScaler, autocast 14 | from tqdm import tqdm 15 | 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | 18 | from iae.contrastive_learning import ContrastiveLearningPairwise 19 | from iae.data_loader import ContrastiveLearningDataset 20 | from iae.drophead import set_drophead 21 | from iae.iae_model import IAEModel 22 | from utils import init_logging 23 | 24 | LOGGER = logging.getLogger() 25 | 26 | def train(args, data_loader, model, scaler=None, iae_model=None, step_global=0): 27 | LOGGER.info("train!") 28 | train_loss = 0 29 | train_steps = 0 30 | model.cuda() 31 | model.train() 32 | best_acc = 0 33 | best_model = copy.deepcopy(iae_model) 34 | for i, data in tqdm(enumerate(data_loader), total=len(data_loader)): 35 | model.optimizer.zero_grad() 36 | # batch_x1: input utterance 37 | # batch_x2: gold intent 38 | # batch_x3: gold utterance 39 | # batch_x4: pseudo intent 40 | 41 | batch_x1, batch_x2, batch_x3, batch_x4 = data 42 | batch_x_cuda1, batch_x_cuda2, batch_x_cuda3, batch_x_cuda4 = {},{},{},{} 43 | for k,v in batch_x1.items(): 44 | batch_x_cuda1[k] = v.cuda() 45 | for k,v in batch_x2.items(): 46 | batch_x_cuda2[k] = v.cuda() 47 | for k,v in batch_x3.items(): 48 | batch_x_cuda3[k] = v.cuda() 49 | for k,v in batch_x4.items(): 50 | batch_x_cuda4[k] = v.cuda() 51 | 52 | if args.amp: 53 | with autocast(): 54 | loss = torch.tensor(0.0, requires_grad=True).cuda() 55 | loss += model(batch_x_cuda1, batch_x_cuda2) 56 | loss += model(batch_x_cuda1, batch_x_cuda3) 57 | if args.pseudo_weight: loss += args.pseudo_weight * model(batch_x_cuda1, batch_x_cuda4) 58 | else: 59 | loss = torch.tensor(0.0, requires_grad=True).cuda() 60 | loss += model(batch_x_cuda1, batch_x_cuda2) 61 | loss += model(batch_x_cuda1, batch_x_cuda3) 62 | if args.pseudo_weight: loss += args.pseudo_weight * model(batch_x_cuda1, batch_x_cuda4) 63 | 64 | torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), 1.0) 65 | 66 | if args.amp: 67 | scaler.scale(loss).backward() 68 | scaler.step(model.optimizer) 69 | scaler.update() 70 | else: 71 | loss.backward() 72 | model.optimizer.step() 73 | 74 | train_loss += loss.item() 75 | train_steps += 1 76 | step_global += 1 77 | 78 | if args.eval_during_training and (step_global % args.eval_step == 0): 79 | checkpoint_dir = os.path.join(args.output_dir, f"checkpoint_tmp") 80 | if not os.path.exists(checkpoint_dir): 81 | os.makedirs(checkpoint_dir) 82 | iae_model.save_model(checkpoint_dir) 83 | 84 | acc = run_eval(test_path=args.val_path, 85 | model_name_or_path=checkpoint_dir, 86 | distance_metric=args.distance_metric) 87 | 88 | if acc > best_acc: 89 | best_acc = acc 90 | best_model = copy.deepcopy(iae_model) 91 | 92 | LOGGER.info(f"step:{step_global}, val_acc:{acc}") 93 | 94 | train_loss /= (train_steps + 1e-9) 95 | return train_loss, step_global, best_model, best_acc 96 | 97 | def main(args): 98 | init_logging(LOGGER) 99 | print(args) 100 | 101 | torch.manual_seed(args.random_seed) 102 | # by default 42 is used, also tried 33, 44, 55 103 | # results don't seem to change too much 104 | 105 | # prepare for output 106 | if not os.path.exists(args.output_dir): 107 | os.makedirs(args.output_dir) 108 | 109 | # load BERT tokenizer, dense_encoder 110 | iae_model = IAEModel() 111 | encoder, tokenizer = iae_model.load_model( 112 | path=args.model_name_or_path, 113 | max_length=args.max_length, 114 | use_cuda=args.use_cuda, 115 | return_model=True 116 | ) 117 | 118 | # adjust dropout rates 119 | encoder.embeddings.dropout = torch.nn.Dropout(p=args.dropout_rate) 120 | for i in range(len(encoder.encoder.layer)): 121 | # hotfix 122 | try: 123 | encoder.encoder.layer[i].attention.self.dropout = torch.nn.Dropout(p=args.dropout_rate) 124 | encoder.encoder.layer[i].attention.output.dropout = torch.nn.Dropout(p=args.dropout_rate) 125 | except: 126 | encoder.encoder.layer[i].attention.attn.dropout = torch.nn.Dropout(p=args.dropout_rate) 127 | encoder.encoder.layer[i].attention.dropout = torch.nn.Dropout(p=args.dropout_rate) 128 | 129 | encoder.encoder.layer[i].output.dropout = torch.nn.Dropout(p=args.dropout_rate) 130 | 131 | # set drophead rate 132 | if args.drophead_rate != 0: 133 | set_drophead(encoder, args.drophead_rate) 134 | 135 | def collate_fn_batch_encoding(batch): 136 | sent1, sent2, sent3, sent4 = zip(*batch) 137 | sent1_toks = tokenizer.batch_encode_plus( 138 | list(sent1), 139 | max_length=args.max_length, 140 | padding="max_length", 141 | truncation=True, 142 | add_special_tokens=True, 143 | return_tensors="pt") 144 | sent2_toks = tokenizer.batch_encode_plus( 145 | list(sent2), 146 | max_length=args.max_length, 147 | padding="max_length", 148 | truncation=True, 149 | add_special_tokens=True, 150 | return_tensors="pt") 151 | sent3_toks = tokenizer.batch_encode_plus( 152 | list(sent3), 153 | max_length=args.max_length, 154 | padding="max_length", 155 | truncation=True, 156 | add_special_tokens=True, 157 | return_tensors="pt") 158 | sent4_toks = tokenizer.batch_encode_plus( 159 | list(sent4), 160 | max_length=args.max_length, 161 | padding="max_length", 162 | truncation=True, 163 | add_special_tokens=True, 164 | return_tensors="pt") 165 | 166 | return sent1_toks, sent2_toks, sent3_toks, sent4_toks 167 | 168 | train_set = ContrastiveLearningDataset( 169 | args.train_path, 170 | tokenizer=tokenizer, 171 | random_span_mask=args.random_span_mask, 172 | draft=args.draft 173 | ) 174 | 175 | train_loader = torch.utils.data.DataLoader( 176 | train_set, 177 | batch_size=args.train_batch_size, 178 | shuffle=True, 179 | num_workers=16, 180 | collate_fn=collate_fn_batch_encoding, 181 | drop_last=True 182 | ) 183 | model = ContrastiveLearningPairwise( 184 | encoder=encoder, 185 | learning_rate=args.learning_rate, 186 | weight_decay=args.weight_decay, 187 | use_cuda=args.use_cuda, 188 | infoNCE_tau=args.infoNCE_tau, 189 | agg_mode=args.agg_mode 190 | ) 191 | if args.parallel: 192 | model.encoder = torch.nn.DataParallel(model.encoder) 193 | LOGGER.info("using nn.DataParallel") 194 | # mixed precision training 195 | if args.amp: 196 | scaler = GradScaler() 197 | else: 198 | scaler = None 199 | 200 | start = time.time() 201 | step_global = 0 202 | best_acc = 0 203 | best_model = copy.deepcopy(iae_model) 204 | for epoch in range(1,args.epoch+1): 205 | LOGGER.info(f"Epoch {epoch}/{args.epoch}") 206 | 207 | # train 208 | train_loss, step_global, ep_best_model, ep_best_acc = train(args, data_loader=train_loader, model=model, 209 | scaler=scaler, iae_model=iae_model, step_global=step_global) 210 | LOGGER.info(f'loss/train_per_epoch={train_loss}/{epoch}') 211 | if ep_best_acc > best_acc: 212 | best_model = ep_best_model 213 | best_acc = ep_best_acc 214 | 215 | # eval after one epoch 216 | tmp_checkpoint_dir = os.path.join(args.output_dir, f"checkpoint_tmp") 217 | if not os.path.exists(tmp_checkpoint_dir): 218 | os.makedirs(tmp_checkpoint_dir) 219 | iae_model.save_model(tmp_checkpoint_dir) 220 | ep_acc = run_eval(test_path=args.val_path, 221 | model_name_or_path=tmp_checkpoint_dir, 222 | distance_metric=args.distance_metric 223 | ) 224 | # remove tmp directory 225 | shutil.rmtree(tmp_checkpoint_dir) 226 | 227 | if ep_acc > best_acc: 228 | best_acc = ep_acc 229 | best_model = copy.deepcopy(iae_model) 230 | 231 | LOGGER.info(f"step:{step_global}, val_acc:{ep_acc}") 232 | 233 | best_model.save_model(args.output_dir) 234 | 235 | end = time.time() 236 | training_time = end-start 237 | training_hour = int(training_time/60/60) 238 | training_minute = int(training_time/60 % 60) 239 | training_second = int(training_time % 60) 240 | LOGGER.info(f"Training Time!{training_hour} hours {training_minute} minutes {training_second} seconds") 241 | LOGGER.info(f"Best val acc={best_acc}") 242 | 243 | if __name__ == '__main__': 244 | """ 245 | Parse input arguments 246 | """ 247 | parser = argparse.ArgumentParser(description='train IAE Model') 248 | 249 | # Required 250 | parser.add_argument('--train_path', type=str, required=True, help='training set directory') 251 | parser.add_argument('--val_path', type=str, help='validation set directory') 252 | parser.add_argument('--output_dir', type=str, required=True, help='Directory for output') 253 | 254 | parser.add_argument('--model_name_or_path', type=str, \ 255 | help='Directory for pretrained model', \ 256 | default="roberta-base") 257 | parser.add_argument('--max_length', default=50, type=int) 258 | parser.add_argument('--learning_rate', default=2e-5, type=float) 259 | parser.add_argument('--weight_decay', default=0.01, type=float) 260 | parser.add_argument('--train_batch_size', default=200, type=int) 261 | parser.add_argument('--epoch', default=3, type=int) 262 | parser.add_argument('--infoNCE_tau', default=0.04, type=float) 263 | parser.add_argument('--agg_mode', default="cls", type=str, help="{cls|mean|mean_std}") 264 | parser.add_argument('--use_cuda', action="store_true") 265 | parser.add_argument('--save_checkpoint_all', action="store_true") 266 | parser.add_argument('--checkpoint_step', type=int, default=10000000) 267 | parser.add_argument('--parallel', action="store_true") 268 | parser.add_argument('--amp', action="store_true", \ 269 | help="automatic mixed precision training") 270 | parser.add_argument('--random_seed', default=42, type=int) 271 | 272 | # data augmentation config 273 | parser.add_argument('--dropout_rate', default=0.1, type=float) 274 | parser.add_argument('--drophead_rate', default=0.0, type=float) 275 | parser.add_argument('--random_span_mask', default=5, type=int, 276 | help="number of chars to be randomly masked on one side of the input") 277 | 278 | parser.add_argument('--distance_metric', default='cosine', choices=['cosine', 'euclidean']) 279 | parser.add_argument('--draft', action='store_true') 280 | parser.add_argument('--eval_during_training', action='store_true') 281 | parser.add_argument('--eval_step', default=200, type=int) 282 | parser.add_argument('--pseudo_weight', default=2, type=float) 283 | args = parser.parse_args() 284 | 285 | main(args) 286 | -------------------------------------------------------------------------------- /pretraining/scripts/create_dataset.sh: -------------------------------------------------------------------------------- 1 | python preprocess/create_pretrain_dataset.py \ 2 | --irl_model_path ../models/irl_model/irl-model-sgd-08-16-2022.tar.gz \ 3 | --top1_dir ../data/sources/topv1 \ 4 | --top2_dir ../data/sources/topv2 \ 5 | --dstc11t2_dir ../data/sources/dstc11t2 \ 6 | --sgd_dir ../data/sources/sgd \ 7 | --multiwoz_dir ../data/sources/multiwoz2.2 \ 8 | --output_dir ../data/pretraining -------------------------------------------------------------------------------- /pretraining/scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | output_name=iae_model 2 | python run_eval.py \ 3 | --test_path ../data/pretraining/val.txt \ 4 | --model_name_or_path ../output/${output_name} \ 5 | --distance_metric cosine \ 6 | --verbose -------------------------------------------------------------------------------- /pretraining/scripts/run_pretrain.sh: -------------------------------------------------------------------------------- 1 | output_name=iae_model 2 | python run_pretrain.py \ 3 | --model_name_or_path sentence-transformers/paraphrase-mpnet-base-v2 \ 4 | --train_path ../data/pretraining/train.txt \ 5 | --val_path ../data/pretraining/val.txt \ 6 | --output_dir ../models/${output_name} \ 7 | --epoch 1 \ 8 | --train_batch_size 50 \ 9 | --learning_rate 1e-6 \ 10 | --max_length 50 \ 11 | --infoNCE_tau 0.05 \ 12 | --dropout_rate 0.1 \ 13 | --drophead_rate 0.0 \ 14 | --random_span_mask 5 \ 15 | --random_seed 42 \ 16 | --agg_mode mean_std \ 17 | --amp \ 18 | --parallel \ 19 | --use_cuda \ 20 | --pseudo_weight 2 -------------------------------------------------------------------------------- /pretraining/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def init_logging(LOGGER): 4 | LOGGER.setLevel(logging.INFO) 5 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', 6 | '%m/%d/%Y %I:%M:%S %p') 7 | console = logging.StreamHandler() 8 | console.setFormatter(fmt) 9 | LOGGER.addHandler(console) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.18.0 2 | numpy==1.23.0 3 | torch==1.11.0 4 | PyYAML==6.0 5 | regex==2022.6.2 6 | tqdm==4.64.0 7 | tensorboardX==2.5.1 8 | sentencepiece==0.1.96 9 | nltk==3.7 10 | pytorch-metric-learning==1.5.0 11 | sentence-transformers==2.2.2 12 | sacrebleu==2.1.0 13 | allennlp==2.9.3 14 | cached-path==1.1.2 15 | https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.2.0/en_core_web_md-3.2.0.tar.gz#egg=en_core_web_md --------------------------------------------------------------------------------