├── logs └── init.txt ├── ckpts └── init.txt ├── datasets └── init.txt ├── results └── dummy.txt ├── verbalizer_template ├── conll2003_vb.txt ├── drop_vb.txt ├── mrqa_vb.txt ├── newsqa_vb.txt ├── race_vb.txt ├── record_vb.txt ├── semevalrc_vb.txt ├── squad_vb.txt ├── squadqa_vb.txt ├── tacred_vb.txt ├── webnlg_vb.txt ├── xsum_vb.txt ├── multitask_vb.txt ├── boolq_vb.txt ├── copa_vb.txt ├── rte_vb.txt ├── sst-2_vb.txt ├── wic_vb.txt ├── wsc_vb.txt ├── imdb_vb.txt ├── multirc_vb.txt ├── rte_vb_old.txt ├── wsc_vb_old.txt ├── amazon_vb.txt ├── sst-2_tem.txt ├── amazon_tem.txt ├── copa_vb_old.txt ├── webnlg_tem.txt ├── conll2003_tem.txt ├── wsc_vb_old2.txt ├── ag_news_tem.txt ├── tacred_tem.txt ├── xsum_tem.txt ├── semevalrc_tem.txt ├── anli_vb.txt ├── cb_vb.txt ├── mnli_vb.txt ├── multitask_tem.txt ├── rte_tem.txt ├── drop_tem.txt ├── mrqa_tem.txt ├── imdb_tem.txt ├── newsqa_tem.txt ├── squad_tem.txt ├── squadqa_tem.txt ├── wsc_tem.txt ├── ag_news_vb.txt ├── dbpedia_tem.txt ├── cb_tem.txt ├── anli_tem.txt ├── mnli_tem.txt ├── wic_tem.txt ├── copa_tem.txt ├── race_tem.txt ├── record_tem.txt ├── boolq_tem.txt ├── multirc_tem.txt └── dbpedia_vb.txt ├── vip.png ├── src ├── __pycache__ │ ├── early_stopping.cpython-36.pyc │ └── prompt_gen_module.cpython-36.pyc ├── early_stopping.py ├── conll_text2bio.py ├── prompt_gen_module.py └── train.py ├── requirements.txt ├── LICENSE ├── cfgs └── cb.yaml ├── LICENCE_dependencies.md └── README.md /logs/init.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ckpts/init.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/init.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/dummy.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /verbalizer_template/conll2003_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "label"} -------------------------------------------------------------------------------- /verbalizer_template/drop_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/mrqa_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/newsqa_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/race_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/record_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/semevalrc_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "label"} -------------------------------------------------------------------------------- /verbalizer_template/squad_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/squadqa_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/tacred_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "label"} -------------------------------------------------------------------------------- /verbalizer_template/webnlg_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "label"} -------------------------------------------------------------------------------- /verbalizer_template/xsum_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/multitask_vb.txt: -------------------------------------------------------------------------------- 1 | {"meta": "answers"} -------------------------------------------------------------------------------- /verbalizer_template/boolq_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "A"} 2 | {"text": "B"} -------------------------------------------------------------------------------- /verbalizer_template/copa_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "A"} 2 | {"text": "B"} -------------------------------------------------------------------------------- /verbalizer_template/rte_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "A"} 2 | {"text": "B"} -------------------------------------------------------------------------------- /verbalizer_template/sst-2_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "0"} 2 | {"text": "1"} -------------------------------------------------------------------------------- /verbalizer_template/wic_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "A"} 2 | {"text": "B"} -------------------------------------------------------------------------------- /verbalizer_template/wsc_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "A"} 2 | {"text": "B"} -------------------------------------------------------------------------------- /verbalizer_template/imdb_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "bad"} 2 | {"text": "good"} -------------------------------------------------------------------------------- /verbalizer_template/multirc_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "A"} 2 | {"text": "B"} -------------------------------------------------------------------------------- /verbalizer_template/rte_vb_old.txt: -------------------------------------------------------------------------------- 1 | {"text": "A"} 2 | {"text": "B"} -------------------------------------------------------------------------------- /verbalizer_template/wsc_vb_old.txt: -------------------------------------------------------------------------------- 1 | {"text": "A"} 2 | {"text": "B"} -------------------------------------------------------------------------------- /vip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/VIP/HEAD/vip.png -------------------------------------------------------------------------------- /verbalizer_template/amazon_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "bad"} 2 | {"text": "good"} -------------------------------------------------------------------------------- /verbalizer_template/sst-2_tem.txt: -------------------------------------------------------------------------------- 1 | {"placeholder": "text_a"} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/amazon_tem.txt: -------------------------------------------------------------------------------- 1 | {"placeholder": "text_a"}. It is {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/copa_vb_old.txt: -------------------------------------------------------------------------------- 1 | {"meta":"choice1"} 2 | {"meta":"choice2"} -------------------------------------------------------------------------------- /verbalizer_template/webnlg_tem.txt: -------------------------------------------------------------------------------- 1 | {"meta": "sentence", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/conll2003_tem.txt: -------------------------------------------------------------------------------- 1 | {"meta": "sentence", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/wsc_vb_old2.txt: -------------------------------------------------------------------------------- 1 | {"text": "Another word"} 2 | {"meta": "span1_text"} -------------------------------------------------------------------------------- /verbalizer_template/ag_news_tem.txt: -------------------------------------------------------------------------------- 1 | {"placeholder": "text_a"} {"placeholder": "text_b"} {"mask"} -------------------------------------------------------------------------------- /verbalizer_template/tacred_tem.txt: -------------------------------------------------------------------------------- 1 | sentence: {"meta": "sentence", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/xsum_tem.txt: -------------------------------------------------------------------------------- 1 | context: {"meta": "passage", "shortenable":True}. {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/semevalrc_tem.txt: -------------------------------------------------------------------------------- 1 | sentence: {"meta": "sentence", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/anli_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "entailment"} 2 | {"text": "neutral"} 3 | {"text": "contradiction"} -------------------------------------------------------------------------------- /verbalizer_template/cb_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "entailment"} 2 | {"text": "contradiction"} 3 | {"text": "neutral"} -------------------------------------------------------------------------------- /verbalizer_template/mnli_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "contradiction"} 2 | {"text": "entailment"} 3 | {"text": "neutral"} -------------------------------------------------------------------------------- /verbalizer_template/multitask_tem.txt: -------------------------------------------------------------------------------- 1 | {"meta":"task"}: {"meta": "text", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/rte_tem.txt: -------------------------------------------------------------------------------- 1 | sentence1: {"placeholder":"text_a"} sentence2: {"placeholder":"text_b"} {"mask"} -------------------------------------------------------------------------------- /verbalizer_template/drop_tem.txt: -------------------------------------------------------------------------------- 1 | query: {"meta":"query"} context: {"meta": "passage", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/mrqa_tem.txt: -------------------------------------------------------------------------------- 1 | query: {"meta":"query"} context: {"meta": "passage", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/imdb_tem.txt: -------------------------------------------------------------------------------- 1 | {"placeholder": "text_a", "shortenable":True}. In summary, the movie was {"mask"}. 2 | -------------------------------------------------------------------------------- /verbalizer_template/newsqa_tem.txt: -------------------------------------------------------------------------------- 1 | query: {"meta":"query"} context: {"meta": "passage", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/squad_tem.txt: -------------------------------------------------------------------------------- 1 | query: {"meta":"query"} context: {"meta": "passage", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/squadqa_tem.txt: -------------------------------------------------------------------------------- 1 | query: {"meta":"query"} context: {"meta": "passage", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/wsc_tem.txt: -------------------------------------------------------------------------------- 1 | {"placeholder":"text_a"} "{"meta":"span2_text"}" refers to "{"meta":"span1_text"}" {"mask"} -------------------------------------------------------------------------------- /verbalizer_template/ag_news_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "world"} 2 | {"text": "sports"} 3 | {"text": "business"} 4 | {"text": "technology"} -------------------------------------------------------------------------------- /verbalizer_template/dbpedia_tem.txt: -------------------------------------------------------------------------------- 1 | {"placeholder": "text_a"} {"placeholder": "text_b"} {"placeholder": "text_a"} is a {"mask"}. 2 | -------------------------------------------------------------------------------- /src/__pycache__/early_stopping.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/VIP/HEAD/src/__pycache__/early_stopping.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/prompt_gen_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/VIP/HEAD/src/__pycache__/prompt_gen_module.cpython-36.pyc -------------------------------------------------------------------------------- /verbalizer_template/cb_tem.txt: -------------------------------------------------------------------------------- 1 | hypothesis: {"placeholder":"text_b","post_processing": lambda x:x+"."} premise: {"placeholder":"text_a"} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/anli_tem.txt: -------------------------------------------------------------------------------- 1 | hypothesis: {"placeholder":"text_b","post_processing": lambda x:x+"."} premise: {"placeholder":"text_a"} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/mnli_tem.txt: -------------------------------------------------------------------------------- 1 | hypothesis: {"placeholder":"text_b","post_processing": lambda x:x+"."} premise: {"placeholder":"text_a"} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/wic_tem.txt: -------------------------------------------------------------------------------- 1 | sentence1: {"placeholder":"text_a"} sentence2: {"placeholder":"text_b"} word: {"meta":"word", "shortenable": False} {"mask"} -------------------------------------------------------------------------------- /verbalizer_template/copa_tem.txt: -------------------------------------------------------------------------------- 1 | choice1: {"meta":"choice1"} choice2: {"meta":"choice2"} premise: {"placeholder":"text_a"} question: {"meta":"question"} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/race_tem.txt: -------------------------------------------------------------------------------- 1 | query: {"meta":"query"} context: {"meta": "passage", "shortenable":True} entities: {"meta":"entities", "shortenable":False} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/record_tem.txt: -------------------------------------------------------------------------------- 1 | query: {"meta":"query"} context: {"meta": "passage", "shortenable":True} entities: {"meta":"entities", "shortenable":True} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/boolq_tem.txt: -------------------------------------------------------------------------------- 1 | hypothesis: {"placeholder":"text_b", "shortenable":False, "post_processing": lambda x:x+"."} premise: {"placeholder":"text_a"} {"mask"} 2 | -------------------------------------------------------------------------------- /verbalizer_template/multirc_tem.txt: -------------------------------------------------------------------------------- 1 | question: {"placeholder":"text_b", "shortenable":False} answer: {"meta":"answer", "shortenable":False, "post_processing": lambda x:x+"."} paragraph: {"placeholder":"text_a"} {"mask"} -------------------------------------------------------------------------------- /verbalizer_template/dbpedia_vb.txt: -------------------------------------------------------------------------------- 1 | {"text": "company"} 2 | {"text": "school"} 3 | {"text": "artist"} 4 | {"text": "athlete"} 5 | {"text": "politics"} 6 | {"text": "transportation"} 7 | {"text": "building"} 8 | {"text": "river"} 9 | {"text": "village"} 10 | {"text": "animal"} 11 | {"text": "plant"} 12 | {"text": "album"} 13 | {"text": "film"} 14 | {"text": "book"} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 2 | numpy==1.19.2 3 | scikit-learn==0.24.2 4 | tqdm==4.62.3 5 | transformers==4.14.1 6 | rouge==1.0.1 7 | seqeval 8 | sentencepiece 9 | sacrebleu==2.0.0 10 | datasets==1.17.0 11 | pandas==1.1.5 12 | NERDA==1.0.0 13 | yacs 14 | tensorboardX 15 | dill 16 | pyarrow 17 | nltk 18 | scipy 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Deep Cognition and Language Research (DeCLaRe) Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cfgs/cb.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset_name: "cb" 3 | metric_set: 4 | - "ACC" 5 | - "Classification-F1" 6 | label_tokens: 'entailment contradiction neutral' 7 | max_seq_l: 480 8 | data_dir: "../datasets/" 9 | dataset_decoder_max_length: 10 10 | data_processor: "super_glue.cb" 11 | model: 12 | model: "t5-lm" 13 | model_name_or_path: "google/t5-base-lm-adapt" 14 | use_cuda: True 15 | model_parallelize: True 16 | tune_plm: False 17 | plm_eval_mode: True 18 | template: "../verbalizer_template/cb_tem.txt" 19 | template_id: 0 20 | verbalizer: "../verbalizer_template/cb_vb.txt" 21 | verbalizer_choice: 0 22 | test: 23 | eval_on_test: False 24 | batch_size: 32 25 | shuffle_data: false 26 | prompt: 27 | num_soft_tokens: 0 28 | num_cq_tokens: 100 29 | init_from_vocab: True 30 | CQ: 31 | temp: 100 32 | num_codebook_samples: 10 33 | commitment_cost: 0.1 34 | identifier: "D32L2H4F64" #just the name that gets appended to ckpts file as identifier 35 | train: 36 | seed: 100 37 | batch_size: 32 38 | num_training_steps: 30000 39 | shuffle_data: true 40 | lr_soft_prompt: 0.3 41 | lr_cq_prompt: 0.0001 42 | eval_every_steps: -1 #-1 denotes eval after every epoch 43 | num_codes: -1 #means 10*num_cq_prompts 44 | early_stop: 20 45 | optimizer: Adafactor 46 | result: 47 | result_path: "../results/cb.txt" 48 | 49 | 50 | -------------------------------------------------------------------------------- /LICENCE_dependencies.md: -------------------------------------------------------------------------------- 1 | | Name | Meta | Classifier | 2 | |---------------|-------------------------------------------|----------------------------------------------------| 3 | | NERDA | MIT license | OSI Approved::MIT License | 4 | | seqeval | MIT license | OSI Approved::MIT License | 5 | | tensorboardX | MIT license | OSI Approved::MIT License | 6 | | numpy | BSD 3-Clause "New" or "Revised" License | OSI Approved::BSD License | 7 | | scipy | BSD 3-Clause "New" or "Revised" License | OSI Approved::BSD License | 8 | | pandas | BSD 3-Clause "New" or "Revised" License | OSI Approved::BSD License | 9 | | scikit-learn | BSD 3-Clause "New" or "Revised" License | OSI Approved::BSD License | 10 | | dill | BSD 3-Clause "New" or "Revised" License | OSI Approved::BSD License | 11 | | transformers | Apache License 2.0 | OSI Approved::Apache Software License | 12 | | yacs | Apache License 2.0 | OSI Approved::Apache Software License | 13 | | pyarrow | Apache License 2.0 | OSI Approved::Apache Software License | 14 | | sentencepiece | Apache License 2.0 | OSI Approved::Apache Software License | 15 | | sacrebleu | Apache License 2.0 | OSI Approved::Apache Software License | 16 | | rouge | Apache License 2.0 | OSI Approved::Apache Software License | 17 | | nltk | Apache License 2.0 | OSI Approved::Apache Software License | 18 | | datasets | Apache License 2.0 | OSI Approved::Apache Software License | 19 | | openprompt | Apache License 2.0 | OSI Approved::Apache Software License | 20 | | tqdm | MPLv2.0, MIT Licences | OSI Approved::Mozilla Public License 2.0 (MPL 2.0) | 21 | -------------------------------------------------------------------------------- /src/early_stopping.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code has been adopted from (https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py) 3 | ''' 4 | 5 | import numpy as np 6 | import torch 7 | 8 | class EarlyStopping: 9 | """Early stops the training if validation loss doesn't improve after a given patience.""" 10 | def __init__(self, patience=7, verbose=False, delta=1e-5, path='checkpoint.pt', only_save_prompt_params=False, trace_func=print): 11 | """ 12 | Args: 13 | patience (int): How long to wait after last time validation loss improved. 14 | Default: 7 15 | verbose (bool): If True, prints a message for each validation loss improvement. 16 | Default: False 17 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 18 | Default: 0 19 | path (str): Path for the checkpoint to be saved to. 20 | Default: 'checkpoint.pt' 21 | trace_func (function): trace print function. 22 | Default: print 23 | """ 24 | self.patience = patience 25 | self.verbose = verbose 26 | self.counter = 0 27 | self.best_score = None 28 | self.early_stop = False 29 | self.val_loss_min = np.Inf 30 | self.delta = delta 31 | self.path = path 32 | self.trace_func = trace_func 33 | self.only_save_prompt_params = only_save_prompt_params 34 | def __call__(self, val_loss, glb_step, model): 35 | 36 | score = val_loss 37 | 38 | if self.best_score is None: 39 | self.best_score = score 40 | self.save_checkpoint(val_loss, model) 41 | elif score < self.best_score + self.delta: 42 | self.counter += 1 43 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 44 | self.trace_func(f'\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t[best accuracy uptil now: {self.best_score}] at step: {glb_step}') 45 | if self.counter >= self.patience: 46 | self.early_stop = True 47 | else: 48 | self.best_score = score 49 | self.save_checkpoint(val_loss, model) 50 | self.counter = 0 51 | 52 | def save_checkpoint(self, val_loss, model): 53 | '''Saves model when validation loss decrease.''' 54 | if self.verbose: 55 | self.trace_func(f'Validation acc increased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 56 | if self.only_save_prompt_params: 57 | torch.save(model.template.state_dict(), self.path) 58 | else: 59 | torch.save(model.state_dict(), self.path) 60 | self.val_loss_min = val_loss 61 | -------------------------------------------------------------------------------- /src/conll_text2bio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def getonebatchresult(sen,target,preds): 4 | #typedic = {"org": "ORG", "location": "LOC", "person": "PER", "mix": "MISC"} 5 | typedic = {"org": "ORG", "money": "MONEY", "country": "GPE", "time": "TIME", "law": "LAW", "fact": "FAC", 6 | "thing": "EVENT", "measure": "QUANTITY", 7 | "order": "ORDINAL", "art": "WORK_OF_ART", "location": "LOC", "language": "LANGUAGE", "person": "PER", 8 | "product": "PRODUCT", "num": "CARDINAL", "national": "NORP", "date": "DATE", "per": "PERCENT", "mix": "MISC"} 9 | sennum = len(sen) 10 | restar = [] 11 | respred = [] 12 | for i in range(sennum): 13 | thissen, thistar, thispred = sen[i], target[i], preds[i] 14 | 15 | thissenlow = thissen.lower() 16 | 17 | sensplit = thissen.split(' ') 18 | sensplitlow = thissenlow.split(' ') 19 | 20 | tarres = ['O' for j in range(len(sensplit))] 21 | predres = ['O' for j in range(len(sensplit))] 22 | 23 | if thistar == 'end' and thispred == 'end': 24 | restar.append(tarres) 25 | respred.append(predres) 26 | continue 27 | 28 | if len(thistar) > 0 and thistar[-1] == ';': 29 | thistar = thistar[:-1] 30 | 31 | tarsplit1 = thistar.split(';') 32 | 33 | if thistar != 'end': 34 | for j in range(len(tarsplit1)): 35 | tarsplit2 = tarsplit1[j].split('!') 36 | if len(tarsplit2) != 2: 37 | continue 38 | entity = tarsplit2[0].strip(' ') 39 | entitylow = entity.lower() 40 | type = tarsplit2[1].strip(' ') 41 | if type not in typedic: 42 | continue 43 | if thissenlow.find(entitylow) == -1: 44 | continue 45 | trueindex = -100 46 | entitysplit = entitylow.split(' ') 47 | for k in range(len(sensplit)): 48 | if sensplitlow[k] == entitysplit[0] or entitysplit[0] in sensplitlow[k]: 49 | iftrue = True 50 | for l in range(1, len(entitysplit)): 51 | if sensplitlow[k + l] != entitysplit[l] and (entitysplit[0] not in sensplitlow[k]): 52 | iftrue = False 53 | break 54 | if iftrue: 55 | trueindex = k 56 | break 57 | if trueindex == -100: 58 | continue 59 | for k in range(trueindex, trueindex + len(entitysplit)): 60 | if k == trueindex: 61 | tarres[k] = 'B-' + typedic[type] 62 | else: 63 | tarres[k] = 'I-' + typedic[type] 64 | 65 | if len(thispred) > 0 and thispred[-1] == ';': 66 | thispred = thispred[:-1] 67 | 68 | tarsplit3 = thispred.split(';') 69 | 70 | if thispred != "end": 71 | for j in range(len(tarsplit3)): 72 | tarsplit4 = tarsplit3[j].split('!') 73 | if len(tarsplit4) != 2: 74 | continue 75 | entity = tarsplit4[0].strip(' ') 76 | entitylow = entity.lower() 77 | type = tarsplit4[1].strip(' ') 78 | if type not in typedic: 79 | continue 80 | if thissenlow.find(entitylow) == -1: 81 | continue 82 | trueindex = -100 83 | entitysplit = entitylow.split(' ') 84 | for k in range(len(sensplit)): 85 | if sensplitlow[k] == entitysplit[0] or entitysplit[0] in sensplitlow[k]: 86 | iftrue = True 87 | for l in range(1, len(entitysplit)): 88 | if sensplitlow[k + l] != entitysplit[l] and (entitysplit[0] not in sensplitlow[k]): 89 | iftrue = False 90 | break 91 | if iftrue: 92 | trueindex = k 93 | break 94 | if trueindex == -100: 95 | continue 96 | else: 97 | for k in range(trueindex, trueindex + len(entitysplit)): 98 | if k == trueindex: 99 | predres[k] = 'B-' + typedic[type] 100 | else: 101 | predres[k] = 'I-' + typedic[type] 102 | restar.append(tarres) 103 | respred.append(predres) 104 | return restar, respred -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VIP-Based Prompting for Parameter-Efficient Learning 2 | 3 | This repository contains code of the paper 4 | 5 | `Vector-Quantized Input-Contextualized Soft Prompts for Natural Language Understanding` (EMNLP 2022) 6 | 7 | _link to the draft:_ [here](https://arxiv.org/abs/2205.11024) 8 | 9 | 10 | ## Motivation 11 | 12 | Prompt Tuning has been largely successful as a parameter-efficient method of conditioning large-scale pre-trained language models to perform downstream tasks. Thus far, soft prompt tuning learns a fixed set of task-specific continuous vectors, i.e., soft tokens that remain static across the task samples. A fixed prompt, however, may not generalize well to the diverse kinds of inputs the task comprises. In order to address this, we propose Vector-quantized Input-contextualized Prompts (VIP) as an extension to the soft prompt tuning framework. VIP particularly focuses on two aspects---contextual prompts that learns input-specific contextualization of the soft prompt tokens through a small-scale sentence encoder and quantized prompts that maps the contextualized prompts to a set of learnable codebook vectors through a Vector quantization network. On various language understanding tasks like SuperGLUE, QA, Relation classification, NER and NLI, VIP outperforms the soft prompt tuning (PT) baseline by an average margin of 1.19%. Further, our generalization studies show that VIP learns more robust prompt representations, surpassing PT by a margin of 0.6% - 5.3% on Out-of-domain QA and NLI tasks respectively, and by 0.75% on Multi-Task setup over 4 tasks spanning across 12 domains. 13 | 14 | - [ ] Here is the architecture summarizing approach- 15 | 16 | ![alt text](./vip.png) 17 | 18 | ## Installation 19 | 20 | We will first install torch version and then we will install other dependencies specific to this repository as well as [OpenPrompt](https://raw.githubusercontent.com/thunlp/OpenPrompt/main) repository. 21 | 22 | - Torch 23 | 24 | ``` 25 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 26 | ``` 27 | 28 | - Other dependencies 29 | ``` 30 | pip install -r ./requirements.txt 31 | ``` 32 | 33 | ## Quick check of setup 34 | Let's check if everything is working fine by running an example script on Commitment Bank (CB). 35 | 36 | Run the following command inside src/ 37 | 38 | ```console 39 | python train.py --cfg '../cfgs/cb.yaml' 40 | ``` 41 | 42 | ## How to run 43 | If the above command runs fine, let's see how to specify datasets and arguments. 44 | 45 | ```console 46 | python train.py [mandatory: --cfg ] [optional: --parameter1 --parameter2 ...] 47 | ``` 48 | 49 | - [x] Configuration: It is **mandatory** to specify path to the config file cfg. It sets up the prompting system including plm, number of soft and cq prompts, etc. We have put all the config files in the directory cfg/. For reference, please look at 'cfgs/cb.yaml'. 50 | 51 | 52 | - [x] Arguments: It is **optional** to specify the following parameters as arguments. For instance, 53 | 54 | Note: Arguments overwrite parameters specified by config file. 55 | 56 | ```console 57 | python train.py --cfg cfgs/cb.yaml --num_soft_tokens 0 --num_cq_tokens 100 58 | ``` 59 | 60 | #### Parameters 61 | * cfg: configuration file 62 | * model: t5-lm (model class) 63 | * model_name_or_path: google/t5-base-lm-adapt (pretrained models) 64 | * model_parallelize: True if want to parallelize model across GPUs 65 | * tune_plm: False of want to keep the plm frozen 66 | * plm_eval_mode: turn-off the drop-out in the frozen model 67 | * num_soft_tokens: number of soft tokens 68 | * num_cq_tokens: number of quantized tokens 69 | * init_from_vocab: if prompt tokens to be initialized from vocab 70 | * template: path to template file, for instance, "../verbalizer_template/cb_tem.txt" 71 | * verbalizer: path to verbalizer file, for instance, "../verbalizer_template/cb_vb.txt" 72 | * dataset: name of the dataset (used to download data from huggingface) 73 | * eval_on_test: test set is used for evaluation, False will evaluate on validation set 74 | * batch_tr: training batch 75 | * batch_te: batch for evaluation 76 | * lr_soft_prompt: learning rate for the standard soft prompts 77 | * lr_cq_prompt: learning rate for the quantized prompt setup, i.e., CQ module. 78 | * eval_every_steps: evaluate after these many steps. -1 denotes epoch-wise evaluation.lr_cq_prompt 79 | * max_steps: maximum number of gradient steps to run the training 80 | * num_codes: number of codebook vectors. -1 value denotes num_codes=10 * num_cq_tokens 81 | * temp: temperature parameter to normalize distances of codebook vecs from contextual prompts 82 | * num_codebook_samples: number of multinomial samples of codebook vecs. 83 | * commitment_cost: commitment of sentence encoder's output to codebook vectors. 84 | * identifier: identifier of the particular setting. It will important while saving model checkpoints. 85 | * result_path: path where results are saves, for instance, ../results/cb.txt 86 | * use_cuda: True 87 | 88 | 89 | ## Supported tasks and benchmark datasets 90 | - SuperGLUE 91 | - MRQA 92 | - ANLI 93 | - CoNLL 94 | - TACRED 95 | - SemEval 96 | 97 | 98 | 99 | ## CQ module 100 | _[Please use this to refer to/modify source code of CQ module]_ 101 | - The function for CQ module is written in `src/prompt_gen_module.py` 102 | - CQ is called by soft_template in `OpenPrompt/openprompt/prompts/soft_template.py` 103 | 104 | 105 | ## Citation 106 | ```bibtex 107 | @article{vip2022prompts, 108 | title={Vector-Quantized Input-Contextualized Soft Prompts for Natural Language Understanding}, 109 | author={Rishabh Bhardwaj, Amrita Saha, Steven C.H. Hoi, and Soujanya Poria} 110 | conference={EMNLP}, 111 | year={2022} 112 | } 113 | ``` 114 | 115 | **Note**: Please cite our paper if you find this repository useful. The latest version is available [here](https://arxiv.org/abs/2205.11024). 116 | 117 | -------------------------------------------------------------------------------- /src/prompt_gen_module.py: -------------------------------------------------------------------------------- 1 | #add the absolute path of openprompt library 2 | import sys, os 3 | import math 4 | 5 | sys.path.append('../OpenPrompt/') 6 | 7 | from torch.utils.data.sampler import RandomSampler 8 | from transformers.configuration_utils import PretrainedConfig 9 | from transformers.generation_utils import GenerationMixin 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import Dataset 13 | from typing import * 14 | from transformers.tokenization_utils import PreTrainedTokenizer 15 | from transformers.utils.dummy_pt_objects import PreTrainedModel 16 | import numpy as np 17 | from torch.utils.data import DataLoader 18 | from transformers import AdamW, get_linear_schedule_with_warmup 19 | import torch.nn.functional as F 20 | from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer 21 | from transformers import T5Tokenizer, T5EncoderModel 22 | from torch.distributions.multinomial import Multinomial 23 | 24 | 25 | class PromptGenerator(nn.Module): 26 | ''' 27 | CQ-module 28 | ''' 29 | def __init__(self, 30 | plm: PreTrainedModel, 31 | num_codes: int, 32 | num_cq_tokens: int, 33 | num_samples: int, 34 | temp=float, 35 | padding_idx = None, 36 | commitment_cost = float, 37 | ema_decay=0.99, 38 | epsilon=1e-2, 39 | centroid_warm_up=False, 40 | calc_k_means=False 41 | ): 42 | 43 | super(PromptGenerator, self).__init__() 44 | 45 | self.raw_embedding = plm.get_input_embeddings() 46 | self.embedding_dim = self.raw_embedding.weight.size(1) 47 | self.num_cq_tokens = num_cq_tokens 48 | self.num_codes = num_codes 49 | self.cls_embeds = None 50 | 51 | self.d_model = self.embedding_dim 52 | self.codebook_size = num_codes 53 | self.padding_idx = padding_idx 54 | 55 | #average perplexity 56 | self.smooth_avg_perplexity = [] 57 | 58 | #transformer encoder 59 | self.fc_in = nn.Linear(768, 32, bias=True) 60 | self.fc_out = nn.Linear(32, 768, bias=True) 61 | self.encoder_layer = TransformerEncoderLayer(d_model=32, 62 | nhead=4, 63 | dim_feedforward=2*32) #2048 64 | self.sentence_encoder = TransformerEncoder(self.encoder_layer, num_layers=2) 65 | 66 | self.codebook = nn.Embedding.from_pretrained(self.raw_embedding.weight[self.num_cq_tokens:self.num_cq_tokens+self.codebook_size].clone().detach(), freeze=False) 67 | self.codebook.weight.data = 100*torch.nn.functional.normalize(self.codebook.weight.data, dim=-1) 68 | 69 | if padding_idx is not None: 70 | self.codebook.weight.data[padding_idx] = 0 71 | 72 | self.commitment_cost = commitment_cost 73 | self.temp = temp 74 | self.num_samples = num_samples 75 | 76 | self.register_buffer('_ema_cluster_size', torch.ones(self.codebook_size)/self.num_codes) 77 | 78 | self._decay = ema_decay 79 | self._epsilon = epsilon 80 | self.discard_ema_cluster_sizes = False 81 | 82 | self.loss_aux = torch.tensor(0.0) 83 | self.noise_contrastive_loss = False 84 | self.dedicated_codebook = False 85 | 86 | def forward(self, xcq, commitment_cost=None, attn_mask=None, temp=None): 87 | device = xcq.device 88 | 89 | #encode the input 90 | batch_size = xcq.size(0) 91 | 92 | #input text 93 | xc_down = self.fc_in(xcq) 94 | xc_down = xc_down.transpose(0,1) 95 | 96 | attn_mask_se = torch.cat([torch.ones(batch_size, self.num_cq_tokens).to(device=xcq.device), attn_mask], dim=-1).clone().detach() 97 | attn_mask_se = attn_mask_se.bool() 98 | 99 | xc_up = self.sentence_encoder( src=xc_down , src_key_padding_mask = ~attn_mask_se ).transpose(0,1) 100 | xc_out = self.fc_out(xc_up[:,:self.num_cq_tokens,:]) 101 | 102 | #quantizer 103 | if commitment_cost is None: #a hyperparameter need to tune 104 | commitment_cost = self.commitment_cost 105 | 106 | if temp is None: 107 | temp = self.temp 108 | 109 | xc_out_shape = xc_out.size() 110 | xc_out_dims = xc_out.dim() 111 | 112 | # Flatten input 113 | flat_xc_out = xc_out.reshape(-1, self.d_model) 114 | 115 | # calculate distances 116 | if self.dedicated_codebook: 117 | for d in range(self.num_cq_tokens): 118 | distances = distances.view(xc_out_shape[0], xc_out_dims[1], self.num_codes) 119 | distances[:, d, :] += 1e5 120 | distances[:, d, 10*d : 10*(d+1)] -= 1e5 121 | distances = distances.reshape(-1, self.num_codes) 122 | else: 123 | distances = (torch.sum(flat_xc_out**2, dim=1, keepdim=True) 124 | + torch.sum(self.codebook.weight**2, dim=1) 125 | - 2 * torch.matmul(flat_xc_out, self.codebook.weight.t())) 126 | 127 | # Define multinomial distribution and sample from it 128 | multi = Multinomial(total_count=self.num_samples, logits=(-distances-1e-5)/temp) 129 | samples = multi.sample().to(device) 130 | 131 | # Soft-quantize and unflatten 132 | xc_quantized = torch.matmul(samples, self.codebook.weight).view(xc_out_shape) / self.num_samples 133 | 134 | # Loss 135 | e_latent_loss = torch.mean((xc_quantized.detach() - xc_out)**2) 136 | loss = commitment_cost * e_latent_loss 137 | 138 | # Use EMA to update the embedding vectors 139 | if self.training: 140 | if self.discard_ema_cluster_sizes: 141 | self._ema_cluster_size = torch.sum(samples, 0) / self.num_samples 142 | self.discard_ema_cluster_sizes = False 143 | else: 144 | self._ema_cluster_size = self._ema_cluster_size * self._decay + \ 145 | (1 - self._decay) * \ 146 | (torch.sum(samples, 0) / self.num_samples) 147 | 148 | # Laplace smoothing of the cluster size 149 | n = torch.sum(self._ema_cluster_size.data) 150 | self._ema_cluster_size = ( 151 | (self._ema_cluster_size + self._epsilon) 152 | / (n + self.codebook_size * self._epsilon) * n) 153 | 154 | dw = torch.matmul(samples.t(), flat_xc_out) / self.num_samples 155 | normalized_ema_w = self.codebook.weight * self._decay + (1 - self._decay) * (dw/self._ema_cluster_size.unsqueeze(1)) #option-1 156 | 157 | if self.padding_idx is not None: 158 | normalized_ema_w[self.padding_idx] = 0 159 | self.codebook.weight = nn.Parameter(normalized_ema_w) 160 | 161 | xc_quantized = xc_out + (xc_quantized - xc_out).detach() 162 | avg_probs = torch.mean(samples, dim=0) / self.num_samples 163 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))# e^(10*log_e(10))) 164 | 165 | samples = samples.reshape(list(xc_out_shape[:xc_out_dims - 1]) + [self.codebook_size]) 166 | 167 | #check the perplexity over the entire epoch 168 | self.smooth_avg_perplexity.append(perplexity.detach()) 169 | print("-->e_latent_loss", e_latent_loss.item()) 170 | print("-->perplexity:", (sum(self.smooth_avg_perplexity[-100:]).detach()/len(self.smooth_avg_perplexity[-100:])).item() ) 171 | self.loss_aux = loss 172 | 173 | if self.noise_contrastive_loss: 174 | xc_out2 = self.sentence_encoder( src=xc_down, src_key_padding_mask = ~attn_mask_se ).transpose(0,1) 175 | xc_out2 = self.fc_out(xc_out2[:,:self.num_cq_tokens,:]) 176 | dist = torch.cdist(xc_out, xc_out2, p=2) 177 | loss_contrastive = torch.softmax(-dist, dim=-1) 178 | loss_contrastive = -torch.diagonal(loss_contrastive, dim1=-2, dim2=-1) 179 | loss_contrastive = torch.mean(loss_contrastive) 180 | print("-->se loss contrastive:", loss_contrastive.item()) 181 | self.loss_aux += loss_contrastive 182 | 183 | return xc_quantized, xcq 184 | 185 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../OpenPrompt') 3 | 4 | import time 5 | import os 6 | import re 7 | import copy 8 | import argparse 9 | import itertools 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | import yaml 16 | 17 | import prompt_gen_module 18 | from sklearn.model_selection import train_test_split 19 | 20 | from openprompt.pipeline_base import PromptForGeneration 21 | from openprompt.prompts.generation_verbalizer import GenerationVerbalizer 22 | from openprompt.data_utils import PROCESSORS 23 | from openprompt.data_utils.utils import InputExample 24 | from openprompt import PromptDataLoader 25 | from openprompt.prompts import SoftTemplate 26 | 27 | from openprompt.utils.crossfit_metrics import evaluate as crossfit_evaluate 28 | from openprompt.utils.crossfit_metrics import METRICS 29 | 30 | 31 | ''' 32 | Model settings 33 | ''' 34 | parser_cfg = argparse.ArgumentParser("") 35 | 36 | parser_cfg.add_argument("--cfg", type=str) 37 | args_cfg, _ = parser_cfg.parse_known_args() 38 | 39 | with open(args_cfg.cfg) as file: 40 | cfg = yaml.load(file, Loader=yaml.FullLoader) 41 | 42 | parser = argparse.ArgumentParser("") 43 | parser.add_argument("--cfg", type=str) 44 | parser.add_argument("--seed", type=int, default=cfg['train']['seed']) 45 | parser.add_argument("--model", type=str, default=cfg['model']['model'], help="We test both t5 and t5-lm in this scripts, the corresponding tokenizerwrapper will be automatically loaded.") 46 | parser.add_argument("--model_name_or_path", default=cfg['model']['model_name_or_path']) 47 | parser.add_argument("--plm_eval_mode", type=bool, default=cfg['model']['plm_eval_mode'], help="whether to turn off the dropout in the freezed model.") 48 | parser.add_argument("--use_cuda", type=bool, default=cfg['model']['use_cuda']) 49 | parser.add_argument("--model_parallelize", type=bool, default=cfg['model']['model_parallelize']) 50 | parser.add_argument("--tune_plm", type=bool, default=cfg['model']['tune_plm']) 51 | parser.add_argument("--verbalizer", type=str, default=cfg['model']['verbalizer']) 52 | parser.add_argument("--template", type=str, default=cfg['model']['template']) 53 | parser.add_argument("--template_id", type=int, default=cfg['model']['template_id']) 54 | parser.add_argument("--data_dir", type=str, default=cfg['dataset']['data_dir']) # sometimes, huggingface datasets can not be automatically downloaded due to network issue, please refer to 0_basic.py line 15 for solutions. 55 | parser.add_argument("--dataset",type=str, default=cfg['dataset']['dataset_name']) 56 | parser.add_argument("--data_processor",type=str,default=cfg['dataset']['data_processor']) 57 | parser.add_argument("--max_steps", type=int, default=cfg['train']['num_training_steps']) 58 | parser.add_argument("--batch_tr", type=int, default=cfg['train']['batch_size']) 59 | parser.add_argument("--batch_te", type=int, default=cfg['test']['batch_size']) 60 | parser.add_argument("--eval_on_test", type=bool, default=cfg['test']['eval_on_test']) 61 | parser.add_argument("--lr_soft_prompt", type=float, default=cfg['train']['lr_soft_prompt']) 62 | parser.add_argument("--lr_cq_prompt", type=float, default=cfg['train']['lr_cq_prompt']) 63 | parser.add_argument("--eval_every_steps", type=int, default=cfg['train']['eval_every_steps']) 64 | parser.add_argument("--optimizer", type=str, default=cfg['train']['optimizer']) 65 | parser.add_argument("--num_codes", type=int, default=cfg['train']['num_codes']) 66 | parser.add_argument("--early_stop", type=int, default=cfg['train']['early_stop']) 67 | parser.add_argument("--num_soft_tokens", type=int, default=cfg['prompt']['num_soft_tokens']) 68 | parser.add_argument("--num_cq_tokens", type=int, default=cfg['prompt']['num_cq_tokens']) 69 | parser.add_argument("--init_from_vocab", type=bool, default=cfg['prompt']['init_from_vocab']) 70 | parser.add_argument("--result_path", type=str, default=cfg['result']['result_path']) 71 | parser.add_argument("--comment", type=str, default="") 72 | parser.add_argument("--temp", type=float, default=cfg['CQ']['temp']) 73 | parser.add_argument("--num_codebook_samples", type=int, default=cfg['CQ']['num_codebook_samples']) 74 | parser.add_argument("--commitment_cost", type=int, default=cfg['CQ']['commitment_cost']) 75 | parser.add_argument("--identifier", type=str, default=cfg['CQ']['identifier']) 76 | 77 | args = parser.parse_args() 78 | 79 | if args.num_codes == -1: 80 | args.num_codes = args.num_cq_tokens*10 81 | 82 | content_write = "" 83 | content_write += f"config file: {args.cfg}\t" 84 | content_write += f"seed: {args.seed}\t" 85 | content_write += f"model: {args.model}\t" 86 | content_write += f"model_name_or_path: {args.model_name_or_path}\t" 87 | content_write += f"use_cuda: {args.use_cuda}\t" 88 | content_write += f"model_parallelize: {args.model_parallelize}\t" 89 | content_write += f"plm_eval_mode: {args.plm_eval_mode}\t" 90 | content_write += f"tune_plm: {args.tune_plm}\t" 91 | content_write += f"verbalizer: {args.verbalizer}\t" 92 | content_write += f"init_from_vocab: {args.init_from_vocab}\t" 93 | content_write += f"eval_every_steps: {args.eval_every_steps}\t" 94 | content_write += f"lr_soft_prompt: {args.lr_soft_prompt}\t" 95 | content_write += f"lr_cq_prompt: {args.lr_cq_prompt}\t" 96 | content_write += f"batch tr: {args.batch_tr}\t" 97 | content_write += f"batch te: {args.batch_te}\t" 98 | content_write += f"optimizer: {args.optimizer}\t" 99 | content_write += f"num_soft_tokens: {args.num_soft_tokens}\t" 100 | content_write += f"num_cq_tokens: {args.num_cq_tokens}\t" 101 | content_write += f"codebook vectors: {args.num_codes}\t" 102 | content_write += f"early stopping patience: {args.early_stop}\t" 103 | content_write += f"number of codebook samples: {args.num_codebook_samples}\t" 104 | content_write += f"identifier: {args.identifier}\t" 105 | content_write += f"comment: {args.comment}" 106 | content_write += "\n" 107 | 108 | print("="*20) 109 | print("Configuration:") 110 | print(content_write.replace('\t','\n->')) 111 | 112 | 113 | #seed for reproduciblity 114 | from openprompt.utils.reproduciblity import set_seed 115 | set_seed(args.seed) 116 | 117 | 118 | 119 | ''' 120 | Initialize data-specifc items 121 | ''' 122 | dataset = {} 123 | Processor = PROCESSORS[args.data_processor] 124 | dataset['train'] = Processor().get_train_examples(args.data_dir) 125 | dataset['validation'] = Processor().get_dev_examples(args.data_dir) 126 | dataset['test'] = Processor().get_test_examples(args.data_dir) 127 | class_labels =Processor().get_labels() 128 | label_tokens = [cfg['dataset']['label_tokens']] 129 | max_seq_l = cfg['dataset']['max_seq_l'] 130 | metric_set = cfg['dataset']['metric_set'] 131 | dataset_decoder_max_length = cfg['dataset']['dataset_decoder_max_length'] 132 | 133 | print(f"\nTrain len:{len(dataset['train'])}; Valid len:{len(dataset['validation'])}; Test len:{len(dataset['test'])}") 134 | 135 | 136 | 137 | 138 | ''' 139 | Model 140 | ''' 141 | 142 | # use lm-adapted version or t5-v1.1 checkpoint. Note that the originial t5 checkpoint has been pretrained 143 | # on part of GLUE dataset, thus should not be used. 144 | from openprompt.plms.seq2seq import T5TokenizerWrapper, T5LMTokenizerWrapper 145 | from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration 146 | from openprompt.plms import load_plm 147 | 148 | 149 | # pre-trained LM such as T5, GPT, etc. 150 | plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path) 151 | 152 | 153 | prompt_generator = None 154 | if args.num_cq_tokens > 0: 155 | prompt_generator = prompt_gen_module.PromptGenerator 156 | 157 | 158 | # template 159 | args.template = os.path.normpath(os.path.join(os.getcwd(), args.template)) 160 | mytemplate = SoftTemplate(model=plm, 161 | tokenizer=tokenizer, 162 | num_soft_tokens=args.num_soft_tokens, 163 | initialize_from_vocab=args.init_from_vocab, 164 | label_tokens=label_tokens, 165 | num_cq_tokens=args.num_cq_tokens, 166 | prompt_generator=prompt_generator, 167 | task_tokens=[f"{args.dataset}"], 168 | num_codes=args.num_codes, 169 | temp = args.temp, 170 | commitment_cost= args.commitment_cost, 171 | num_codebook_samples=args.num_codebook_samples).from_file(args.template, choice=args.template_id) 172 | 173 | 174 | # verbalizer 175 | args.verbalizer = os.path.normpath(os.path.join(os.getcwd(), args.verbalizer)) 176 | myverbalizer = GenerationVerbalizer(tokenizer, classes=class_labels, is_rule=True).from_file(args.verbalizer) 177 | 178 | 179 | # prompt model: plug-in everything in a complete network 180 | prompt_model = PromptForGeneration(plm=plm,template=mytemplate, freeze_plm=(not args.tune_plm), plm_eval_mode=args.plm_eval_mode) 181 | 182 | 183 | # shift prompt model to gpu 184 | if args.use_cuda: 185 | prompt_model= prompt_model.cuda() 186 | 187 | if args.model_parallelize: 188 | prompt_model.parallelize() 189 | 190 | 191 | ''' 192 | data loaders 193 | ''' 194 | train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer, # be sure to add verbalizer 195 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=dataset_decoder_max_length, # be sure to use larger decoder_max_length for teacher forcing. 196 | batch_size=args.batch_tr,shuffle=True, teacher_forcing=True, predict_eos_token=True, # be sure to use teacher_forcing and predict_eos_token=True 197 | truncate_method="tail") 198 | 199 | 200 | validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer, 201 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=dataset_decoder_max_length, 202 | batch_size=args.batch_te,shuffle=False, teacher_forcing=False, predict_eos_token=False, # predict_eos_token=True or False are both ok 203 | truncate_method="tail") 204 | 205 | 206 | if args.eval_on_test: #false for SuperGLUE datasets 207 | test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer, 208 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=dataset_decoder_max_length, 209 | batch_size=args.batch_te,shuffle=False, teacher_forcing=False, predict_eos_token=False, # predict_eos_token=True or False are both ok 210 | truncate_method="tail") 211 | 212 | 213 | generation_arguments = { 214 | "max_length": dataset_decoder_max_length, 215 | } 216 | 217 | 218 | 219 | 220 | 221 | ''' 222 | evaluation function 223 | ''' 224 | if args.eval_every_steps == -1: 225 | args.eval_every_steps = math.ceil(len(dataset["train"])/args.batch_tr) 226 | print(f"\n\n\n We will do epoch wise evaluation, i.e., at steps: {args.eval_every_steps}") 227 | else: 228 | print(f"\n\n\n We will evaluate at steps: {args.eval_every_steps}") 229 | 230 | def evaluate(prompt_model, dataloader, dataset, cluster_mode=False, phase='val'): 231 | 232 | prompt_model.eval() 233 | predictions = [] 234 | ground_truths = [] 235 | 236 | for step, inputs in enumerate(dataloader): 237 | if cluster_mode and step > 100: 238 | break 239 | if args.use_cuda: 240 | inputs = inputs.cuda() 241 | _, output_sentence = prompt_model.generate(inputs, **generation_arguments, verbose=False) 242 | 243 | predictions.extend(output_sentence) 244 | 245 | if ('meteor' in metric_set) or ('bleu' in metric_set): 246 | inputs['tgt_text'] = [k.split("-$$-") for k in inputs['tgt_text']] 247 | 248 | ground_truths.extend(inputs['tgt_text']) 249 | 250 | assert len(predictions)==len(ground_truths), (len(predictions), len(ground_truths)) 251 | 252 | predictions = [prediction.strip() for prediction in predictions] 253 | 254 | if args.dataset == 'conll2003': 255 | from conll_text2bio import getonebatchresult 256 | inp_sentences = [e.meta['sentence'] for e in dataset] 257 | print(f"predictions {predictions[0:1]}, \nground_truths {ground_truths[0:1]}") 258 | ground_truths, predictions = getonebatchresult(inp_sentences, ground_truths, predictions) 259 | # shown one example 260 | print(f"predictions {predictions[0:1]}, \nground_truths {ground_truths[0:1]}") 261 | 262 | elif ('meteor' not in metric_set) and ('bleu' not in metric_set): 263 | ground_truths = [ground_truth.strip() for ground_truth in ground_truths] 264 | # shown one example 265 | print(f"predictions {predictions[0]}, ground_truths {ground_truths[0]}") 266 | 267 | else: 268 | # shown one example 269 | print(f"predictions {predictions[0]}, ground_truths {ground_truths[0]}") 270 | 271 | scores = dict() 272 | for metric in metric_set: 273 | score = crossfit_evaluate(predictions, ground_truths, metric=metric) 274 | if args.dataset == 'conll2003': 275 | score = score['overall_f1'] 276 | scores[metric] = score 277 | else: 278 | scores[metric] = score 279 | 280 | torch.cuda.empty_cache() 281 | return scores 282 | 283 | 284 | 285 | ''' 286 | optimizer 287 | 288 | [Note:when lr is 0.3 with adafactor, it is the same as the configuration of https://arxiv.org/abs/2104.08691] 289 | ''' 290 | 291 | from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup 292 | from transformers.optimization import Adafactor, AdafactorSchedule 293 | 294 | tot_step = args.max_steps 295 | 296 | optimizer_soft = None 297 | optimizer_cq = None 298 | 299 | 300 | if args.optimizer.lower() == "adafactor": 301 | soft_optimizer_parameters = [{'params': [p for name, p in prompt_model.template.named_parameters() if (('raw_embedding' not in name) and ('PromptGen' not in name))]}] # remove the raw_embedding manually from the optimization 302 | optimizer_soft = Adafactor(soft_optimizer_parameters, 303 | lr=args.lr_soft_prompt, 304 | relative_step=False, 305 | scale_parameter=False, 306 | warmup_init=False) 307 | scheduler_soft = get_constant_schedule_with_warmup(optimizer_soft, num_warmup_steps=0) # when num_warmup_steps is 0, it is the same as the configuration of https://arxiv.org/abs/2104.08691 308 | 309 | if args.num_cq_tokens > 0: 310 | cq_optimizer_parameters = [{'params': [p for name, p in prompt_model.template.PromptGen.named_parameters() if (('raw_embedding' not in name))]}] 311 | 312 | optimizer_cq = Adafactor(cq_optimizer_parameters, 313 | lr=args.lr_cq_prompt, 314 | relative_step=False, 315 | scale_parameter=False, 316 | warmup_init=False) 317 | 318 | scheduler_cq = get_constant_schedule_with_warmup(optimizer_cq, num_warmup_steps=0) # when num_warmup_steps is 0, it is the same as the configuration of https://arxiv.org/abs/2104.08691 319 | 320 | 321 | 322 | 323 | ''' 324 | training 325 | 326 | ''' 327 | from early_stopping import EarlyStopping 328 | delta = 1e-5 329 | save_code_name = f"{args.dataset}_sd{args.seed}_s{args.num_soft_tokens}_v{args.num_cq_tokens}_lS{args.lr_soft_prompt}_lV{args.lr_cq_prompt}_bTr{args.batch_tr}_bTe{args.batch_te}_nC{args.num_codes}_nCS{args.num_codebook_samples}_mS{args.max_steps}_evS{args.eval_every_steps}_eS{args.early_stop}_t{args.temp}_cC{args.commitment_cost}_sE:{args.identifier}" 330 | early_stopping = EarlyStopping(patience=args.early_stop, delta=delta, verbose=True, path=f"../ckpts/{save_code_name}.ckpt", only_save_prompt_params=True) 331 | 332 | log_path = f"../logs/{save_code_name}.txt" 333 | mode = 'a' if os.path.exists(log_path) else 'w' 334 | 335 | with open(log_path, mode) as f: 336 | f.write("\n\n\n\n\n") 337 | f.write("="*50+"\n") 338 | f.write("Configuration:\n") 339 | f.write(content_write.replace('\t','\n->')+"\n") 340 | 341 | 342 | # variables to keep track of training process 343 | best_val_acc = 0 344 | 345 | best_val_score_dict = dict() 346 | loss_list = [] 347 | best_val_acc_list = [] 348 | 349 | glb_step = 1 350 | best_glb_step = 0 351 | 352 | val_step = 0 353 | best_val_step = 0 354 | 355 | tot_train_time = 0 356 | 357 | leave_training = False 358 | 359 | best_prompt_model = None 360 | 361 | # epochs 362 | for epoch in range(1000000): 363 | 364 | #print(f"Begin epoch {epoch}") 365 | 366 | for step, inputs in enumerate(train_dataloader): 367 | 368 | #number of backward pass 369 | glb_step += 1 370 | 371 | if args.use_cuda: 372 | inputs = inputs.cuda() 373 | 374 | tot_train_time -= time.time() 375 | 376 | prompt_model.train() 377 | 378 | 379 | print('\n\ngradient step: ', glb_step) 380 | 381 | loss = prompt_model(inputs) 382 | 383 | if args.num_cq_tokens > 0: 384 | print("-->Entropy loss:", loss.item()) 385 | print("-->Aux loss:", prompt_model.template.PromptGen.loss_aux.item()) 386 | loss += prompt_model.template.PromptGen.loss_aux 387 | print("-->Total loss:", loss.item()) 388 | print("Best val till now: ", best_val_acc, " at step: ", best_glb_step) 389 | 390 | loss.backward() 391 | 392 | if optimizer_soft is not None: 393 | optimizer_soft.step() 394 | optimizer_soft.zero_grad() 395 | scheduler_soft.step() 396 | 397 | if optimizer_cq is not None: 398 | optimizer_cq.step() 399 | optimizer_cq.zero_grad() 400 | scheduler_cq.step() 401 | 402 | tot_train_time += time.time() 403 | 404 | if glb_step % args.eval_every_steps == 0: 405 | 406 | #number of validations 407 | val_step += 1 408 | 409 | print('\n\n\n validating...') 410 | torch.cuda.empty_cache() 411 | val_score_dict = evaluate(prompt_model, validation_dataloader, dataset['validation']) 412 | val_acc = sum(val_score_dict.values())/len(val_score_dict) 413 | 414 | print(f"\n\t\t\t\t[val acc at step {glb_step}: {val_acc}]\n") 415 | 416 | if val_acc > best_val_acc + delta: 417 | best_val_acc_list.append(f"{glb_step}:{round(val_acc,4)}") 418 | best_glb_step = glb_step 419 | 420 | best_val_step = val_step 421 | best_val_acc = val_acc 422 | best_val_score_dict = val_score_dict 423 | best_prompt_model = copy.deepcopy(prompt_model) 424 | 425 | with open(log_path, 'a') as f: 426 | f.write(f"\t+ Entropy loss: {loss.item()}\n") 427 | if args.num_cq_tokens > 0: 428 | f.write(f"\t+ commitment_cost: {prompt_model.template.PromptGen.loss_aux.item()}\n") 429 | f.write(f"\t+ Val acc at step {glb_step}: {val_acc}\n") 430 | f.write(f"\t+ Best val till now: {best_val_acc} at step: {best_glb_step}\n\n") 431 | 432 | #early stopping 433 | early_stopping(val_acc, best_glb_step, prompt_model) 434 | 435 | if early_stopping.early_stop: 436 | leave_training = True 437 | print("Early stopping...") 438 | break 439 | 440 | if glb_step > args.max_steps: 441 | leave_training = True 442 | break 443 | 444 | if leave_training: 445 | break 446 | 447 | 448 | 449 | ''' 450 | testing 451 | 452 | ''' 453 | test_acc = 0 454 | val_chck = None 455 | test_score_dict = dict() 456 | val_score_dict = dict() 457 | if args.eval_on_test: 458 | print("testing...") 459 | del prompt_model 460 | best_prompt_model = best_prompt_model.cuda() 461 | best_prompt_model.parallelize() 462 | val_acc_to_check = evaluate(best_prompt_model, validation_dataloader, dataset['validation']) 463 | val_acc_to_check = sum(val_acc_to_check.values())/len(val_acc_to_check) 464 | test_score_dict = evaluate(best_prompt_model, test_dataloader, dataset['test'], phase='test') 465 | test_acc = sum(test_score_dict.values())/len(test_score_dict) 466 | print("Test accuracy:", test_acc) 467 | 468 | 469 | 470 | 471 | ''' 472 | save results 473 | 474 | ''' 475 | print('Best score is...') 476 | print(f"best train step: {best_glb_step} | best val step: {best_val_step} | Best Valid Acc: {best_val_acc:.3f} | Num Soft: {args.num_soft_tokens} | Num CQ: {args.num_cq_tokens}") 477 | print(f"saved in {args.result_path}") 478 | 479 | import os 480 | import datetime 481 | 482 | mode = 'a' if os.path.exists(args.result_path) else 'w' 483 | with open(args.result_path, mode) as f: 484 | if args.comment != "": 485 | f.write(f"---> data: {args.dataset} | date-time: {datetime.datetime.now()} | Num Soft: {args.num_soft_tokens} | Num CQ: {args.num_cq_tokens} | Num Codes: {args.num_codes} | Seed: {args.seed}| lr_soft: {args.lr_soft_prompt} | lr_cq: {args.lr_cq_prompt} | val_chck: {val_chck} | Comment: {args.comment}\n") 486 | else: 487 | f.write(f"---> data: {args.dataset} | date-time: {datetime.datetime.now()} | Num Soft: {args.num_soft_tokens} | Num CQ: {args.num_cq_tokens} | Num Codes: {args.num_codes} | Seed: {args.seed} | lr_soft: {args.lr_soft_prompt} | lr_cq: {args.lr_cq_prompt} | val_chck: {val_chck} |\n") 488 | f.write(f"best train step: {best_glb_step} | best val step: {best_val_step} | Val dict: {best_val_score_dict} | Test dict: {test_score_dict}\n") 489 | f.write(f"Best Valid Acc: {best_val_acc:.3f}% | Test Acc: {test_acc:.3f}%\n") 490 | f.write(f"Performance trajectory: {best_val_acc_list}\n") 491 | f.write(f"file_name: {save_code_name}\n\n") 492 | --------------------------------------------------------------------------------