├── data ├── README.md ├── train_labels_demo.json └── train_demo.json ├── commands ├── README.md ├── train.sh ├── train_t5.sh ├── train_t5_prompt.sh ├── train_bert_prompt.sh ├── train_t5_bitfit.sh └── train_bert_bitfit.sh ├── data_ultis.py ├── bert_model.py ├── t5_model.py ├── data_gain.py ├── utils.py ├── README.md ├── prompt_t5.py ├── test.py ├── prompt_test.py ├── prompt_bert.py ├── inference.py ├── prompt_train.py └── train.py /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | -------------------------------------------------------------------------------- /commands/README.md: -------------------------------------------------------------------------------- 1 | # Commands 2 | -------------------------------------------------------------------------------- /data/train_labels_demo.json: -------------------------------------------------------------------------------- 1 | [1, 1, 1, 1, 1] -------------------------------------------------------------------------------- /commands/train.sh: -------------------------------------------------------------------------------- 1 | cd ../ 2 | python train.py --train_data_path ./data/train.json \ 3 | --train_label_path ./data/train_labels.json \ 4 | --dev_data_path ./data/dev.json \ 5 | --dev_label_path ./data/dev_labels.json \ 6 | --model_name_or_path bert \ 7 | --output_dir ./checkpoint 8 | -------------------------------------------------------------------------------- /commands/train_t5.sh: -------------------------------------------------------------------------------- 1 | cd ../ 2 | python train.py --train_data_path ./data/train.json \ 3 | --train_label_path ./data/train_labels.json \ 4 | --dev_data_path ./data/dev.json \ 5 | --dev_label_path ./data/dev_labels.json \ 6 | --model_name_or_path t5 \ 7 | --output_dir ./checkpoint 8 | -------------------------------------------------------------------------------- /commands/train_t5_prompt.sh: -------------------------------------------------------------------------------- 1 | cd ../ 2 | python prompt_train.py --train_data_path ./data/train.json \ 3 | --train_label_path ./data/train_labels.json \ 4 | --dev_data_path ./data/dev.json \ 5 | --dev_label_path ./data/dev_labels.json \ 6 | --model_name_or_path t5 \ 7 | --output_dir ./checkpoint 8 | -------------------------------------------------------------------------------- /commands/train_bert_prompt.sh: -------------------------------------------------------------------------------- 1 | cd ../ 2 | python prompt_train.py --train_data_path ./data/train.json \ 3 | --train_label_path ./data/train_labels.json \ 4 | --dev_data_path ./data/dev.json \ 5 | --dev_label_path ./data/dev_labels.json \ 6 | --model_name_or_path bert \ 7 | --output_dir ./checkpoint 8 | -------------------------------------------------------------------------------- /commands/train_t5_bitfit.sh: -------------------------------------------------------------------------------- 1 | cd ../ 2 | python train.py --train_data_path ./data/train.json \ 3 | --train_label_path ./data/train_labels.json \ 4 | --dev_data_path ./data/dev.json \ 5 | --dev_label_path ./data/dev_labels.json \ 6 | --model_name_or_path t5 \ 7 | --output_dir ./checkpoint \ 8 | --bitfit 9 | -------------------------------------------------------------------------------- /commands/train_bert_bitfit.sh: -------------------------------------------------------------------------------- 1 | cd ../ 2 | python train.py --train_data_path ./data/train.json \ 3 | --train_label_path ./data/train_labels.json \ 4 | --dev_data_path ./data/dev.json \ 5 | --dev_label_path ./data/dev_labels.json \ 6 | --model_name_or_path bert \ 7 | --output_dir ./checkpoint \ 8 | --bitfit 9 | -------------------------------------------------------------------------------- /data_ultis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from torch.utils.data import Dataset 4 | 5 | class ScData(Dataset): 6 | def __init__(self, args, data_path, label_path, tokenizer): 7 | self.data = json.load(open(data_path, 'r')) 8 | self.label = json.load(open(label_path, 'r')) 9 | self.tokenizer = tokenizer 10 | self.max_len = args.max_len 11 | 12 | def __len__(self): 13 | return len(self.data) 14 | 15 | def __getitem__(self, index): 16 | input = self.tokenizer.encode_plus(self.data[index], max_length=self.max_len, padding='max_length', truncation=True) 17 | label = self.label[index] 18 | input_ids = torch.tensor(input.input_ids, dtype=torch.long) 19 | attention_mask = torch.tensor(input.attention_mask, dtype=torch.long) 20 | label = torch.tensor(label, dtype=torch.long) 21 | data_sample = (input_ids, attention_mask, label) 22 | return data_sample -------------------------------------------------------------------------------- /bert_model.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, RobertaModel 2 | from torch import nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class CLSModel(nn.Module): 8 | def __init__(self, args): 9 | super(CLSModel, self).__init__() 10 | self.encode_proj = nn.Linear(args.hidden_size, args.project_dim) 11 | self.model = BertModel.from_pretrained('bert-base-uncased') 12 | # self.model = RobertaModel.from_pretrained("roberta-base") 13 | 14 | def forward(self, input_ids, attention_mask, labels=None): 15 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 16 | pooler_output = outputs.pooler_output 17 | logits = self.encode_proj(pooler_output) 18 | probs = F.log_softmax(logits, -1) 19 | if labels != None: 20 | loss = F.nll_loss(probs, labels.to(probs.device), reduction='mean') 21 | return loss, logits 22 | else: 23 | return logits -------------------------------------------------------------------------------- /t5_model.py: -------------------------------------------------------------------------------- 1 | from transformers import T5ForConditionalGeneration 2 | from torch import nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class T5CLSModel(nn.Module): 8 | def __init__(self, args): 9 | super(T5CLSModel, self).__init__() 10 | # self.encode_proj = nn.Linear(args.hidden_size, args.project_dim) 11 | self.model = T5ForConditionalGeneration.from_pretrained('t5-base') 12 | 13 | def forward(self, input_ids, attention_mask, labels=None): 14 | batch_size = input_ids.shape[0] 15 | decoder_input_ids = torch.zeros(batch_size, 1, dtype=int).to(input_ids.device) 16 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, 17 | decoder_input_ids=decoder_input_ids, return_dict=True) 18 | logits = outputs['logits'][:, 0, [7163, 11213, 27635, 2971, 3922, 24784, 4158]] 19 | probs = F.log_softmax(logits, -1) 20 | if labels != None: 21 | loss = F.nll_loss(probs, labels.to(probs.device), reduction='mean') 22 | return loss, logits 23 | else: 24 | return logits 25 | -------------------------------------------------------------------------------- /data_gain.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | from tqdm import tqdm 4 | 5 | url = 'http://api.niutrans.com/NiuTransServer/translation?' 6 | 7 | sents = json.load(open('data/mydata/train.json')) 8 | lable = json.load(open('data/mydata/train_labels.json')) 9 | # print(a) 10 | new_train = [] 11 | new_train_label = [] 12 | for i in tqdm(range(len(sents))): 13 | if lable[i] == 0: 14 | new_train.append(sents[i]) 15 | new_train_label.append(lable[i]) 16 | else: 17 | new_train.append(sents[i]) 18 | new_train_label.append(lable[i]) 19 | data = {"from": 'en', "to": 'zh', "apikey": '', "src_text": sents[i]} 20 | res = requests.post(url, data=data).json() 21 | new_data = {"from": 'zh', "to": 'en', "apikey": '', "src_text": str(res['tgt_text'])} 22 | new_res = requests.post(url, data=new_data).json() 23 | new_sent = new_res['tgt_text'] 24 | new_train.append(new_sent) 25 | new_train_label.append(lable[i]) 26 | # json.dump(new_train, open('data/mydata/new_train.json', 'w')) 27 | # json.dump(new_train_label, open('data/mydata/new_train_labels.json', 'w')) 28 | # a = json.load(open('data/mydata/new_train.json', 'r')) 29 | # print(len(a)) 30 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn.metrics import accuracy_score, f1_score 7 | import torch 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | def set_seed(args): 12 | random.seed(args.seed) 13 | np.random.seed(args.seed) 14 | torch.manual_seed(args.seed) 15 | 16 | def get_metric(preds, golds): 17 | 18 | label_list = ['neutral', 'anger', 'disgust', 'fear', 'joy', 'sadness', 'surprise'] 19 | golds = [label_list[gold] for gold in golds] 20 | preds = [label_list[pred] for pred in preds] 21 | acc = accuracy_score(golds, preds) 22 | f1_macro = f1_score(golds, preds, average='macro') 23 | 24 | return {'acc': acc, 'f1': f1_macro} 25 | 26 | def eval_model(model, valid_dataloader): 27 | model.eval() 28 | predict_list = [] 29 | golden_list = [] 30 | with torch.no_grad(): 31 | for batch in tqdm(valid_dataloader): 32 | input_ids = batch[0].cuda() 33 | attention_mask = batch[1].cuda() 34 | labels = batch[2] 35 | outputs = model(input_ids, attention_mask) 36 | max_score, max_idxs = torch.max(outputs, 1) 37 | predict_idxs = max_idxs.view(-1).tolist() 38 | predict_list.extend(predict_idxs) 39 | golden_idxs = labels.view(-1).tolist() 40 | golden_list.extend(golden_idxs) 41 | evaluation_results = get_metric(predict_list, golden_list) 42 | return evaluation_results 43 | 44 | 45 | def set_env(args, run_type): 46 | if not os.path.exists(args.output_dir): 47 | os.mkdir(args.output_dir) 48 | handlers = [logging.FileHandler(os.path.abspath(args.output_dir) + '/' + run_type + '_log.txt'), logging.StreamHandler()] 49 | logging.basicConfig( 50 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 51 | datefmt="%m/%d/%Y %H:%M:%S", 52 | level=logging.INFO, 53 | handlers=handlers) 54 | set_seed(args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt Tuning For Sentiment Classification Base On Pre-trained Language Models 2 | 3 | ## **Code for the Internship at NEU-NLP Lab** 4 | 5 | ``` 6 | |--checkpoint/ 7 | |--data/ 8 | |------train_demo.json 9 | |------train_labels_demo.json 10 | |--commands/ 11 | |------train.sh 12 | |------train_t5.sh 13 | |------train_bert_bitfit.sh 14 | |------train_bert_prompt.sh 15 | |------train_t5_bitfit.sh 16 | |------train_t5_prompt.sh 17 | |--output/ 18 | |--bert_model.py 19 | |--data_gain.py 20 | |--data_ultis.py 21 | |--inference.py 22 | |--KB_prompt.py 23 | |--prompt_bert.py 24 | |--prompt_t5.py 25 | |--prompt_test.py 26 | |--prompt_train.py 27 | |--t5_model.py 28 | |--test.py 29 | |--train.py 30 | |--utils.py 31 | |--README.md 32 | ``` 33 | 34 | ## Data 35 | 36 | The data for this project is internal. The format of the data is as shown in the example in the `./data/train.json` and `./data/train_labels.json` . 37 | 38 | The dev set and the test set have the same format as the train set. 39 | 40 | Before you can run the code, you will need to process your own dataset into the following 6 files in `./data/`: 41 | 42 | ``` 43 | |--data/ 44 | |------train_demo.json 45 | |------train_labels_demo.json 46 | |------train.json 47 | |------train_labels.json 48 | |------dev.json 49 | |------dev_labels.json 50 | |------test.json 51 | |------test_labels.json 52 | ``` 53 | 54 | ## Experiment 55 | 56 | We explore the different performance of pre-trained language model(BERT and T5) among Full Fine-tuning, Bias-term Fine-tuning and Prompt-tuning. We regard ***accuracy*** as our main evaluation. 57 | 58 | | Method | ACC | 59 | | ------------------ | --------- | 60 | | BERT-fine-tuning | 85.66 | 61 | | BERT-BitFit | 83.41 | 62 | | BERT-hard-P-tuning | 84.11 | 63 | | BERT-soft-P-tuning | 85.09 | 64 | | T5-fine-tuning | **87.03** | 65 | | T5-BitFit | 83.55 | 66 | | T5-hard-P-tuning | 85.55 | 67 | | T5-soft-P-tuning | **86.48** | 68 | 69 | ## Commands 70 | 71 | ```shell 72 | cd ./commands 73 | ``` 74 | 75 | ### 1. Full fine tuning 76 | 77 | ```shell 78 | # For BERT 79 | bash train.sh 80 | # For T5 81 | bash train_t5.sh 82 | ``` 83 | 84 | ### 2. Bias-term Fine-tuning(BitFit) 85 | 86 | ```shell 87 | # For BERT 88 | bash train_bert_bitfit.sh 89 | # For T5 90 | bash train_t5_bitfit.sh 91 | ``` 92 | 93 | ### 3. Prompt-tuning 94 | 95 | ```shell 96 | # For BERT 97 | bash train_bert_prompt.sh 98 | # For T5 99 | bash train_t5_prompt.sh 100 | ``` 101 | 102 | ### 4. Inference for test 103 | 104 | ```shell 105 | # Fine tuning or BitFit 106 | python test.py --model_name bert/t5 107 | # prompt tuning 108 | python prompt_test.py --model_name bert/t5 109 | ``` 110 | 111 | ## Others 112 | 113 | checkpoints will be saved in ./checkpoints 114 | 115 | train and test logs will be saved in ./checkpoints 116 | 117 | If you have questions, suggestions, and bug reports, please email: 118 | 119 | ``` 120 | lvyuanhuiyi@foxmail.com 121 | ``` 122 | 123 | -------------------------------------------------------------------------------- /prompt_t5.py: -------------------------------------------------------------------------------- 1 | from transformers import T5ForConditionalGeneration, T5Config 2 | from torch import nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class PromptT5CLSModel(nn.Module): 8 | def __init__(self, args): 9 | super(PromptT5CLSModel, self).__init__() 10 | self.config = T5Config.from_pretrained('t5-base') 11 | self.t5 = T5ForConditionalGeneration.from_pretrained('t5-base') 12 | self.soft_embedding_layer = None 13 | self.normal_embedding_layer = self.t5.get_input_embeddings() 14 | # self.proj_linear = nn.Linear(22, 7) 15 | 16 | # self.prefix_soft_index, self.suffix_soft_index = [3, 27569, 10], [31484, 17, 10, 1] 17 | # self.prefix_soft_index, self.suffix_soft_index = [8, 6493, 13], [19, 1] 18 | self.prefix_soft_index, self.suffix_soft_index = [8, 6493, 13], [31484, 17, 10, 1] 19 | self.p_num, self.s_num = len(self.prefix_soft_index), len(self.suffix_soft_index) 20 | self.prefix_soft_embedding_layer = nn.Embedding( 21 | self.p_num, self.config.hidden_size 22 | ) 23 | self.suffix_soft_embedding_layer = nn.Embedding( 24 | self.s_num, self.config.hidden_size 25 | ) 26 | self.prefix_soft_embedding_layer.weight.data = torch.stack( 27 | [self.normal_embedding_layer.weight.data[i, :].clone().detach().requires_grad_(True) for i in 28 | self.prefix_soft_index] 29 | ) 30 | self.suffix_soft_embedding_layer.weight.data = torch.stack( 31 | [self.normal_embedding_layer.weight.data[i, :].clone().detach().requires_grad_(True) for i in 32 | self.suffix_soft_index] 33 | ) 34 | self.prefix_soft_ids = torch.tensor(range(self.p_num)) 35 | self.suffix_soft_ids = torch.tensor(range(self.s_num)) 36 | for param in self.t5.parameters(): 37 | param.requires_grad_(False) 38 | 39 | def forward(self, input_ids, attention_mask, labels=None): 40 | batch_size = input_ids.shape[0] 41 | decoder_input_ids = torch.zeros(batch_size, 1, dtype=int).to(input_ids.device) 42 | prefix_soft_ids = torch.stack([self.prefix_soft_ids for i in range(batch_size)]).to(input_ids.device) 43 | suffix_soft_ids = torch.stack([self.suffix_soft_ids for i in range(batch_size)]).to(input_ids.device) 44 | 45 | prefix_soft_embeddings = self.prefix_soft_embedding_layer(prefix_soft_ids) 46 | suffix_soft_embeddings = self.suffix_soft_embedding_layer(suffix_soft_ids) 47 | 48 | text_embeddings = self.normal_embedding_layer(input_ids) 49 | 50 | input_embeddings = torch.cat( 51 | [prefix_soft_embeddings, text_embeddings, suffix_soft_embeddings], 52 | dim=1 53 | ) 54 | 55 | prefix_soft_attention_mask = torch.ones(batch_size, self.p_num).to(input_ids.device) 56 | suffix_soft_attention_mask = torch.ones(batch_size, self.s_num).to(input_ids.device) 57 | attention_mask = torch.cat( 58 | [prefix_soft_attention_mask, attention_mask, suffix_soft_attention_mask], 59 | dim=1 60 | ) 61 | output = self.t5( 62 | inputs_embeds=input_embeddings, 63 | decoder_input_ids=decoder_input_ids, 64 | attention_mask=attention_mask, 65 | return_dict=True 66 | ) 67 | logits = output['logits'][:, 0, [7163, 11213, 27635, 2971, 3922, 24784, 4158]] 68 | # logits = output['logits'][:, 0, [7163, 16, 25880, 11213, 1080, 12603, 27635, 13006, 5591, 69 | # 2971, 7403, 6541, 15, 3922, 1095, 5010, 24784, 26887, 70 | # 10875, 4158, 12914, 7544]] 71 | # logits = self.proj_linear(logits) 72 | probs = F.log_softmax(logits, -1) 73 | if labels != None: 74 | loss = F.nll_loss(probs, labels.to(probs.device), reduction='mean') 75 | return loss, logits 76 | else: 77 | return logits 78 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, T5Tokenizer 2 | import os 3 | import argparse 4 | import logging 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from data_ultis import ScData 8 | from bert_model import CLSModel 9 | from t5_model import T5CLSModel 10 | from utils import set_env, eval_model 11 | from tqdm import tqdm 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def inference(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--model_name", 20 | default=None, 21 | type=str, 22 | required=True, 23 | ) 24 | parser.add_argument( 25 | "--test_data_path", 26 | default='./data/test.json', 27 | type=str, 28 | help="The testing data path.", 29 | ) 30 | parser.add_argument( 31 | "--test_label_path", 32 | default='./data/test_labels.json', 33 | type=str, 34 | help="The testing label path.", 35 | ) 36 | parser.add_argument( 37 | "--max_len", 38 | default=150, 39 | type=int, 40 | help="The maximum total input sequence length after tokenization.", 41 | ) 42 | parser.add_argument( 43 | "--test_batch_size", 44 | default=128, 45 | type=int, 46 | help="Batch size per GPU/CPU for training.", 47 | ) 48 | parser.add_argument( 49 | "--hidden_size", 50 | default=768, 51 | type=int, 52 | help="Hidden size.", 53 | ) 54 | parser.add_argument( 55 | "--project_dim", 56 | default=7, 57 | type=int, 58 | help="Project Dim.", 59 | ) 60 | parser.add_argument( 61 | "--bert_checkpoint_dir", 62 | default='./checkpoint/save_model_best.pt', 63 | type=str, 64 | help="the path of fine-tuned bert checkpoint.", 65 | ) 66 | parser.add_argument( 67 | "--t5_checkpoint_dir", 68 | default='./checkpoint/save_t5_model_best.pt', 69 | type=str, 70 | help="the path of fine-tuned t5 checkpoint.", 71 | ) 72 | parser.add_argument( 73 | "--output_dir", 74 | default='./checkpoint', 75 | type=str, 76 | help="The output directory where the model predictions and checkpoints will be written.", 77 | ) 78 | parser.add_argument( 79 | "--seed", 80 | type=int, 81 | default=1234, 82 | help="random seed for initialization", 83 | ) 84 | args = parser.parse_args() 85 | log_name = args.model_name + '_test' 86 | set_env(args, log_name) 87 | if args.model_name == 'bert': 88 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 89 | # tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 90 | model = CLSModel(args) 91 | model.load_state_dict(torch.load(args.bert_checkpoint_dir)['model']) 92 | model.cuda() 93 | else: 94 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 95 | model = T5CLSModel(args) 96 | # model.load_state_dict(torch.load(args.t5_checkpoint_dir)['model']) 97 | model.cuda() 98 | test_data = ScData(args, args.test_data_path, args.test_label_path, tokenizer) 99 | test_reader = DataLoader(dataset=test_data, num_workers=0, 100 | batch_size=args.test_batch_size, shuffle=False) 101 | 102 | res = [] 103 | with torch.no_grad(): 104 | for batch in tqdm(test_reader): 105 | input_ids = batch[0].cuda() 106 | attention_mask = batch[1].cuda() 107 | outputs = model(input_ids, attention_mask) 108 | max_score, max_idxs = torch.max(outputs, 1) 109 | predict_idxs = max_idxs.view(-1).tolist() 110 | res.extend(predict_idxs) 111 | # result = eval_model(model, test_reader) 112 | # logger.info('test acc: {0}, F1: {1}'.format(result['acc'], result['f1'])) 113 | print('length of test:', len(res)) 114 | with open('output/20195199_3.csv', 'w', encoding='utf-8') as fw: 115 | for item in res: 116 | fw.write(str(item)+'\n') 117 | print('saved!') 118 | 119 | 120 | if __name__ == "__main__": 121 | inference() 122 | -------------------------------------------------------------------------------- /prompt_test.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, T5Tokenizer, RobertaTokenizer 2 | import os 3 | import argparse 4 | import logging 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from data_ultis import ScData 8 | from prompt_bert import PromptBertModel 9 | from prompt_t5 import PromptT5CLSModel 10 | from utils import set_env, eval_model 11 | from tqdm import tqdm 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def inference(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--model_name", 20 | default=None, 21 | type=str, 22 | required=True, 23 | ) 24 | parser.add_argument( 25 | "--test_data_path", 26 | default='./data/test.json', 27 | type=str, 28 | help="The testing data path.", 29 | ) 30 | parser.add_argument( 31 | "--test_label_path", 32 | default='./data/test_labels.json', 33 | type=str, 34 | help="The testing label path.", 35 | ) 36 | parser.add_argument( 37 | "--max_len", 38 | default=150, 39 | type=int, 40 | help="The maximum total input sequence length after tokenization.", 41 | ) 42 | parser.add_argument( 43 | "--test_batch_size", 44 | default=128, 45 | type=int, 46 | help="Batch size per GPU/CPU for training.", 47 | ) 48 | parser.add_argument( 49 | "--hidden_size", 50 | default=768, 51 | type=int, 52 | help="Hidden size.", 53 | ) 54 | parser.add_argument( 55 | "--project_dim", 56 | default=7, 57 | type=int, 58 | help="Project Dim.", 59 | ) 60 | parser.add_argument( 61 | "--bert_checkpoint_dir", 62 | default='./checkpoint/save_prompt_bert_model_best.pt', 63 | type=str, 64 | help="the path of fine-tuned bert checkpoint.", 65 | ) 66 | parser.add_argument( 67 | "--t5_checkpoint_dir", 68 | default='./checkpoint/save_prompt_t5_model_best.pt', 69 | type=str, 70 | help="the path of fine-tuned t5 checkpoint.", 71 | ) 72 | parser.add_argument( 73 | "--output_dir", 74 | default='./checkpoint', 75 | type=str, 76 | help="The output directory where the model predictions and checkpoints will be written.", 77 | ) 78 | parser.add_argument( 79 | "--seed", 80 | type=int, 81 | default=1234, 82 | help="random seed for initialization", 83 | ) 84 | args = parser.parse_args() 85 | log_name = args.model_name + '_test_prompt' 86 | set_env(args, log_name) 87 | if args.model_name == 'bert': 88 | tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') 89 | # tokenizer = RobertaTokenizer.from_pretrained('roberta-large') 90 | model = PromptBertModel(args) 91 | model.load_state_dict(torch.load(args.bert_checkpoint_dir)['model']) 92 | model.cuda() 93 | else: 94 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 95 | model = PromptT5CLSModel(args) 96 | model.load_state_dict(torch.load(args.t5_checkpoint_dir)['model']) 97 | model.cuda() 98 | test_data = ScData(args, args.test_data_path, args.test_label_path, tokenizer) 99 | test_reader = DataLoader(dataset=test_data, num_workers=0, 100 | batch_size=args.test_batch_size, shuffle=False) 101 | 102 | res = [] 103 | with torch.no_grad(): 104 | for batch in tqdm(test_reader): 105 | input_ids = batch[0].cuda() 106 | attention_mask = batch[1].cuda() 107 | outputs = model(input_ids, attention_mask) 108 | max_score, max_idxs = torch.max(outputs, 1) 109 | predict_idxs = max_idxs.view(-1).tolist() 110 | res.extend(predict_idxs) 111 | # result = eval_model(model, test_reader) 112 | # logger.info('test acc: {0}, F1: {1}'.format(result['acc'], result['f1'])) 113 | print('length of test:', len(res)) 114 | with open('output/20195199_1.csv', 'w', encoding='utf-8') as fw: 115 | for item in res: 116 | fw.write(str(item) + '\n') 117 | print('saved!') 118 | 119 | if __name__ == "__main__": 120 | inference() 121 | -------------------------------------------------------------------------------- /prompt_bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertForMaskedLM, BertConfig, BertTokenizer, RobertaTokenizer, RobertaConfig, RobertaForMaskedLM 2 | from torch import nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class PromptBertModel(nn.Module): 8 | def __init__(self, args): 9 | super(PromptBertModel, self).__init__() 10 | self.config = BertConfig.from_pretrained('bert-large-uncased') 11 | self.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') 12 | self.model = BertForMaskedLM.from_pretrained('bert-large-uncased') 13 | # self.config = RobertaConfig.from_pretrained('roberta-large') 14 | # self.tokenizer = RobertaTokenizer.from_pretrained('roberta-large') 15 | # self.model = RobertaForMaskedLM.from_pretrained('roberta-large') 16 | 17 | self.prefix_soft_index, self.suffix_soft_index = [3, 27569, 10], [11167, 10] 18 | self.p_num, self.s_num = len(self.prefix_soft_index), len(self.suffix_soft_index) 19 | self.prefix_soft_embedding_layer = nn.Embedding( 20 | self.p_num, self.config.hidden_size 21 | ) 22 | self.suffix_soft_embedding_layer = nn.Embedding( 23 | self.s_num, self.config.hidden_size 24 | ) 25 | self.normal_embedding_layer = self.model.get_input_embeddings() 26 | self.prefix_soft_embedding_layer.weight.data = torch.stack( 27 | [self.normal_embedding_layer.weight.data[i, :].clone().detach().requires_grad_(True) for i in 28 | self.prefix_soft_index] 29 | ) 30 | self.suffix_soft_embedding_layer.weight.data = torch.stack( 31 | [self.normal_embedding_layer.weight.data[i, :].clone().detach().requires_grad_(True) for i in 32 | self.suffix_soft_index] 33 | ) 34 | self.prefix_soft_ids = torch.tensor(range(self.p_num)) 35 | self.suffix_soft_ids = torch.tensor(range(self.s_num)) 36 | self.mask_ids = torch.tensor([self.tokenizer.mask_token_id]) 37 | for param in self.model.parameters(): 38 | param.requires_grad_(False) 39 | 40 | def forward(self, input_ids, attention_mask, labels=None): 41 | batch_size = input_ids.shape[0] 42 | prefix_soft_ids = torch.stack([self.prefix_soft_ids for i in range(batch_size)]).to(input_ids.device) 43 | mask_ids = torch.stack([self.mask_ids for i in range(batch_size)]).to(input_ids.device) 44 | suffix_soft_ids = torch.stack([self.suffix_soft_ids for i in range(batch_size)]).to(input_ids.device) 45 | 46 | prefix_soft_embeddings = self.prefix_soft_embedding_layer(prefix_soft_ids) 47 | suffix_soft_embeddings = self.suffix_soft_embedding_layer(suffix_soft_ids) 48 | 49 | text_embeddings = self.normal_embedding_layer(input_ids) 50 | mask_embeddings = self.normal_embedding_layer(mask_ids) 51 | input_embeddings = torch.cat( 52 | [prefix_soft_embeddings, text_embeddings, suffix_soft_embeddings, mask_embeddings], 53 | dim=1 54 | ) 55 | prefix_soft_attention_mask = torch.ones(batch_size, self.p_num).to(input_ids.device) 56 | mask_attention_mask = torch.ones(batch_size, 1).to(input_ids.device) 57 | suffix_soft_attention_mask = torch.ones(batch_size, self.s_num).to(input_ids.device) 58 | 59 | attention_mask = torch.cat( 60 | [prefix_soft_attention_mask, attention_mask, suffix_soft_attention_mask, mask_attention_mask], 61 | dim=1 62 | ) 63 | outputs = self.model(inputs_embeds=input_embeddings, attention_mask=attention_mask)[0] 64 | # masked_token_pos = torch.full(masked_token_pos.shape, 50 + self.p_num).to(input_ids.device) 65 | # vocab_size = outputs.shape[2] 66 | # masked_token_pos = torch.unsqueeze(masked_token_pos, 1) 67 | # masked_token_pos = torch.unsqueeze(masked_token_pos, 2) 68 | # masked_token_pos = torch.stack([masked_token_pos] * vocab_size, 2) 69 | # masked_token_pos = torch.squeeze(masked_token_pos, 3) 70 | # masked_token_logits = torch.gather(outputs, 1, masked_token_pos) 71 | # 72 | # masked_token_logits = masked_token_logits.reshape(-1, vocab_size) 73 | logits = outputs[:, -1, [8699, 4963, 12721, 3571, 6569, 12039, 4474]] 74 | 75 | probs = F.log_softmax(logits, -1) 76 | if labels != None: 77 | loss = F.nll_loss(probs, labels.to(probs.device), reduction='mean') 78 | return loss, logits 79 | else: 80 | return logits -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from tqdm import tqdm 4 | from torch.utils.data import Dataset, DataLoader 5 | from transformers import BertTokenizer 6 | from transformers import logging 7 | from bert_model import CLSModel 8 | from t5_model import T5CLSModel 9 | import argparse 10 | 11 | class SCDataset(Dataset): 12 | def __init__(self, dataset): 13 | self.dataset = dataset 14 | self.data_size = len(dataset) 15 | 16 | def __len__(self): 17 | return self.data_size 18 | 19 | def __getitem__(self, index): 20 | return self.dataset[index] 21 | 22 | 23 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 24 | logging.set_verbosity_error() 25 | 26 | def coffate_fn_test(examples): 27 | inputs, targets = [], [] 28 | for sent in examples: 29 | inputs.append(sent) 30 | targets.append(-1) 31 | inputs = tokenizer(inputs, 32 | padding=True, 33 | truncation=True, 34 | return_tensors="pt", 35 | max_length=512) 36 | targets = torch.tensor(targets) 37 | return inputs, targets 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument( 42 | "--model_name", 43 | default=None, 44 | type=str, 45 | required=True, 46 | ) 47 | parser.add_argument( 48 | "--test_data_path", 49 | default='./data/mydata/dev.json', 50 | type=str, 51 | help="The testing data path.", 52 | ) 53 | parser.add_argument( 54 | "--test_label_path", 55 | default='./data/mydata/dev_labels.json', 56 | type=str, 57 | help="The testing label path.", 58 | ) 59 | parser.add_argument( 60 | "--max_len", 61 | default=256, 62 | type=int, 63 | help="The maximum total input sequence length after tokenization.", 64 | ) 65 | parser.add_argument( 66 | "--test_batch_size", 67 | default=128, 68 | type=int, 69 | help="Batch size per GPU/CPU for training.", 70 | ) 71 | parser.add_argument( 72 | "--hidden_size", 73 | default=768, 74 | type=int, 75 | help="Hidden size.", 76 | ) 77 | parser.add_argument( 78 | "--project_dim", 79 | default=7, 80 | type=int, 81 | help="Project Dim.", 82 | ) 83 | parser.add_argument( 84 | "--bert_checkpoint_dir", 85 | default='./checkpoint/save_model_best.pt', 86 | type=str, 87 | help="the path of fine-tuned bert checkpoint.", 88 | ) 89 | parser.add_argument( 90 | "--t5_checkpoint_dir", 91 | default='./checkpoint/save_t5_model_best.pt', 92 | type=str, 93 | help="the path of fine-tuned t5 checkpoint.", 94 | ) 95 | parser.add_argument( 96 | "--output_dir", 97 | default='./checkpoint', 98 | type=str, 99 | help="The output directory where the model predictions and checkpoints will be written.", 100 | ) 101 | parser.add_argument( 102 | "--seed", 103 | type=int, 104 | default=1234, 105 | help="random seed for initialization", 106 | ) 107 | args = parser.parse_args() 108 | test_data_path = "./data/mydata/test.tsv" 109 | test_data = [] 110 | with open(test_data_path, 'r', encoding="utf-8") as fr: 111 | for line in fr.readlines(): 112 | sentence = line.strip() 113 | test_data.append(sentence) 114 | test_dataset = SCDataset(test_data) 115 | test_dataloader = DataLoader(test_dataset, 116 | batch_size=1, 117 | collate_fn=coffate_fn_test) 118 | 119 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 120 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 121 | model = CLSModel(args) 122 | model.load_state_dict(torch.load(args.bert_checkpoint_dir)['model']) 123 | model.cuda() 124 | 125 | res = [] 126 | for batch in tqdm(test_dataloader, desc=f"Testing"): 127 | inputs, targets = [x.to(device) for x in batch] 128 | with torch.no_grad(): 129 | bert_output = model(inputs) 130 | #print(bert_output.argmax(dim=1).data.item()) 131 | res.append(str(bert_output.argmax(dim=1).data.item())) 132 | 133 | with open('test_res.tsv', 'w', encoding='utf-8') as fw: 134 | for item in res: 135 | fw.write(item+'\n') -------------------------------------------------------------------------------- /data/train_demo.json: -------------------------------------------------------------------------------- 1 | ["A film that is so much a 30\\'s Warners film in an era when each studio had a particular look and style to their output, unlike today where simply getting audiences is the object. Curitz was one of the quintessential Warners house directors working with tight economy and great efficiency whilst creating quality, working methods that were very much the requirements of a director at Warners, a studio that was one of the \"big five\" majors in this era producing quality films for their large chains of theatres. Even though we have a setting of the upper classes on Long Island there is the generic Warners style embedded here with a narrative that could have been \"torn from the headlines\". Another example is the when the photographers comment on the girls legs early in the film and she comments that \"They\\'re not the trophies\" gives the film a more working mans, down to earth feel, for these were the audiences that Warners were targeting in the great depression. ( ironically Columbia and Universal were the two minors under these five majors until the 50\\'s when their involvement in television changed their fortunes - they would have made something like this very cheaply and without the polish and great talent ) Curtiz has created from an excellent script a film that moves along at a rapid pace whilst keeping the viewer with great camera angles and swift editing. Thank heavens there is no soppy love interest sub-plot so the fun can just keep rolling along.", "Hey guys, i have been looking every where to find these two movies and i can't find them anywhere in my local area. ( I am Australian ) . Could You please help me and tell me where i can buy it from. In General Home Ward Bound 1 and 2 are the best movies i have ever seen and are good for people of all ages. It was my favourite movie wen i was 5 and it still is even now when i am a teenager. It is a great movie for the whole family. My entire family loves this movie except for my younger sister because i have watched it that many times that she is sick of it. I love this movie and i cant wait till i can buy it again on DVD. Sally", "I have complained to ABC about the cancellation of six degrees. If enough people do the same then it could be enough to bring this fabulous show back to life!! Just go onto the official site and the rest is simple enough. I do not understand why this show has been cancelled. What a fantastic show, cast and characters. The whole concept is gripping viewing! I am astounded that my favourite show is over after just one series. Why is this? Six degrees is phenomenal, it's better than so many other TV programmes out there! Until I heard they were stopping it from a friend it hadn't even occurred to me that this might happen.", "Paul Greengrass definitely saved the best Bourne for last! I've heard a lot of people complain about they way he filmed this movie, and some have even compared the camera style to the Blair Witch Project. All I have to say to that is...are you kidding me? Come on it was not that bad at all. I think it helps the action scenes to feel more realistic, which I would prefer over highly stylized stunt choreography. As for the rest of the movie I really didn't even notice it. You can tell that Damon has really gotten comfortable with the role of Jason Bourne. Sometimes that can be a bad thing, but in this case its a really good thing. He really becomes Jason Bourne in this installment. Damon also has a great supporting cast in Joan Allen, Ezra Kramer, and Julia Stiles. David Strathairn was a great addition to the cast, as he added more depth to the secret CIA organization. Even though the movie is filled with great car chases and nonstop action, they managed to stick a fair amount of character development in their with all of that going on. This film stands far above the other two Bourne movies, and is definitely one of the best movies of the 2007 summer season!", "People forget that there have been several King Kong ripoffs- Congo, King Kong Vs. Godzilla, King Kong ( 1976 ) , they all ripoff one another, but YETI stands on its own. It only borrows one element from King Kong and that is the animal\\'s attraction with one female. The YETI myth is based on Bigfoot ( not like King Kong ) and archeologists have been fascinated it, at one time they did exist,but there is no scientific data to prove it. This movie is hard to find ,but its worth watching it. The first time I watched it was on \"Elvira\\'s Mistress of the Dark Shows\" in the early 1980\\'s. It sent chills down my spine as a kid, especially when the YETI got mad. I saw it again, around 1:00am on ABC about 2 to 3yrs ago. Seeing it again made me appreciate it more, it has some overall good effects ( for its time ) and the story involves a mute boy and his dog, and an evil businessman person who wants to kill the YETI for his own purposes. Also the music is pretty cool,its very YETI like. :- ) Gianfranco Parolini and the Yetians creates a great monster like atmosphere. Vote 7 and half out of 10."] -------------------------------------------------------------------------------- /prompt_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from transformers import ( 3 | AdamW, 4 | get_linear_schedule_with_warmup, 5 | BertTokenizer, 6 | T5Tokenizer, 7 | RobertaTokenizer 8 | ) 9 | import os 10 | import argparse 11 | import logging 12 | import torch 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset, DistributedSampler 14 | from data_ultis import ScData 15 | from prompt_bert import PromptBertModel 16 | from prompt_t5 import PromptT5CLSModel 17 | from utils import set_env, set_seed, get_metric, eval_model 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def train(args, model, train_dataloader, valid_dataloader): 23 | real_batch_size = args.train_batch_size * args.gradient_accumulation_steps 24 | t_total = train_dataloader.dataset.__len__() // real_batch_size * args.num_train_epochs 25 | 26 | param_optimizer = list(model.named_parameters()) 27 | optimizer_grouped_parameters = [{'params': [p for n, p in param_optimizer], 'weight_decay': 0.01}] 28 | optimizer = AdamW(optimizer_grouped_parameters, 29 | lr=args.learning_rate, eps=args.adam_epsilon) 30 | scheduler = get_linear_schedule_with_warmup( 31 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 32 | 33 | logger.info("***** Running training *****") 34 | logger.info(" Num Epochs = %d", args.num_train_epochs) 35 | logger.info(" Instantaneous batch size per GPU = %d", 36 | args.train_batch_size) 37 | logger.info(" Gradient Accumulation steps = %d", 38 | args.gradient_accumulation_steps) 39 | logger.info(" Total optimization steps = %d", t_total) 40 | global_step = 0 41 | tr_loss = 0.0 42 | best_acc = 0.0 43 | model.zero_grad() 44 | for epoch in range(int(args.num_train_epochs)): 45 | for step, batch in enumerate(train_dataloader): 46 | model.train() 47 | input_ids = batch[0].cuda() 48 | attention_mask = batch[1].cuda() 49 | labels = batch[2].cuda() 50 | outputs = model(input_ids, attention_mask, labels) 51 | loss = outputs[0] 52 | if args.gradient_accumulation_steps > 1: 53 | loss = loss / args.gradient_accumulation_steps 54 | loss.backward() 55 | tr_loss += loss.item() 56 | if (step + 1) % args.gradient_accumulation_steps == 0: 57 | global_step += 1 58 | torch.nn.utils.clip_grad_norm_( 59 | model.parameters(), args.max_grad_norm) 60 | logger.info( 61 | 'Epoch: {}, Step: {}, Loss: {:.4f}, lr: {:.6f}'.format(epoch, global_step, (tr_loss / global_step), 62 | optimizer.param_groups[0]["lr"])) 63 | optimizer.step() 64 | scheduler.step() 65 | model.zero_grad() 66 | if global_step % args.eval_steps == 0: 67 | logger.info('Start eval!') 68 | evaluation_results = eval_model(model, valid_dataloader) 69 | acc = evaluation_results["acc"] 70 | f1 = evaluation_results["f1"] 71 | logger.info('Dev acc: {0}, F1: {1}'.format(acc, f1)) 72 | if acc >= best_acc: 73 | best_acc = acc 74 | if args.model_name_or_path == 'bert': 75 | torch.save({'epoch': epoch, 76 | 'model': model.state_dict()}, 77 | os.path.join(args.output_dir, "save_prompt_bert_model_best.pt")) 78 | else: 79 | torch.save({'epoch': epoch, 80 | 'model': model.state_dict()}, 81 | os.path.join(args.output_dir, "save_prompt_t5_model_best.pt")) 82 | logger.info("Saved best epoch {0}, best acc {1}".format(epoch, best_acc)) 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument( 88 | "--train_data_path", 89 | default=None, 90 | type=str, 91 | required=True, 92 | help="The training data path.", 93 | ) 94 | parser.add_argument( 95 | "--dev_data_path", 96 | default=None, 97 | type=str, 98 | required=True, 99 | help="The validation data path.", 100 | ) 101 | parser.add_argument( 102 | "--train_label_path", 103 | default=None, 104 | type=str, 105 | required=True, 106 | help="The training label path.", 107 | ) 108 | parser.add_argument( 109 | "--dev_label_path", 110 | default=None, 111 | type=str, 112 | required=True, 113 | help="The validation label path.", 114 | ) 115 | parser.add_argument( 116 | "--model_name_or_path", 117 | default=None, 118 | type=str, 119 | required=True, 120 | ) 121 | parser.add_argument( 122 | "--output_dir", 123 | default=None, 124 | type=str, 125 | required=True, 126 | help="The output directory where the model predictions and checkpoints will be written.", 127 | ) 128 | parser.add_argument( 129 | "--max_len", 130 | default=150, 131 | type=int, 132 | help="The maximum total input sequence length after tokenization.", 133 | ) 134 | parser.add_argument( 135 | "--train_batch_size", 136 | default=128, 137 | type=int, 138 | help="Batch size per GPU/CPU for training.", 139 | ) 140 | parser.add_argument( 141 | "--dev_batch_size", 142 | default=128, 143 | type=int, 144 | help="Batch size per GPU/CPU for training.", 145 | ) 146 | parser.add_argument( 147 | "--gradient_accumulation_steps", 148 | type=int, 149 | default=1, 150 | help="Number of updates steps to accumulate before performing a backward/update pass.", 151 | ) 152 | parser.add_argument( 153 | "--learning_rate", 154 | default=1e-5, 155 | type=float, 156 | help="The initial learning rate for Adam.", 157 | ) 158 | parser.add_argument( 159 | "--weight_decay", 160 | default=0.0, 161 | type=float, 162 | help="Weight decay if we apply some.", 163 | ) 164 | parser.add_argument( 165 | "--adam_epsilon", 166 | default=1e-8, 167 | type=float, 168 | help="Epsilon for Adam optimizer.", 169 | ) 170 | parser.add_argument( 171 | "--max_grad_norm", 172 | default=1.0, 173 | type=float, 174 | help="Max gradient norm.", 175 | ) 176 | parser.add_argument( 177 | "--bitfit", 178 | default=False, 179 | action="store_true", 180 | ) 181 | parser.add_argument( 182 | "--num_train_epochs", 183 | default=10, 184 | type=float, 185 | help="Total number of training epochs to perform.", 186 | ) 187 | parser.add_argument( 188 | "--warmup_steps", 189 | default=0, 190 | type=int, 191 | help="Linear warmup over warmup_steps.", 192 | ) 193 | parser.add_argument( 194 | "--hidden_size", 195 | default=768, 196 | type=int, 197 | help="Hidden size.", 198 | ) 199 | parser.add_argument( 200 | "--project_dim", 201 | default=7, 202 | type=int, 203 | help="Project Dim.", 204 | ) 205 | parser.add_argument( 206 | "--eval_steps", 207 | type=int, 208 | default=500, 209 | help="eval model every X updates steps.", 210 | ) 211 | parser.add_argument( 212 | "--seed", 213 | type=int, 214 | default=1234, 215 | help="random seed for initialization", 216 | ) 217 | parser.add_argument( 218 | "--prefix", 219 | default=None, 220 | type=str 221 | ) 222 | parser.add_argument( 223 | "--suffix", 224 | default=None, 225 | type=str 226 | ) 227 | args = parser.parse_args() 228 | log_name = args.model_name_or_path + '_train_prompt' 229 | set_env(args, log_name) 230 | if args.model_name_or_path == 'bert': 231 | tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') 232 | # tokenizer = RobertaTokenizer.from_pretrained('roberta-large') 233 | model = PromptBertModel(args) 234 | else: 235 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 236 | model = PromptT5CLSModel(args) 237 | model.cuda() 238 | logger.info("Loading training set.") 239 | train_data = ScData(args, args.train_data_path, args.train_label_path, tokenizer) 240 | train_sampler = RandomSampler(train_data) 241 | train_reader = DataLoader(dataset=train_data, sampler=train_sampler, num_workers=0, 242 | batch_size=args.train_batch_size) 243 | dev_data = ScData(args, args.dev_data_path, args.dev_label_path, tokenizer) 244 | dev_reader = DataLoader(dataset=dev_data, num_workers=0, 245 | batch_size=args.dev_batch_size, shuffle=False) 246 | train(args, model, train_reader, dev_reader) 247 | 248 | 249 | if __name__ == "__main__": 250 | main() 251 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from transformers import ( 3 | AdamW, 4 | get_linear_schedule_with_warmup, 5 | BertTokenizer, 6 | RobertaTokenizer, 7 | T5Tokenizer 8 | ) 9 | import os 10 | import argparse 11 | import logging 12 | import torch 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset, DistributedSampler 14 | from data_ultis import ScData 15 | from bert_model import CLSModel 16 | from t5_model import T5CLSModel 17 | from utils import set_env, set_seed, get_metric, eval_model 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def train(args, model, train_dataloader, valid_dataloader): 23 | real_batch_size = args.train_batch_size * args.gradient_accumulation_steps 24 | t_total = train_dataloader.dataset.__len__() // real_batch_size * args.num_train_epochs 25 | if args.bitfit: 26 | trainable_parameters = [] 27 | trainable_names = [] 28 | trainable_components = ['encode_proj', 'bias'] 29 | for name, param in model.named_parameters(): 30 | param.requires_grad = False 31 | for component in trainable_components: 32 | if component in name: 33 | trainable_parameters.append(param) 34 | trainable_names.append(name) 35 | param.requires_grad = True 36 | break 37 | logger.info("*** BitFit Training *****") 38 | logger.info(trainable_names) 39 | 40 | optimizer_grouped_parameters = [ 41 | {'params': trainable_parameters, 'weight_decay': 0.0} 42 | ] 43 | else: 44 | param_optimizer = list(model.named_parameters()) 45 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 46 | # no_decay = [] 47 | optimizer_grouped_parameters = [ 48 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 49 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 50 | ] 51 | 52 | optimizer = AdamW(optimizer_grouped_parameters, 53 | lr=args.learning_rate, eps=args.adam_epsilon) 54 | scheduler = get_linear_schedule_with_warmup( 55 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 56 | 57 | logger.info("***** Running training *****") 58 | logger.info(" Num Epochs = %d", args.num_train_epochs) 59 | logger.info(" Instantaneous batch size per GPU = %d", 60 | args.train_batch_size) 61 | logger.info(" Gradient Accumulation steps = %d", 62 | args.gradient_accumulation_steps) 63 | logger.info(" Total optimization steps = %d", t_total) 64 | global_step = 0 65 | tr_loss = 0.0 66 | best_acc = 0.0 67 | model.zero_grad() 68 | for epoch in range(int(args.num_train_epochs)): 69 | for step, batch in enumerate(train_dataloader): 70 | model.train() 71 | input_ids = batch[0].cuda() 72 | attention_mask = batch[1].cuda() 73 | labels = batch[2].cuda() 74 | outputs = model(input_ids, attention_mask, labels) 75 | loss = outputs[0] 76 | if args.gradient_accumulation_steps > 1: 77 | loss = loss / args.gradient_accumulation_steps 78 | loss.backward() 79 | tr_loss += loss.item() 80 | if (step + 1) % args.gradient_accumulation_steps == 0: 81 | global_step += 1 82 | torch.nn.utils.clip_grad_norm_( 83 | model.parameters(), args.max_grad_norm) 84 | logger.info( 85 | 'Epoch: {}, Step: {}, Loss: {:.4f}, lr: {:.6f}'.format(epoch, global_step, (tr_loss / global_step), 86 | optimizer.param_groups[0]["lr"])) 87 | optimizer.step() 88 | scheduler.step() 89 | model.zero_grad() 90 | if global_step % args.eval_steps == 0: 91 | logger.info('Start eval!') 92 | evaluation_results = eval_model(model, valid_dataloader) 93 | acc = evaluation_results["acc"] 94 | f1 = evaluation_results["f1"] 95 | logger.info('Dev acc: {0}, F1: {1}'.format(acc, f1)) 96 | if acc >= best_acc: 97 | best_acc = acc 98 | if args.model_name_or_path == 'bert': 99 | torch.save({'epoch': epoch, 100 | 'model': model.state_dict()}, 101 | os.path.join(args.output_dir, "save_model_best.pt")) 102 | else: 103 | torch.save({'epoch': epoch, 104 | 'model': model.state_dict()}, 105 | os.path.join(args.output_dir, "save_t5_model_best.pt")) 106 | logger.info("Saved best epoch {0}, best acc {1}".format(epoch, best_acc)) 107 | 108 | 109 | def main(): 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument( 112 | "--train_data_path", 113 | default=None, 114 | type=str, 115 | required=True, 116 | help="The training data path.", 117 | ) 118 | parser.add_argument( 119 | "--dev_data_path", 120 | default=None, 121 | type=str, 122 | required=True, 123 | help="The validation data path.", 124 | ) 125 | parser.add_argument( 126 | "--train_label_path", 127 | default=None, 128 | type=str, 129 | required=True, 130 | help="The training label path.", 131 | ) 132 | parser.add_argument( 133 | "--dev_label_path", 134 | default=None, 135 | type=str, 136 | required=True, 137 | help="The validation label path.", 138 | ) 139 | parser.add_argument( 140 | "--model_name_or_path", 141 | default=None, 142 | type=str, 143 | required=True, 144 | ) 145 | parser.add_argument( 146 | "--output_dir", 147 | default=None, 148 | type=str, 149 | required=True, 150 | help="The output directory where the model predictions and checkpoints will be written.", 151 | ) 152 | parser.add_argument( 153 | "--max_len", 154 | default=150, 155 | type=int, 156 | help="The maximum total input sequence length after tokenization.", 157 | ) 158 | parser.add_argument( 159 | "--train_batch_size", 160 | default=128, 161 | type=int, 162 | help="Batch size per GPU/CPU for training.", 163 | ) 164 | parser.add_argument( 165 | "--dev_batch_size", 166 | default=128, 167 | type=int, 168 | help="Batch size per GPU/CPU for training.", 169 | ) 170 | parser.add_argument( 171 | "--gradient_accumulation_steps", 172 | type=int, 173 | default=1, 174 | help="Number of updates steps to accumulate before performing a backward/update pass.", 175 | ) 176 | parser.add_argument( 177 | "--learning_rate", 178 | default=1e-5, 179 | type=float, 180 | help="The initial learning rate for Adam.", 181 | ) 182 | parser.add_argument( 183 | "--weight_decay", 184 | default=0.0, 185 | type=float, 186 | help="Weight decay if we apply some.", 187 | ) 188 | parser.add_argument( 189 | "--adam_epsilon", 190 | default=1e-8, 191 | type=float, 192 | help="Epsilon for Adam optimizer.", 193 | ) 194 | parser.add_argument( 195 | "--max_grad_norm", 196 | default=1.0, 197 | type=float, 198 | help="Max gradient norm.", 199 | ) 200 | parser.add_argument( 201 | "--bitfit", 202 | default=False, 203 | action="store_true", 204 | ) 205 | parser.add_argument( 206 | "--num_train_epochs", 207 | default=10, 208 | type=float, 209 | help="Total number of training epochs to perform.", 210 | ) 211 | parser.add_argument( 212 | "--warmup_steps", 213 | default=0, 214 | type=int, 215 | help="Linear warmup over warmup_steps.", 216 | ) 217 | parser.add_argument( 218 | "--hidden_size", 219 | default=768, 220 | type=int, 221 | help="Hidden size.", 222 | ) 223 | parser.add_argument( 224 | "--project_dim", 225 | default=7, 226 | type=int, 227 | help="Project Dim.", 228 | ) 229 | parser.add_argument( 230 | "--eval_steps", 231 | type=int, 232 | default=1000, 233 | help="eval model every X updates steps.", 234 | ) 235 | parser.add_argument( 236 | "--seed", 237 | type=int, 238 | default=1234, 239 | help="random seed for initialization", 240 | ) 241 | parser.add_argument( 242 | "--prefix", 243 | default=None, 244 | type=str 245 | ) 246 | parser.add_argument( 247 | "--suffix", 248 | default=None, 249 | type=str 250 | ) 251 | args = parser.parse_args() 252 | log_name = args.model_name_or_path + '_train' 253 | set_env(args, log_name) 254 | if args.model_name_or_path == 'bert': 255 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 256 | # tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 257 | model = CLSModel(args) 258 | else: 259 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 260 | model = T5CLSModel(args) 261 | model.cuda() 262 | logger.info("Loading training set.") 263 | train_data = ScData(args, args.train_data_path, args.train_label_path, tokenizer) 264 | train_sampler = RandomSampler(train_data) 265 | train_reader = DataLoader(dataset=train_data, sampler=train_sampler, num_workers=0, 266 | batch_size=args.train_batch_size) 267 | dev_data = ScData(args, args.dev_data_path, args.dev_label_path, tokenizer) 268 | dev_reader = DataLoader(dataset=dev_data, num_workers=0, 269 | batch_size=args.dev_batch_size, shuffle=False) 270 | train(args, model, train_reader, dev_reader) 271 | 272 | 273 | if __name__ == "__main__": 274 | main() 275 | --------------------------------------------------------------------------------