├── .gitignore
├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── LICENSE.txt
├── README.md
├── SECURITY.md
├── __init__.py
├── examples
└── train_warp_mnli.py
├── figures
└── overture_logo.png
├── img
├── eval_acc_warp_roberta_mnli.png
└── train_loss_warp_roberta_mnli.png
├── models
├── __init__.py
├── modeling_bert.py
├── modeling_roberta.py
└── modeling_xlm_roberta.py
├── requirements.txt
├── soft_prompts.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Others
132 | *.DS_Store
133 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2 | #ECCN:Open Source
3 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Salesforce Open Source Community Code of Conduct
2 |
3 | ## About the Code of Conduct
4 |
5 | Equality is a core value at Salesforce. We believe a diverse and inclusive
6 | community fosters innovation and creativity, and are committed to building a
7 | culture where everyone feels included.
8 |
9 | Salesforce open-source projects are committed to providing a friendly, safe, and
10 | welcoming environment for all, regardless of gender identity and expression,
11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
12 | race, age, religion, level of experience, education, socioeconomic status, or
13 | other similar personal characteristics.
14 |
15 | The goal of this code of conduct is to specify a baseline standard of behavior so
16 | that people with different social values and communication styles can work
17 | together effectively, productively, and respectfully in our open source community.
18 | It also establishes a mechanism for reporting issues and resolving conflicts.
19 |
20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior
21 | in a Salesforce open-source project may be reported by contacting the Salesforce
22 | Open Source Conduct Committee at ossconduct@salesforce.com.
23 |
24 | ## Our Pledge
25 |
26 | In the interest of fostering an open and welcoming environment, we as
27 | contributors and maintainers pledge to making participation in our project and
28 | our community a harassment-free experience for everyone, regardless of gender
29 | identity and expression, sexual orientation, disability, physical appearance,
30 | body size, ethnicity, nationality, race, age, religion, level of experience, education,
31 | socioeconomic status, or other similar personal characteristics.
32 |
33 | ## Our Standards
34 |
35 | Examples of behavior that contributes to creating a positive environment
36 | include:
37 |
38 | * Using welcoming and inclusive language
39 | * Being respectful of differing viewpoints and experiences
40 | * Gracefully accepting constructive criticism
41 | * Focusing on what is best for the community
42 | * Showing empathy toward other community members
43 |
44 | Examples of unacceptable behavior by participants include:
45 |
46 | * The use of sexualized language or imagery and unwelcome sexual attention or
47 | advances
48 | * Personal attacks, insulting/derogatory comments, or trolling
49 | * Public or private harassment
50 | * Publishing, or threatening to publish, others' private information—such as
51 | a physical or electronic address—without explicit permission
52 | * Other conduct which could reasonably be considered inappropriate in a
53 | professional setting
54 | * Advocating for or encouraging any of the above behaviors
55 |
56 | ## Our Responsibilities
57 |
58 | Project maintainers are responsible for clarifying the standards of acceptable
59 | behavior and are expected to take appropriate and fair corrective action in
60 | response to any instances of unacceptable behavior.
61 |
62 | Project maintainers have the right and responsibility to remove, edit, or
63 | reject comments, commits, code, wiki edits, issues, and other contributions
64 | that are not aligned with this Code of Conduct, or to ban temporarily or
65 | permanently any contributor for other behaviors that they deem inappropriate,
66 | threatening, offensive, or harmful.
67 |
68 | ## Scope
69 |
70 | This Code of Conduct applies both within project spaces and in public spaces
71 | when an individual is representing the project or its community. Examples of
72 | representing a project or community include using an official project email
73 | address, posting via an official social media account, or acting as an appointed
74 | representative at an online or offline event. Representation of a project may be
75 | further defined and clarified by project maintainers.
76 |
77 | ## Enforcement
78 |
79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
80 | reported by contacting the Salesforce Open Source Conduct Committee
81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated
82 | and will result in a response that is deemed necessary and appropriate to the
83 | circumstances. The committee is obligated to maintain confidentiality with
84 | regard to the reporter of an incident. Further details of specific enforcement
85 | policies may be posted separately.
86 |
87 | Project maintainers who do not follow or enforce the Code of Conduct in good
88 | faith may face temporary or permanent repercussions as determined by other
89 | members of the project's leadership and the Salesforce Open Source Conduct
90 | Committee.
91 |
92 | ## Attribution
93 |
94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
98 |
99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
100 |
101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
102 | [golang-coc]: https://golang.org/conduct
103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
106 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, Salesforce.com, Inc.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11 |
12 | 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Project Overture - A Prompt-Tuning Library for Researchers
2 | > Why name it **Overture**? An overture in music is an orchestral piece at the beginning which sets the mood and tone for what's about to come. We think of prompt tuning as analogous to that; in the types of prompt tuning methods we consider, a prompt is prepended to the input that sets the tone for the downstream task.
3 |
4 |
5 |
6 |
7 | # Introduction
8 | Prompt Tuning has recently become an important research direction in Natural Language Processing. In contrast to classical fine-tuning, which involves optimizing the weights of the entire network, (one style of) prompt tuning keeps the large language model (a.k.a. the "backbone") frozen and instead prepends a few learnable vectors to each input which are learnt in order to accomplish a task. This brings the number of parameters to train from O(millions) down to a few thousand while still achieving similar levels of performance. There are other benefits that have been found in the research community for prompt tuned models when compared to classically trained models.
9 |
10 | # Methods Supported
11 | The repository leverages the HuggingFace Transformers repository and currently, we support [WARP-like](https://arxiv.org/abs/2101.00121) prompt-tuning for masked language modeling(MLM), text classification models, and extractive question answering (e.g., SQuAD). We plan on adding support for [Seq2Seq prompt-tuning](https://arxiv.org/abs/2104.08691v1) soon. If there is any other algorithm/method that you would like for us to prioritize, please write to us or file a feature request. Finally, we refer an interested reader to the [excellent survey](http://pretrain.nlpedia.ai/) on the topic for the various types of prompt tuning methods and their history.
12 |
13 | # Some Potential Extensions
14 | Here are some research ideas one could experiment with our codebase. Since the community is evolving rapidly, it is entirely possible that some of these ideas have already been studied. Please file an issue if that is the case, or if you want to contribute more ideas.
15 |
16 | 1. Does prompt tuning on a multilingual backbone (e.g., mBERT or XLM) lead to models that can perform cross-lingual zero-shot transfer?
17 | 2. How can we make the prompts more interpretable? Could adding a loss to make the prompt vectors be close to existing word embeddings help?
18 | 3. Can prompts learned for BERT-Large help learn prompts for RoBERTa-Large?
19 |
20 | # Design Choices & Other Similar Libraries
21 |
22 | Fundamentally, we designed the repository for researchers to easily experiment with ideas within the realm of prompt-tuning. As such, we intentionally do not abstract away the sub-components. The repository is intended to be a fork-and-edit library and is designed to be easily extensible for the kinds of projects we anticipated people to use the library for.
23 |
24 | A recently released library, [OpenPrompt](https://github.com/thunlp/OpenPrompt), is also intended to be a library for prompt tuning and we refer an interested practitioner to their repository for further exploration and comparisons. OpenPrompt may be a better fit for those who seek greater abstraction.
25 |
26 | # How to Use
27 | Inside the examples folder, we provide training code for RoBERTa-Large model on the MNLI dataset (in the style of [WARP](https://arxiv.org/abs/2101.00121)). To start training:
28 | ```bash
29 | CUDA_VISIBLE_DEVICES=0 python train_warp_mnli.py --save_prompts_path dir_to_save_prompts --save_classifier_path dir_to_save_classifier
30 | ```
31 |
32 | After training, user should expect the model performance (accuracy) to be 87-89, which matches the original [WARP](https://arxiv.org/abs/2101.00121) paper results. The curve of training loss and evaluation of validation set from one run is shown below.
33 |
34 | Training Loss | Evaluation Accuracy on Validation Set
35 | :-------------------------:|:-------------------------:
36 |  | 
37 |
38 | ### Dev environment
39 | - Python 3.8.5
40 | - A-100 GPU, CUDA Version: 11.0
41 | - Other dependencies: [requirements.txt](./requirements.txt)
42 |
43 | ### API
44 | ```python
45 | # importing RoBERTa based API
46 | from models.modeling_roberta import WARPPromptedRobertaForMaskedLM, WARPPromptedRobertaForSequenceClassification, WARPPromptedRobertaForQuestionAnswering
47 | # importing Bert based API
48 | from models.modeling_bert import WARPPromptedBertForMaskedLM, WARPPromptedBertForSequenceClassification, WARPPromptedBertForQuestionAnswering
49 | # importing XLM-RoBERTa based API
50 | from models.modeling_roberta import WARPPromptedXLMRobertaForMaskedLM, WARPPromptedXLMRobertaForSequenceClassification, WARPPromptedXLMRobertaForQuestionAnswering
51 | # importing function for randomly masking inputs
52 | from utils import random_mask_input_ids
53 |
54 | # initialize model for MNLI task
55 | model = WARPPromptedRobertaForSequenceClassification(
56 | pretrained_backbone_path = "roberta-large",
57 | n_prompts = 8,
58 | seed_token_id_for_prompts_embeddings = 50264, # token id for ""
59 | mask_token_id = 50264,
60 | token_ids_for_classification_head = [1342, 12516, 10800], # 'ent', 'neutral', 'cont'
61 | pretrained_prompts_path = None,
62 | pretrained_classifier_path = None
63 | )
64 |
65 | # initialize model for masked language modeling (MLM)
66 | model = WARPPromptedRobertaForMaskedLM(
67 | pretrained_backbone_path = "roberta-large",
68 | n_prompts = 8,
69 | seed_token_id_for_prompts_embeddings = 50264,
70 | pretrained_prompts_path = None
71 | )
72 |
73 | # prepad input ids before feeding into model
74 | features = tokenizer([str_1, str_2, ..., str_n], return_tensors='pt', truncation=True, padding=True)
75 | features["input_ids"] = torch.cat([torch.full((features["input_ids"].shape[0], n_prompts), 0), features['input_ids']], 1)
76 |
77 | # randomly mask input ids for MLM task
78 | features['input_ids'] = random_mask_input_ids(features['input_ids'], mask_token_id, prob = .15)
79 |
80 | # initialize model for question answering (QA)
81 | model = WARPPromptedRobertaForQuestionAnswering(
82 | pretrained_backbone_path = "roberta-large",
83 | n_prompts = 4,
84 | seed_token_id_for_prompts_embeddings = 50264,
85 | pretrained_prompts_path = None,
86 | freeze_qa_outputs_layer = False,
87 | )
88 | ```
89 |
90 | # Reference
91 | - [WARP: Word-level Adversarial ReProgramming](https://aclanthology.org/2021.acl-long.381.pdf)
92 | - [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/abs/2104.08691v1)
93 |
94 |
95 | # Contact
96 | Please contact [Jin Qu](mailto:jqu@salesforce.com) if you are interested in collaboration, internship opportunities, or discussions. Feel free to create issues if you discover a bug or want to request new features for future release.
97 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | ## Security
2 |
3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
4 | as soon as it is discovered. This library limits its runtime dependencies in
5 | order to reduce the total cost of ownership as much as can be, but all consumers
6 | should remain vigilant and have their security stakeholders review all third-party
7 | products (3PP) like this one and their dependencies.
8 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/__init__.py
--------------------------------------------------------------------------------
/examples/train_warp_mnli.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 |
3 | from torch.utils.data import DataLoader
4 | from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
5 | import tqdm
6 | import logging
7 | import math
8 | import pandas as pd
9 | import argparse
10 | import random
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | import copy
18 | import json
19 | import sys
20 |
21 | sys.path.append("../")
22 | from utils import set_seed
23 | from models.modeling_roberta import (
24 | WARPPromptedRobertaForMaskedLM,
25 | WARPPromptedRobertaForSequenceClassification,
26 | )
27 |
28 | logging.basicConfig(
29 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
30 | datefmt="%m/%d/%Y %H:%M:%S",
31 | level=logging.INFO,
32 | )
33 | logger = logging.getLogger(__name__)
34 |
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument("--n_prompts", default=20, type=int)
37 | parser.add_argument("--backbone_model", default="roberta-large", type=str)
38 | parser.add_argument("--pretrained_prompts_path", type=str)
39 | parser.add_argument("--pretrained_classifier_path", type=str)
40 | parser.add_argument("--save_prompts_path", type=str)
41 | parser.add_argument("--save_classifier_path", type=str)
42 | parser.add_argument("--seed", default=42, type=int)
43 | parser.add_argument("--train_batch_size", default=16, type=int)
44 | parser.add_argument("--eval_batch_size", default=32, type=int)
45 | parser.add_argument("--total_train_steps", default=250000, type=int)
46 | parser.add_argument(
47 | "--seed_token_id_for_prompts_embeddings", default=50264, type=int
48 | ) # used token id for ""
49 | parser.add_argument("--mask_token_id", default=50264, type=int)
50 | parser.add_argument(
51 | "--token_ids_for_classification_head",
52 | action="append",
53 | dest="token_ids_for_classification_head",
54 | default=[1342, 12516, 10800],
55 | ) #'ent', 'neutral', 'cont'
56 | parser.add_argument("--cycle", default=50000, type=int)
57 | parser.add_argument("--warmup_proportion", default=0.06, type=float)
58 | parser.add_argument("--patience", default=50, type=int)
59 | parser.add_argument("--learning_rate", default=0.003, type=float)
60 | parser.add_argument("--max_seq_length", default=512, type=int)
61 | parser.add_argument(
62 | "--gradient_accumulation_steps",
63 | type=int,
64 | default=1,
65 | help="Number of updates steps to accumulate before performing a backward/update pass.",
66 | )
67 | parser.add_argument(
68 | "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
69 | )
70 |
71 | parser.add_argument("--use_tensorboard", action="store_true")
72 | parser.add_argument("--tensorboard_dir", type=str, default="./runs_mnli_cls")
73 |
74 | parser.add_argument("--tmult", default=1, type=int)
75 | parser.add_argument("--comment", default="", type=str)
76 | args = parser.parse_args()
77 |
78 | # set seed
79 | set_seed(args.seed)
80 |
81 | # load data
82 | dataset = load_dataset("multi_nli", split="train")
83 | dataset = pd.DataFrame.from_dict(
84 | {
85 | "premise": [d["premise"] for d in dataset],
86 | "hypothesis": [d["hypothesis"] for d in dataset],
87 | "gold_label": [
88 | ("entailment", "neutral", "contradiction")[d["label"]] for d in dataset
89 | ],
90 | }
91 | )
92 | valid_dataset = load_dataset("multi_nli", split="validation_matched")
93 | valid_dataset = pd.DataFrame.from_dict(
94 | {
95 | "premise": [d["premise"] for d in valid_dataset],
96 | "hypothesis": [d["hypothesis"] for d in valid_dataset],
97 | "gold_label": [
98 | ("entailment", "neutral", "contradiction")[d["label"]]
99 | for d in valid_dataset
100 | ],
101 | }
102 | )
103 |
104 | label_dict = {"entailment": 0, "neutral": 1, "contradiction": 2}
105 |
106 | # load tokenizer
107 | tokenizer = AutoTokenizer.from_pretrained(args.backbone_model)
108 |
109 | # set device
110 | device = torch.device(
111 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
112 | )
113 | print(f"device: {device}")
114 |
115 | # initialize model
116 | model = WARPPromptedRobertaForSequenceClassification(
117 | pretrained_backbone_path=args.backbone_model,
118 | n_prompts=args.n_prompts,
119 | seed_token_id_for_prompts_embeddings=args.seed_token_id_for_prompts_embeddings,
120 | mask_token_id=args.mask_token_id,
121 | token_ids_for_classification_head=args.token_ids_for_classification_head,
122 | pretrained_prompts_path=args.pretrained_prompts_path,
123 | pretrained_classifier_path=args.pretrained_classifier_path,
124 | )
125 |
126 |
127 | # move model to device
128 | model.to(device)
129 |
130 | # set up scheduler
131 | optim = torch.optim.Adam(
132 | list(filter(lambda x: x.requires_grad, model.parameters())), lr=args.learning_rate
133 | )
134 | warmup = int(args.warmup_proportion * args.cycle)
135 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
136 | optim, args.cycle, T_mult=args.tmult
137 | )
138 |
139 | middle_string = " "
140 |
141 |
142 | def validate():
143 | model.eval()
144 | total_seen = 0
145 | total_correct = 0
146 | with torch.no_grad():
147 | for _ in tqdm.tqdm(range(0, valid_dataset.shape[0], args.eval_batch_size)):
148 | exs = list(range(_, _ + args.eval_batch_size))
149 | if exs[-1] >= valid_dataset.shape[0] or len(exs) < 2:
150 | print("BREAKING!")
151 | break
152 | total_seen += len(exs)
153 |
154 | valid_features = tokenizer(
155 | [
156 | valid_dataset.premise.iloc[i]
157 | + "?"
158 | + middle_string
159 | + valid_dataset.hypothesis.iloc[i]
160 | for i in exs
161 | ],
162 | return_tensors="pt",
163 | truncation=True,
164 | padding=True,
165 | max_length=args.max_seq_length - args.n_prompts,
166 | )
167 | valid_features["input_ids"] = torch.cat(
168 | [
169 | torch.full(
170 | (valid_features["input_ids"].shape[0], args.n_prompts), 0
171 | ),
172 | valid_features["input_ids"],
173 | ],
174 | 1,
175 | )
176 | valid_features["input_ids"] = valid_features["input_ids"].to(device)
177 | valid_features["attention_mask"] = torch.cat(
178 | [
179 | torch.full(
180 | (valid_features["input_ids"].shape[0], args.n_prompts), 1
181 | ),
182 | valid_features["attention_mask"],
183 | ],
184 | 1,
185 | )
186 | valid_features["attention_mask"] = valid_features["attention_mask"].to(
187 | device
188 | )
189 |
190 | lbl = (
191 | torch.Tensor(
192 | [label_dict[valid_dataset.gold_label.iloc[i]] for i in exs]
193 | )
194 | .type(torch.int64)
195 | .cuda()
196 | )
197 | bleh = model(**valid_features)
198 | total_correct += torch.sum(bleh.argmax(1) == lbl).item()
199 | acc = total_correct / total_seen * 100.0
200 | print({"val_acc": acc})
201 | return acc
202 |
203 |
204 | # set up tensorboard
205 | if args.use_tensorboard:
206 | writer = SummaryWriter(log_dir=f"{args.tensorboard_dir}/{args.comment}")
207 |
208 | # training loop
209 | model.eval()
210 | optim.zero_grad()
211 | best_model = None
212 | best_acc = 0.0
213 | cnt = 0
214 |
215 | for step in tqdm.tqdm(range(args.total_train_steps)):
216 | exs = random.sample(range(0, dataset.shape[0]), args.train_batch_size)
217 | train_features = tokenizer(
218 | [
219 | dataset.premise.iloc[i] + "?" + middle_string + dataset.hypothesis.iloc[i]
220 | for i in exs
221 | ],
222 | return_tensors="pt",
223 | truncation=True,
224 | padding=True,
225 | max_length=args.max_seq_length - args.n_prompts,
226 | )
227 | train_features["input_ids"] = torch.cat(
228 | [
229 | torch.full((train_features["input_ids"].shape[0], args.n_prompts), 0),
230 | train_features["input_ids"],
231 | ],
232 | 1,
233 | )
234 | train_features["input_ids"] = train_features["input_ids"].to(device)
235 | train_features["attention_mask"] = torch.cat(
236 | [
237 | torch.full((train_features["input_ids"].shape[0], args.n_prompts), 1),
238 | train_features["attention_mask"],
239 | ],
240 | 1,
241 | )
242 | train_features["attention_mask"] = train_features["attention_mask"].to(device)
243 |
244 | lbl = (
245 | torch.Tensor([label_dict[dataset.gold_label.iloc[i]] for i in exs])
246 | .type(torch.int64)
247 | .to(device)
248 | )
249 |
250 | bleh = model(**train_features)
251 |
252 | loss = (
253 | torch.nn.functional.cross_entropy(bleh, lbl) / args.gradient_accumulation_steps
254 | )
255 | loss.backward()
256 | if (step % args.gradient_accumulation_steps) == 0:
257 | optim.step()
258 | scheduler.step()
259 | optim.zero_grad()
260 |
261 | if step % 1000 == 0:
262 | acc = validate()
263 | if acc > best_acc:
264 | best_acc = acc
265 | cnt = 0
266 | best_soft_embedding = copy.deepcopy(
267 | model.backbone.roberta.embeddings.word_embeddings
268 | )
269 | best_classification_head = copy.deepcopy(model.classification_head)
270 | else:
271 | cnt += 1
272 |
273 | # log into tensorboard
274 | if args.use_tensorboard:
275 | writer.add_scalar(
276 | "avg train loss", loss.item() / args.train_batch_size, step
277 | )
278 | writer.add_scalar("acc val", acc, step)
279 |
280 | if cnt >= args.patience:
281 | best_soft_embedding.save_pretrained_soft_prompts(args.save_prompts_path)
282 | best_classification_head.save_pretrained_classifier(
283 | args.save_classifier_path
284 | )
285 | raise Exception("running out of patience!")
286 |
--------------------------------------------------------------------------------
/figures/overture_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/figures/overture_logo.png
--------------------------------------------------------------------------------
/img/eval_acc_warp_roberta_mnli.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/img/eval_acc_warp_roberta_mnli.png
--------------------------------------------------------------------------------
/img/train_loss_warp_roberta_mnli.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/img/train_loss_warp_roberta_mnli.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/Overture/6ea974b495f35ba586540f737791204dce8cda35/models/__init__.py
--------------------------------------------------------------------------------
/models/modeling_bert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import logging
4 | import torch
5 | import torch.nn as nn
6 | from transformers import (
7 | AutoModelForMaskedLM,
8 | AutoTokenizer,
9 | AutoModelForQuestionAnswering,
10 | )
11 |
12 | import sys
13 |
14 | sys.path.append("../")
15 | from soft_prompts import PromptedWordEmbeddings
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def gelu(x):
21 | return (
22 | 0.5
23 | * x
24 | * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
25 | )
26 |
27 |
28 | # 1, can be used for MLM
29 | # 2, can also be used for MLM-style classification without additional classification layer, like "The Power of Scale for Parameter-Efficient Prompt Tuning" paper
30 | class WARPPromptedBertForMaskedLM(nn.Module):
31 | def __init__(
32 | self,
33 | pretrained_backbone_path,
34 | n_prompts,
35 | seed_token_id_for_prompts_embeddings,
36 | pretrained_prompts_path=None,
37 | ):
38 | """
39 | pretrained_backbone_path: str, path to or name of backbone model, e.g. bert-large-uncased;
40 | n_prompts: int, number of prompts;
41 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token.
42 | """
43 | super(WARPPromptedBertForMaskedLM, self).__init__()
44 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path)
45 | self.n_prompts = n_prompts
46 | # freeze backbone model
47 | for _, p in self.backbone.named_parameters():
48 | p.requires_grad = False
49 |
50 | hidden_size = self.backbone.config.hidden_size
51 | original_word_embeddings = self.backbone.bert.embeddings.word_embeddings
52 | prompted_word_embeddings = PromptedWordEmbeddings(
53 | original_word_embeddings,
54 | n_prompts,
55 | hidden_size,
56 | seed_token_id_for_prompts_embeddings,
57 | )
58 | if pretrained_prompts_path is not None:
59 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
60 | pretrained_prompts_path
61 | )
62 | logger.info(
63 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}"
64 | )
65 |
66 | self.backbone.bert.embeddings.word_embeddings = prompted_word_embeddings
67 |
68 | def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
69 | return self.backbone(input_ids, attention_mask, token_type_ids, labels=labels)
70 |
71 |
72 | # classification head modified from https://github.com/huggingface/transformers/blob/5e3b4a70d3d17f2482d50aea230f7ed42b3a8fd0/src/transformers/models/bert/modeling_bert.py#L663
73 | class BertClassificationHead(nn.Module):
74 | """Bert Head for masked language modeling."""
75 |
76 | def __init__(self, ori_lm_head, weight_tensors, hidden_size, prediction_dim):
77 | """
78 | ori_lm_head: original lm_head from bert model, can be accessed by model.cls;
79 | weight_tensors: initialize final classifier layer with the specified weight tensors, usually from verbalzier token embeddings;
80 | hidden_size, int, backbone model hidden size;
81 | prediction_dim, int, output dimension of classifier layer.
82 | """
83 | super().__init__()
84 |
85 | self.classifier = torch.nn.Linear(hidden_size, prediction_dim, bias=True)
86 | self.classifier.weight = weight_tensors
87 |
88 | self.ori_lm_head = ori_lm_head
89 | self.ori_lm_head.predictions.decoder = self.classifier
90 |
91 | def load_from_pretrained_classifier(self, pretrained_classifier_path):
92 | path = os.path.join(pretrained_classifier_path, "classifier.pt")
93 | pretrained_classifier = torch.load(path)
94 | if (
95 | pretrained_classifier.weight.shape == self.classifier.weight.shape
96 | and pretrained_classifier.bias.shape == self.classifier.bias.shape
97 | ):
98 | self.classifier = pretrained_classifier
99 | logger.info(
100 | f"loaded pretrained classifier from {pretrained_classifier_path}"
101 | )
102 | else:
103 | raise Exception(
104 | f"pretrained classifier weights dimension: {pretrained_classifier.weight.shape}, bias dimension: {pretrained_classifier.bias.shape} \
105 | but classifier initialized with {self.classifier.weight.shape} and {self.classifier.bias.shape}"
106 | )
107 |
108 | def save_pretrained_classifier(self, save_directory):
109 | path = os.path.join(save_directory, "classifier.pt")
110 | if not os.path.isdir(save_directory):
111 | os.mkdir(save_directory)
112 | torch.save(self.classifier, path)
113 | logger.info(f"saved trained classifier at: {save_directory}")
114 |
115 | def forward(self, features, **kwargs):
116 | x = self.classifier(features)
117 | return x
118 |
119 |
120 | class WARPPromptedBertForSequenceClassification(nn.Module):
121 | def __init__(
122 | self,
123 | pretrained_backbone_path,
124 | n_prompts,
125 | seed_token_id_for_prompts_embeddings,
126 | mask_token_id,
127 | token_ids_for_classification_head,
128 | pretrained_prompts_path=None,
129 | pretrained_classifier_path=None,
130 | ):
131 | """
132 | pretrained_backbone_path: str, path to or name of backbone model, e.g. bert-large-uncased;
133 | n_prompts: int, number of prompts;
134 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token;
135 | mask_token_id: int, token id for mask token, 103 for huggingface bert-large-uncased model;
136 | token_ids_for_classification_head: list of int, used for initilize classifier weights;
137 | pretrained_prompts_path: str or None, path to pretrained prompts;
138 | pretrained_classifier_path: str or None, path to pretrained classifier layer.
139 | """
140 | super(WARPPromptedBertForSequenceClassification, self).__init__()
141 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path)
142 | self.n_prompts = n_prompts
143 | self.mask_token_id = mask_token_id
144 | # freeze backbone model
145 | for _, p in self.backbone.named_parameters():
146 | p.requires_grad = False
147 |
148 | # modify embedding layer for soft prompts
149 | hidden_size = self.backbone.config.hidden_size
150 | original_word_embeddings = self.backbone.bert.embeddings.word_embeddings
151 | prompted_word_embeddings = PromptedWordEmbeddings(
152 | original_word_embeddings,
153 | n_prompts,
154 | hidden_size,
155 | seed_token_id_for_prompts_embeddings,
156 | )
157 | if pretrained_prompts_path is not None:
158 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
159 | pretrained_prompts_path
160 | )
161 |
162 | self.backbone.bert.embeddings.word_embeddings = prompted_word_embeddings
163 |
164 | # classification head
165 | weights4lm_head = torch.nn.Parameter(
166 | self.backbone.bert.embeddings.word_embeddings.ori_emb.weight[
167 | token_ids_for_classification_head
168 | ]
169 | )
170 | prediction_dim = len(token_ids_for_classification_head)
171 | self.classification_head = BertClassificationHead(
172 | self.backbone.cls, weights4lm_head, hidden_size, prediction_dim
173 | )
174 | if pretrained_classifier_path is not None:
175 | self.classification_head.load_from_pretrained_classifier(
176 | pretrained_classifier_path
177 | )
178 |
179 | # remove original lm_head
180 | del self.backbone.cls
181 |
182 | def forward(self, input_ids, attention_mask, token_type_ids):
183 | before_classifier = self.backbone.bert(
184 | input_ids, attention_mask, token_type_ids
185 | )[0]
186 | mask_token_locations = torch.where(input_ids == self.mask_token_id)
187 | return self.classification_head(before_classifier[mask_token_locations])
188 |
189 |
190 | class WARPPromptedBertForQuestionAnswering(nn.Module):
191 | def __init__(
192 | self,
193 | pretrained_backbone_path,
194 | n_prompts,
195 | seed_token_id_for_prompts_embeddings,
196 | pretrained_prompts_path=None,
197 | freeze_qa_outputs_layer=True,
198 | ):
199 | """
200 | pretrained_backbone_path: str, path to or name of backbone model, e.g. bert-large-uncased;
201 | n_prompts: int, number of prompts;
202 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token.
203 | """
204 | super(WARPPromptedBertForQuestionAnswering, self).__init__()
205 | self.backbone = AutoModelForQuestionAnswering.from_pretrained(
206 | pretrained_backbone_path
207 | )
208 | self.n_prompts = n_prompts
209 | # freeze backbone model and/except the final qa_output layer
210 | for n, p in self.backbone.named_parameters():
211 | p.requires_grad = False
212 | if "qa_outputs" in n and not freeze_qa_outputs_layer:
213 | p.requires_grad = True
214 |
215 | hidden_size = self.backbone.config.hidden_size
216 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings
217 | prompted_word_embeddings = PromptedWordEmbeddings(
218 | original_word_embeddings,
219 | n_prompts,
220 | hidden_size,
221 | seed_token_id_for_prompts_embeddings,
222 | )
223 | if pretrained_prompts_path is not None:
224 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
225 | pretrained_prompts_path
226 | )
227 | logger.info(
228 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}"
229 | )
230 |
231 | self.backbone.bert.embeddings.word_embeddings = prompted_word_embeddings
232 |
233 | def forward(
234 | self,
235 | input_ids,
236 | attention_mask,
237 | token_type_ids,
238 | start_positions=None,
239 | end_positions=None,
240 | ):
241 | return self.backbone(
242 | input_ids,
243 | attention_mask,
244 | token_type_ids,
245 | start_positions=start_positions,
246 | end_positions=end_positions,
247 | )
248 |
--------------------------------------------------------------------------------
/models/modeling_roberta.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import logging
4 | import torch
5 | import torch.nn as nn
6 | from transformers import (
7 | AutoModelForMaskedLM,
8 | AutoTokenizer,
9 | AutoModelForQuestionAnswering,
10 | )
11 |
12 | import sys
13 |
14 | sys.path.append("../")
15 | from soft_prompts import PromptedWordEmbeddings
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def gelu(x):
21 | return (
22 | 0.5
23 | * x
24 | * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
25 | )
26 |
27 |
28 | # 1, can be used for MLM
29 | # 2, can also be used for MLM-style classification without additional classification layer, like "The Power of Scale for Parameter-Efficient Prompt Tuning" paper
30 | class WARPPromptedRobertaForMaskedLM(nn.Module):
31 | def __init__(
32 | self,
33 | pretrained_backbone_path,
34 | n_prompts,
35 | seed_token_id_for_prompts_embeddings,
36 | pretrained_prompts_path=None,
37 | ):
38 | """
39 | pretrained_backbone_path: str, path to or name of backbone model, e.g. roberta-large;
40 | n_prompts: int, number of prompts;
41 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token.
42 | """
43 | super(WARPPromptedRobertaForMaskedLM, self).__init__()
44 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path)
45 | self.n_prompts = n_prompts
46 | # freeze backbone model
47 | for _, p in self.backbone.named_parameters():
48 | p.requires_grad = False
49 |
50 | hidden_size = self.backbone.config.hidden_size
51 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings
52 | prompted_word_embeddings = PromptedWordEmbeddings(
53 | original_word_embeddings,
54 | n_prompts,
55 | hidden_size,
56 | seed_token_id_for_prompts_embeddings,
57 | )
58 | if pretrained_prompts_path is not None:
59 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
60 | pretrained_prompts_path
61 | )
62 | logger.info(
63 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}"
64 | )
65 |
66 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings
67 |
68 | def forward(self, input_ids, attention_mask, labels=None):
69 | return self.backbone(input_ids, attention_mask, labels=labels)
70 |
71 |
72 | # classification head modified from https://github.com/huggingface/transformers/blob/5e3b4a70d3d17f2482d50aea230f7ed42b3a8fd0/src/transformers/models/roberta/modeling_roberta.py#L1123
73 | class RobertaClassificationHead(nn.Module):
74 | """Roberta Head for masked language modeling."""
75 |
76 | def __init__(self, ori_lm_head, weight_tensors, hidden_size, prediction_dim):
77 | """
78 | ori_lm_head: original lm_head from roberta model, can be accessed by model.lm_head;
79 | weight_tensors: initialize final classifier layer with the specified weight tensors, usually from verbalzier token embeddings;
80 | hidden_size, int, backbone model hidden size;
81 | prediction_dim, int, output dimension of classifier layer.
82 | """
83 | super().__init__()
84 | self.dense = ori_lm_head.dense
85 | self.layer_norm = ori_lm_head.layer_norm
86 | self.bias = ori_lm_head.bias
87 |
88 | self.classifier = torch.nn.Linear(hidden_size, prediction_dim, bias=True)
89 | self.classifier.weight = weight_tensors
90 |
91 | def load_from_pretrained_classifier(self, pretrained_classifier_path):
92 | path = os.path.join(pretrained_classifier_path, "classifier.pt")
93 | pretrained_classifier = torch.load(path)
94 | if (
95 | pretrained_classifier.weight.shape == self.classifier.weight.shape
96 | and pretrained_classifier.bias.shape == self.classifier.bias.shape
97 | ):
98 | self.classifier = pretrained_classifier
99 | logger.info(
100 | f"loaded pretrained classifier from {pretrained_classifier_path}"
101 | )
102 | else:
103 | raise Exception(
104 | f"pretrained classifier weights dimension: {pretrained_classifier.weight.shape}, bias dimension: {pretrained_classifier.bias.shape} \
105 | but classifier initialized with {self.classifier.weight.shape} and {self.classifier.bias.shape}"
106 | )
107 |
108 | def save_pretrained_classifier(self, save_directory):
109 | path = os.path.join(save_directory, "classifier.pt")
110 | if not os.path.isdir(save_directory):
111 | os.mkdir(save_directory)
112 | torch.save(self.classifier, path)
113 | logger.info(f"saved trained classifier at: {save_directory}")
114 |
115 | def forward(self, features, **kwargs):
116 | x = self.dense(features)
117 | x = gelu(x)
118 | x = self.layer_norm(x)
119 | return self.classifier(x)
120 |
121 | def _tie_weights(self):
122 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
123 | self.bias = self.decoder.bias
124 |
125 |
126 | class WARPPromptedRobertaForSequenceClassification(nn.Module):
127 | def __init__(
128 | self,
129 | pretrained_backbone_path,
130 | n_prompts,
131 | seed_token_id_for_prompts_embeddings,
132 | mask_token_id,
133 | token_ids_for_classification_head,
134 | pretrained_prompts_path=None,
135 | pretrained_classifier_path=None,
136 | ):
137 | """
138 | pretrained_backbone_path: str, path to or name of backbone model, e.g. roberta-large;
139 | n_prompts: int, number of prompts;
140 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token;
141 | mask_token_id: int, token id for mask token, 50264 for huggingface roberta model;
142 | token_ids_for_classification_head: list of int, used for initilize classifier weights;
143 | pretrained_prompts_path: str or None, path to pretrained prompts;
144 | pretrained_classifier_path: str or None, path to pretrained classifier layer.
145 | """
146 | super(WARPPromptedRobertaForSequenceClassification, self).__init__()
147 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path)
148 | self.n_prompts = n_prompts
149 | self.mask_token_id = mask_token_id
150 | # freeze backbone model
151 | for _, p in self.backbone.named_parameters():
152 | p.requires_grad = False
153 |
154 | # modify embedding layer for soft prompts
155 | hidden_size = self.backbone.config.hidden_size
156 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings
157 | prompted_word_embeddings = PromptedWordEmbeddings(
158 | original_word_embeddings,
159 | n_prompts,
160 | hidden_size,
161 | seed_token_id_for_prompts_embeddings,
162 | )
163 | if pretrained_prompts_path is not None:
164 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
165 | pretrained_prompts_path
166 | )
167 |
168 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings
169 |
170 | # classification head
171 | weights4lm_head = torch.nn.Parameter(
172 | self.backbone.roberta.embeddings.word_embeddings.ori_emb.weight[
173 | token_ids_for_classification_head
174 | ]
175 | )
176 | prediction_dim = len(token_ids_for_classification_head)
177 | self.classification_head = RobertaClassificationHead(
178 | self.backbone.lm_head, weights4lm_head, hidden_size, prediction_dim
179 | )
180 | if pretrained_classifier_path is not None:
181 | self.classification_head.load_from_pretrained_classifier(
182 | pretrained_classifier_path
183 | )
184 |
185 | # remove original lm_head
186 | del self.backbone.lm_head
187 |
188 | def forward(self, input_ids, attention_mask):
189 | before_classifier = self.backbone.roberta(input_ids, attention_mask)[0]
190 | mask_token_locations = torch.where(input_ids == self.mask_token_id)
191 | return self.classification_head(before_classifier[mask_token_locations])
192 |
193 |
194 | class WARPPromptedRobertaForQuestionAnswering(nn.Module):
195 | def __init__(
196 | self,
197 | pretrained_backbone_path,
198 | n_prompts,
199 | seed_token_id_for_prompts_embeddings,
200 | pretrained_prompts_path=None,
201 | freeze_qa_outputs_layer=True,
202 | ):
203 | """
204 | pretrained_backbone_path: str, path to or name of backbone model, e.g. roberta-large;
205 | n_prompts: int, number of prompts;
206 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token.
207 | """
208 | super(WARPPromptedRobertaForQuestionAnswering, self).__init__()
209 | self.backbone = AutoModelForQuestionAnswering.from_pretrained(
210 | pretrained_backbone_path
211 | )
212 | self.n_prompts = n_prompts
213 | # freeze backbone model and/except the final qa_output layer
214 | for n, p in self.backbone.named_parameters():
215 | p.requires_grad = False
216 | if "qa_outputs" in n and not freeze_qa_outputs_layer:
217 | p.requires_grad = True
218 |
219 | hidden_size = self.backbone.config.hidden_size
220 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings
221 | prompted_word_embeddings = PromptedWordEmbeddings(
222 | original_word_embeddings,
223 | n_prompts,
224 | hidden_size,
225 | seed_token_id_for_prompts_embeddings,
226 | )
227 | if pretrained_prompts_path is not None:
228 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
229 | pretrained_prompts_path
230 | )
231 | logger.info(
232 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}"
233 | )
234 |
235 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings
236 |
237 | def forward(
238 | self, input_ids, attention_mask, start_positions=None, end_positions=None
239 | ):
240 | return self.backbone(
241 | input_ids,
242 | attention_mask,
243 | start_positions=start_positions,
244 | end_positions=end_positions,
245 | )
246 |
--------------------------------------------------------------------------------
/models/modeling_xlm_roberta.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import logging
4 | import torch
5 | import torch.nn as nn
6 | from transformers import (
7 | AutoModelForMaskedLM,
8 | AutoTokenizer,
9 | AutoModelForQuestionAnswering,
10 | )
11 |
12 | import sys
13 |
14 | sys.path.append("../")
15 | from soft_prompts import PromptedWordEmbeddings
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def gelu(x):
21 | return (
22 | 0.5
23 | * x
24 | * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
25 | )
26 |
27 |
28 | # 1, can be used for MLM
29 | # 2, can also be used for MLM-style classification without additional classification layer, like "The Power of Scale for Parameter-Efficient Prompt Tuning" paper
30 | class WARPPromptedXLMRobertaForMaskedLM(nn.Module):
31 | def __init__(
32 | self,
33 | pretrained_backbone_path,
34 | n_prompts,
35 | seed_token_id_for_prompts_embeddings,
36 | pretrained_prompts_path=None,
37 | ):
38 | """
39 | pretrained_backbone_path: str, path to or name of backbone model, e.g. xlm-roberta-large;
40 | n_prompts: int, number of prompts;
41 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token.
42 | """
43 | super(WARPPromptedXLMRobertaForMaskedLM, self).__init__()
44 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path)
45 | self.n_prompts = n_prompts
46 | # freeze backbone model
47 | for _, p in self.backbone.named_parameters():
48 | p.requires_grad = False
49 |
50 | hidden_size = self.backbone.config.hidden_size
51 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings
52 | prompted_word_embeddings = PromptedWordEmbeddings(
53 | original_word_embeddings,
54 | n_prompts,
55 | hidden_size,
56 | seed_token_id_for_prompts_embeddings,
57 | )
58 | if pretrained_prompts_path is not None:
59 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
60 | pretrained_prompts_path
61 | )
62 | logger.info(
63 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}"
64 | )
65 |
66 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings
67 |
68 | def forward(self, input_ids, attention_mask, labels=None):
69 | return self.backbone(input_ids, attention_mask, labels=labels)
70 |
71 |
72 | # classification head modified from https://github.com/huggingface/transformers/blob/5e3b4a70d3d17f2482d50aea230f7ed42b3a8fd0/src/transformers/models/roberta/modeling_roberta.py#L1123
73 | class XLMRobertaClassificationHead(nn.Module):
74 | """XLMRoberta Head for masked language modeling."""
75 |
76 | def __init__(self, ori_lm_head, weight_tensors, hidden_size, prediction_dim):
77 | """
78 | ori_lm_head: original lm_head from roberta model, can be accessed by model.lm_head;
79 | weight_tensors: initialize final classifier layer with the specified weight tensors, usually from verbalzier token embeddings;
80 | hidden_size, int, backbone model hidden size;
81 | prediction_dim, int, output dimension of classifier layer.
82 | """
83 | super().__init__()
84 | self.dense = ori_lm_head.dense
85 | self.layer_norm = ori_lm_head.layer_norm
86 | self.bias = ori_lm_head.bias
87 |
88 | self.classifier = torch.nn.Linear(hidden_size, prediction_dim, bias=True)
89 | self.classifier.weight = weight_tensors
90 |
91 | def load_from_pretrained_classifier(self, pretrained_classifier_path):
92 | path = os.path.join(pretrained_classifier_path, "classifier.pt")
93 | pretrained_classifier = torch.load(path)
94 | if (
95 | pretrained_classifier.weight.shape == self.classifier.weight.shape
96 | and pretrained_classifier.bias.shape == self.classifier.bias.shape
97 | ):
98 | self.classifier = pretrained_classifier
99 | logger.info(
100 | f"loaded pretrained classifier from {pretrained_classifier_path}"
101 | )
102 | else:
103 | raise Exception(
104 | f"pretrained classifier weights dimension: {pretrained_classifier.weight.shape}, bias dimension: {pretrained_classifier.bias.shape} \
105 | but classifier initialized with {self.classifier.weight.shape} and {self.classifier.bias.shape}"
106 | )
107 |
108 | def save_pretrained_classifier(self, save_directory):
109 | path = os.path.join(save_directory, "classifier.pt")
110 | if not os.path.isdir(save_directory):
111 | os.mkdir(save_directory)
112 | torch.save(self.classifier, path)
113 | logger.info(f"saved trained classifier at: {save_directory}")
114 |
115 | def forward(self, features, **kwargs):
116 | x = self.dense(features)
117 | x = gelu(x)
118 | x = self.layer_norm(x)
119 | return self.classifier(x)
120 |
121 | def _tie_weights(self):
122 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
123 | self.bias = self.decoder.bias
124 |
125 |
126 | class WARPPromptedXLMRobertaForSequenceClassification(nn.Module):
127 | def __init__(
128 | self,
129 | pretrained_backbone_path,
130 | n_prompts,
131 | seed_token_id_for_prompts_embeddings,
132 | mask_token_id,
133 | token_ids_for_classification_head,
134 | pretrained_prompts_path=None,
135 | pretrained_classifier_path=None,
136 | ):
137 | """
138 | pretrained_backbone_path: str, path to or name of backbone model, e.g. xlm-roberta-large;
139 | n_prompts: int, number of prompts;
140 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token;
141 | mask_token_id: int, token id for mask token, 250001 for huggingface xlm-roberta model;
142 | token_ids_for_classification_head: list of int, used for initilize classifier weights;
143 | pretrained_prompts_path: str or None, path to pretrained prompts;
144 | pretrained_classifier_path: str or None, path to pretrained classifier layer.
145 | """
146 | super(WARPPromptedXLMRobertaForSequenceClassification, self).__init__()
147 | self.backbone = AutoModelForMaskedLM.from_pretrained(pretrained_backbone_path)
148 | self.n_prompts = n_prompts
149 | self.mask_token_id = mask_token_id
150 | # freeze backbone model
151 | for _, p in self.backbone.named_parameters():
152 | p.requires_grad = False
153 |
154 | # modify embedding layer for soft prompts
155 | hidden_size = self.backbone.config.hidden_size
156 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings
157 | prompted_word_embeddings = PromptedWordEmbeddings(
158 | original_word_embeddings,
159 | n_prompts,
160 | hidden_size,
161 | seed_token_id_for_prompts_embeddings,
162 | )
163 | if pretrained_prompts_path is not None:
164 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
165 | pretrained_prompts_path
166 | )
167 |
168 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings
169 |
170 | # classification head
171 | weights4lm_head = torch.nn.Parameter(
172 | self.backbone.roberta.embeddings.word_embeddings.ori_emb.weight[
173 | token_ids_for_classification_head
174 | ]
175 | )
176 | prediction_dim = len(token_ids_for_classification_head)
177 | self.classification_head = XLMRobertaClassificationHead(
178 | self.backbone.lm_head, weights4lm_head, hidden_size, prediction_dim
179 | )
180 | if pretrained_classifier_path is not None:
181 | self.classification_head.load_from_pretrained_classifier(
182 | pretrained_classifier_path
183 | )
184 |
185 | # remove original lm_head
186 | del self.backbone.lm_head
187 |
188 | def forward(self, input_ids, attention_mask):
189 | before_classifier = self.backbone.roberta(input_ids, attention_mask)[0]
190 | mask_token_locations = torch.where(input_ids == self.mask_token_id)
191 | return self.classification_head(before_classifier[mask_token_locations])
192 |
193 |
194 | class WARPPromptedXLMRobertaForQuestionAnswering(nn.Module):
195 | def __init__(
196 | self,
197 | pretrained_backbone_path,
198 | n_prompts,
199 | seed_token_id_for_prompts_embeddings,
200 | pretrained_prompts_path=None,
201 | freeze_qa_outputs_layer=True,
202 | ):
203 | """
204 | pretrained_backbone_path: str, path to or name of backbone model, e.g. xlm-roberta-large;
205 | n_prompts: int, number of prompts;
206 | seed_token_id_for_prompts_embeddings: int, use embedding of a specific token to initialize prompts weights, usually use mask token.
207 | """
208 | super(WARPPromptedXLMRobertaForQuestionAnswering, self).__init__()
209 | self.backbone = AutoModelForQuestionAnswering.from_pretrained(
210 | pretrained_backbone_path
211 | )
212 | self.n_prompts = n_prompts
213 | # freeze backbone model and/except the final qa_output layer
214 | for n, p in self.backbone.named_parameters():
215 | p.requires_grad = False
216 | if "qa_outputs" in n and not freeze_qa_outputs_layer:
217 | p.requires_grad = True
218 |
219 | hidden_size = self.backbone.config.hidden_size
220 | original_word_embeddings = self.backbone.roberta.embeddings.word_embeddings
221 | prompted_word_embeddings = PromptedWordEmbeddings(
222 | original_word_embeddings,
223 | n_prompts,
224 | hidden_size,
225 | seed_token_id_for_prompts_embeddings,
226 | )
227 | if pretrained_prompts_path is not None:
228 | prompted_word_embeddings.load_from_pretrained_soft_prompts(
229 | pretrained_prompts_path
230 | )
231 | logger.info(
232 | f"loaded pretrained soft prompts from: {pretrained_prompts_path}"
233 | )
234 |
235 | self.backbone.roberta.embeddings.word_embeddings = prompted_word_embeddings
236 |
237 | def forward(
238 | self, input_ids, attention_mask, start_positions=None, end_positions=None
239 | ):
240 | return self.backbone(
241 | input_ids,
242 | attention_mask,
243 | start_positions=start_positions,
244 | end_positions=end_positions,
245 | )
246 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.11.3
2 | pandas==1.3.4
3 | tqdm==4.62.3
4 | datasets==1.13.3
5 | numpy==1.21.2
6 | torch==1.7.1
7 | scikit-learn==0.24.2
--------------------------------------------------------------------------------
/soft_prompts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import torch
4 | import torch.nn as nn
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class PromptedWordEmbeddings(nn.Module):
10 | def __init__(
11 | self,
12 | original_word_embeddings,
13 | n_prompts,
14 | hidden_size,
15 | seed_token_id_for_prompts_embeddings,
16 | ):
17 | """
18 | original_word_embeddings: word embedding layer from backbone transformer model;
19 | n_prompts: int, number of soft prompts;
20 | hidden_size: int, should be same as backbone transformer model hidden size for embedding;
21 | seed_token_id_for_prompts_embeddings: soft prompts will be initialized with the weights of this token embedding, usually use mask token.
22 | """
23 | super(PromptedWordEmbeddings, self).__init__()
24 | self.ori_emb = original_word_embeddings
25 | self.n_prompts = n_prompts
26 | self.soft_prompts = (
27 | torch.zeros(n_prompts, hidden_size)
28 | + original_word_embeddings.weight[seed_token_id_for_prompts_embeddings]
29 | .clone()
30 | .detach()
31 | )
32 | self.soft_prompts = nn.Parameter(self.soft_prompts, requires_grad=True)
33 | logger.info(
34 | f"initialized soft prompts with dimension: {self.soft_prompts.shape}"
35 | )
36 |
37 | def load_from_pretrained_soft_prompts(self, pretrained_prompts_path):
38 | pretrained_soft_prompts = torch.load(f"{pretrained_prompts_path}/prompts.pt")
39 | if pretrained_soft_prompts.shape[0] == self.n_prompts:
40 | self.soft_prompts = pretrained_soft_prompts
41 | logger.info(
42 | f"loaded pretrained soft prompts from {pretrained_prompts_path}"
43 | )
44 | else:
45 | raise Exception(
46 | f"pretrained soft prompts dimension: {pretrained_soft_prompts.shape}, but initialized with {self.soft_prompts.shape}"
47 | )
48 |
49 | def save_pretrained_soft_prompts(self, save_directory):
50 | path = os.path.join(save_directory, "prompts.pt")
51 | if not os.path.isdir(save_directory):
52 | os.mkdir(save_directory)
53 | torch.save(self.soft_prompts, path)
54 | logger.info(f"saved trained soft prompts at {save_directory}")
55 |
56 | def forward(self, prepadded_input_ids):
57 | """
58 | prepadded_input_ids: input_ids after tokenization + prepadded tensors as placeholder for prompts
59 | e.g. torch.cat([torch.full((features["input_ids"].shape[0], n_prompts), 0), features['input_ids']], 1)
60 | """
61 | emb = self.ori_emb(prepadded_input_ids[:, self.n_prompts :])
62 | expanded_prompts = self.soft_prompts.repeat(prepadded_input_ids.shape[0], 1, 1)
63 | return torch.cat([expanded_prompts, emb], 1)
64 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 |
5 |
6 | def set_seed(seed=0):
7 | np.random.seed(seed)
8 | random.seed(seed)
9 | torch.manual_seed(seed)
10 | torch.backends.cudnn.deterministic = True
11 | torch.backends.cudnn.benchmark = False
12 | return True
13 |
14 |
15 | # generate randomly masked input_ids for MLM task
16 | # modified from https://towardsdatascience.com/masked-language-modelling-with-bert-7d49793e5d2c
17 | def random_mask_input_ids(input_ids, mask_token_id, exceptions, prob=0.15):
18 | """
19 | exceptions: list, token ids that should not be masked
20 | """
21 | probs = torch.rand(input_ids.shape)
22 | mask = probs < prob
23 | for ex_id in exceptions:
24 | mask = mask * (input_ids != ex_id)
25 | selection = []
26 | for i in range(input_ids.shape[0]):
27 | selection.append(torch.flatten(mask[i].nonzero()).tolist())
28 | for i in range(input_ids.shape[0]):
29 | input_ids[i, selection[i]] = mask_token_id
30 | return input_ids
31 |
--------------------------------------------------------------------------------