├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING-ARCHIVED.md ├── LICENSE ├── README.md ├── SECURITY.md ├── adversarial-training ├── codemixer.py ├── run_codemixer_nli.py └── run_codemixer_sa.py ├── attacks ├── bumblebee.py ├── polygloss.py ├── run_bumblebee_nli.py ├── run_bumblebee_qa.py ├── run_polygloss_nli.py └── squad_utils.py └── scripts ├── create_polygloss_dictionaries_from_muse.py ├── extract-xnli-sentences-to-dict.py ├── extract-xquad-questions-to-dict.py ├── extract_en_xnli.py ├── generate_xnli_cleanDL.py ├── language-opus-nmt-map.json ├── run_sentiment_analysis.py ├── translate_tweeteval.py └── translate_xnli.py /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Polyglots 2 | This repository contains code for the paper "[Code-Mixing on Sesame Street: Dawn of the Adversarial Polyglots](https://arxiv.org/abs/2103.09593)" (NAACL-HLT 2021). 3 | 4 | Authors: [Samson Tan](https://samsontmr.github.io) and [Shafiq Joty](https://raihanjoty.github.io) 5 | 6 | 7 | # Usage 8 | 9 | ## Adversarial Polyglots 10 | Scripts for running `PolyGloss` and `Bumblebee` on NLI and QA datasets (MNLI/SQuAD formats) are in the `attacks` folder. Preprocessing scripts can be found in `scripts`. `PolyGloss` and `Bumblebee` return the _*adversarial*_ examples with the highest and lowest losses. The one that induced the lower loss _(minimally perturbed)_ is usually less perturbed, but the one that induced a higher loss _(maximally perturbed)_ should transfer more successfully to other models. 11 | 12 | `PolyGloss` requires a dictionary constructed from the bilingual [MUSE dictionaries](https://github.com/facebookresearch/MUSE#ground-truth-bilingual-dictionaries). After downloading the dictionaries into a folder under `scripts` labeled `dictionaries`, run `scripts/create_polygloss_dictionaries_from_muse.py`. 13 | 14 | `Bumblebee` requires a dictionary/JSON consisting of sentence-translations pairs. The `extract-xnli-sentences-to-dict.py` and `extract-xquad-questions-to-dict.py` scripts in `scripts` can be used to create these dictionaries for files in the MNLI and SQuAD formats (e.g., XNLI and XQuAD). JSONs for the XNLI test set can be found [here](https://github.com/salesforce/adversarial-polyglots-data). 15 | 16 | ## Code-mixed Adversarial Training 17 | Code for generating code-mixed adversarial training (CAT) examples are in `adversarial-training`. Since the alignment step is the most time-consuming, we decouple it from the example perturbation step. Users can generate only the alignments by using the `--extract_phrases` option or load precomputed alignments via the `phrase_alignments` option. 18 | 19 | Similar to `Bumblebee`, `Code-Mixer` requires a dictionary/JSON consisting of sentence-translations pairs. The `extract-xnli-sentences-to-dict.py` and `extract-xquad-questions-to-dict.py` scripts in `scripts` can be used to create these dictionaries for files in the MNLI and SQuAD formats (e.g., XNLI and XQuAD). 20 | 21 | # Translated XNLI Data 22 | We translated the [XNLI data](https://cims.nyu.edu/~sbowman/xnli) to 18 other languages using machine-translation systems (see paper for details). Translation script is in `scripts`. Translated data can be found [here](https://github.com/salesforce/adversarial-polyglots-data). 23 | 24 | # Citation 25 | Please cite the following if you use the code/data in this repository: 26 | ``` 27 | @inproceedings{tan-joty-2021-code-mixing, 28 | title = "Code-Mixing on Sesame Street: {D}awn of the Adversarial Polyglots", 29 | author = "Tan, Samson and Joty, Shafiq", 30 | booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", 31 | month = jun, 32 | year = "2021", 33 | address = "Online", 34 | publisher = "Association for Computational Linguistics", 35 | url = "https://www.aclweb.org/anthology/2021.naacl-main.282", 36 | pages = "3596--3616", 37 | } 38 | 39 | ``` 40 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /adversarial-training/codemixer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random, json, string, warnings 3 | from typing import Union, List, Set, Dict 4 | from simalign import SentenceAligner 5 | from nltk.translate.phrase_based import phrase_extraction 6 | 7 | 8 | class CodeMixer(object): 9 | def __init__(self, matrix_lg: str, embedded_lgs: Union[List[str],Set[str]], device='cuda', precomputed_phrases=False): 10 | warnings.filterwarnings("ignore", category=FutureWarning) 11 | if not precomputed_phrases: 12 | self.aligner = SentenceAligner(model="xlmr", token_type="bpe", matching_methods="m", device=device) 13 | self.rtl_lgs = {'ar', 'he'} 14 | self.matrix_lg = matrix_lg 15 | self.embedded_lgs = set(embedded_lgs) 16 | 17 | 18 | def get_phrases(self, matrix_sentence: str, translations: Dict[str,str], 19 | sample_lgs: Union[List[str],Set[str]]=None): 20 | if not sample_lgs: 21 | sample_lgs = self.embedded_lgs 22 | filtered_translations = {k: v for k,v in translations.items() if k in sample_lgs} 23 | matrix_tokens = matrix_sentence.split() 24 | phrases = [] 25 | for lg, embedded_sentence in filtered_translations.items(): 26 | tokenized_embedded = embedded_sentence.split() 27 | alignments = self.aligner.get_word_aligns(matrix_tokens, tokenized_embedded) 28 | candidate_phrases = phrase_extraction(matrix_sentence, embedded_sentence, alignments['mwmf'], 4) 29 | candidate_phrases = [(p[0], p[1], p[2], p[3], lg) for p in candidate_phrases if p[2][0] not in string.punctuation] 30 | if lg == 'zh': 31 | candidate_phrases = [(p[0], p[1], p[2], p[3].replace(' ', ''), p[4]) for p in candidate_phrases] 32 | if lg == 'th': 33 | candidate_phrases = [(p[0], p[1], p[2], p[3].replace(' ', ' ').replace(' ,', ','), p[4]) for p in candidate_phrases] 34 | 35 | phrases += candidate_phrases 36 | 37 | sorted_phrases = sorted(phrases, key=lambda x: (x[0][0], -x[0][1])) 38 | 39 | 40 | grouped_phrases = {i:[] for i in range(len(matrix_tokens))} 41 | for phrase in sorted_phrases: 42 | grouped_phrases[phrase[0][0]].append(phrase) 43 | 44 | return grouped_phrases 45 | 46 | def swap_phrase(self, tokens, replace_start_idx, replace_end_idx, to_replace): 47 | return tokens[0:replace_start_idx] + [to_replace] + tokens[replace_end_idx:] 48 | 49 | def get_weights(self, lg_counts): 50 | filtered_lg_counts = {k: v for k,v in lg_counts.items() if k in self.embedded_lgs or k == self.matrix_lg} 51 | 52 | total_count = sum(filtered_lg_counts.values()) 53 | return {k: v/total_count for k,v in filtered_lg_counts.items()} 54 | 55 | def generate(self, sentence, reference_translations, probability=0.15, lg_counts: Dict[str,int]=None): 56 | phrases = self.get_phrases(sentence, reference_translations) 57 | return generate_precomputed_alignments(sentence, phrases, probability, lg_counts) 58 | 59 | def generate_precomputed_alignments(self, sentence, phrase_alignments, probability=0.15):#, lg_counts: Dict[str,int]=None): 60 | tokens = sentence.split() 61 | 62 | token_length = len(tokens) 63 | pos = 0 64 | prev_lg = self.matrix_lg 65 | prev_replacement_pos = pos 66 | while pos < token_length: 67 | candidates = phrase_alignments.get(pos) 68 | 69 | pos += 1 70 | if random.random() >= probability or not candidates: 71 | prev_lg = self.matrix_lg 72 | continue 73 | 74 | eligible_candidates = [] 75 | 76 | for candidate in candidates: 77 | phrase_to_replace = candidate[2] 78 | replacement = candidate[3] 79 | replacement_lg = candidate[4] 80 | replace_start_idx = candidate[0][0] - token_length 81 | replace_end_idx = candidate[0][1] - token_length 82 | 83 | 84 | if phrase_to_replace.split() != tokens[replace_start_idx:replace_end_idx]: 85 | continue 86 | if replacement_lg not in self.rtl_lgs and replacement_lg == prev_lg and candidate[1][1] <= prev_replacement_pos: 87 | continue 88 | eligible_candidates.append(candidate) 89 | 90 | if eligible_candidates: 91 | chosen_candidate = random.choice(eligible_candidates) 92 | 93 | replacement_lg = chosen_candidate[4] 94 | replacement = chosen_candidate[3] 95 | replace_start_idx = chosen_candidate[0][0] - token_length 96 | replace_end_idx = chosen_candidate[0][1] - token_length 97 | 98 | tokens = self.swap_phrase(tokens, replace_start_idx, replace_end_idx, replacement) 99 | prev_lg = replacement_lg 100 | prev_replacement_pos = pos 101 | pos = max(replace_end_idx, pos) 102 | 103 | else: 104 | prev_lg = self.matrix_lg 105 | continue 106 | 107 | return ' '.join(tokens) 108 | -------------------------------------------------------------------------------- /adversarial-training/run_codemixer_nli.py: -------------------------------------------------------------------------------- 1 | from transformers import glue_processors as processors 2 | import csv, os, argparse, json, ray, time, torch, jsonlines, random 3 | from tqdm import tqdm 4 | from pathlib import Path 5 | from codemixer import CodeMixer 6 | from ray.util import ActorPool 7 | from math import ceil 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--data", "-d", default=None, type=str, required=False, help="The input data file, e.g., 'data/MNLI/train.tsv'.") 11 | parser.add_argument("--output_dir", "-o", default=None, type=str, required=True, help="The output directory.") 12 | parser.add_argument("--emb_lgs", '-t', default='fr,es,de,zh,el,bg,ru,tr,ar,vi,th,hi,sw,ur', type=str, required=False, help="Embedded languages.") 13 | parser.add_argument("--seed", '-sd', default=42, type=int, required=False, help="Random seed") 14 | parser.add_argument("--sample_lgs", '-sl', default=2, type=int, required=False, help="Number of embedded languages per example.") 15 | parser.add_argument("--sample_prob", '-sp', default=0.33, type=float, required=False, help="Probability of modifying a sentence.") 16 | parser.add_argument("--perturb_prob", '-ptb', default=0.15, type=float, required=False, help="Probability of perturbing a word/phrase.") 17 | parser.add_argument("--device", default='cuda', type=str, required=False, help="Device to use {'tpu', 'cuda', 'cpu'}.") 18 | parser.add_argument("--lg_counts", default=None, type=str, required=False, help="Path to adversarial language distribution.") 19 | parser.add_argument("--extract_phrases", '-ep', action='store_true', required=False, help="Extract phrases only.") 20 | parser.add_argument("--phrase_alignments", '-pa', type=str, default=None, required=False, help="Path to extracted phrase alignments.") 21 | parser.add_argument("--split", '-s', default='train', type=str, required=False, help="Use the train or test data as the source.") 22 | parser.add_argument("--num_k", '-k', default=1, type=int, required=False, help="Number of perturbed examples per clean example.") 23 | parser.add_argument("--gpu", '-g', default=0.25, type=float, required=False, help="GPU allocation per actor (% of one GPU). Total number of parallel actors is calculated using this value. Set to 0 to use CPU.") 24 | args = parser.parse_args() 25 | 26 | 27 | MTX_LG = 'en' 28 | 29 | emb_lgs = args.emb_lgs.split(',') 30 | 31 | # V100: 0.1428, A100: 32 | USE_CUDA = torch.cuda.is_available() and args.gpu > 0 33 | NUM_GPU_PER_ACTOR = args.gpu if USE_CUDA else 0 34 | NUM_ACTOR_CPU_ONLY = 80 35 | TOP_UP = 0 # Increase this value if number of generated examples is insufficient due to lack of candidates for some examples. We compensate by increasing k for other examples with more candidates. 36 | 37 | 38 | @ray.remote(num_gpus=NUM_GPU_PER_ACTOR) 39 | class CodemixActor(object): 40 | def __init__(self, mtx_lg, emb_lgs, lg_counts, device, actor_id, reference_translations=None): 41 | print(str(actor_id) + ' spawned') 42 | if args.phrase_alignments: 43 | self.mixer = CodeMixer(mtx_lg, emb_lgs, device, True) 44 | else: 45 | self.mixer = CodeMixer(mtx_lg, emb_lgs, device) 46 | 47 | self.refs = reference_translations 48 | self.lg_counts_dict = {k:v for k,v in lg_counts.items() if k in emb_lgs} if lg_counts else None 49 | self.lg_counts = list(self.lg_counts_dict.values()) if self.lg_counts_dict else None 50 | self.emb_lgs = list(self.lg_counts_dict.keys()) if self.lg_counts_dict else emb_lgs 51 | 52 | def mutate(self, batch, k, top_up=0): 53 | results = [] 54 | for example in tqdm(batch): 55 | 56 | if example.get('preserved', False): 57 | results.append({'sentence1': example['sentence1'], 58 | 'sentence2': example['sentence2'], 59 | 'gold_label': example['gold_label'], 60 | 'preserved': 1}) 61 | continue 62 | 63 | 64 | num_tries = 0 65 | results_set = set() 66 | 67 | adjusted_k = k+1 if top_up > 0 else k 68 | 69 | while len(results_set) < adjusted_k and num_tries < k*50: 70 | result = {} 71 | example['sentence1_phrases'] = {int(k): v for k,v in example['sentence1_phrases'].items()} 72 | example['sentence2_phrases'] = {int(k): v for k,v in example['sentence2_phrases'].items()} 73 | 74 | result['sentence1'] = self.mixer.generate_precomputed_alignments(example['sentence1'], 75 | example['sentence1_phrases'], 76 | args.perturb_prob) 77 | result['sentence1'] = result['sentence1'].replace(' ,', ',').replace(' .', '.')\ 78 | .replace(" '", "'").replace('( ', '(')\ 79 | .replace(' )', ')').replace(' ?', '?').replace(' !', '!') 80 | result['sentence2'] = self.mixer.generate_precomputed_alignments(example['sentence2'], 81 | example['sentence2_phrases'], 82 | args.perturb_prob) 83 | result['sentence2'] = result['sentence2'].replace(' ,', ',').replace(' .', '.')\ 84 | .replace(" '", "'").replace('( ', '(')\ 85 | .replace(' )', ')').replace(' ?', '?').replace(' !', '!') 86 | result['gold_label'] = example['gold_label'] 87 | if result['sentence1'] != example['sentence1'] or result['sentence2'] != example['sentence2']: 88 | result['preserved'] = 0 89 | else: 90 | result['preserved'] = 1 91 | 92 | if (result['sentence1'], result['sentence2']) not in results_set: 93 | results_set.add((result['sentence1'], result['sentence2'])) 94 | results.append(result) 95 | if len(results_set) > k: 96 | top_up -= 1 97 | num_tries += 1 98 | 99 | return results 100 | 101 | def extract_phrases(self, batch): 102 | results = [] 103 | for i, example in enumerate(tqdm(batch)): 104 | text_a_phrases = {} 105 | text_b_phrases = {} 106 | random.seed(args.seed+i+len(example.text_a)) 107 | if random.random() <= args.sample_prob: 108 | random.seed(args.seed+i+len(example.text_a)) 109 | chosen_lgs = [] 110 | tries = 0 111 | while len(list(filter(None, [self.refs[0][example.text_a][lg] for lg in chosen_lgs]))) < args.sample_lgs and tries < 5*args.sample_lgs: 112 | tries += 1 113 | chosen_lgs = random.choices(self.emb_lgs, weights=self.lg_counts, k=args.sample_lgs) 114 | text_a_phrases = self.mixer.get_phrases(example.text_a, self.refs[0][example.text_a], chosen_lgs) 115 | 116 | 117 | random.seed(args.seed+i+2*len(example.text_b)) 118 | if random.random() <= args.sample_prob: 119 | random.seed(args.seed+i+2+len(example.text_b)) 120 | chosen_lgs = [] #random.choices(self.emb_lgs, weights=self.lg_counts, k=args.sample_lgs) 121 | while len(list(filter(None, [self.refs[1][example.text_b][lg] for lg in chosen_lgs]))) < args.sample_lgs and tries < 5: 122 | tries += 1 123 | chosen_lgs = random.choices(self.emb_lgs, weights=self.lg_counts, k=args.sample_lgs) 124 | text_b_phrases = self.mixer.get_phrases(example.text_b, self.refs[1][example.text_b], chosen_lgs) 125 | if example.label == 'contradictory': 126 | example.label = 'contradiction' 127 | if text_a_phrases or text_b_phrases: 128 | preserved = 0 129 | else: 130 | preserved = 1 131 | results.append({'sentence1': example.text_a, 'sentence2': example.text_b, 132 | 'sentence1_phrases': text_a_phrases, 'sentence2_phrases': text_b_phrases, 133 | 'preserved': preserved, 'gold_label': example.label}) 134 | 135 | return results 136 | 137 | 138 | def _create_output_data(examples): 139 | #output = [['sentence1', 'sentence2', 'gold_label', 'preserved']] 140 | output = [] 141 | for example in tqdm(examples, desc='Creating output'): 142 | output_line = [example['sentence1'], example['sentence2'], example['gold_label'], example['preserved']] 143 | output.append(output_line) 144 | return output 145 | 146 | 147 | def _write_tsv(output, output_file): 148 | with open(output_file, "w", encoding="utf-8-sig") as f: 149 | writer = csv.writer(f, delimiter="\t", quotechar=None) 150 | for row in tqdm(output, desc='Writing output'): 151 | writer.writerow(row) 152 | 153 | def get_examples(data_dir, split): 154 | if split == 'train': 155 | return processors['mnli']().get_train_examples(data_dir) 156 | elif split == 'test': 157 | return processors['mnli']().get_test_examples(data_dir) 158 | 159 | def get_examples_w_phrases(data_file): 160 | examples = [] 161 | with jsonlines.open(data_file, mode='r') as reader: 162 | for example in reader: 163 | examples.append(example) 164 | return examples 165 | 166 | if args.phrase_alignments: 167 | reference_translations = None 168 | examples = get_examples_w_phrases(args.phrase_alignments) 169 | else: 170 | reference_translations = [json.load(open('xnli-'+args.split+'-sentence1-reference-translations-en-head.json','r')), 171 | json.load(open('xnli-'+args.split+'-sentence2-reference-translations-en-head','r'))] 172 | examples = get_examples(args.data, args.split) 173 | examples.reverse() 174 | 175 | if args.lg_counts: 176 | weight_flag = '.weighted' 177 | lg_counts = json.load(open(args.lg_counts,'r')) 178 | else: 179 | weight_flag = '.unweighted' 180 | lg_counts = None 181 | 182 | output_path = Path(args.output_dir, 'codemixed_mnli.'+'_'.join(emb_lgs)+weight_flag) 183 | 184 | 185 | 186 | 187 | args.device = 'cpu' if args.device == 'cuda' and not torch.cuda.is_available() else args.device 188 | 189 | num_actors = int(torch.cuda.device_count() // NUM_GPU_PER_ACTOR) if args.device == 'cuda' else (64 if args.device == 'tpu' else NUM_ACTOR_CPU_ONLY) 190 | print('Number of CodeMixers:', num_actors) 191 | 192 | total_exs = len(examples) 193 | print(total_exs) 194 | len_per_batch = ceil(total_exs / num_actors) 195 | 196 | batches = [examples[i:i+len_per_batch] for i in range(0, total_exs, len_per_batch)] 197 | 198 | ray.init() 199 | 200 | 201 | 202 | actors = ActorPool([CodemixActor.remote(MTX_LG, emb_lgs, lg_counts, args.device, i, reference_translations) 203 | for i in range(num_actors)]) 204 | start = time.time() 205 | 206 | 207 | if args.phrase_alignments: 208 | results = list(actors.map(lambda actor, batch: actor.mutate.remote(batch, args.num_k, TOP_UP), batches)) 209 | time_taken = time.time() - start 210 | condition_name = '.'.join(args.phrase_alignments.split('/')[-1].split('.')[1:-1]) 211 | else: 212 | phrase_results = list(actors.map(lambda actor, batch: actor.extract_phrases.remote(batch), batches)) 213 | 214 | output_path.mkdir(parents=True, exist_ok=True) 215 | 216 | 217 | time_taken = time.time() - start 218 | condition_name = 'seed-' + str(args.seed) + '.smp_lgs-' + str(args.sample_lgs) + '.smp_prob-'+ str(args.sample_prob) 219 | combined_phrase_results = [ex for batch in phrase_results for ex in batch] 220 | output_file_phrases = Path(output_path, args.split+'-extracted_phrases.' + condition_name + '.jsonl') 221 | with jsonlines.open(output_file_phrases, mode='w') as writer: 222 | for result in tqdm(combined_phrase_results, desc='Writing output'): 223 | writer.write(result) 224 | if not args.extract_phrases: 225 | results = list(actors.map(lambda actor, batch: actor.mutate.remote(batch, args.num_k), phrase_results)) 226 | 227 | 228 | if not args.extract_phrases: 229 | results = [ex for batch in results for ex in batch] 230 | output_path = Path(output_path, condition_name + '.ptb_prob-' + str(args.perturb_prob) + '.k-' + str(args.num_k)) 231 | output_path.mkdir(parents=True, exist_ok=True) 232 | if args.split == 'train': 233 | output_file = Path(output_path, 'train.tsv') 234 | elif args.split == 'test': 235 | output_file = Path(output_path, 'test_matched.tsv') 236 | _write_tsv(_create_output_data(results), output_file) 237 | 238 | print("Time taken:", time_taken / 60) 239 | -------------------------------------------------------------------------------- /adversarial-training/run_codemixer_sa.py: -------------------------------------------------------------------------------- 1 | from transformers import glue_processors as processors 2 | from datasets import load_from_disk 3 | import csv, os, argparse, json, ray, time, torch, jsonlines, random 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | from codemixer import CodeMixer 7 | from ray.util import ActorPool 8 | from math import ceil 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--data", "-d", default=None, type=str, required=False, help="The input data file, e.g., 'data/MNLI/train.tsv'.") 12 | parser.add_argument("--output_dir", "-o", default=None, type=str, required=True, help="The output directory.") 13 | parser.add_argument("--emb_lgs", '-t', default='es,hi', type=str, required=False, help="Embedded languages.") 14 | parser.add_argument("--seed", '-sd', default=42, type=int, required=False, help="Random seed") 15 | parser.add_argument("--sample_lgs", '-sl', default=2, type=int, required=False, help="Number of embedded languages per example.") 16 | parser.add_argument("--sample_prob", '-sp', default=1.0, type=float, required=False, help="Probability of modifying a sentence.") 17 | parser.add_argument("--perturb_prob", '-ptb', default=0.15, type=float, required=False, help="Probability of perturbing a word/phrase.") 18 | parser.add_argument("--device", default='cuda', type=str, required=False, help="Device to use {'tpu', 'cuda', 'cpu'}.") 19 | parser.add_argument("--lg_counts", default=None, type=str, required=False, help="Path to adversarial language distribution.") 20 | parser.add_argument("--extract_phrases", '-ep', action='store_true', required=False, help="Extract phrases only.") 21 | parser.add_argument("--phrase_alignments", '-pa', type=str, default=None, required=False, help="Path to extracted phrase alignments.") 22 | parser.add_argument("--split", '-s', default='train', type=str, required=False, help="Use the train or test data as the source.") 23 | parser.add_argument("--num_k", '-k', default=1, type=int, required=False, help="Number of perturbed examples per clean example.") 24 | 25 | args = parser.parse_args() 26 | 27 | 28 | MTX_LG = 'en' 29 | 30 | emb_lgs = args.emb_lgs.split(',') 31 | 32 | # V100: 0.1428, A100: 33 | NUM_GPU_PER_ACTOR = 0.5 34 | NUM_ACTOR_CPU_ONLY = 80 35 | TOP_UP = 0#4891 36 | 37 | 38 | @ray.remote(num_gpus=NUM_GPU_PER_ACTOR) 39 | class CodemixActor(object): 40 | def __init__(self, mtx_lg, emb_lgs, lg_counts, device, actor_id):#, reference_translations=None): 41 | print(str(actor_id) + ' spawned') 42 | if args.phrase_alignments: 43 | self.mixer = CodeMixer(mtx_lg, emb_lgs, device, True) 44 | else: 45 | self.mixer = CodeMixer(mtx_lg, emb_lgs, device) 46 | 47 | #self.refs = reference_translations 48 | self.lg_counts_dict = {k:v for k,v in lg_counts.items() if k in emb_lgs} if lg_counts else None 49 | self.lg_counts = list(self.lg_counts_dict.values()) if self.lg_counts_dict else None 50 | self.emb_lgs = list(self.lg_counts_dict.keys()) if self.lg_counts_dict else emb_lgs 51 | 52 | def mutate(self, batch, k, top_up=0): 53 | results = [] 54 | for example in tqdm(batch): 55 | 56 | if example.get('preserved', False): 57 | continue 58 | 59 | num_tries = 0 60 | results_set = set() 61 | 62 | adjusted_k = k+1 if top_up > 0 else k 63 | 64 | while len(results_set) < adjusted_k and num_tries < k*50: 65 | result = {} 66 | example['text_phrases'] = {int(key): [phrase_tuple for phrase_tuple in v if phrase_tuple[-1] in self.emb_lgs] 67 | for key,v in example['text_phrases'].items()} 68 | 69 | result['text'] = self.mixer.generate_precomputed_alignments(example['text'], 70 | example['text_phrases'], 71 | args.perturb_prob) 72 | result['text'] = result['text'].replace(' ,', ',').replace(' .', '.')\ 73 | .replace(" '", "'").replace('( ', '(')\ 74 | .replace(' )', ')').replace(' ?', '?').replace(' !', '!') 75 | 76 | result['label'] = example['label'] 77 | if result['text'] != example['text']: 78 | result['preserved'] = 0 79 | else: 80 | result['preserved'] = 1 81 | 82 | if result['text'] not in results_set: 83 | results_set.add(result['text']) 84 | results.append(result) 85 | if len(results_set) > k: 86 | top_up -= 1 87 | num_tries += 1 88 | 89 | return results 90 | 91 | def extract_phrases(self, batch): 92 | results = [] 93 | for i, example in enumerate(tqdm(batch)): 94 | text_phrases = {} 95 | random.seed(args.seed+i+len(example['text'])) 96 | if random.random() <= args.sample_prob: 97 | random.seed(args.seed+i+len(example['text'])) 98 | chosen_lgs = [] 99 | tries = 0 100 | while len(chosen_lgs) < args.sample_lgs and tries < 5*args.sample_lgs: 101 | tries += 1 102 | chosen_lgs = random.choices(self.emb_lgs, weights=self.lg_counts, k=args.sample_lgs) 103 | refs = {lg: example[lg] for lg in chosen_lgs} 104 | text_phrases = self.mixer.get_phrases(example['text'], refs, chosen_lgs) 105 | 106 | if text_phrases: 107 | preserved = 0 108 | else: 109 | preserved = 1 110 | results.append({'text': example['text'], 111 | 'text_phrases': text_phrases, 112 | 'preserved': preserved, 'label': example['label']}) 113 | 114 | return results 115 | 116 | 117 | def _create_output_data(examples): 118 | output = [] 119 | for example in tqdm(examples, desc='Creating output'): 120 | if example['preserved'] == 0: 121 | output_line = [example['text'], example['label']] 122 | output.append(output_line) 123 | return output 124 | 125 | 126 | def _write_tsv(output, output_file): 127 | with open(output_file, "w", encoding="utf-8-sig") as f: 128 | writer = csv.writer(f, delimiter="\t", quotechar='"') 129 | for row in tqdm(output, desc='Writing output'): 130 | writer.writerow(row) 131 | 132 | def get_examples(data_dir, split): 133 | return load_from_disk(data_dir)[split] 134 | 135 | def get_examples_w_phrases(data_file): 136 | examples = [] 137 | with jsonlines.open(data_file, mode='r') as reader: 138 | for example in reader: 139 | examples.append(example) 140 | return examples 141 | 142 | if args.phrase_alignments: 143 | reference_translations = None 144 | examples = get_examples_w_phrases(args.phrase_alignments) 145 | else: 146 | examples = load_from_disk(args.data)[args.split] 147 | 148 | if args.lg_counts: 149 | weight_flag = '.weighted' 150 | lg_counts = json.load(open(args.lg_counts,'r')) 151 | else: 152 | weight_flag = '.unweighted' 153 | lg_counts = None 154 | 155 | output_path = Path(args.output_dir, 'codemixed_sa.'+'_'.join(emb_lgs)+weight_flag) 156 | 157 | 158 | args.device = 'cpu' if args.device == 'cuda' and not torch.cuda.is_available() else args.device 159 | 160 | num_actors = int(torch.cuda.device_count() // NUM_GPU_PER_ACTOR) if args.device == 'cuda' else (64 if args.device == 'tpu' else NUM_ACTOR_CPU_ONLY) 161 | print('Number of CodeMixers:', num_actors) 162 | 163 | total_exs = len(examples) 164 | print(total_exs) 165 | len_per_batch = ceil(total_exs / num_actors) 166 | 167 | 168 | if not args.phrase_alignments: 169 | batches = [examples.select(range(i, i+len_per_batch)) for i in range(0, total_exs, len_per_batch)] 170 | else: 171 | batches = [examples[i:i+len_per_batch] for i in range(0, total_exs, len_per_batch)] 172 | 173 | ray.init() 174 | 175 | 176 | 177 | actors = ActorPool([CodemixActor.remote(MTX_LG, emb_lgs, lg_counts, args.device, i) 178 | for i in range(num_actors)]) 179 | start = time.time() 180 | 181 | 182 | if args.phrase_alignments: 183 | results = list(actors.map(lambda actor, batch: actor.mutate.remote(batch, args.num_k, TOP_UP), batches)) 184 | time_taken = time.time() - start 185 | condition_name = '.'.join(args.phrase_alignments.split('/')[-1].split('.')[1:-1]) 186 | else: 187 | phrase_results = list(actors.map(lambda actor, batch: actor.extract_phrases.remote(batch), batches)) 188 | 189 | output_path.mkdir(parents=True, exist_ok=True) 190 | 191 | 192 | time_taken = time.time() - start 193 | condition_name = 'seed-' + str(args.seed) + '.smp_lgs-' + str(args.sample_lgs) + '.smp_prob-'+ str(args.sample_prob) 194 | combined_phrase_results = [ex for batch in phrase_results for ex in batch] 195 | output_file_phrases = Path(output_path, args.split+'-extracted_phrases.' + condition_name + '.jsonl') 196 | with jsonlines.open(output_file_phrases, mode='w') as writer: 197 | for result in tqdm(combined_phrase_results, desc='Writing output'): 198 | writer.write(result) 199 | if not args.extract_phrases: 200 | results = list(actors.map(lambda actor, batch: actor.mutate.remote(batch, args.num_k), phrase_results)) 201 | 202 | 203 | if not args.extract_phrases: 204 | results = [ex for batch in results for ex in batch] 205 | output_path = Path(output_path, condition_name + '.ptb_prob-' + str(args.perturb_prob) + '.k-' + str(args.num_k)) 206 | output_path.mkdir(parents=True, exist_ok=True) 207 | if args.split == 'train': 208 | output_file = Path(output_path, 'train.tsv') 209 | elif args.split == 'test': 210 | output_file = Path(output_path, 'test_matched.tsv') 211 | _write_tsv(_create_output_data(results), output_file) 212 | 213 | print("Time taken:", time_taken / 60) 214 | 215 | 216 | 217 | -------------------------------------------------------------------------------- /attacks/bumblebee.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | import json, random, torch 4 | from typing import Union, List, Set, Dict 5 | from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, pipeline 6 | from transformers import BartTokenizer, BartTokenizerFast, RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer 7 | from torch.nn import CrossEntropyLoss 8 | from sortedcontainers import SortedKeyList 9 | from math import copysign 10 | from simalign import SentenceAligner 11 | from nltk.translate.phrase_based import phrase_extraction 12 | import string, warnings 13 | from squad_utils import f1_max 14 | 15 | DEFAULT_IGNORE_WORDS = {'is','was','am','are','were','the','a', 'that'} 16 | 17 | class BumblebeeBase(metaclass=ABCMeta): 18 | def __init__(self, 19 | source_lg: str, 20 | target_lgs: Union[List[str],Set[str]], 21 | ignore_words: Union[List[str],Set[str]]=DEFAULT_IGNORE_WORDS, 22 | transliteration_map: Dict[str,str]=None 23 | ): 24 | 25 | self.aligner = SentenceAligner(model="xlmr", token_type="bpe", matching_methods="m") 26 | 27 | self.rtl_lgs = {'ar', 'he'} 28 | self.source_lg = source_lg 29 | self.target_lgs = set(target_lgs) 30 | self.ignore_words = set(ignore_words) 31 | self.transliteration_map = transliteration_map 32 | 33 | 34 | @abstractmethod 35 | def generate(self): 36 | return 37 | 38 | def transliterate_if_hindi(self, phrase, lg): 39 | if self.transliteration_map and lg == 'hi': 40 | return self.transliterate(phrase) 41 | else: 42 | return phrase 43 | 44 | def transliterate(self, phrase): 45 | words = phrase.split() 46 | new_words = [] 47 | for word in words: 48 | if word in self.transliteration_map: 49 | new_words.append(self.transliteration_map[word]) 50 | else: 51 | new_words.append(word) 52 | return ' '.join(new_words) 53 | 54 | 55 | def swap_phrase_simple(self, sentence, to_be_replaced, to_replace): 56 | return (' ' + sentence + ' ').replace(' ' + to_be_replaced + ' ', ' ' + to_replace + ' ', 1).strip() 57 | 58 | def swap_phrase(self, tokens, replace_start_idx, replace_end_idx, to_replace): 59 | return ' '.join(tokens[0:replace_start_idx] + [to_replace] + tokens[replace_end_idx:]) 60 | 61 | 62 | def swap_phrase_w_split_bwd(self, sentence, split_pos, to_be_replaced, to_replace): 63 | tokenized = sentence.split() 64 | back = ' '.join(tokenized[split_pos:]) 65 | front_tok = tokenized[:split_pos] 66 | 67 | new_back_tok = (' ' + back + ' ').replace(' ' + to_be_replaced + ' ', ' ' + to_replace + ' ', 1).strip().split() 68 | 69 | return ' '.join(front_tok + new_back_tok) 70 | 71 | 72 | def filter_equiv_constraint(self, phrases): 73 | sorted_phrases = sorted(phrases, key=lambda x: (x[0][0], x[0][1])) 74 | forward_constrained_candidates = [] 75 | curr_matrix_pos = sorted_phrases[0][0][1] 76 | curr_embedded_pos = sorted_phrases[0][1][1] 77 | for phrase in sorted_phrases: 78 | if phrase[0][0] > curr_matrix_pos and phrase[1][0] < curr_embedded_pos: 79 | continue 80 | curr_matrix_pos = phrase[0][1] 81 | curr_embedded_pos = phrase[1][1] 82 | forward_constrained_candidates.append(phrase) 83 | 84 | return forward_constrained_candidates 85 | 86 | def get_phrases(self, matrix_sentence: str, translations: Dict[str,str]): 87 | filtered_translations = {k: v for k,v in translations.items() if k in self.target_lgs} 88 | matrix_tokens = matrix_sentence.split() 89 | phrases = [] 90 | for lg, embedded_sentence in filtered_translations.items(): 91 | 92 | tokenized_embedded = embedded_sentence.split() 93 | alignments = self.aligner.get_word_aligns(matrix_tokens, tokenized_embedded) 94 | 95 | candidate_phrases = phrase_extraction(matrix_sentence, embedded_sentence, alignments['mwmf'], 4) 96 | # [matrix range, matrix text, embedded range, embedded text, embedded language] 97 | candidate_phrases = [(p[0], p[1], p[2], p[3], lg) for p in candidate_phrases if p[2][0] not in string.punctuation] 98 | if lg == 'zh': 99 | candidate_phrases = [(p[0], p[1], p[2], p[3].replace(' ', ''), p[4]) for p in candidate_phrases] 100 | if lg == 'th': 101 | candidate_phrases = [(p[0], p[1], p[2], p[3].replace(' ', ' ').replace(' ,', ','), p[4]) for p in candidate_phrases] 102 | if self.transliteration_map and lg == 'hi': 103 | candidate_phrases = [(p[0], p[1], p[2], self.transliterate(p[3]), p[4]) for p in candidate_phrases] 104 | 105 | phrases += candidate_phrases 106 | 107 | sorted_phrases = sorted(phrases, key=lambda x: (x[0][0], -x[0][1])) 108 | 109 | 110 | grouped_phrases = {i:[] for i in range(len(matrix_tokens))} 111 | for phrase in sorted_phrases: 112 | grouped_phrases[phrase[0][0]].append(phrase) 113 | 114 | return grouped_phrases 115 | 116 | 117 | class BumblebeePairSequenceClassification(BumblebeeBase): 118 | def __init__(self, source_lg, target_lgs, labels: List[str], transliteration_map: Dict[str,str]=None): 119 | super().__init__(source_lg, target_lgs, transliteration_map=transliteration_map) 120 | self.labels = labels 121 | 122 | def is_flipped(self, predicted, label): 123 | return predicted != label 124 | 125 | def generate(self, sentence1, sentence2, label_text, 126 | reference_translations: Union[List[Dict[str,Dict[str,str]]],Dict[str,Dict[str,str]]], 127 | beam_size=1, early_terminate=False): 128 | assert label_text in self.labels 129 | warnings.filterwarnings("ignore", category=FutureWarning) 130 | label = self.labels.index(label_text) 131 | orig_s1_tokens = sentence1.split() 132 | orig_s2_tokens = sentence2.split() 133 | num_queries = 1 134 | original_loss, init_predicted = self.get_loss(sentence1, sentence2, label) 135 | 136 | if self.is_flipped(init_predicted, label): 137 | return sentence1, sentence2, self.labels[init_predicted], {self.source_lg: -1}, num_queries, sentence1, sentence2, self.labels[init_predicted], {self.source_lg: -1} 138 | 139 | # search 140 | 141 | s1_reference_translations = reference_translations[0] 142 | s2_reference_translations = reference_translations[1] 143 | 144 | s1_phrases = self.get_phrases(sentence1, s1_reference_translations) 145 | s2_phrases = self.get_phrases(sentence2, s2_reference_translations) 146 | 147 | 148 | # init beam 149 | 150 | successful_candidates = SortedKeyList(key=lambda x: x[:-1]) 151 | 152 | s1_token_length = len(orig_s1_tokens) 153 | s2_token_length = len(orig_s2_tokens) 154 | 155 | 156 | # (loss, pos1, pos2, predicted_label, s1_tokens, s2_tokens, 157 | # lg1, prev_s1_replacement_token_positions, lg2, prev_s2_replacement_token_positions) 158 | ''' 159 | Start from the front 160 | ''' 161 | init_lg_counts = {lg: 0 for lg in self.target_lgs} 162 | curr_beam = [(original_loss, 0, 0, 163 | init_predicted, sentence1, sentence2, 164 | self.source_lg, (0, 1), 165 | self.source_lg, (0, 1), 166 | init_lg_counts)] 167 | new_beam = SortedKeyList(key=lambda x: x[:-1]) 168 | early_terminate_flag = False 169 | while curr_beam and (not early_terminate or not early_terminate_flag): 170 | for prev_loss, curr_pos1, curr_pos2, prev_predicted, prev_s1, prev_s2, prev_lg1, \ 171 | prev_s1_replacement_pos, prev_lg2, prev_s2_replacement_pos, lg_counts in curr_beam: 172 | 173 | prev_s1_tokens = prev_s1.split() 174 | prev_s2_tokens = prev_s2.split() 175 | 176 | s1_candidates = s1_phrases.get(curr_pos1) 177 | s2_candidates = s2_phrases.get(curr_pos2) 178 | 179 | if curr_pos1+1 < s1_token_length: 180 | new_beam.add((original_loss, curr_pos1+1, curr_pos2, 181 | init_predicted, prev_s1, prev_s2, 182 | self.source_lg, (curr_pos1+1, curr_pos1+2), 183 | prev_lg2, prev_s2_replacement_pos, 184 | lg_counts)) 185 | if curr_pos2+1 < s2_token_length: 186 | new_beam.add((original_loss, curr_pos1, curr_pos2+1, 187 | init_predicted, prev_s1, prev_s2, 188 | prev_lg1, prev_s1_replacement_pos, 189 | self.source_lg, (curr_pos2+1, curr_pos2+2), 190 | lg_counts)) 191 | 192 | 193 | for s1_candidate in s1_candidates: 194 | phrase_to_replace = s1_candidate[2] 195 | replacement = s1_candidate[3] 196 | replacement_lg = s1_candidate[4] 197 | replace_start_idx = s1_candidate[0][0] - s1_token_length 198 | replace_end_idx = s1_candidate[0][1] - s1_token_length 199 | 200 | if phrase_to_replace.split() != prev_s1_tokens[replace_start_idx:replace_end_idx]: 201 | continue 202 | 203 | if replacement_lg not in self.rtl_lgs and replacement_lg == prev_lg1 and s1_candidate[1][1] <= prev_s1_replacement_pos[0]: 204 | continue 205 | 206 | s1_perturbed = self.swap_phrase(prev_s1_tokens, replace_start_idx, replace_end_idx, replacement) 207 | 208 | new_loss, new_predicted = self.get_loss(s1_perturbed, prev_s2, label) 209 | new_lg_counts = lg_counts.copy() 210 | new_lg_counts[replacement_lg] += 1 211 | if self.is_flipped(new_predicted, label): 212 | successful_candidates.add((new_loss, new_predicted, s1_perturbed, prev_s2, new_lg_counts)) 213 | early_terminate_flag = True 214 | if early_terminate: break 215 | if curr_pos1+1 < s1_token_length: 216 | new_beam.add((new_loss, curr_pos1+1, curr_pos2, new_predicted, s1_perturbed, prev_s2, 217 | replacement_lg, s1_candidate[1], prev_lg2, prev_s2_replacement_pos, 218 | new_lg_counts)) 219 | num_queries += 1 220 | 221 | for s2_candidate in s2_candidates: 222 | phrase_to_replace = s2_candidate[2] 223 | replacement = s2_candidate[3] 224 | replacement_lg = s2_candidate[4] 225 | replace_start_idx = s2_candidate[0][0] - s2_token_length 226 | replace_end_idx = s2_candidate[0][1] - s2_token_length 227 | 228 | if phrase_to_replace.split() != prev_s2_tokens[replace_start_idx:replace_end_idx]: 229 | continue 230 | 231 | if replacement_lg not in self.rtl_lgs and replacement_lg == prev_lg2 and s2_candidate[1][1] <= prev_s2_replacement_pos[0]: 232 | continue 233 | 234 | s2_perturbed = self.swap_phrase(prev_s2_tokens, replace_start_idx, replace_end_idx, replacement) 235 | 236 | new_loss, new_predicted = self.get_loss(prev_s1, s2_perturbed, label) 237 | new_lg_counts = lg_counts.copy() 238 | new_lg_counts[replacement_lg] += 1 239 | if self.is_flipped(new_predicted, label): 240 | successful_candidates.add((new_loss, new_predicted, prev_s1, s2_perturbed, new_lg_counts)) 241 | early_terminate_flag = True 242 | if early_terminate: break 243 | if curr_pos2+1 > s2_token_length: 244 | new_beam.add((new_loss, curr_pos1, curr_pos2+1, new_predicted, prev_s1, s2_perturbed, 245 | prev_lg1, prev_s1_replacement_pos, replacement_lg, s2_candidate[1], 246 | new_lg_counts)) 247 | num_queries += 1 248 | 249 | curr_beam = new_beam[-beam_size:] # trim beam 250 | new_beam = SortedKeyList(key=lambda x: x[:-1]) 251 | 252 | if successful_candidates: 253 | _, final_predicted, sentence1, sentence2, final_lg_counts = successful_candidates[-1] 254 | _, lowest_final_predicted, lowest_sentence1, lowest_sentence2, lowest_final_lg_counts = successful_candidates[0] 255 | else: 256 | final_predicted = init_predicted 257 | lowest_sentence1 = sentence1 258 | lowest_sentence2 = sentence2 259 | lowest_final_predicted = init_predicted 260 | final_lg_counts = {self.source_lg: -1} 261 | lowest_final_lg_counts = {self.source_lg: -1} 262 | 263 | return sentence1, sentence2, self.labels[final_predicted], final_lg_counts, num_queries, lowest_sentence1, lowest_sentence2, self.labels[lowest_final_predicted], lowest_final_lg_counts 264 | 265 | 266 | class BumblebeePairSequenceClassificationHF(BumblebeePairSequenceClassification): 267 | def __init__(self, model_path, source_lg: str, 268 | target_lgs: Union[List[str],Set[str]], 269 | labels: List[str], 270 | is_nli=False, 271 | use_cuda=True, 272 | transliteration_map: Dict[str,str]=None): 273 | super().__init__(source_lg, target_lgs, labels, transliteration_map=transliteration_map) 274 | if torch.cuda.is_available() and use_cuda: 275 | self.device = 'cuda' 276 | else: 277 | self.device = 'cpu' 278 | config = AutoConfig.from_pretrained(model_path) 279 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 280 | if is_nli and self.tokenizer.__class__ in (RobertaTokenizer, 281 | RobertaTokenizerFast, 282 | XLMRobertaTokenizer, 283 | BartTokenizer, 284 | BartTokenizerFast): 285 | # hack to handle roberta models from huggingface 286 | self.labels[1], self.labels[2] = self.labels[2], self.labels[1] 287 | 288 | self.model = AutoModelForSequenceClassification.from_pretrained(model_path, config=config) 289 | 290 | self.model.eval() 291 | self.model.to(self.device) 292 | self.loss_fn = CrossEntropyLoss() 293 | 294 | def get_loss(self, sentence1, sentence2, label, max_seq_len=128): 295 | logits, _ = self.model_predict(sentence1, sentence2, max_seq_len) 296 | label_tensor = torch.tensor([label]).to(self.device) 297 | loss = self.loss_fn(logits, label_tensor) 298 | 299 | return loss.item(), logits.argmax().item() 300 | 301 | def model_predict(self, sentence1, sentence2, max_seq_len=512): 302 | inputs = self.tokenizer.encode_plus(sentence1, sentence2, add_special_tokens=True, max_length=max_seq_len, truncation=True) 303 | input_ids = torch.tensor(inputs["input_ids"]).unsqueeze(0).to(self.device) 304 | if "token_type_ids" in inputs.keys(): 305 | token_type_ids = torch.tensor(inputs["token_type_ids"]).unsqueeze(0).to(self.device) 306 | else: 307 | # handle XLM-R 308 | token_type_ids = torch.tensor([0]*len(input_ids)).unsqueeze(0).to(self.device) 309 | 310 | outputs = self.model(input_ids,token_type_ids=token_type_ids) 311 | logits = outputs[0] 312 | return logits, self.labels[logits.argmax().item()] 313 | 314 | 315 | class BumblebeeQuestionAnswering(BumblebeeBase): 316 | def is_flipped(self, predicted, answer_texts: List[str]): 317 | return f1_max(predicted, answer_texts) == 0 318 | 319 | def generate(self, question_dict, context, 320 | reference_translations: Dict[str,Dict[str,str]], 321 | beam_size=1): 322 | warnings.filterwarnings("ignore", category=FutureWarning) 323 | question = question_dict['question'] 324 | answer_dict = {} 325 | gold_starts = [ans['answer_start'] for ans in question_dict['answers']] 326 | gold_texts = [ans['text'] for ans in question_dict['answers']] 327 | gold_ends = [gold_starts[i]+len(text) for i, text in enumerate(gold_texts)] 328 | answer_dict['gold_char_spans'] = list(zip(gold_starts, gold_ends)) 329 | 330 | answer_dict['gold_texts'] = gold_texts 331 | 332 | 333 | num_queries = 1 334 | original_loss, init_predicted = self.get_loss(question, context, answer_dict) 335 | 336 | # search 337 | question_phrases = self.get_phrases(question, reference_translations) 338 | 339 | # init beam 340 | beam = SortedList() 341 | successful_candidates = SortedList() 342 | partially_successful_candidates = SortedList() 343 | 344 | orig_qn_tokens = question.split() 345 | qn_token_length = len(orig_qn_tokens) 346 | 347 | # (loss, pos, predicted_answer, qn_tokens) 348 | ''' 349 | Start from the front 350 | ''' 351 | beam.add((original_loss, 0, init_predicted, question, self.source_lg, (0, 1))) 352 | 353 | while beam: 354 | prev_loss, curr_pos, prev_predicted, prev_qn, prev_lg, prev_replacement_pos = beam.pop() 355 | prev_qn_tokens = prev_qn.split() 356 | 357 | qn_candidates = question_phrases.get(curr_pos) 358 | 359 | if curr_pos+1 < qn_token_length: 360 | beam.add((original_loss, curr_pos+1, init_predicted, prev_qn, self.source_lg, (curr_pos+1, curr_pos+2))) 361 | 362 | for qn_candidate in qn_candidates: 363 | phrase_to_replace = qn_candidate[2] 364 | replacement = qn_candidate[3] 365 | replacement_lg = qn_candidate[4] 366 | replace_start_idx = qn_candidate[0][0] - qn_token_length 367 | replace_end_idx = qn_candidate[0][1] - qn_token_length 368 | 369 | if phrase_to_replace.split() != prev_qn_tokens[replace_start_idx:replace_end_idx]: 370 | continue 371 | 372 | if replacement_lg not in self.rtl_lgs and replacement_lg == prev_lg and qn_candidate[1][1] <= prev_replacement_pos[0]: 373 | continue 374 | 375 | 376 | qn_perturbed = self.swap_phrase(prev_qn_tokens, replace_start_idx, replace_end_idx, replacement) 377 | 378 | new_loss, new_predicted = self.get_loss(qn_perturbed, context, answer_dict) 379 | if self.is_flipped(new_predicted, answer_dict['gold_texts']): 380 | successful_candidates.add((new_loss, new_predicted, qn_perturbed)) 381 | elif f1_max(new_predicted, answer_dict['gold_texts']) < 1.0: 382 | partially_successful_candidates.add((new_loss, new_predicted, qn_perturbed)) 383 | if curr_pos+1 < qn_token_length: 384 | beam.add((new_loss, curr_pos+1, new_predicted, qn_perturbed, replacement_lg, qn_candidate[1])) 385 | num_queries += 1 386 | 387 | beam = SortedList(beam[-beam_size:]) # trim beam 388 | 389 | if successful_candidates: 390 | _, final_predicted, question = successful_candidates[-1] 391 | _, lowest_final_predicted, lowest_question = successful_candidates[0] 392 | elif partially_successful_candidates: 393 | _, final_predicted, question = partially_successful_candidates[-1] 394 | _, lowest_final_predicted, lowest_question = partially_successful_candidates[0] 395 | else: 396 | final_predicted = init_predicted 397 | lowest_question = question 398 | lowest_final_predicted = init_predicted 399 | 400 | return question, final_predicted, f1_max(final_predicted, answer_dict['gold_texts']), lowest_question, lowest_final_predicted, f1_max(lowest_final_predicted, answer_dict['gold_texts']) 401 | 402 | 403 | def bwd_generate(self, question_dict, context, 404 | reference_translations: Dict[str,Dict[str,str]], 405 | beam_size=1): 406 | 407 | question = question_dict['question'] 408 | answer_dict = {} 409 | gold_starts = [ans['answer_start'] for ans in question_dict['answers']] 410 | gold_texts = [ans['text'] for ans in question_dict['answers']] 411 | gold_ends = [gold_starts[i]+len(text) for i, text in enumerate(gold_texts)] 412 | answer_dict['gold_char_spans'] = list(zip(gold_starts, gold_ends)) 413 | 414 | answer_dict['gold_texts'] = gold_texts 415 | 416 | 417 | num_queries = 1 418 | original_loss, init_predicted = self.get_loss(question, context, answer_dict) 419 | 420 | 421 | # search 422 | question_phrases = self.get_phrases(question, reference_translations) 423 | 424 | # init beam 425 | beam = SortedList() 426 | successful_candidates = SortedList() 427 | partially_successful_candidates = SortedList() 428 | 429 | orig_qn_tokens = question.split() 430 | qn_token_length = len(orig_qn_tokens) 431 | 432 | # (loss, pos, predicted_answer, qn_tokens) 433 | ''' 434 | Start from the back 435 | ''' 436 | beam.add((original_loss, qn_token_length-1, init_predicted, question)) 437 | 438 | while beam: 439 | prev_loss, curr_pos, prev_predicted, prev_qn = beam.pop() 440 | prev_qn_tokens = prev_qn.split() 441 | 442 | qn_candidates = question_phrases.get(curr_pos) 443 | 444 | if curr_pos-1 > -1: 445 | beam.add((original_loss, curr_pos-1, init_predicted, prev_qn)) 446 | 447 | 448 | for qn_candidate in qn_candidates: 449 | phrase_to_replace = qn_candidate[2] 450 | replacement = qn_candidate[3] 451 | 452 | if phrase_to_replace not in prev_qn: 453 | continue 454 | 455 | qn_perturbed = self.swap_phrase_w_split(prev_qn, curr_pos, phrase_to_replace, replacement) 456 | 457 | new_loss, new_predicted = self.get_loss(qn_perturbed, context, answer_dict) 458 | if self.is_flipped(new_predicted, answer_dict['gold_texts']): 459 | successful_candidates.add((new_loss, new_predicted, qn_perturbed)) 460 | elif f1_max(new_predicted, answer_dict['gold_texts']) < 0.5: 461 | partially_successful_candidates.add((new_loss, new_predicted, qn_perturbed)) 462 | if curr_pos-1 > -1: 463 | beam.add((new_loss, curr_pos-1, new_predicted, qn_perturbed)) 464 | num_queries += 1 465 | 466 | beam = SortedList(beam[-beam_size:]) # trim beam 467 | 468 | if successful_candidates: 469 | _, final_predicted, question = successful_candidates[-1] 470 | _, lowest_final_predicted, lowest_question = successful_candidates[0] 471 | elif partially_successful_candidates: 472 | _, final_predicted, question = partially_successful_candidates[-1] 473 | _, lowest_final_predicted, lowest_question = partially_successful_candidates[0] 474 | else: 475 | final_predicted = init_predicted 476 | lowest_question = question 477 | lowest_final_predicted = init_predicted 478 | 479 | return question, final_predicted, f1_max(final_predicted, answer_dict['gold_texts']), lowest_question, lowest_final_predicted, f1_max(lowest_final_predicted, answer_dict['gold_texts']) 480 | 481 | 482 | def get_lowest_loss(self, start_logits_tensor, end_logits_tensor, gold_spans): 483 | start_logits_tensor = start_logits_tensor.to('cpu') 484 | end_logits_tensor = end_logits_tensor.to('cpu') 485 | target_tensors = [(torch.tensor([gold_start]), torch.tensor([gold_end])) \ 486 | for gold_start, gold_end in gold_spans] 487 | 488 | losses = [] 489 | for target_start, target_end in target_tensors: 490 | avg_loss = (self.loss_fn(start_logits_tensor, target_start) \ 491 | + self.loss_fn(end_logits_tensor, target_end))/2 492 | losses.append(avg_loss) 493 | return min(losses).item() 494 | 495 | 496 | class BumblebeeQuestionAnsweringHF(BumblebeeQuestionAnswering): 497 | def __init__(self, model_path, source_lg: str, 498 | target_lgs: Union[List[str],Set[str]], 499 | use_cuda=True): 500 | super().__init__(source_lg, target_lgs) 501 | if torch.cuda.is_available() and use_cuda: 502 | device = 'cuda' 503 | else: 504 | device = 'cpu' 505 | self.qa_pipeline = pipeline('question-answering', model=model_path, 506 | tokenizer=model_path, device=device) 507 | 508 | 509 | def get_loss(self, question, context, answer_dict): 510 | result = self.qa_pipeline(question=question, context=context) 511 | return copysign(result['score'], -f1_max(result['answer'], answer_dict['gold_texts'])), result['answer'] 512 | -------------------------------------------------------------------------------- /attacks/polygloss.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import json, random, torch 3 | from typing import Union, List, Set, Dict 4 | from tokenizers.pre_tokenizers import BertPreTokenizer 5 | from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification 6 | from transformers import BartTokenizer, BartTokenizerFast, RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer 7 | from torch.nn import CrossEntropyLoss 8 | from sortedcontainers import SortedList 9 | from opencc import OpenCC 10 | 11 | 12 | class PolyTokenizer: 13 | contractions = {"won't": "will not", 14 | "n't": " not", 15 | "'ll": " will", 16 | "I'm": "I am", 17 | "she's": "she is", 18 | "he's": "he is", 19 | "'re": " are" 20 | } 21 | 22 | def __init__(self): 23 | self.tokenizer = BertPreTokenizer() 24 | 25 | def expand_en_contractions(self, sentence): 26 | new_sentence = sentence 27 | for contraction, expansion in self.contractions.items(): 28 | new_sentence = new_sentence.replace(contraction, expansion) 29 | return new_sentence 30 | 31 | def tokenize(self, sentence): 32 | expanded = self.expand_en_contractions(sentence) 33 | tokenized = self.tokenizer.pre_tokenize(expanded) 34 | tokens, char_ids = list(zip(*tokenized)) 35 | 36 | return list(tokens), char_ids 37 | 38 | 39 | def detokenize(self, tokens, char_ids): 40 | out_tokens = [] 41 | for i, token in enumerate(tokens): 42 | if i != 0 and char_ids[i-1][1] != char_ids[i][0]: 43 | out_tokens.append(' ') 44 | out_tokens.append(token) 45 | 46 | return ''.join(out_tokens) 47 | 48 | 49 | DEFAULT_IGNORE_WORDS = {'is','was','am','are','were','the','a', 'that'} 50 | 51 | 52 | class PolyglossBase(metaclass=ABCMeta): 53 | def __init__(self, 54 | source_lg: str, 55 | target_lgs: Union[List[str],Set[str]], 56 | map_path: str, 57 | ignore_words: Union[List[str],Set[str]]=DEFAULT_IGNORE_WORDS 58 | ): 59 | with open(map_path, 'r') as f: 60 | self.word_map = json.load(f) 61 | self.source_lg = source_lg 62 | self.target_lgs = set(target_lgs) 63 | self.ignore_words = set(ignore_words) 64 | self.polytokenizer = PolyTokenizer() 65 | self.cc = OpenCC('s2t') 66 | 67 | @abstractmethod 68 | def generate(self): 69 | pass 70 | 71 | def get_candidates(self, word, reference_translations=None): 72 | all_lgs_cands = self.word_map.get(word) 73 | if not all_lgs_cands: 74 | return [word] 75 | candidates = all_lgs_cands.items() 76 | 77 | filtered_candidates = [word] 78 | for lg, lst in candidates: 79 | if lg not in self.target_lgs: 80 | continue 81 | for w in lst: 82 | if not reference_translations or w in reference_translations[lg]: 83 | filtered_candidates.append(w) 84 | if lg =='zh': 85 | trad = self.cc.convert(w) 86 | if trad != w: filtered_candidates.append(trad) 87 | return filtered_candidates 88 | 89 | 90 | class PolyglossPairSequenceClassification(PolyglossBase): 91 | def __init__(self, source_lg, target_lgs, map_path, labels: List[str]): 92 | super().__init__(source_lg, target_lgs, map_path) 93 | self.labels = labels 94 | 95 | def is_flipped(self, predicted, label): 96 | return predicted != label 97 | 98 | def generate(self, sentence1, sentence2, label_text, beam_size=1, 99 | reference_translations: Union[List[Dict[str,Dict[str,str]]],Dict[str,Dict[str,str]]]=None): 100 | assert label_text in self.labels 101 | label = self.labels.index(label_text) 102 | orig_s1_tokens, orig_s1_char_ids = self.polytokenizer.tokenize(sentence1) 103 | orig_s2_tokens, orig_s2_char_ids = self.polytokenizer.tokenize(sentence2) 104 | num_queries = 1 105 | original_loss, init_predicted = self.get_loss(sentence1, sentence2, label) 106 | 107 | if self.is_flipped(init_predicted, label): 108 | return sentence1, sentence2, self.labels[init_predicted], num_queries, sentence1, sentence2, self.labels[init_predicted] 109 | 110 | # search 111 | 112 | # init beam 113 | beam = SortedList() 114 | successful_candidates = SortedList() 115 | 116 | s1_token_length = len(orig_s1_tokens) 117 | s2_token_length = len(orig_s2_tokens) 118 | 119 | 120 | if reference_translations: 121 | s1_reference_translations = reference_translations[0] 122 | s2_reference_translations = reference_translations[1] 123 | else: 124 | s1_reference_translations = None 125 | s2_reference_translations = None 126 | 127 | # (loss, pos1, pos2, predicted_label, s1_tokens, s2_tokens) 128 | beam.add((original_loss, 0, 0, init_predicted, orig_s1_tokens, orig_s2_tokens)) 129 | 130 | while beam: 131 | prev_loss, curr_pos1, curr_pos2, prev_predicted, prev_s1_tokens, prev_s2_tokens = beam.pop() 132 | s2_candidates = None 133 | 134 | s1_token_to_modify = prev_s1_tokens[curr_pos1] 135 | if s1_token_to_modify not in self.ignore_words: 136 | s1_candidates = self.get_candidates(s1_token_to_modify, s1_reference_translations) 137 | else: 138 | s1_candidates = [s1_token_to_modify] 139 | 140 | 141 | s2_token_to_modify = prev_s2_tokens[curr_pos2] 142 | if s2_token_to_modify not in self.ignore_words: 143 | s2_candidates = self.get_candidates(s2_token_to_modify, s2_reference_translations) 144 | else: 145 | s2_candidates = [s2_token_to_modify] 146 | 147 | 148 | for s1_candidate in s1_candidates: 149 | new_s1_tokens = prev_s1_tokens.copy() 150 | if new_s1_tokens[curr_pos1] == s1_candidate: 151 | if curr_pos1+1 < s1_token_length: 152 | beam.add((prev_loss, curr_pos1+1, curr_pos2, prev_predicted, prev_s1_tokens, prev_s2_tokens)) 153 | continue 154 | new_s1_tokens[curr_pos1] = s1_candidate 155 | s1_perturbed = self.polytokenizer.detokenize(new_s1_tokens, orig_s1_char_ids) 156 | s2_prev = self.polytokenizer.detokenize(prev_s2_tokens, orig_s2_char_ids) 157 | new_loss, new_predicted = self.get_loss(s1_perturbed, s2_prev, label) 158 | if self.is_flipped(new_predicted, label): 159 | successful_candidates.add((new_loss, new_predicted, new_s1_tokens, prev_s2_tokens)) 160 | if curr_pos1+1 < s1_token_length: 161 | beam.add((new_loss, curr_pos1+1, curr_pos2, new_predicted, new_s1_tokens, prev_s2_tokens)) 162 | num_queries += 1 163 | 164 | for s2_candidate in s2_candidates: 165 | new_s2_tokens = prev_s2_tokens.copy() 166 | if new_s2_tokens[curr_pos2] == s2_candidate: 167 | if curr_pos2+1 < s2_token_length: 168 | beam.add((prev_loss, curr_pos1, curr_pos2+1, prev_predicted, prev_s1_tokens, prev_s2_tokens)) 169 | continue 170 | new_s2_tokens[curr_pos2] = s2_candidate 171 | s1_prev = self.polytokenizer.detokenize(prev_s1_tokens, orig_s1_char_ids) 172 | s2_perturbed = self.polytokenizer.detokenize(new_s2_tokens, orig_s2_char_ids) 173 | new_loss, new_predicted = self.get_loss(s1_prev, s2_perturbed, label) 174 | if self.is_flipped(new_predicted, label): 175 | successful_candidates.add((new_loss, new_predicted, prev_s1_tokens, new_s2_tokens)) 176 | num_queries += 1 177 | if curr_pos2+1 < s2_token_length: 178 | beam.add((new_loss, curr_pos1, curr_pos2+1, new_predicted, prev_s1_tokens, new_s2_tokens)) 179 | 180 | beam = SortedList(beam[-beam_size:]) # trim beam 181 | 182 | if successful_candidates: 183 | _, final_predicted, final_s1_tokens, final_s2_tokens = successful_candidates[-1] 184 | sentence1 = self.polytokenizer.detokenize(final_s1_tokens, orig_s1_char_ids) 185 | sentence2 = self.polytokenizer.detokenize(final_s2_tokens, orig_s2_char_ids) 186 | _, lowest_final_predicted, lowest_final_s1_tokens, lowest_final_s2_tokens = successful_candidates[0] 187 | lowest_sentence1 = self.polytokenizer.detokenize(lowest_final_s1_tokens, orig_s1_char_ids) 188 | lowest_sentence2 = self.polytokenizer.detokenize(lowest_final_s2_tokens, orig_s2_char_ids) 189 | else: 190 | final_predicted = init_predicted 191 | lowest_sentence1 = sentence1 192 | lowest_sentence2 = sentence2 193 | lowest_final_predicted = init_predicted 194 | 195 | return sentence1, sentence2, self.labels[final_predicted], num_queries, lowest_sentence1, lowest_sentence2, self.labels[lowest_final_predicted] 196 | 197 | 198 | class PolyglossPairSequenceClassificationHF(PolyglossPairSequenceClassification): 199 | def __init__(self, model_path, source_lg: str, 200 | target_lgs: Union[List[str],Set[str]], 201 | map_path: str, 202 | labels: List[str], 203 | is_nli=False, 204 | use_cuda=True): 205 | super().__init__(source_lg, target_lgs, map_path, labels) 206 | if torch.cuda.is_available() and use_cuda: 207 | self.device = 'cuda' 208 | else: 209 | self.device = 'cpu' 210 | config = AutoConfig.from_pretrained(model_path) 211 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 212 | if is_nli and self.tokenizer.__class__ in (RobertaTokenizer, 213 | RobertaTokenizerFast, 214 | XLMRobertaTokenizer, 215 | BartTokenizer, 216 | BartTokenizerFast): 217 | # hack to handle roberta models from huggingface 218 | labels[1], labels[2] = labels[2], labels[1] 219 | 220 | self.model = AutoModelForSequenceClassification.from_pretrained(model_path, config=config) 221 | 222 | self.model.eval() 223 | self.model.to(self.device) 224 | self.loss_fn = CrossEntropyLoss() 225 | 226 | def get_loss(self, sentence1, sentence2, label, max_seq_len=128): 227 | logits, _ = self.model_predict(sentence1, sentence2, max_seq_len) 228 | label_tensor = torch.tensor([label]).to(self.device) 229 | loss = self.loss_fn(logits, label_tensor) 230 | 231 | return loss.item(), logits.argmax().item() 232 | 233 | def model_predict(self, sentence1, sentence2, max_seq_len=512): 234 | inputs = self.tokenizer.encode_plus(sentence1, sentence2, add_special_tokens=True, max_length=max_seq_len, truncation=True) 235 | input_ids = torch.tensor(inputs["input_ids"]).unsqueeze(0).to(self.device) 236 | if "token_type_ids" in inputs.keys(): 237 | token_type_ids = torch.tensor(inputs["token_type_ids"]).unsqueeze(0).to(self.device) 238 | else: 239 | # handle XLM-R 240 | token_type_ids = torch.tensor([0]*len(input_ids)).unsqueeze(0).to(self.device) 241 | 242 | outputs = self.model(input_ids,token_type_ids=token_type_ids) 243 | logits = outputs[0] 244 | return logits, self.labels[logits.argmax().item()] 245 | -------------------------------------------------------------------------------- /attacks/run_bumblebee_nli.py: -------------------------------------------------------------------------------- 1 | from transformers import glue_processors as processors 2 | import csv, os, argparse, ray, torch, time, json 3 | from pathlib import Path 4 | from bumblebee import BumblebeePairSequenceClassificationHF 5 | from tqdm import tqdm 6 | from ray.util import ActorPool 7 | from math import ceil 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--data", "-d", default=None, type=str, required=True, help="The input data directory, e.g., 'data/MNLI'.") 11 | parser.add_argument("--model", "-m", type=str, required=True) 12 | parser.add_argument("--output_dir", "-o", default=None, type=str, required=True, help="The output directory.") 13 | parser.add_argument("--mm", action='store_true', required=False, help="Use Mismatch dev data.") 14 | parser.add_argument("--split", '-s', default='test', type=str, required=False, help="Use the train, dev, or test data as the source.") 15 | parser.add_argument("--beam", '-b', default=1, type=int, required=False, help="Beam size.") 16 | parser.add_argument("--alpha", '-a', default=1, type=float, required=False, help="Alpha.") 17 | parser.add_argument("--tgt_langs", '-t', default='fr,es,de,zh,el,bg,ru,tr,ar,vi,th,hi,sw,ur', type=str, required=False, help="Embedded languages.") 18 | parser.add_argument("--transliterate", '-trl', action='store_true', required=False, help="Transliterate non-Latin scripts") 19 | parser.add_argument("--unsup", '-u', action='store_true', required=False, help="use unsupervised translations.") 20 | parser.add_argument("--gpu", '-g', default=0.25, type=float, required=False, help="GPU allocation per actor (% of one GPU). Total number of parallel actors is calculated using this value. Set to 0 to use CPU.") 21 | args = parser.parse_args() 22 | 23 | USE_CUDA = torch.cuda.is_available() and args.gpu > 0 24 | NUM_GPU_PER_ACTOR = args.gpu if USE_CUDA else 0 25 | SRC_LANG = 'en' 26 | TGT_LANGS = args.tgt_langs.split(',') 27 | LABELS = ["contradiction", "entailment", "neutral"] 28 | 29 | @ray.remote(num_gpus=NUM_GPU_PER_ACTOR) 30 | class PolyglotActor(object): 31 | def __init__(self, model, src_lang, tgt_langs, reference_translations, transliteration_map=None): 32 | self.polyglot = BumblebeePairSequenceClassificationHF(model, src_lang, tgt_langs, LABELS, 33 | is_nli=True, use_cuda=USE_CUDA, 34 | transliteration_map=transliteration_map) 35 | self.reference_translations = reference_translations 36 | 37 | def mutate(self, batch, beam): 38 | score = 0 39 | early_terminate = args.split == 'train' 40 | results = [] 41 | for example in tqdm(batch): 42 | prem_refs = self.reference_translations[0][example.text_a] 43 | hypo_refs = self.reference_translations[1][example.text_b] 44 | if example.label == 'contradictory': 45 | example.label = 'contradiction' 46 | refs = [prem_refs, hypo_refs] 47 | prem, hypo, text_label, lg_counts, _, \ 48 | lowest_prem, lowest_hypo, lowest_text_label, lowest_lg_counts = self.polyglot.generate(example.text_a, 49 | example.text_b, 50 | example.label, 51 | reference_translations=refs, 52 | beam_size=beam, 53 | early_terminate=early_terminate) 54 | if sum(lg_counts.values()) > 0: 55 | example.text_a = prem.replace(' ,', ',').replace(' .', '.')\ 56 | .replace(" '", "'").replace('( ', '(')\ 57 | .replace(' )', ')').replace(' ?', '?').replace(' !', '!') 58 | example.text_b = hypo.replace(' ,', ',').replace(' .', '.')\ 59 | .replace(" '", "'").replace('( ', '(')\ 60 | .replace(' )', ')').replace(' ?', '?').replace(' !', '!') 61 | 62 | example.text_a_lowest = lowest_prem.replace(' ,', ',').replace(' .', '.')\ 63 | .replace(" '", "'").replace('( ', '(')\ 64 | .replace(' )', ')').replace(' ?', '?').replace(' !', '!') 65 | example.text_b_lowest = lowest_hypo.replace(' ,', ',').replace(' .', '.')\ 66 | .replace(" '", "'").replace('( ', '(')\ 67 | .replace(' )', ')').replace(' ?', '?').replace(' !', '!') 68 | 69 | example.adv_label = text_label 70 | example.adv_label_lowest = lowest_text_label 71 | example.lg_counts = lg_counts 72 | if args.split == 'train': 73 | results.append(example) 74 | elif text_label == example.label: 75 | score +=1 76 | 77 | if args.split != 'train': # in testing mode we want to output all examples, not just the successfully perturbed ones, for easy evaluation 78 | results.append(example) 79 | 80 | return results, score 81 | 82 | 83 | def _create_output_data(examples, input_tsv_list): 84 | output = [] 85 | columns = {} 86 | for (i, line) in enumerate(input_tsv_list): 87 | output_line = line.copy() 88 | if i == 0: 89 | output_line.insert(-1, 'predicted_label') 90 | output_line.insert(-1, 'sentence1_lowest') 91 | output_line.insert(-1, 'sentence2_lowest') 92 | output_line.insert(-1, 'predicted_label_lowest') 93 | columns = {col:i for i, col in enumerate(output_line)} 94 | output.append(output_line) 95 | continue 96 | #output_line[4] = '-' 97 | #output_line[5] = '-' 98 | #output_line[6] = '-' 99 | #output_line[7] = '-' 100 | output_line[columns['sentence1']] = examples[i-1].text_a 101 | output_line[columns['sentence2']] = examples[i-1].text_b 102 | try: 103 | output_line.insert(-1, examples[i-1].adv_label) 104 | except AttributeError: 105 | output_line.insert(-1, '-') 106 | try: 107 | output_line.insert(-1, examples[i-1].text_a_lowest) 108 | output_line.insert(-1, examples[i-1].text_b_lowest) 109 | output_line.insert(-1, examples[i-1].adv_label_lowest) 110 | except AttributeError: 111 | output_line.insert(-1, '-') 112 | output_line.insert(-1, '-') 113 | output_line.insert(-1, '-') 114 | output.append(output_line) 115 | return output 116 | 117 | def _write_tsv(output, output_file): 118 | with open(output_file, "w", encoding="utf-8-sig") as f: 119 | writer = csv.writer(f, delimiter="\t", quotechar=None) 120 | for row in output: 121 | writer.writerow(row) 122 | 123 | def collate_lg_counts(examples, lgs): 124 | lg_counts = {lg:0 for lg in lgs} 125 | for example in examples: 126 | try: 127 | for lg, count in example.lg_counts.items(): 128 | lg_counts[lg] += count 129 | except AttributeError: 130 | continue 131 | return lg_counts 132 | 133 | 134 | def get_examples(data_dir, task, split): 135 | if split == 'dev': 136 | return processors[task]().get_dev_examples(data_dir) 137 | elif split == 'test': 138 | return processors[task]().get_test_examples(data_dir) 139 | elif split == 'train': 140 | return processors[task]().get_train_examples(data_dir) 141 | raise ValueError('Must be train, dev, or test') 142 | 143 | 144 | if args.unsup: 145 | unsup_param = '.unsup' 146 | else: 147 | unsup_param = '' 148 | 149 | if args.transliterate: 150 | transliterate_param = '.translit' 151 | else: 152 | transliterate_param = '' 153 | 154 | output_path = Path(args.output_dir, 155 | 'bumblebee-pair_' + \ 156 | args.data.split('/')[-1] + '.' + args.split + \ 157 | '.' + args.model.strip('/').split('/')[-1] + \ 158 | '.' + '_'.join(TGT_LANGS) + unsup_param + transliterate_param + \ 159 | '.beam-' + str(args.beam) + '.equiv_constr') 160 | 161 | 162 | output_file = args.split 163 | if args.mm: 164 | output_file += '_mismatched.tsv' 165 | input_tsv = processors['mnli-mm']()._read_tsv(Path(args.data, output_file)) 166 | examples = get_examples(args.data, 'mnli-mm', args.split) 167 | else: 168 | if args.split != 'train': 169 | output_file += '_matched' 170 | output_file += '.tsv' 171 | input_tsv = processors['mnli']()._read_tsv(Path(args.data, output_file)) 172 | examples = get_examples(args.data, 'mnli', args.split) 173 | 174 | output_file = str(Path(output_path, output_file)) 175 | print('Output file path:', output_file) 176 | 177 | if args.transliterate: 178 | transliteration_map = json.load(open('en-hi.transliterations.json','r')) 179 | else: 180 | transliteration_map = None 181 | 182 | if args.unsup: 183 | reference_translations = [json.load(open('../dictionaries/xnli-unsup-sentence1-reference-translations-en-head.json','r')), 184 | json.load(open('../dictionaries/xnli-unsup-sentence2-reference-translations-en-head.json','r'))] 185 | else: 186 | reference_translations = [json.load(open('../dictionaries/xnli-'+args.split+'-sentence1-reference-translations-en-head.json','r')), 187 | json.load(open('../dictionaries/xnli-'+args.split+'-sentence2-reference-translations-en-head.json','r'))] 188 | 189 | 190 | num_actors = int(torch.cuda.device_count() // NUM_GPU_PER_ACTOR) if USE_CUDA else 10 191 | print('Number of Polyglots:', num_actors) 192 | 193 | total_exs = len(examples) 194 | print(total_exs) 195 | len_per_batch = ceil(total_exs / num_actors) 196 | 197 | batches = [examples[i:i+len_per_batch] for i in range(0, total_exs, len_per_batch)] 198 | 199 | ray.init() 200 | actors = ActorPool([PolyglotActor.remote(args.model, SRC_LANG, TGT_LANGS, reference_translations, transliteration_map) 201 | for i in range(num_actors)]) 202 | start = time.time() 203 | results, scores = map(list, zip(*actors.map(lambda actor, batch: actor.mutate.remote(batch, args.beam), batches))) 204 | time_taken = time.time() - start 205 | results = [ex for batch in results for ex in batch] 206 | 207 | print("Acc:", str(sum(scores)/total_exs * 100)) 208 | print("Time taken:", time_taken / 60) 209 | print("Output: ", str(output_path)) 210 | output_path.mkdir(parents=True, exist_ok=True) 211 | _write_tsv(_create_output_data(results, input_tsv), output_file) 212 | 213 | lg_counts = collate_lg_counts(results, TGT_LANGS) 214 | json.dump(lg_counts, open(Path(output_path, 'lg_counts.json'),'w')) 215 | print("Adversarial language distribution: ", str(lg_counts)) 216 | with open(str(Path(output_path, args.split+'_results.txt')), 'w') as t: 217 | t.write('Acc: '+str(sum(scores)/total_exs * 100)+'\n') 218 | t.write("Time taken: "+str(time_taken / 60)) 219 | -------------------------------------------------------------------------------- /attacks/run_bumblebee_qa.py: -------------------------------------------------------------------------------- 1 | import os, argparse, ray, torch, time, json, copy 2 | from pathlib import Path 3 | from bumblebee_align import BumblebeeQuestionAnsweringHF 4 | from tqdm import tqdm 5 | from ray.util import ActorPool 6 | from math import ceil 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--data", "-d", default=None, type=str, required=True, help="The input data file, e.g., 'data/xquad.en.json'.") 10 | parser.add_argument("--model", "-m", type=str, required=True) 11 | parser.add_argument("--output_dir", "-o", default=None, type=str, required=True, help="The output directory.") 12 | parser.add_argument("--beam", '-b', default=0, type=int, required=False, help="Beam size.") 13 | parser.add_argument("--tgt_langs", '-t', default='ar,de,el,es,hi,ru,th,tr,vi,zh', type=str, required=False, help="Embedded languages.") 14 | parser.add_argument("--gpu", '-g', default=0.25, type=float, required=False, help="GPU allocation per actor (% of one GPU). Total number of parallel actors is calculated using this value. Set to 0 to use CPU.") 15 | args = parser.parse_args() 16 | 17 | 18 | USE_CUDA = torch.cuda.is_available() and args.gpu > 0 19 | NUM_GPU_PER_ACTOR = args.gpu if USE_CUDA else 0 20 | SRC_LANG = 'en' 21 | TGT_LANGS = args.tgt_langs.split(',') 22 | 23 | @ray.remote(num_gpus=NUM_GPU_PER_ACTOR) 24 | class PolyglotActor(object): 25 | def __init__(self, model, src_lang, tgt_langs, reference_translations): 26 | self.polyglot = BumblebeeQuestionAnsweringHF(model, src_lang, tgt_langs, use_cuda=USE_CUDA) 27 | self.reference_translations = reference_translations 28 | 29 | def mutate(self, batch, beam): 30 | score = 0 31 | perturbed_questions = {"highest": {}, 32 | "lowest": {}} 33 | 34 | 35 | for question_dict, context in tqdm(batch): 36 | refs = self.reference_translations[question_dict['question']] 37 | 38 | question, text_label, f1, lowest_question, lowest_text_label, lowest_f1 = self.polyglot.generate(question_dict, 39 | context, 40 | reference_translations=refs, 41 | beam_size=beam) 42 | perturbed_questions["highest"][question_dict['id']] = question 43 | perturbed_questions["lowest"][question_dict['id']] = lowest_question 44 | score += f1 45 | return perturbed_questions, score 46 | 47 | 48 | def _create_output_data(questions, input_data_path): 49 | input_data = json.load(open(input_data_path)) 50 | data = copy.deepcopy(input_data) 51 | for i, article in enumerate(input_data['data']): 52 | for j, paragraph in enumerate(article['paragraphs']): 53 | for k, qa in enumerate(paragraph['qas']): 54 | data['data'][i]['paragraphs'][j]['qas'][k]['question'] = questions[qa['id']] 55 | return data 56 | 57 | def get_examples(data_path): 58 | input_data = json.load(open(data_path)) 59 | examples = [] 60 | for i, article in enumerate(input_data['data']): 61 | for j, paragraph in enumerate(article['paragraphs']): 62 | for k, qa in enumerate(paragraph['qas']): 63 | examples.append((qa, paragraph['context'])) 64 | return examples 65 | 66 | 67 | output_path = Path(args.output_dir, 68 | 'bumblebee-squad.' + \ 69 | args.model.strip('/').split('/')[-1] + \ 70 | '.' + '_'.join(TGT_LANGS) + \ 71 | '.beam-' + str(args.beam) + '.equiv_constr'+ '.pipe') 72 | 73 | 74 | output_file = 'bumblebee.' + args.data.strip('/').split('/')[-1].split('.')[0] 75 | 76 | output_file = str(Path(output_path, output_file)) 77 | print('Output file path:', output_file) 78 | 79 | 80 | examples = get_examples(args.data) 81 | 82 | reference_translations = json.load(open('xquad-question-reference-translations-en_head-th_zh_ws_tokenized.json','r')) 83 | 84 | 85 | num_actors = int(torch.cuda.device_count() // NUM_GPU_PER_ACTOR) if USE_CUDA else 15 86 | print('Number of Polyglots:', num_actors) 87 | 88 | total_exs = len(examples) 89 | print(total_exs) 90 | len_per_batch = ceil(total_exs / num_actors) 91 | 92 | batches = [examples[i:i+len_per_batch] for i in range(0, total_exs, len_per_batch)] 93 | 94 | ray.init() 95 | actors = ActorPool([PolyglotActor.remote(args.model, SRC_LANG, TGT_LANGS, reference_translations) 96 | for i in range(num_actors)]) 97 | 98 | 99 | start = time.time() 100 | results, scores = map(list, zip(*actors.map(lambda actor, batch: actor.mutate.remote(batch, args.beam), batches))) 101 | time_taken = time.time() - start 102 | results = {'highest': {qid: ex for batch in results for qid, ex in batch['highest'].items()}, 103 | 'lowest': {qid: ex for batch in results for qid, ex in batch['lowest'].items()}} 104 | 105 | print("F1:", str(sum(scores)/total_exs * 100)) 106 | print("Time taken:", time_taken / 60) 107 | print("Output: ", str(output_path)) 108 | output_path.mkdir(parents=True, exist_ok=True) 109 | 110 | for loss in ['highest', 'lowest']: 111 | with open(output_file+'.'+loss+'.json', 'w') as outf: 112 | json.dump(_create_output_data(results[loss], args.data), outf, ensure_ascii=False, indent=4) 113 | 114 | with open(str(Path(output_path, 'results.txt')), 'w') as t: 115 | t.write('F1: '+str(sum(scores)/total_exs * 100)+'\n') 116 | t.write("Time taken: "+str(time_taken / 60)+'\n') 117 | -------------------------------------------------------------------------------- /attacks/run_polygloss_nli.py: -------------------------------------------------------------------------------- 1 | from transformers import glue_processors as processors 2 | import csv, os, argparse, ray, torch, time, json 3 | from pathlib import Path 4 | from polygloss import PolyglossPairSequenceClassificationHF 5 | from tqdm import tqdm 6 | from ray.util import ActorPool 7 | from math import ceil 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--data", "-d", default=None, type=str, required=True, help="The input data directory, e.g., 'data/MNLI'.") 11 | parser.add_argument("--model", "-m", type=str, required=True) 12 | parser.add_argument("--output_dir", "-o", default=None, type=str, required=True, help="The output directory.") 13 | parser.add_argument("--mm", action='store_true', required=False, help="Use Mismatch dev data.") 14 | parser.add_argument("--split", '-s', default='test', type=str, required=False, help="Use the dev or test data as the source.") 15 | parser.add_argument("--beam", '-b', default=0, type=int, required=False, help="Beam size.") 16 | parser.add_argument("--tgt_langs", '-t', default='zh,hi,fr,de,tr', type=str, required=False, help="Comma separated list of embedded languages.") 17 | parser.add_argument("--use_reference_translations", '-r', action='store_true', required=False, help="Filter candidates with reference translations.") 18 | parser.add_argument("--simplified_zh", '-szh', action='store_true', required=False, help="Use simplified Chinese dict") 19 | parser.add_argument("--gpu", '-g', default=0.33, type=float, required=False, help="GPU allocation per actor (% of one GPU). Total number of parallel actors is calculated using this value. Set to 0 to use CPU.") 20 | 21 | args = parser.parse_args() 22 | 23 | USE_CUDA = torch.cuda.is_available() and args.gpu > 0 24 | NUM_GPU_PER_ACTOR = args.gpu if USE_CUDA else 0 # set gpu usage 25 | SRC_LANG = 'en' # matrix language 26 | LABELS = ["contradiction", "entailment", "neutral"] 27 | 28 | @ray.remote(num_gpus=NUM_GPU_PER_ACTOR) 29 | class PolyglotActor(object): 30 | def __init__(self, model, src_lang, tgt_langs, src_tgts_map, reference_translations=None): 31 | self.polyglot = PolyglossPairSequenceClassificationHF(model, src_lang, tgt_langs, src_tgts_map, LABELS, is_nli=True, use_cuda=USE_CUDA) 32 | self.reference_translations = reference_translations 33 | 34 | def mutate(self, batch, beam): 35 | score = 0 36 | for example in tqdm(batch): 37 | prem_refs = self.reference_translations[0][example.text_a] if self.reference_translations else None 38 | hypo_refs = self.reference_translations[1][example.text_b] if self.reference_translations else None 39 | refs = [prem_refs, hypo_refs] 40 | prem, hypo, text_label, _, lowest_prem, lowest_hypo, lowest_text_label = self.polyglot.generate(example.text_a, 41 | example.text_b, 42 | example.label, 43 | beam_size=beam, 44 | reference_translations=refs) 45 | if text_label != example.label: 46 | example.text_a = prem 47 | example.text_b = hypo 48 | 49 | example.text_a_lowest = lowest_prem 50 | example.text_b_lowest = lowest_hypo 51 | 52 | example.adv_label = text_label 53 | example.adv_label_lowest = lowest_text_label 54 | else: 55 | score +=1 56 | return batch, score 57 | 58 | def _create_output_data(examples, input_tsv_list): 59 | output = [] 60 | columns = {} 61 | for (i, line) in enumerate(input_tsv_list): 62 | output_line = line.copy() 63 | if i == 0: 64 | output_line.insert(-1, 'predicted_label') 65 | output_line.insert(-1, 'sentence1_lowest') 66 | output_line.insert(-1, 'sentence2_lowest') 67 | output_line.insert(-1, 'predicted_label_lowest') 68 | columns = {col:i for i, col in enumerate(output_line)} 69 | output.append(output_line) 70 | continue 71 | output_line[4] = '-' 72 | output_line[5] = '-' 73 | output_line[6] = '-' 74 | output_line[7] = '-' 75 | output_line[columns['sentence1']] = examples[i-1].text_a 76 | output_line[columns['sentence2']] = examples[i-1].text_b 77 | try: 78 | output_line.insert(-1, examples[i-1].adv_label) 79 | except AttributeError: 80 | output_line.insert(-1, '-') 81 | try: 82 | output_line.insert(-1, examples[i-1].text_a_lowest) 83 | output_line.insert(-1, examples[i-1].text_b_lowest) 84 | output_line.insert(-1, examples[i-1].adv_label_lowest) 85 | except AttributeError: 86 | output_line.insert(-1, '-') 87 | output_line.insert(-1, '-') 88 | output_line.insert(-1, '-') 89 | output.append(output_line) 90 | return output 91 | 92 | def _write_tsv(output, output_file): 93 | with open(output_file, "w", encoding="utf-8-sig") as f: 94 | writer = csv.writer(f, delimiter="\t", quotechar=None) 95 | for row in output: 96 | writer.writerow(row) 97 | 98 | def get_examples(data_dir, task, split): 99 | if split == 'dev': 100 | return processors[task]().get_dev_examples(data_dir) 101 | if split == 'test': 102 | return processors[task]().get_test_examples(data_dir) 103 | raise ValueError('Must be dev or test') 104 | 105 | 106 | TGT_LANGS = args.tgt_langs.split(',') 107 | 108 | reference_translations = None 109 | refs_var = 'no_ref' 110 | if args.use_reference_translations: 111 | reference_translations = [json.load(open('../dictionaries/xnli-'+args.split+'-sentence1-reference-translations-en-head.json','r')), 112 | json.load(open('../dictionaries/xnli-'+args.split+'-sentence2-reference-translations-en-head.json','r'))] 113 | refs_var = 'ref_constrained' 114 | 115 | output_path = Path(args.output_dir, 116 | 'polygloss_pairseqcls_' + \ 117 | args.data.strip('/').split('/')[-1] + \ 118 | '.' + args.model.split('/')[-1] + \ 119 | '.' + '_'.join(TGT_LANGS) + \ 120 | '.beam-' + str(args.beam) + \ 121 | '.' + refs_var) 122 | output_path.mkdir(parents=True, exist_ok=True) 123 | 124 | output_file = args.split + '_' 125 | if args.mm: 126 | output_file += 'mismatched' 127 | else: 128 | output_file += 'matched' 129 | output_file += '.tsv' 130 | output_file = str(Path(output_path, output_file)) 131 | print('Output file path:', output_file) 132 | 133 | if args.mm: 134 | input_tsv = processors['mnli-mm']()._read_tsv(args.data+'/'+args.split +'_mismatched.tsv') 135 | examples = get_examples(args.data, 'mnli-mm', args.split) 136 | else: 137 | input_tsv = processors['mnli']()._read_tsv(args.data+'/'+args.split +'_matched.tsv') 138 | examples = get_examples(args.data, 'mnli', args.split) 139 | 140 | if args.simplified_zh: 141 | word_map = 'en_to_all_map_simplified_zh.json' 142 | else: 143 | word_map = 'en_to_all_map.json' 144 | 145 | 146 | num_actors = int(torch.cuda.device_count() // NUM_GPU_PER_ACTOR) if USE_CUDA else int(25 // max(1, 0.5 * args.beam)) 147 | print('Number of Polyglots:', num_actors) 148 | 149 | total_exs = len(examples) 150 | print(total_exs) 151 | len_per_batch = ceil(total_exs / num_actors) 152 | 153 | batches = [examples[i:i+len_per_batch] for i in range(0, total_exs, len_per_batch)] 154 | 155 | ray.init() 156 | actors = ActorPool([PolyglotActor.remote(args.model, SRC_LANG, TGT_LANGS, word_map, reference_translations) 157 | for i in range(num_actors)]) 158 | start = time.time() 159 | results, scores = map(list, zip(*actors.map(lambda actor, batch: actor.mutate.remote(batch, args.beam), batches))) 160 | time_taken = time.time() - start 161 | results = [ex for batch in results for ex in batch] 162 | 163 | print("Acc:", str(sum(scores)/total_exs * 100)) 164 | print("Time taken:", time_taken / 60) 165 | print("Output: ", str(output_path)) 166 | 167 | _write_tsv(_create_output_data(results, input_tsv), output_file) 168 | with open(str(Path(output_path, 'time_taken.txt')), 'w') as t: 169 | t.write('Acc: '+str(sum(scores)/total_exs * 100)+'\n') 170 | t.write("Time taken:"+str(time_taken / 60)) 171 | -------------------------------------------------------------------------------- /attacks/squad_utils.py: -------------------------------------------------------------------------------- 1 | # SQuAD evaluation script 2 | import collections, re, string 3 | 4 | def normalize_answer(s): 5 | """Lower text and remove punctuation, articles and extra whitespace.""" 6 | def remove_articles(text): 7 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 8 | return re.sub(regex, ' ', text) 9 | 10 | def white_space_fix(text): 11 | return ' '.join(text.split()) 12 | 13 | def remove_punc(text): 14 | exclude = set(string.punctuation) 15 | return ''.join(ch for ch in text if ch not in exclude) 16 | 17 | def lower(text): 18 | return text.lower() 19 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 20 | 21 | 22 | def get_tokens(s): 23 | if not s: 24 | return [] 25 | return normalize_answer(s).split() 26 | 27 | 28 | def compute_exact(a_gold, a_pred): 29 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 30 | 31 | 32 | def compute_f1(a_gold, a_pred): 33 | gold_toks = get_tokens(a_gold) 34 | pred_toks = get_tokens(a_pred) 35 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 36 | num_same = sum(common.values()) 37 | if len(gold_toks) == 0 or len(pred_toks) == 0: 38 | return int(gold_toks == pred_toks) 39 | if num_same == 0: 40 | return 0 41 | precision = 1.0 * num_same / len(pred_toks) 42 | recall = 1.0 * num_same / len(gold_toks) 43 | f1 = (2 * precision * recall) / (precision + recall) 44 | return f1 45 | 46 | 47 | def metric_max(metric, pred, labels): 48 | scores = [] 49 | for label in labels: 50 | score = metric(pred, label) 51 | scores.append(score) 52 | return max(scores) 53 | 54 | 55 | def f1_max(pred, labels): 56 | return metric_max(compute_f1, pred, labels) 57 | -------------------------------------------------------------------------------- /scripts/create_polygloss_dictionaries_from_muse.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import listdir 3 | from os.path import isfile, join 4 | from opencc import OpenCC 5 | cc = OpenCC('t2s') 6 | 7 | ''' 8 | Output format 9 | { 10 | "word": { 11 | "lg": [word, word, word], 12 | "lg": [word, word, word], 13 | } 14 | } 15 | ''' 16 | 17 | # get all files in dictionaries and lg ids, each dictionary should be named en-.txt 18 | dict_path = './dictionaries' 19 | dict_files = {f.split('.')[0].split('-')[1]:f for f in listdir(dict_path) if isfile(join(dict_path, f))} 20 | print(dict_files) 21 | 22 | en_to_all_map = {} 23 | 24 | # populate en_to_all_map, use simplified zh 25 | for lg_id, file in dict_files.items(): 26 | with open(join(dict_path, file),'r') as fin: 27 | for line in fin: 28 | if '\t' in line: 29 | en_word, other_word = line.strip().split('\t') 30 | else: 31 | en_word, other_word = line.strip().split() 32 | 33 | if en_word == other_word: 34 | continue 35 | print(en_word, other_word) 36 | if not en_to_all_map.get(en_word): 37 | en_to_all_map[en_word] = {} 38 | if not en_to_all_map[en_word].get(lg_id): 39 | en_to_all_map[en_word][lg_id] = [] 40 | if lg_id == 'zh': 41 | other_word = cc.convert(other_word) 42 | print(other_word) 43 | en_to_all_map[en_word][lg_id].append(other_word) 44 | 45 | 46 | 47 | json.dump(en_to_all_map,open('en_to_all_map_simplified_zh.json','w'),indent=4,ensure_ascii=False) 48 | -------------------------------------------------------------------------------- /scripts/extract-xnli-sentences-to-dict.py: -------------------------------------------------------------------------------- 1 | import argparse, json 2 | from pathlib import Path 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--data", "-d", default=None, type=str, required=True, help="The input data file, e.g., 'data/xnli.tsv'.") 6 | parser.add_argument("--matrix", "-m", default='en', type=str, required=True, help="The matrix language.") 7 | parser.add_argument("--output", "-o", default='../dictionaries', type=str, required=True, help="The output directory.") 8 | parser.add_argument("--split", '-s', default='test', type=str, required=False, help="train, dev, or test.") 9 | 10 | 11 | sentence1s = {} 12 | sentence2s = {} 13 | 14 | with open(args.data,'r') as fin: 15 | headerline = next(fin).strip() 16 | header = {col:i for i, col in enumerate(headerline.split('\t'))} 17 | for i, line in enumerate(fin): 18 | cells = line.split('\t') 19 | if not sentence1s.get(cells[header['pairID']]): 20 | sentence1s[cells[header['pairID']]] = {} 21 | if not sentence2s.get(cells[header['pairID']]): 22 | sentence2s[cells[header['pairID']]] = {} 23 | if cells[header['language']] == 'th' or cells[header['language']] == 'zh': 24 | sentence1s[cells[header['pairID']]][cells[header['language']]] = cells[header['sentence1_tokenized']] 25 | sentence2s[cells[header['pairID']]][cells[header['language']]] = cells[header['sentence2_tokenized']] 26 | else: 27 | sentence1s[cells[header['pairID']]][cells[header['language']]] = cells[header['sentence1']] 28 | sentence2s[cells[header['pairID']]][cells[header['language']]] = cells[header['sentence2']] 29 | 30 | new_sentence1s = {} 31 | new_sentence2s = {} 32 | 33 | for d in (new_sentence1s, new_sentence2s): 34 | for k,v in d.items(): 35 | d[v[args.matrix]] = v 36 | 37 | json.dump(new_sentence1s, open(Path(args.output,'xnli-'+args.split+'-sentence1-reference-translations-'+args.matrix+'-head.json'),'w'), indent=False, ensure_ascii=False) 38 | json.dump(new_sentence2s, open(Path(args.output,'xnli-'+args.split+'-sentence2-reference-translations-'+args.matrix+'-head.json'),'w'), indent=False, ensure_ascii=False) 39 | -------------------------------------------------------------------------------- /scripts/extract-xquad-questions-to-dict.py: -------------------------------------------------------------------------------- 1 | import argparse, json, jieba 2 | from pythainlp.tokenize import word_tokenize 3 | from pathlib import Path 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--data_dir", "-d", default=None, type=str, required=True, help="The input data directory, e.g., 'data'. Files are expected to be named 'xquad..json'") 7 | parser.add_argument("--matrix", "-m", default='en', type=str, required=True, help="The matrix language.") 8 | parser.add_argument("--output", "-o", default='../dictionaries', type=str, required=True, help="The output directory.") 9 | 10 | # Extract question 11 | 12 | 13 | data = {} 14 | 15 | for lg in ['en', 'es', 'de', 'el', 'ru', 'tr', 'ar', 'vi', 'th', 'zh', 'hi']: 16 | xquad_lg_data = json.load(open(Path(args.data_dir, 'xquad.'+lg+'.json'), 'r')) 17 | for article in xquad_lg_data['data']: 18 | for paragraph in article['paragraphs']: 19 | for qa in paragraph['qas']: 20 | question = qa['question'] 21 | if lg == 'zh': 22 | question = ' '.join(jieba.cut(qa['question'], cut_all=False)) 23 | 24 | if lg == 'th': 25 | question = ' '.join(word_tokenize(qa['question'])) 26 | 27 | if not data.get(qa['id']): 28 | data[qa['id']] = {} 29 | data[qa['id']][lg] = question 30 | 31 | new_data = {} 32 | for k,v in data.items(): 33 | new_data[v[args.matrix]] = v 34 | 35 | json.dump(new_data, open(Path(args.output,'xquad-question-reference-translations-'+args.matrix+'-head.json'),'w'), indent=False, ensure_ascii=False) 36 | -------------------------------------------------------------------------------- /scripts/extract_en_xnli.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument("--data", "-d", default=None, type=str, required=True, help="The input data file, e.g., 'data/XNLI/test.tsv'.") 5 | parser.add_argument("--output_dir", "-o", default=None, type=str, required=True, help="The output directory.") 6 | args = parse.parse_args() 7 | 8 | LG = 'en' 9 | 10 | out_file_path = os.path.join(args.output_dir, 'xnli-en', 'test_matched.tsv') 11 | 12 | if not os.path.exists(out_dir): 13 | os.makedirs(out_dir) 14 | 15 | with open(args.data, 'r') as fin, open(out_file_path, 'w') as fout: 16 | headerline = next(fin) 17 | fout.write(headerline) 18 | header = {col:i for i, col in enumerate(headerline.strip().split('\t'))} 19 | for i, line in enumerate(fin): 20 | cells = line.split('\t') 21 | if cells[header['language']] == LG: 22 | fout.write(line) 23 | -------------------------------------------------------------------------------- /scripts/generate_xnli_cleanDL.py: -------------------------------------------------------------------------------- 1 | import json, random, argparse 2 | 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--data", "-d", default=None, type=str, required=True, help="The input data file, e.g., 'data/XNLI/test.tsv'.") 6 | parser.add_argument("--output_file", "-o", default=None, type=str, required=True, help="The output file path and name.") 7 | args = parse.parse_args() 8 | 9 | 10 | reference_translations = [json.load(open('../dictionaries/xnli-extended-sentence1-reference-translations-en_head.json','r')), 11 | json.load(open('../dictionaries/xnli-extended-sentence1-reference-translations-en_head.json','r'))] 12 | 13 | 14 | languages = ['en','fr','es','de','zh','el','bg','ru','tr','ar','vi','th','hi','sw','ur'] 15 | 16 | with open(args.data,'r') as fin, open(args.output_file,'w') as fout: 17 | headerline = next(fin).strip() 18 | fout.write(headerline+'\n') 19 | header = {col:i for i, col in enumerate(headerline.split('\t'))} 20 | for line in fin: 21 | cells = line.strip().split('\t') 22 | lg1 = random.choice(languages) 23 | lg2 = random.choice([lg for lg in languages if lg != lg1]) 24 | new_sent1 = reference_translations[0][cells[header['sentence1']]][lg1] 25 | new_sent2 = reference_translations[1][cells[header['sentence2']]][lg2] 26 | cells[header['sentence1']] = new_sent1 27 | cells[header['sentence2']] = new_sent2 28 | fout.write('\t'.join(cells)+'\n') 29 | -------------------------------------------------------------------------------- /scripts/language-opus-nmt-map.json: -------------------------------------------------------------------------------- 1 | { 2 | "af": "Helsinki-NLP/opus-mt-en-af", 3 | "sq": "Helsinki-NLP/opus-mt-en-sq", 4 | "ar": "Helsinki-NLP/opus-mt-en-ar", 5 | "bn": "", 6 | "bs": "", 7 | "bg": "Helsinki-NLP/opus-mt-en-bg", 8 | "ca": "Helsinki-NLP/opus-mt-en-ca", 9 | "zh": "Helsinki-NLP/opus-mt-en-zh", 10 | "hr": "", 11 | "cs": "Helsinki-NLP/opus-mt-en-cs", 12 | "da": "Helsinki-NLP/opus-mt-en-da", 13 | "nl": "Helsinki-NLP/opus-mt-en-nl", 14 | "et": "Helsinki-NLP/opus-mt-en-et", 15 | "tl": "Helsinki-NLP/opus-mt-en-tl", 16 | "fi": "Helsinki-NLP/opus-mt-en-fi", 17 | "fr": "Helsinki-NLP/opus-mt-en-fr", 18 | "de": "Helsinki-NLP/opus-mt-en-de", 19 | "el": "Helsinki-NLP/opus-mt-en-el", 20 | "he": "Helsinki-NLP/opus-mt-en-he", 21 | "hi": "Helsinki-NLP/opus-mt-en-hi", 22 | "hu": "Helsinki-NLP/opus-mt-en-hu", 23 | "id": "Helsinki-NLP/opus-mt-en-id", 24 | "it": "Helsinki-NLP/opus-mt-en-it", 25 | "ja": "Helsinki-NLP/opus-mt-en-jap", 26 | "lt": "", 27 | "mk": "Helsinki-NLP/opus-mt-en-mk", 28 | "ms": "", 29 | "no": "", 30 | "fa": "", 31 | "pl": "", 32 | "pt": "", 33 | "ro": "Helsinki-NLP/opus-mt-en-ro", 34 | "ru": "Helsinki-NLP/opus-mt-en-ru", 35 | "sk": "Helsinki-NLP/opus-mt-en-sk", 36 | "sl": "", 37 | "es": "Helsinki-NLP/opus-mt-en-es", 38 | "sv": "Helsinki-NLP/opus-mt-en-sv", 39 | "ta": "", 40 | "tr": "Helsinki-NLP/opus-mt-en-trk", 41 | "uk": "Helsinki-NLP/opus-mt-en-uk", 42 | "vi": "Helsinki-NLP/opus-mt-en-vi" 43 | } 44 | -------------------------------------------------------------------------------- /scripts/run_sentiment_analysis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import logging 20 | import os 21 | import random 22 | import sys 23 | from dataclasses import dataclass, field 24 | from typing import Optional 25 | 26 | import datasets 27 | import numpy as np 28 | from datasets import load_dataset, load_metric, DatasetDict 29 | 30 | import transformers 31 | from transformers import ( 32 | AutoConfig, 33 | AutoModelForSequenceClassification, 34 | AutoTokenizer, 35 | DataCollatorWithPadding, 36 | EvalPrediction, 37 | HfArgumentParser, 38 | PretrainedConfig, 39 | Trainer, 40 | TrainingArguments, 41 | default_data_collator, 42 | set_seed, 43 | ) 44 | from transformers.trainer_utils import get_last_checkpoint 45 | from transformers.utils import check_min_version 46 | from transformers.utils.versions import require_version 47 | 48 | 49 | # + 50 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 51 | #check_min_version("4.12.0.dev0") 52 | # - 53 | 54 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 55 | 56 | task_to_keys = { 57 | "cola": ("sentence", None), 58 | "mnli": ("premise", "hypothesis"), 59 | "mrpc": ("sentence1", "sentence2"), 60 | "qnli": ("question", "sentence"), 61 | "qqp": ("question1", "question2"), 62 | "rte": ("sentence1", "sentence2"), 63 | "sst2": ("sentence", None), 64 | "stsb": ("sentence1", "sentence2"), 65 | "wnli": ("sentence1", "sentence2"), 66 | "tweet_eval": ("text", None), 67 | } 68 | 69 | logger = logging.getLogger(__name__) 70 | 71 | 72 | @dataclass 73 | class DataTrainingArguments: 74 | """ 75 | Arguments pertaining to what data we are going to input our model for training and eval. 76 | 77 | Using `HfArgumentParser` we can turn this class 78 | into argparse arguments to be able to specify them on 79 | the command line. 80 | """ 81 | 82 | task_name: Optional[str] = field( 83 | default=None, 84 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 85 | ) 86 | dataset_name: Optional[str] = field( 87 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 88 | ) 89 | dataset_config_name: Optional[str] = field( 90 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 91 | ) 92 | max_seq_length: int = field( 93 | default=128, 94 | metadata={ 95 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 96 | "than this will be truncated, sequences shorter will be padded." 97 | }, 98 | ) 99 | overwrite_cache: bool = field( 100 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 101 | ) 102 | pad_to_max_length: bool = field( 103 | default=True, 104 | metadata={ 105 | "help": "Whether to pad all samples to `max_seq_length`. " 106 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 107 | }, 108 | ) 109 | max_train_samples: Optional[int] = field( 110 | default=None, 111 | metadata={ 112 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 113 | "value if set." 114 | }, 115 | ) 116 | max_eval_samples: Optional[int] = field( 117 | default=None, 118 | metadata={ 119 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 120 | "value if set." 121 | }, 122 | ) 123 | max_predict_samples: Optional[int] = field( 124 | default=None, 125 | metadata={ 126 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 127 | "value if set." 128 | }, 129 | ) 130 | train_file: Optional[str] = field( 131 | default=None, metadata={"help": "A csv or a json file containing the training data."} 132 | ) 133 | validation_file: Optional[str] = field( 134 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 135 | ) 136 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 137 | 138 | def __post_init__(self): 139 | if self.task_name is not None: 140 | self.task_name = self.task_name.lower() 141 | if self.task_name not in task_to_keys.keys(): 142 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 143 | elif self.dataset_name is not None: 144 | pass 145 | elif self.train_file is None and self.validation_file is None: 146 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 147 | elif self.train_file: 148 | train_extension = self.train_file.split(".")[-1] 149 | assert train_extension in ["tsv", "csv", "json"], "`train_file` should be a csv or a json file." 150 | elif self.validation_file: 151 | validation_extension = self.validation_file.split(".")[-1] 152 | assert validation_extension in ["tsv", "csv", "json"] 153 | #assert ( 154 | # validation_extension == train_extension 155 | #), "`validation_file` should have the same extension (csv or json) as `train_file`." 156 | 157 | 158 | @dataclass 159 | class ModelArguments: 160 | """ 161 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 162 | """ 163 | 164 | model_name_or_path: Optional[str] = field( 165 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 166 | ) 167 | config_name: Optional[str] = field( 168 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 169 | ) 170 | tokenizer_name: Optional[str] = field( 171 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 172 | ) 173 | cache_dir: Optional[str] = field( 174 | default=None, 175 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 176 | ) 177 | use_fast_tokenizer: bool = field( 178 | default=True, 179 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 180 | ) 181 | model_revision: str = field( 182 | default="main", 183 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 184 | ) 185 | use_auth_token: bool = field( 186 | default=False, 187 | metadata={ 188 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 189 | "with private models)." 190 | }, 191 | ) 192 | 193 | 194 | def main(): 195 | # See all possible arguments in src/transformers/training_args.py 196 | # or by passing the --help flag to this script. 197 | # We now keep distinct sets of args, for a cleaner separation of concerns. 198 | 199 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 200 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 201 | # If we pass only one argument to the script and it's the path to a json file, 202 | # let's parse it to get our arguments. 203 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 204 | else: 205 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 206 | 207 | # Setup logging 208 | logging.basicConfig( 209 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 210 | datefmt="%m/%d/%Y %H:%M:%S", 211 | handlers=[logging.StreamHandler(sys.stdout)], 212 | ) 213 | 214 | log_level = training_args.get_process_log_level() 215 | logger.setLevel(log_level) 216 | datasets.utils.logging.set_verbosity(log_level) 217 | transformers.utils.logging.set_verbosity(log_level) 218 | transformers.utils.logging.enable_default_handler() 219 | transformers.utils.logging.enable_explicit_format() 220 | 221 | # Log on each process the small summary: 222 | logger.warning( 223 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 224 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 225 | ) 226 | logger.info(f"Training/evaluation parameters {training_args}") 227 | 228 | # Detecting last checkpoint. 229 | last_checkpoint = None 230 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 231 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 232 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 233 | raise ValueError( 234 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 235 | "Use --overwrite_output_dir to overcome." 236 | ) 237 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 238 | logger.info( 239 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 240 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 241 | ) 242 | 243 | # Set seed before initializing model. 244 | set_seed(training_args.seed) 245 | 246 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 247 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 248 | # 249 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 250 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 251 | # label if at least two columns are provided. 252 | # 253 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 254 | # single column. You can easily tweak this behavior (see below) 255 | # 256 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 257 | # download the dataset. 258 | if data_args.task_name is not None: 259 | # Downloading and loading a dataset from the hub. 260 | raw_datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) 261 | elif data_args.dataset_name is not None: 262 | # Downloading and loading a dataset from the hub. 263 | raw_datasets = load_dataset( 264 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 265 | ) 266 | else: 267 | # Loading a dataset from your local files. 268 | # CSV/JSON training and evaluation files are needed. 269 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 270 | 271 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 272 | # when you use `do_predict` without specifying a GLUE benchmark task. 273 | if training_args.do_predict: 274 | if data_args.test_file is not None: 275 | train_extension = data_args.train_file.split(".")[-1] 276 | test_extension = data_args.test_file.split(".")[-1] 277 | assert ( 278 | test_extension == train_extension 279 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 280 | data_files["test"] = data_args.test_file 281 | else: 282 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 283 | 284 | for key in data_files.keys(): 285 | logger.info(f"load a local file for {key}: {data_files[key]}") 286 | 287 | if (data_args.train_file and data_args.train_file.endswith(".csv")) or (data_args.validation_file and data_args.validation_file.endswith(".csv")): 288 | # Loading a dataset from local csv files 289 | raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) 290 | elif (data_args.train_file and data_args.train_file.endswith(".tsv")) or (data_args.validation_file and data_args.validation_file.endswith(".tsv")): 291 | # Loading a dataset from local csv files 292 | raw_datasets = DatasetDict({"train": load_dataset("csv", delimiter="\t", data_files=data_args.train_file, cache_dir=model_args.cache_dir, split='train'), 293 | "validation": load_dataset("csv", delimiter="\t", data_files=data_args.validation_file, cache_dir=model_args.cache_dir, split='train')}) 294 | else: 295 | # Loading a dataset from local json files 296 | raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) 297 | # See more about loading any type of standard or custom dataset at 298 | # https://huggingface.co/docs/datasets/loading_datasets.html. 299 | 300 | # Labels 301 | if data_args.task_name is not None: 302 | is_regression = data_args.task_name == "stsb" 303 | if not is_regression: 304 | label_list = raw_datasets["train"].features["label"].names 305 | num_labels = len(label_list) 306 | else: 307 | num_labels = 1 308 | else: 309 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 310 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 311 | if is_regression: 312 | num_labels = 1 313 | else: 314 | # A useful fast method: 315 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 316 | label_list = raw_datasets["train"].unique("label") #["negative", "neutral", "positive"] 317 | label_list.sort() # Let's sort it for determinism 318 | num_labels = len(label_list) 319 | 320 | # Load pretrained model and tokenizer 321 | # 322 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 323 | # download model & vocab. 324 | config = AutoConfig.from_pretrained( 325 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 326 | num_labels=num_labels, 327 | finetuning_task=data_args.task_name, 328 | cache_dir=model_args.cache_dir, 329 | revision=model_args.model_revision, 330 | use_auth_token=True if model_args.use_auth_token else None, 331 | ) 332 | tokenizer = AutoTokenizer.from_pretrained( 333 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 334 | cache_dir=model_args.cache_dir, 335 | use_fast=model_args.use_fast_tokenizer, 336 | revision=model_args.model_revision, 337 | use_auth_token=True if model_args.use_auth_token else None, 338 | ) 339 | if model_args.model_name_or_path: 340 | model = AutoModelForSequenceClassification.from_pretrained( 341 | model_args.model_name_or_path, 342 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 343 | config=config, 344 | cache_dir=model_args.cache_dir, 345 | revision=model_args.model_revision, 346 | use_auth_token=True if model_args.use_auth_token else None, 347 | ) 348 | else: 349 | logger.info("Training new model from scratch") 350 | model = AutoModelForSequenceClassification.from_config(config) 351 | 352 | # Preprocessing the raw_datasets 353 | if data_args.task_name is not None: 354 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 355 | else: 356 | sentence1_key, sentence2_key = 'text', None 357 | '''# Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 358 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 359 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 360 | sentence1_key, sentence2_key = "sentence1", "sentence2" 361 | else: 362 | if len(non_label_column_names) >= 2: 363 | sentence1_key, sentence2_key = non_label_column_names[:2] 364 | else: 365 | sentence1_key, sentence2_key = non_label_column_names[0], None''' 366 | 367 | # Padding strategy 368 | if data_args.pad_to_max_length: 369 | padding = "max_length" 370 | else: 371 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 372 | padding = False 373 | 374 | # Some models have set the order of the labels to use, so let's make sure we do use it. 375 | label_to_id = None 376 | if ( 377 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 378 | and data_args.task_name is not None 379 | and not is_regression 380 | ): 381 | # Some have all caps in their config, some don't. 382 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 383 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 384 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 385 | else: 386 | logger.warning( 387 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 388 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 389 | "\nIgnoring the model labels as a result.", 390 | ) 391 | elif data_args.task_name is None and not is_regression and isinstance(label_list[0], str): 392 | label_to_id = {v: i for i, v in enumerate(label_list)} 393 | 394 | if label_to_id is not None: 395 | model.config.label2id = label_to_id 396 | model.config.id2label = {id: label for label, id in config.label2id.items()} 397 | elif data_args.task_name is not None and not is_regression: 398 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 399 | model.config.id2label = {id: label for label, id in config.label2id.items()} 400 | 401 | if data_args.max_seq_length > tokenizer.model_max_length: 402 | logger.warning( 403 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 404 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 405 | ) 406 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 407 | 408 | def preprocess_function(examples): 409 | # Tokenize the texts 410 | args = ( 411 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 412 | ) 413 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 414 | 415 | # Map labels to IDs (not necessary for GLUE tasks) 416 | if label_to_id is not None and "label" in examples and isinstance(examples["label"][0], str): 417 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 418 | return result 419 | 420 | with training_args.main_process_first(desc="dataset map pre-processing"): 421 | raw_datasets = raw_datasets.map( 422 | preprocess_function, 423 | batched=True, 424 | load_from_cache_file=not data_args.overwrite_cache, 425 | desc="Running tokenizer on dataset", 426 | ) 427 | if training_args.do_train: 428 | if "train" not in raw_datasets: 429 | raise ValueError("--do_train requires a train dataset") 430 | train_dataset = raw_datasets["train"] 431 | if data_args.max_train_samples is not None: 432 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 433 | 434 | if training_args.do_eval: 435 | if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: 436 | raise ValueError("--do_eval requires a validation dataset") 437 | eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] 438 | if data_args.max_eval_samples is not None: 439 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 440 | 441 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 442 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 443 | raise ValueError("--do_predict requires a test dataset") 444 | predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"] 445 | if data_args.max_predict_samples is not None: 446 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 447 | 448 | # Log a few random samples from the training set: 449 | if training_args.do_train: 450 | for index in random.sample(range(len(train_dataset)), 3): 451 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 452 | 453 | # Get the metric function 454 | if data_args.task_name is not None: 455 | metric = load_metric("glue", data_args.task_name) 456 | else: 457 | #metric = load_metric("accuracy") 458 | metric = load_metric("f1") 459 | 460 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 461 | # predictions and label_ids field) and has to return a dictionary string to float. 462 | def compute_metrics(p: EvalPrediction): 463 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 464 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 465 | if data_args.task_name is not None: 466 | result = metric.compute(predictions=preds, references=p.label_ids) 467 | if len(result) > 1: 468 | result["combined_score"] = np.mean(list(result.values())).item() 469 | return result 470 | elif is_regression: 471 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 472 | else: 473 | result = metric.compute(predictions=preds, references=p.label_ids, average="macro") 474 | if len(result) > 1: 475 | result["macroF1"] = np.mean(list(result.values())).item() 476 | return result 477 | #return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 478 | 479 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 480 | if data_args.pad_to_max_length: 481 | data_collator = default_data_collator 482 | elif training_args.fp16: 483 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 484 | else: 485 | data_collator = None 486 | 487 | # Initialize our Trainer 488 | trainer = Trainer( 489 | model=model, 490 | args=training_args, 491 | train_dataset=train_dataset if training_args.do_train else None, 492 | eval_dataset=eval_dataset if training_args.do_eval else None, 493 | compute_metrics=compute_metrics, 494 | tokenizer=tokenizer, 495 | data_collator=data_collator, 496 | ) 497 | 498 | # Training 499 | if training_args.do_train: 500 | checkpoint = None 501 | if training_args.resume_from_checkpoint is not None: 502 | checkpoint = training_args.resume_from_checkpoint 503 | elif last_checkpoint is not None: 504 | checkpoint = last_checkpoint 505 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 506 | metrics = train_result.metrics 507 | max_train_samples = ( 508 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 509 | ) 510 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 511 | 512 | trainer.save_model() # Saves the tokenizer too for easy upload 513 | 514 | trainer.log_metrics("train", metrics) 515 | trainer.save_metrics("train", metrics) 516 | trainer.save_state() 517 | 518 | # Evaluation 519 | if training_args.do_eval: 520 | logger.info("*** Evaluate ***") 521 | 522 | # Loop to handle MNLI double evaluation (matched, mis-matched) 523 | tasks = [data_args.task_name] 524 | eval_datasets = [eval_dataset] 525 | if data_args.task_name == "mnli": 526 | tasks.append("mnli-mm") 527 | eval_datasets.append(raw_datasets["validation_mismatched"]) 528 | 529 | for eval_dataset, task in zip(eval_datasets, tasks): 530 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 531 | 532 | max_eval_samples = ( 533 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 534 | ) 535 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 536 | 537 | trainer.log_metrics("eval", metrics) 538 | trainer.save_metrics("eval", metrics) 539 | 540 | if training_args.do_predict: 541 | logger.info("*** Predict ***") 542 | 543 | # Loop to handle MNLI double evaluation (matched, mis-matched) 544 | tasks = [data_args.task_name] 545 | predict_datasets = [predict_dataset] 546 | if data_args.task_name == "mnli": 547 | tasks.append("mnli-mm") 548 | predict_datasets.append(raw_datasets["test_mismatched"]) 549 | 550 | for predict_dataset, task in zip(predict_datasets, tasks): 551 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 552 | predict_dataset = predict_dataset.remove_columns("label") 553 | predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions 554 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 555 | 556 | output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt") 557 | if trainer.is_world_process_zero(): 558 | with open(output_predict_file, "w") as writer: 559 | logger.info(f"***** Predict results {task} *****") 560 | writer.write("index\tprediction\n") 561 | for index, item in enumerate(predictions): 562 | if is_regression: 563 | writer.write(f"{index}\t{item:3.3f}\n") 564 | else: 565 | item = label_list[item] 566 | writer.write(f"{index}\t{item}\n") 567 | 568 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} 569 | if data_args.task_name is not None: 570 | kwargs["language"] = "en" 571 | kwargs["dataset_tags"] = "glue" 572 | kwargs["dataset_args"] = data_args.task_name 573 | kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" 574 | 575 | #if training_args.push_to_hub: 576 | # trainer.push_to_hub(**kwargs) 577 | #else: 578 | # trainer.create_model_card(**kwargs) 579 | 580 | 581 | def _mp_fn(index): 582 | # For xla_spawn (TPUs) 583 | main() 584 | 585 | 586 | if __name__ == "__main__": 587 | main() 588 | -------------------------------------------------------------------------------- /scripts/translate_tweeteval.py: -------------------------------------------------------------------------------- 1 | from transformers import MarianMTModel, MarianTokenizer 2 | from datasets import load_dataset 3 | from tqdm import tqdm 4 | 5 | translator_paths = {'es': 'Helsinki-NLP/opus-mt-en-es', 'hi': 'Helsinki-NLP/opus-mt-en-hi'} 6 | translators = {lg: {"tokenizer": MarianTokenizer.from_pretrained(path), 7 | "model": MarianMTModel.from_pretrained(path).cuda()} 8 | for lg, path in translator_paths.items()} 9 | 10 | print('Translators loaded.') 11 | 12 | 13 | def translate(sentence, tokenizer, model): 14 | inputs = tokenizer(sentence, return_tensors='pt') 15 | outputs = model.generate(inputs.input_ids.cuda(), num_beams=5, early_stopping=True,num_return_sequences=1) 16 | return [tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False) for output in outputs] 17 | 18 | 19 | data = load_dataset("tweet_eval", "sentiment") 20 | 21 | def translate_add_columns(example, translators): 22 | for lg in translators.keys(): 23 | tokenizer = translators[lg]['tokenizer'] 24 | model = translators[lg]['model'] 25 | example[lg] = translate(example['text'], tokenizer, model)[0] 26 | return example 27 | 28 | print(data['test'].select([0]).map(translate_add_columns, fn_kwargs={'translators': translators})[0]) 29 | 30 | data['test'].map(translate_add_columns, fn_kwargs={'translators': translators}).save_to_disk('data/tweeteval_sentiment_translated_test') 31 | 32 | data['validation'].map(translate_add_columns, fn_kwargs={'translators': translators}).save_to_disk('data/tweeteval_sentiment_translated_validation') 33 | -------------------------------------------------------------------------------- /scripts/translate_xnli.py: -------------------------------------------------------------------------------- 1 | from transformers import MarianMTModel, MarianTokenizer 2 | import json, argparse 3 | from tqdm import tqdm 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--data", "-d", default=None, type=str, required=True, help="The input data file, e.g., 'data/XNLI/test.tsv'.") 7 | parser.add_argument("--output_dir", "-o", default=None, type=str, required=True, help="The output directory.") 8 | parser.add_argument("--include_header", "-o", default=True, type=bool, help="Whether to include header in the output tsv. Don't include if you want to concatenate it with the original tsv.") 9 | args = parse.parse_args() 10 | 11 | 12 | translator_paths = json.load(open('language-opus-nmt-map.json')) 13 | translators = {lg: {"tokenizer": MarianTokenizer.from_pretrained(path), 14 | "model": MarianMTModel.from_pretrained(path)} 15 | for lg, path in translator_paths.items() if path and lg in {'en','fr','es','de','zh','el','bg','ru','tr','ar','vi','th','hi','ur','sw'}} 16 | 17 | print('Translators loaded.') 18 | 19 | 20 | def translate(sentence, tokenizer, model): 21 | inputs = tokenizer.prepare_seq2seq_batch([sentence]) 22 | outputs = model.generate(inputs['input_ids'], num_beams=5, early_stopping=True,num_return_sequences=1) 23 | return [tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False) for output in outputs] 24 | 25 | 26 | out_file_path = os.path.join(args.output_dir, 'xnli-opus-hf', 'test_matched.tsv') 27 | 28 | if not os.path.exists(out_dir): 29 | os.makedirs(out_dir) 30 | 31 | with open(args.data, 'r') as fin, open(out_file_path, 'w') as fout: 32 | headerline = next(fin).strip() 33 | header = {col:i for i, col in enumerate(headerline.split('\t'))} 34 | if args.include_header: 35 | fout.write(headerline+'\n') 36 | for i, line in enumerate(tqdm(fin, total=5010)): 37 | en_cells = line.split('\t') 38 | for lg in tqdm(translators.keys()): 39 | if lg in {'en','fr','es','de','zh','el','bg','ru','tr','ar','vi','th','hi','ur','sw'}: 40 | continue 41 | new_cells = en_cells.copy() 42 | tokenizer = translators[lg]['tokenizer'] 43 | model = translators[lg]['model'] 44 | new_cells[header['language']] = lg 45 | new_cells[header['sentence1_tokenized']] = '' 46 | new_cells[header['sentence2_tokenized']] = '' 47 | new_cells[header['sentence1']] = translate(en_cells[header['sentence1']], tokenizer, model) 48 | new_cells[header['sentence2']] = translate(en_cells[header['sentence2']], tokenizer, model) 49 | fout.write('\t'.join(new_cells)) 50 | --------------------------------------------------------------------------------