├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── __init__.py ├── examples └── train_warp_mnli.py ├── figures └── overture_logo.png ├── img ├── eval_acc_warp_roberta_mnli.png └── train_loss_warp_roberta_mnli.png ├── models ├── __init__.py ├── modeling_bert.py ├── modeling_roberta.py └── modeling_xlm_roberta.py ├── requirements.txt ├── soft_prompts.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Others 132 | *.DS_Store 133 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Salesforce.com, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Project Overture - A Prompt-Tuning Library for Researchers 2 | > Why name it **Overture**? An overture in music is an orchestral piece at the beginning which sets the mood and tone for what's about to come. We think of prompt tuning as analogous to that; in the types of prompt tuning methods we consider, a prompt is prepended to the input that sets the tone for the downstream task. 3 |

4 | 5 |

6 | 7 | # Introduction 8 | Prompt Tuning has recently become an important research direction in Natural Language Processing. In contrast to classical fine-tuning, which involves optimizing the weights of the entire network, (one style of) prompt tuning keeps the large language model (a.k.a. the "backbone") frozen and instead prepends a few learnable vectors to each input which are learnt in order to accomplish a task. This brings the number of parameters to train from O(millions) down to a few thousand while still achieving similar levels of performance. There are other benefits that have been found in the research community for prompt tuned models when compared to classically trained models. 9 | 10 | # Methods Supported 11 | The repository leverages the HuggingFace Transformers repository and currently, we support [WARP-like](https://arxiv.org/abs/2101.00121) prompt-tuning for masked language modeling(MLM), text classification models, and extractive question answering (e.g., SQuAD). We plan on adding support for [Seq2Seq prompt-tuning](https://arxiv.org/abs/2104.08691v1) soon. If there is any other algorithm/method that you would like for us to prioritize, please write to us or file a feature request. Finally, we refer an interested reader to the [excellent survey](http://pretrain.nlpedia.ai/) on the topic for the various types of prompt tuning methods and their history. 12 | 13 | # Some Potential Extensions 14 | Here are some research ideas one could experiment with our codebase. Since the community is evolving rapidly, it is entirely possible that some of these ideas have already been studied. Please file an issue if that is the case, or if you want to contribute more ideas. 15 | 16 | 1. Does prompt tuning on a multilingual backbone (e.g., mBERT or XLM) lead to models that can perform cross-lingual zero-shot transfer? 17 | 2. How can we make the prompts more interpretable? Could adding a loss to make the prompt vectors be close to existing word embeddings help? 18 | 3. Can prompts learned for BERT-Large help learn prompts for RoBERTa-Large? 19 | 20 | # Design Choices & Other Similar Libraries 21 | 22 | Fundamentally, we designed the repository for researchers to easily experiment with ideas within the realm of prompt-tuning. As such, we intentionally do not abstract away the sub-components. The repository is intended to be a fork-and-edit library and is designed to be easily extensible for the kinds of projects we anticipated people to use the library for. 23 | 24 | A recently released library, [OpenPrompt](https://github.com/thunlp/OpenPrompt), is also intended to be a library for prompt tuning and we refer an interested practitioner to their repository for further exploration and comparisons. OpenPrompt may be a better fit for those who seek greater abstraction. 25 | 26 | # How to Use 27 | Inside the examples folder, we provide training code for RoBERTa-Large model on the MNLI dataset (in the style of [WARP](https://arxiv.org/abs/2101.00121)). To start training: 28 | ```bash 29 | CUDA_VISIBLE_DEVICES=0 python train_warp_mnli.py --save_prompts_path dir_to_save_prompts --save_classifier_path dir_to_save_classifier 30 | ``` 31 | 32 | After training, user should expect the model performance (accuracy) to be 87-89, which matches the original [WARP](https://arxiv.org/abs/2101.00121) paper results. The curve of training loss and evaluation of validation set from one run is shown below. 33 | 34 | Training Loss | Evaluation Accuracy on Validation Set 35 | :-------------------------:|:-------------------------: 36 | ![train_loss_curve](./img/train_loss_warp_roberta_mnli.png) | ![eval_validation_value](./img/eval_acc_warp_roberta_mnli.png) 37 | 38 | ### Dev environment 39 | - Python 3.8.5 40 | - A-100 GPU, CUDA Version: 11.0 41 | - Other dependencies: [requirements.txt](./requirements.txt) 42 | 43 | ### API 44 | ```python 45 | # importing RoBERTa based API 46 | from models.modeling_roberta import WARPPromptedRobertaForMaskedLM, WARPPromptedRobertaForSequenceClassification, WARPPromptedRobertaForQuestionAnswering 47 | # importing Bert based API 48 | from models.modeling_bert import WARPPromptedBertForMaskedLM, WARPPromptedBertForSequenceClassification, WARPPromptedBertForQuestionAnswering 49 | # importing XLM-RoBERTa based API 50 | from models.modeling_roberta import WARPPromptedXLMRobertaForMaskedLM, WARPPromptedXLMRobertaForSequenceClassification, WARPPromptedXLMRobertaForQuestionAnswering 51 | # importing function for randomly masking inputs 52 | from utils import random_mask_input_ids 53 | 54 | # initialize model for MNLI task 55 | model = WARPPromptedRobertaForSequenceClassification( 56 | pretrained_backbone_path = "roberta-large", 57 | n_prompts = 8, 58 | seed_token_id_for_prompts_embeddings = 50264, # token id for "" 59 | mask_token_id = 50264, 60 | token_ids_for_classification_head = [1342, 12516, 10800], # 'ent', 'neutral', 'cont' 61 | pretrained_prompts_path = None, 62 | pretrained_classifier_path = None 63 | ) 64 | 65 | # initialize model for masked language modeling (MLM) 66 | model = WARPPromptedRobertaForMaskedLM( 67 | pretrained_backbone_path = "roberta-large", 68 | n_prompts = 8, 69 | seed_token_id_for_prompts_embeddings = 50264, 70 | pretrained_prompts_path = None 71 | ) 72 | 73 | # prepad input ids before feeding into model 74 | features = tokenizer([str_1, str_2, ..., str_n], return_tensors='pt', truncation=True, padding=True) 75 | features["input_ids"] = torch.cat([torch.full((features["input_ids"].shape[0], n_prompts), 0), features['input_ids']], 1) 76 | 77 | # randomly mask input ids for MLM task 78 | features['input_ids'] = random_mask_input_ids(features['input_ids'], mask_token_id, prob = .15) 79 | 80 | # initialize model for question answering (QA) 81 | model = WARPPromptedRobertaForQuestionAnswering( 82 | pretrained_backbone_path = "roberta-large", 83 | n_prompts = 4, 84 | seed_token_id_for_prompts_embeddings = 50264, 85 | pretrained_prompts_path = None, 86 | freeze_qa_outputs_layer = False, 87 | ) 88 | ``` 89 | 90 | # Reference 91 | - [WARP: Word-level Adversarial ReProgramming](https://aclanthology.org/2021.acl-long.381.pdf) 92 | - [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/abs/2104.08691v1) 93 | 94 | 95 | # Contact 96 | Please contact [Jin Qu](mailto:jqu@salesforce.com) if you are interested in collaboration, internship opportunities, or discussions. Feel free to create issues if you discover a bug or want to request new features for future release. 97 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/__init__.py -------------------------------------------------------------------------------- /examples/train_warp_mnli.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | from torch.utils.data import DataLoader 4 | from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel 5 | import tqdm 6 | import logging 7 | import math 8 | import pandas as pd 9 | import argparse 10 | import random 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | import copy 18 | import json 19 | import sys 20 | 21 | sys.path.append("../") 22 | from utils import set_seed 23 | from models.modeling_roberta import ( 24 | WARPPromptedRobertaForMaskedLM, 25 | WARPPromptedRobertaForSequenceClassification, 26 | ) 27 | 28 | logging.basicConfig( 29 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 30 | datefmt="%m/%d/%Y %H:%M:%S", 31 | level=logging.INFO, 32 | ) 33 | logger = logging.getLogger(__name__) 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--n_prompts", default=20, type=int) 37 | parser.add_argument("--backbone_model", default="roberta-large", type=str) 38 | parser.add_argument("--pretrained_prompts_path", type=str) 39 | parser.add_argument("--pretrained_classifier_path", type=str) 40 | parser.add_argument("--save_prompts_path", type=str) 41 | parser.add_argument("--save_classifier_path", type=str) 42 | parser.add_argument("--seed", default=42, type=int) 43 | parser.add_argument("--train_batch_size", default=16, type=int) 44 | parser.add_argument("--eval_batch_size", default=32, type=int) 45 | parser.add_argument("--total_train_steps", default=250000, type=int) 46 | parser.add_argument( 47 | "--seed_token_id_for_prompts_embeddings", default=50264, type=int 48 | ) # used token id for "" 49 | parser.add_argument("--mask_token_id", default=50264, type=int) 50 | parser.add_argument( 51 | "--token_ids_for_classification_head", 52 | action="append", 53 | dest="token_ids_for_classification_head", 54 | default=[1342, 12516, 10800], 55 | ) #'ent', 'neutral', 'cont' 56 | parser.add_argument("--cycle", default=50000, type=int) 57 | parser.add_argument("--warmup_proportion", default=0.06, type=float) 58 | parser.add_argument("--patience", default=50, type=int) 59 | parser.add_argument("--learning_rate", default=0.003, type=float) 60 | parser.add_argument("--max_seq_length", default=512, type=int) 61 | parser.add_argument( 62 | "--gradient_accumulation_steps", 63 | type=int, 64 | default=1, 65 | help="Number of updates steps to accumulate before performing a backward/update pass.", 66 | ) 67 | parser.add_argument( 68 | "--no_cuda", action="store_true", help="Whether not to use CUDA when available" 69 | ) 70 | 71 | parser.add_argument("--use_tensorboard", action="store_true") 72 | parser.add_argument("--tensorboard_dir", type=str, default="./runs_mnli_cls") 73 | 74 | parser.add_argument("--tmult", default=1, type=int) 75 | parser.add_argument("--comment", default="", type=str) 76 | args = parser.parse_args() 77 | 78 | # set seed 79 | set_seed(args.seed) 80 | 81 | # load data 82 | dataset = load_dataset("multi_nli", split="train") 83 | dataset = pd.DataFrame.from_dict( 84 | { 85 | "premise": [d["premise"] for d in dataset], 86 | "hypothesis": [d["hypothesis"] for d in dataset], 87 | "gold_label": [ 88 | ("entailment", "neutral", "contradiction")[d["label"]] for d in dataset 89 | ], 90 | } 91 | ) 92 | valid_dataset = load_dataset("multi_nli", split="validation_matched") 93 | valid_dataset = pd.DataFrame.from_dict( 94 | { 95 | "premise": [d["premise"] for d in valid_dataset], 96 | "hypothesis": [d["hypothesis"] for d in valid_dataset], 97 | "gold_label": [ 98 | ("entailment", "neutral", "contradiction")[d["label"]] 99 | for d in valid_dataset 100 | ], 101 | } 102 | ) 103 | 104 | label_dict = {"entailment": 0, "neutral": 1, "contradiction": 2} 105 | 106 | # load tokenizer 107 | tokenizer = AutoTokenizer.from_pretrained(args.backbone_model) 108 | 109 | # set device 110 | device = torch.device( 111 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 112 | ) 113 | print(f"device: {device}") 114 | 115 | # initialize model 116 | model = WARPPromptedRobertaForSequenceClassification( 117 | pretrained_backbone_path=args.backbone_model, 118 | n_prompts=args.n_prompts, 119 | seed_token_id_for_prompts_embeddings=args.seed_token_id_for_prompts_embeddings, 120 | mask_token_id=args.mask_token_id, 121 | token_ids_for_classification_head=args.token_ids_for_classification_head, 122 | pretrained_prompts_path=args.pretrained_prompts_path, 123 | pretrained_classifier_path=args.pretrained_classifier_path, 124 | ) 125 | 126 | 127 | # move model to device 128 | model.to(device) 129 | 130 | # set up scheduler 131 | optim = torch.optim.Adam( 132 | list(filter(lambda x: x.requires_grad, model.parameters())), lr=args.learning_rate 133 | ) 134 | warmup = int(args.warmup_proportion * args.cycle) 135 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 136 | optim, args.cycle, T_mult=args.tmult 137 | ) 138 | 139 | middle_string = " " 140 | 141 | 142 | def validate(): 143 | model.eval() 144 | total_seen = 0 145 | total_correct = 0 146 | with torch.no_grad(): 147 | for _ in tqdm.tqdm(range(0, valid_dataset.shape[0], args.eval_batch_size)): 148 | exs = list(range(_, _ + args.eval_batch_size)) 149 | if exs[-1] >= valid_dataset.shape[0] or len(exs) < 2: 150 | print("BREAKING!") 151 | break 152 | total_seen += len(exs) 153 | 154 | valid_features = tokenizer( 155 | [ 156 | valid_dataset.premise.iloc[i] 157 | + "?" 158 | + middle_string 159 | + valid_dataset.hypothesis.iloc[i] 160 | for i in exs 161 | ], 162 | return_tensors="pt", 163 | truncation=True, 164 | padding=True, 165 | max_length=args.max_seq_length - args.n_prompts, 166 | ) 167 | valid_features["input_ids"] = torch.cat( 168 | [ 169 | torch.full( 170 | (valid_features["input_ids"].shape[0], args.n_prompts), 0 171 | ), 172 | valid_features["input_ids"], 173 | ], 174 | 1, 175 | ) 176 | valid_features["input_ids"] = valid_features["input_ids"].to(device) 177 | valid_features["attention_mask"] = torch.cat( 178 | [ 179 | torch.full( 180 | (valid_features["input_ids"].shape[0], args.n_prompts), 1 181 | ), 182 | valid_features["attention_mask"], 183 | ], 184 | 1, 185 | ) 186 | valid_features["attention_mask"] = valid_features["attention_mask"].to( 187 | device 188 | ) 189 | 190 | lbl = ( 191 | torch.Tensor( 192 | [label_dict[valid_dataset.gold_label.iloc[i]] for i in exs] 193 | ) 194 | .type(torch.int64) 195 | .cuda() 196 | ) 197 | bleh = model(**valid_features) 198 | total_correct += torch.sum(bleh.argmax(1) == lbl).item() 199 | acc = total_correct / total_seen * 100.0 200 | print({"val_acc": acc}) 201 | return acc 202 | 203 | 204 | # set up tensorboard 205 | if args.use_tensorboard: 206 | writer = SummaryWriter(log_dir=f"{args.tensorboard_dir}/{args.comment}") 207 | 208 | # training loop 209 | model.eval() 210 | optim.zero_grad() 211 | best_model = None 212 | best_acc = 0.0 213 | cnt = 0 214 | 215 | for step in tqdm.tqdm(range(args.total_train_steps)): 216 | exs = random.sample(range(0, dataset.shape[0]), args.train_batch_size) 217 | train_features = tokenizer( 218 | [ 219 | dataset.premise.iloc[i] + "?" + middle_string + dataset.hypothesis.iloc[i] 220 | for i in exs 221 | ], 222 | return_tensors="pt", 223 | truncation=True, 224 | padding=True, 225 | max_length=args.max_seq_length - args.n_prompts, 226 | ) 227 | train_features["input_ids"] = torch.cat( 228 | [ 229 | torch.full((train_features["input_ids"].shape[0], args.n_prompts), 0), 230 | train_features["input_ids"], 231 | ], 232 | 1, 233 | ) 234 | train_features["input_ids"] = train_features["input_ids"].to(device) 235 | train_features["attention_mask"] = torch.cat( 236 | [ 237 | torch.full((train_features["input_ids"].shape[0], args.n_prompts), 1), 238 | train_features["attention_mask"], 239 | ], 240 | 1, 241 | ) 242 | train_features["attention_mask"] = train_features["attention_mask"].to(device) 243 | 244 | lbl = ( 245 | torch.Tensor([label_dict[dataset.gold_label.iloc[i]] for i in exs]) 246 | .type(torch.int64) 247 | .to(device) 248 | ) 249 | 250 | bleh = model(**train_features) 251 | 252 | loss = ( 253 | torch.nn.functional.cross_entropy(bleh, lbl) / args.gradient_accumulation_steps 254 | ) 255 | loss.backward() 256 | if (step % args.gradient_accumulation_steps) == 0: 257 | optim.step() 258 | scheduler.step() 259 | optim.zero_grad() 260 | 261 | if step % 1000 == 0: 262 | acc = validate() 263 | if acc > best_acc: 264 | best_acc = acc 265 | cnt = 0 266 | best_soft_embedding = copy.deepcopy( 267 | model.backbone.roberta.embeddings.word_embeddings 268 | ) 269 | best_classification_head = copy.deepcopy(model.classification_head) 270 | else: 271 | cnt += 1 272 | 273 | # log into tensorboard 274 | if args.use_tensorboard: 275 | writer.add_scalar( 276 | "avg train loss", loss.item() / args.train_batch_size, step 277 | ) 278 | writer.add_scalar("acc val", acc, step) 279 | 280 | if cnt >= args.patience: 281 | best_soft_embedding.save_pretrained_soft_prompts(args.save_prompts_path) 282 | best_classification_head.save_pretrained_classifier( 283 | args.save_classifier_path 284 | ) 285 | raise Exception("running out of patience!") 286 | -------------------------------------------------------------------------------- /figures/overture_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/figures/overture_logo.png -------------------------------------------------------------------------------- /img/eval_acc_warp_roberta_mnli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/img/eval_acc_warp_roberta_mnli.png -------------------------------------------------------------------------------- /img/train_loss_warp_roberta_mnli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/img/train_loss_warp_roberta_mnli.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/models/__init__.py -------------------------------------------------------------------------------- /models/modeling_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import logging 4 | import torch 5 | import torch.nn as nn 6 | from transformers import ( 7 | AutoModelForMaskedLM, 8 | AutoTokenizer, 9 | AutoModelForQuestionAnswering, 10 | ) 11 | 12 | import sys 13 | 14 | sys.path.append("../") 15 | from soft_prompts import PromptedWordEmbeddings 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def gelu(x): 21 | return ( 22 | 0.5 23 | * x 24 | * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 25 | ) 26 | 27 | 28 | # 1, can be used for MLM 29 | # 2, can also be used for MLM-style classification without additional classification layer, like "The Power of Scale for Parameter-Efficient Prompt Tuning" paper 30 | class WARPPromptedBertForMaskedLM(nn.Module): 31 | def __init__( 32 | self, 33 | pretrained_backbone_path, 34 | n_prompts, 35 | seed_token_id_for_prompts_embeddings, 36 | pretrained_prompts_path=None, 37 | ): 38 | """ 39 | pretrained_backbone_path: str, path to or name of backbone model, e.g. bert-large-uncased; 40 | n_prompts: int, number of prompts; 41 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token. 42 | """ 43 | super(WARPPromptedBertForMaskedLM, self).__init__() 44 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path) 45 | self.n_prompts = n_prompts 46 | # freeze backbone model 47 | for _, p in self.backbone.named_parameters(): 48 | p.requires_grad = False 49 | 50 | hidden_size = self.backbone.config.hidden_size 51 | original_word_embeddings = self.backbone.bert.embeddings.word_embeddings 52 | prompted_word_embeddings = PromptedWordEmbeddings( 53 | original_word_embeddings, 54 | n_prompts, 55 | hidden_size, 56 | seed_token_id_for_prompts_embeddings, 57 | ) 58 | if pretrained_prompts_path is not None: 59 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 60 | pretrained_prompts_path 61 | ) 62 | logger.info( 63 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}" 64 | ) 65 | 66 | self.backbone.bert.embeddings.word_embeddings = prompted_word_embeddings 67 | 68 | def forward(self, input_ids, attention_mask, token_type_ids, labels=None): 69 | return self.backbone(input_ids, attention_mask, token_type_ids, labels=labels) 70 | 71 | 72 | # classification head modified from https://github.com/huggingface/transformers/blob/5e3b4a70d3d17f2482d50aea230f7ed42b3a8fd0/src/transformers/models/bert/modeling_bert.py#L663 73 | class BertClassificationHead(nn.Module): 74 | """Bert Head for masked language modeling.""" 75 | 76 | def __init__(self, ori_lm_head, weight_tensors, hidden_size, prediction_dim): 77 | """ 78 | ori_lm_head: original lm_head from bert model, can be accessed by model.cls; 79 | weight_tensors: initialize final classifier layer with the specified weight tensors, usually from verbalzier token embeddings; 80 | hidden_size, int, backbone model hidden size; 81 | prediction_dim, int, output dimension of classifier layer. 82 | """ 83 | super().__init__() 84 | 85 | self.classifier = torch.nn.Linear(hidden_size, prediction_dim, bias=True) 86 | self.classifier.weight = weight_tensors 87 | 88 | self.ori_lm_head = ori_lm_head 89 | self.ori_lm_head.predictions.decoder = self.classifier 90 | 91 | def load_from_pretrained_classifier(self, pretrained_classifier_path): 92 | path = os.path.join(pretrained_classifier_path, "classifier.pt") 93 | pretrained_classifier = torch.load(path) 94 | if ( 95 | pretrained_classifier.weight.shape == self.classifier.weight.shape 96 | and pretrained_classifier.bias.shape == self.classifier.bias.shape 97 | ): 98 | self.classifier = pretrained_classifier 99 | logger.info( 100 | f"loaded pretrained classifier from {pretrained_classifier_path}" 101 | ) 102 | else: 103 | raise Exception( 104 | f"pretrained classifier weights dimension: {pretrained_classifier.weight.shape}, bias dimension: {pretrained_classifier.bias.shape} \ 105 | but classifier initialized with {self.classifier.weight.shape} and {self.classifier.bias.shape}" 106 | ) 107 | 108 | def save_pretrained_classifier(self, save_directory): 109 | path = os.path.join(save_directory, "classifier.pt") 110 | if not os.path.isdir(save_directory): 111 | os.mkdir(save_directory) 112 | torch.save(self.classifier, path) 113 | logger.info(f"saved trained classifier at: {save_directory}") 114 | 115 | def forward(self, features, **kwargs): 116 | x = self.classifier(features) 117 | return x 118 | 119 | 120 | class WARPPromptedBertForSequenceClassification(nn.Module): 121 | def __init__( 122 | self, 123 | pretrained_backbone_path, 124 | n_prompts, 125 | seed_token_id_for_prompts_embeddings, 126 | mask_token_id, 127 | token_ids_for_classification_head, 128 | pretrained_prompts_path=None, 129 | pretrained_classifier_path=None, 130 | ): 131 | """ 132 | pretrained_backbone_path: str, path to or name of backbone model, e.g. bert-large-uncased; 133 | n_prompts: int, number of prompts; 134 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token; 135 | mask_token_id: int, token id for mask token, 103 for huggingface bert-large-uncased model; 136 | token_ids_for_classification_head: list of int, used for initilize classifier weights; 137 | pretrained_prompts_path: str or None, path to pretrained prompts; 138 | pretrained_classifier_path: str or None, path to pretrained classifier layer. 139 | """ 140 | super(WARPPromptedBertForSequenceClassification, self).__init__() 141 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path) 142 | self.n_prompts = n_prompts 143 | self.mask_token_id = mask_token_id 144 | # freeze backbone model 145 | for _, p in self.backbone.named_parameters(): 146 | p.requires_grad = False 147 | 148 | # modify embedding layer for soft prompts 149 | hidden_size = self.backbone.config.hidden_size 150 | original_word_embeddings = self.backbone.bert.embeddings.word_embeddings 151 | prompted_word_embeddings = PromptedWordEmbeddings( 152 | original_word_embeddings, 153 | n_prompts, 154 | hidden_size, 155 | seed_token_id_for_prompts_embeddings, 156 | ) 157 | if pretrained_prompts_path is not None: 158 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 159 | pretrained_prompts_path 160 | ) 161 | 162 | self.backbone.bert.embeddings.word_embeddings = prompted_word_embeddings 163 | 164 | # classification head 165 | weights4lm_head = torch.nn.Parameter( 166 | self.backbone.bert.embeddings.word_embeddings.ori_emb.weight[ 167 | token_ids_for_classification_head 168 | ] 169 | ) 170 | prediction_dim = len(token_ids_for_classification_head) 171 | self.classification_head = BertClassificationHead( 172 | self.backbone.cls, weights4lm_head, hidden_size, prediction_dim 173 | ) 174 | if pretrained_classifier_path is not None: 175 | self.classification_head.load_from_pretrained_classifier( 176 | pretrained_classifier_path 177 | ) 178 | 179 | # remove original lm_head 180 | del self.backbone.cls 181 | 182 | def forward(self, input_ids, attention_mask, token_type_ids): 183 | before_classifier = self.backbone.bert( 184 | input_ids, attention_mask, token_type_ids 185 | )[0] 186 | mask_token_locations = torch.where(input_ids == self.mask_token_id) 187 | return self.classification_head(before_classifier[mask_token_locations]) 188 | 189 | 190 | class WARPPromptedBertForQuestionAnswering(nn.Module): 191 | def __init__( 192 | self, 193 | pretrained_backbone_path, 194 | n_prompts, 195 | seed_token_id_for_prompts_embeddings, 196 | pretrained_prompts_path=None, 197 | freeze_qa_outputs_layer=True, 198 | ): 199 | """ 200 | pretrained_backbone_path: str, path to or name of backbone model, e.g. bert-large-uncased; 201 | n_prompts: int, number of prompts; 202 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token. 203 | """ 204 | super(WARPPromptedBertForQuestionAnswering, self).__init__() 205 | self.backbone = AutoModelForQuestionAnswering.from_pretrained( 206 | pretrained_backbone_path 207 | ) 208 | self.n_prompts = n_prompts 209 | # freeze backbone model and/except the final qa_output layer 210 | for n, p in self.backbone.named_parameters(): 211 | p.requires_grad = False 212 | if "qa_outputs" in n and not freeze_qa_outputs_layer: 213 | p.requires_grad = True 214 | 215 | hidden_size = self.backbone.config.hidden_size 216 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings 217 | prompted_word_embeddings = PromptedWordEmbeddings( 218 | original_word_embeddings, 219 | n_prompts, 220 | hidden_size, 221 | seed_token_id_for_prompts_embeddings, 222 | ) 223 | if pretrained_prompts_path is not None: 224 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 225 | pretrained_prompts_path 226 | ) 227 | logger.info( 228 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}" 229 | ) 230 | 231 | self.backbone.bert.embeddings.word_embeddings = prompted_word_embeddings 232 | 233 | def forward( 234 | self, 235 | input_ids, 236 | attention_mask, 237 | token_type_ids, 238 | start_positions=None, 239 | end_positions=None, 240 | ): 241 | return self.backbone( 242 | input_ids, 243 | attention_mask, 244 | token_type_ids, 245 | start_positions=start_positions, 246 | end_positions=end_positions, 247 | ) 248 | -------------------------------------------------------------------------------- /models/modeling_roberta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import logging 4 | import torch 5 | import torch.nn as nn 6 | from transformers import ( 7 | AutoModelForMaskedLM, 8 | AutoTokenizer, 9 | AutoModelForQuestionAnswering, 10 | ) 11 | 12 | import sys 13 | 14 | sys.path.append("../") 15 | from soft_prompts import PromptedWordEmbeddings 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def gelu(x): 21 | return ( 22 | 0.5 23 | * x 24 | * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 25 | ) 26 | 27 | 28 | # 1, can be used for MLM 29 | # 2, can also be used for MLM-style classification without additional classification layer, like "The Power of Scale for Parameter-Efficient Prompt Tuning" paper 30 | class WARPPromptedRobertaForMaskedLM(nn.Module): 31 | def __init__( 32 | self, 33 | pretrained_backbone_path, 34 | n_prompts, 35 | seed_token_id_for_prompts_embeddings, 36 | pretrained_prompts_path=None, 37 | ): 38 | """ 39 | pretrained_backbone_path: str, path to or name of backbone model, e.g. roberta-large; 40 | n_prompts: int, number of prompts; 41 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token. 42 | """ 43 | super(WARPPromptedRobertaForMaskedLM, self).__init__() 44 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path) 45 | self.n_prompts = n_prompts 46 | # freeze backbone model 47 | for _, p in self.backbone.named_parameters(): 48 | p.requires_grad = False 49 | 50 | hidden_size = self.backbone.config.hidden_size 51 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings 52 | prompted_word_embeddings = PromptedWordEmbeddings( 53 | original_word_embeddings, 54 | n_prompts, 55 | hidden_size, 56 | seed_token_id_for_prompts_embeddings, 57 | ) 58 | if pretrained_prompts_path is not None: 59 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 60 | pretrained_prompts_path 61 | ) 62 | logger.info( 63 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}" 64 | ) 65 | 66 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings 67 | 68 | def forward(self, input_ids, attention_mask, labels=None): 69 | return self.backbone(input_ids, attention_mask, labels=labels) 70 | 71 | 72 | # classification head modified from https://github.com/huggingface/transformers/blob/5e3b4a70d3d17f2482d50aea230f7ed42b3a8fd0/src/transformers/models/roberta/modeling_roberta.py#L1123 73 | class RobertaClassificationHead(nn.Module): 74 | """Roberta Head for masked language modeling.""" 75 | 76 | def __init__(self, ori_lm_head, weight_tensors, hidden_size, prediction_dim): 77 | """ 78 | ori_lm_head: original lm_head from roberta model, can be accessed by model.lm_head; 79 | weight_tensors: initialize final classifier layer with the specified weight tensors, usually from verbalzier token embeddings; 80 | hidden_size, int, backbone model hidden size; 81 | prediction_dim, int, output dimension of classifier layer. 82 | """ 83 | super().__init__() 84 | self.dense = ori_lm_head.dense 85 | self.layer_norm = ori_lm_head.layer_norm 86 | self.bias = ori_lm_head.bias 87 | 88 | self.classifier = torch.nn.Linear(hidden_size, prediction_dim, bias=True) 89 | self.classifier.weight = weight_tensors 90 | 91 | def load_from_pretrained_classifier(self, pretrained_classifier_path): 92 | path = os.path.join(pretrained_classifier_path, "classifier.pt") 93 | pretrained_classifier = torch.load(path) 94 | if ( 95 | pretrained_classifier.weight.shape == self.classifier.weight.shape 96 | and pretrained_classifier.bias.shape == self.classifier.bias.shape 97 | ): 98 | self.classifier = pretrained_classifier 99 | logger.info( 100 | f"loaded pretrained classifier from {pretrained_classifier_path}" 101 | ) 102 | else: 103 | raise Exception( 104 | f"pretrained classifier weights dimension: {pretrained_classifier.weight.shape}, bias dimension: {pretrained_classifier.bias.shape} \ 105 | but classifier initialized with {self.classifier.weight.shape} and {self.classifier.bias.shape}" 106 | ) 107 | 108 | def save_pretrained_classifier(self, save_directory): 109 | path = os.path.join(save_directory, "classifier.pt") 110 | if not os.path.isdir(save_directory): 111 | os.mkdir(save_directory) 112 | torch.save(self.classifier, path) 113 | logger.info(f"saved trained classifier at: {save_directory}") 114 | 115 | def forward(self, features, **kwargs): 116 | x = self.dense(features) 117 | x = gelu(x) 118 | x = self.layer_norm(x) 119 | return self.classifier(x) 120 | 121 | def _tie_weights(self): 122 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 123 | self.bias = self.decoder.bias 124 | 125 | 126 | class WARPPromptedRobertaForSequenceClassification(nn.Module): 127 | def __init__( 128 | self, 129 | pretrained_backbone_path, 130 | n_prompts, 131 | seed_token_id_for_prompts_embeddings, 132 | mask_token_id, 133 | token_ids_for_classification_head, 134 | pretrained_prompts_path=None, 135 | pretrained_classifier_path=None, 136 | ): 137 | """ 138 | pretrained_backbone_path: str, path to or name of backbone model, e.g. roberta-large; 139 | n_prompts: int, number of prompts; 140 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token; 141 | mask_token_id: int, token id for mask token, 50264 for huggingface roberta model; 142 | token_ids_for_classification_head: list of int, used for initilize classifier weights; 143 | pretrained_prompts_path: str or None, path to pretrained prompts; 144 | pretrained_classifier_path: str or None, path to pretrained classifier layer. 145 | """ 146 | super(WARPPromptedRobertaForSequenceClassification, self).__init__() 147 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path) 148 | self.n_prompts = n_prompts 149 | self.mask_token_id = mask_token_id 150 | # freeze backbone model 151 | for _, p in self.backbone.named_parameters(): 152 | p.requires_grad = False 153 | 154 | # modify embedding layer for soft prompts 155 | hidden_size = self.backbone.config.hidden_size 156 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings 157 | prompted_word_embeddings = PromptedWordEmbeddings( 158 | original_word_embeddings, 159 | n_prompts, 160 | hidden_size, 161 | seed_token_id_for_prompts_embeddings, 162 | ) 163 | if pretrained_prompts_path is not None: 164 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 165 | pretrained_prompts_path 166 | ) 167 | 168 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings 169 | 170 | # classification head 171 | weights4lm_head = torch.nn.Parameter( 172 | self.backbone.roberta.embeddings.word_embeddings.ori_emb.weight[ 173 | token_ids_for_classification_head 174 | ] 175 | ) 176 | prediction_dim = len(token_ids_for_classification_head) 177 | self.classification_head = RobertaClassificationHead( 178 | self.backbone.lm_head, weights4lm_head, hidden_size, prediction_dim 179 | ) 180 | if pretrained_classifier_path is not None: 181 | self.classification_head.load_from_pretrained_classifier( 182 | pretrained_classifier_path 183 | ) 184 | 185 | # remove original lm_head 186 | del self.backbone.lm_head 187 | 188 | def forward(self, input_ids, attention_mask): 189 | before_classifier = self.backbone.roberta(input_ids, attention_mask)[0] 190 | mask_token_locations = torch.where(input_ids == self.mask_token_id) 191 | return self.classification_head(before_classifier[mask_token_locations]) 192 | 193 | 194 | class WARPPromptedRobertaForQuestionAnswering(nn.Module): 195 | def __init__( 196 | self, 197 | pretrained_backbone_path, 198 | n_prompts, 199 | seed_token_id_for_prompts_embeddings, 200 | pretrained_prompts_path=None, 201 | freeze_qa_outputs_layer=True, 202 | ): 203 | """ 204 | pretrained_backbone_path: str, path to or name of backbone model, e.g. roberta-large; 205 | n_prompts: int, number of prompts; 206 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token. 207 | """ 208 | super(WARPPromptedRobertaForQuestionAnswering, self).__init__() 209 | self.backbone = AutoModelForQuestionAnswering.from_pretrained( 210 | pretrained_backbone_path 211 | ) 212 | self.n_prompts = n_prompts 213 | # freeze backbone model and/except the final qa_output layer 214 | for n, p in self.backbone.named_parameters(): 215 | p.requires_grad = False 216 | if "qa_outputs" in n and not freeze_qa_outputs_layer: 217 | p.requires_grad = True 218 | 219 | hidden_size = self.backbone.config.hidden_size 220 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings 221 | prompted_word_embeddings = PromptedWordEmbeddings( 222 | original_word_embeddings, 223 | n_prompts, 224 | hidden_size, 225 | seed_token_id_for_prompts_embeddings, 226 | ) 227 | if pretrained_prompts_path is not None: 228 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 229 | pretrained_prompts_path 230 | ) 231 | logger.info( 232 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}" 233 | ) 234 | 235 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings 236 | 237 | def forward( 238 | self, input_ids, attention_mask, start_positions=None, end_positions=None 239 | ): 240 | return self.backbone( 241 | input_ids, 242 | attention_mask, 243 | start_positions=start_positions, 244 | end_positions=end_positions, 245 | ) 246 | -------------------------------------------------------------------------------- /models/modeling_xlm_roberta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import logging 4 | import torch 5 | import torch.nn as nn 6 | from transformers import ( 7 | AutoModelForMaskedLM, 8 | AutoTokenizer, 9 | AutoModelForQuestionAnswering, 10 | ) 11 | 12 | import sys 13 | 14 | sys.path.append("../") 15 | from soft_prompts import PromptedWordEmbeddings 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def gelu(x): 21 | return ( 22 | 0.5 23 | * x 24 | * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 25 | ) 26 | 27 | 28 | # 1, can be used for MLM 29 | # 2, can also be used for MLM-style classification without additional classification layer, like "The Power of Scale for Parameter-Efficient Prompt Tuning" paper 30 | class WARPPromptedXLMRobertaForMaskedLM(nn.Module): 31 | def __init__( 32 | self, 33 | pretrained_backbone_path, 34 | n_prompts, 35 | seed_token_id_for_prompts_embeddings, 36 | pretrained_prompts_path=None, 37 | ): 38 | """ 39 | pretrained_backbone_path: str, path to or name of backbone model, e.g. xlm-roberta-large; 40 | n_prompts: int, number of prompts; 41 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token. 42 | """ 43 | super(WARPPromptedXLMRobertaForMaskedLM, self).__init__() 44 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path) 45 | self.n_prompts = n_prompts 46 | # freeze backbone model 47 | for _, p in self.backbone.named_parameters(): 48 | p.requires_grad = False 49 | 50 | hidden_size = self.backbone.config.hidden_size 51 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings 52 | prompted_word_embeddings = PromptedWordEmbeddings( 53 | original_word_embeddings, 54 | n_prompts, 55 | hidden_size, 56 | seed_token_id_for_prompts_embeddings, 57 | ) 58 | if pretrained_prompts_path is not None: 59 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 60 | pretrained_prompts_path 61 | ) 62 | logger.info( 63 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}" 64 | ) 65 | 66 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings 67 | 68 | def forward(self, input_ids, attention_mask, labels=None): 69 | return self.backbone(input_ids, attention_mask, labels=labels) 70 | 71 | 72 | # classification head modified from https://github.com/huggingface/transformers/blob/5e3b4a70d3d17f2482d50aea230f7ed42b3a8fd0/src/transformers/models/roberta/modeling_roberta.py#L1123 73 | class XLMRobertaClassificationHead(nn.Module): 74 | """XLMRoberta Head for masked language modeling.""" 75 | 76 | def __init__(self, ori_lm_head, weight_tensors, hidden_size, prediction_dim): 77 | """ 78 | ori_lm_head: original lm_head from roberta model, can be accessed by model.lm_head; 79 | weight_tensors: initialize final classifier layer with the specified weight tensors, usually from verbalzier token embeddings; 80 | hidden_size, int, backbone model hidden size; 81 | prediction_dim, int, output dimension of classifier layer. 82 | """ 83 | super().__init__() 84 | self.dense = ori_lm_head.dense 85 | self.layer_norm = ori_lm_head.layer_norm 86 | self.bias = ori_lm_head.bias 87 | 88 | self.classifier = torch.nn.Linear(hidden_size, prediction_dim, bias=True) 89 | self.classifier.weight = weight_tensors 90 | 91 | def load_from_pretrained_classifier(self, pretrained_classifier_path): 92 | path = os.path.join(pretrained_classifier_path, "classifier.pt") 93 | pretrained_classifier = torch.load(path) 94 | if ( 95 | pretrained_classifier.weight.shape == self.classifier.weight.shape 96 | and pretrained_classifier.bias.shape == self.classifier.bias.shape 97 | ): 98 | self.classifier = pretrained_classifier 99 | logger.info( 100 | f"loaded pretrained classifier from {pretrained_classifier_path}" 101 | ) 102 | else: 103 | raise Exception( 104 | f"pretrained classifier weights dimension: {pretrained_classifier.weight.shape}, bias dimension: {pretrained_classifier.bias.shape} \ 105 | but classifier initialized with {self.classifier.weight.shape} and {self.classifier.bias.shape}" 106 | ) 107 | 108 | def save_pretrained_classifier(self, save_directory): 109 | path = os.path.join(save_directory, "classifier.pt") 110 | if not os.path.isdir(save_directory): 111 | os.mkdir(save_directory) 112 | torch.save(self.classifier, path) 113 | logger.info(f"saved trained classifier at: {save_directory}") 114 | 115 | def forward(self, features, **kwargs): 116 | x = self.dense(features) 117 | x = gelu(x) 118 | x = self.layer_norm(x) 119 | return self.classifier(x) 120 | 121 | def _tie_weights(self): 122 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 123 | self.bias = self.decoder.bias 124 | 125 | 126 | class WARPPromptedXLMRobertaForSequenceClassification(nn.Module): 127 | def __init__( 128 | self, 129 | pretrained_backbone_path, 130 | n_prompts, 131 | seed_token_id_for_prompts_embeddings, 132 | mask_token_id, 133 | token_ids_for_classification_head, 134 | pretrained_prompts_path=None, 135 | pretrained_classifier_path=None, 136 | ): 137 | """ 138 | pretrained_backbone_path: str, path to or name of backbone model, e.g. xlm-roberta-large; 139 | n_prompts: int, number of prompts; 140 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token; 141 | mask_token_id: int, token id for mask token, 250001 for huggingface xlm-roberta model; 142 | token_ids_for_classification_head: list of int, used for initilize classifier weights; 143 | pretrained_prompts_path: str or None, path to pretrained prompts; 144 | pretrained_classifier_path: str or None, path to pretrained classifier layer. 145 | """ 146 | super(WARPPromptedXLMRobertaForSequenceClassification, self).__init__() 147 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path) 148 | self.n_prompts = n_prompts 149 | self.mask_token_id = mask_token_id 150 | # freeze backbone model 151 | for _, p in self.backbone.named_parameters(): 152 | p.requires_grad = False 153 | 154 | # modify embedding layer for soft prompts 155 | hidden_size = self.backbone.config.hidden_size 156 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings 157 | prompted_word_embeddings = PromptedWordEmbeddings( 158 | original_word_embeddings, 159 | n_prompts, 160 | hidden_size, 161 | seed_token_id_for_prompts_embeddings, 162 | ) 163 | if pretrained_prompts_path is not None: 164 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 165 | pretrained_prompts_path 166 | ) 167 | 168 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings 169 | 170 | # classification head 171 | weights4lm_head = torch.nn.Parameter( 172 | self.backbone.roberta.embeddings.word_embeddings.ori_emb.weight[ 173 | token_ids_for_classification_head 174 | ] 175 | ) 176 | prediction_dim = len(token_ids_for_classification_head) 177 | self.classification_head = XLMRobertaClassificationHead( 178 | self.backbone.lm_head, weights4lm_head, hidden_size, prediction_dim 179 | ) 180 | if pretrained_classifier_path is not None: 181 | self.classification_head.load_from_pretrained_classifier( 182 | pretrained_classifier_path 183 | ) 184 | 185 | # remove original lm_head 186 | del self.backbone.lm_head 187 | 188 | def forward(self, input_ids, attention_mask): 189 | before_classifier = self.backbone.roberta(input_ids, attention_mask)[0] 190 | mask_token_locations = torch.where(input_ids == self.mask_token_id) 191 | return self.classification_head(before_classifier[mask_token_locations]) 192 | 193 | 194 | class WARPPromptedXLMRobertaForQuestionAnswering(nn.Module): 195 | def __init__( 196 | self, 197 | pretrained_backbone_path, 198 | n_prompts, 199 | seed_token_id_for_prompts_embeddings, 200 | pretrained_prompts_path=None, 201 | freeze_qa_outputs_layer=True, 202 | ): 203 | """ 204 | pretrained_backbone_path: str, path to or name of backbone model, e.g. xlm-roberta-large; 205 | n_prompts: int, number of prompts; 206 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token. 207 | """ 208 | super(WARPPromptedXLMRobertaForQuestionAnswering, self).__init__() 209 | self.backbone = AutoModelForQuestionAnswering.from_pretrained( 210 | pretrained_backbone_path 211 | ) 212 | self.n_prompts = n_prompts 213 | # freeze backbone model and/except the final qa_output layer 214 | for n, p in self.backbone.named_parameters(): 215 | p.requires_grad = False 216 | if "qa_outputs" in n and not freeze_qa_outputs_layer: 217 | p.requires_grad = True 218 | 219 | hidden_size = self.backbone.config.hidden_size 220 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings 221 | prompted_word_embeddings = PromptedWordEmbeddings( 222 | original_word_embeddings, 223 | n_prompts, 224 | hidden_size, 225 | seed_token_id_for_prompts_embeddings, 226 | ) 227 | if pretrained_prompts_path is not None: 228 | prompted_word_embeddings.load_from_pretrained_soft_prompts( 229 | pretrained_prompts_path 230 | ) 231 | logger.info( 232 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}" 233 | ) 234 | 235 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings 236 | 237 | def forward( 238 | self, input_ids, attention_mask, start_positions=None, end_positions=None 239 | ): 240 | return self.backbone( 241 | input_ids, 242 | attention_mask, 243 | start_positions=start_positions, 244 | end_positions=end_positions, 245 | ) 246 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.11.3 2 | pandas==1.3.4 3 | tqdm==4.62.3 4 | datasets==1.13.3 5 | numpy==1.21.2 6 | torch==1.7.1 7 | scikit-learn==0.24.2 -------------------------------------------------------------------------------- /soft_prompts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class PromptedWordEmbeddings(nn.Module): 10 | def __init__( 11 | self, 12 | original_word_embeddings, 13 | n_prompts, 14 | hidden_size, 15 | seed_token_id_for_prompts_embeddings, 16 | ): 17 | """ 18 | original_word_embeddings: word embedding layer from backbone transformer model; 19 | n_prompts: int, number of soft prompts; 20 | hidden_size: int, should be same as backbone transformer model hidden size for embedding; 21 | seed_token_id_for_prompts_embeddings: soft prompts will be initialized with the weights of this token embedding, usually use mask token. 22 | """ 23 | super(PromptedWordEmbeddings, self).__init__() 24 | self.ori_emb = original_word_embeddings 25 | self.n_prompts = n_prompts 26 | self.soft_prompts = ( 27 | torch.zeros(n_prompts, hidden_size) 28 | + original_word_embeddings.weight[seed_token_id_for_prompts_embeddings] 29 | .clone() 30 | .detach() 31 | ) 32 | self.soft_prompts = nn.Parameter(self.soft_prompts, requires_grad=True) 33 | logger.info( 34 | f"initialized soft prompts with dimension: {self.soft_prompts.shape}" 35 | ) 36 | 37 | def load_from_pretrained_soft_prompts(self, pretrained_prompts_path): 38 | pretrained_soft_prompts = torch.load(f"{pretrained_prompts_path}/prompts.pt") 39 | if pretrained_soft_prompts.shape[0] == self.n_prompts: 40 | self.soft_prompts = pretrained_soft_prompts 41 | logger.info( 42 | f"loaded pretrained soft prompts from {pretrained_prompts_path}" 43 | ) 44 | else: 45 | raise Exception( 46 | f"pretrained soft prompts dimension: {pretrained_soft_prompts.shape}, but initialized with {self.soft_prompts.shape}" 47 | ) 48 | 49 | def save_pretrained_soft_prompts(self, save_directory): 50 | path = os.path.join(save_directory, "prompts.pt") 51 | if not os.path.isdir(save_directory): 52 | os.mkdir(save_directory) 53 | torch.save(self.soft_prompts, path) 54 | logger.info(f"saved trained soft prompts at {save_directory}") 55 | 56 | def forward(self, prepadded_input_ids): 57 | """ 58 | prepadded_input_ids: input_ids after tokenization + prepadded tensors as placeholder for prompts 59 | e.g. torch.cat([torch.full((features["input_ids"].shape[0], n_prompts), 0), features['input_ids']], 1) 60 | """ 61 | emb = self.ori_emb(prepadded_input_ids[:, self.n_prompts :]) 62 | expanded_prompts = self.soft_prompts.repeat(prepadded_input_ids.shape[0], 1, 1) 63 | return torch.cat([expanded_prompts, emb], 1) 64 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | 6 | def set_seed(seed=0): 7 | np.random.seed(seed) 8 | random.seed(seed) 9 | torch.manual_seed(seed) 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | return True 13 | 14 | 15 | # generate randomly masked input_ids for MLM task 16 | # modified from https://towardsdatascience.com/masked-language-modelling-with-bert-7d49793e5d2c 17 | def random_mask_input_ids(input_ids, mask_token_id, exceptions, prob=0.15): 18 | """ 19 | exceptions: list, token ids that should not be masked 20 | """ 21 | probs = torch.rand(input_ids.shape) 22 | mask = probs < prob 23 | for ex_id in exceptions: 24 | mask = mask * (input_ids != ex_id) 25 | selection = [] 26 | for i in range(input_ids.shape[0]): 27 | selection.append(torch.flatten(mask[i].nonzero()).tolist()) 28 | for i in range(input_ids.shape[0]): 29 | input_ids[i, selection[i]] = mask_token_id 30 | return input_ids 31 | --------------------------------------------------------------------------------