├── README.md ├── parabart.py ├── parabart_qqpeval.py ├── parabart_senteval.py ├── synt_vocab.pkl ├── train_parabart.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## ParaBART 2 | 3 | Code for our NAACL-2021 paper ["Disentangling Semantics and Syntax in Sentence Embeddings with Pre-trained Language Models"](https://arxiv.org/abs/2104.05115). 4 | 5 | If you find this repository useful, please consider citing our paper. 6 | ``` 7 | @inproceedings{huang2021disentangling, 8 | title = {Disentangling Semantics and Syntax in Sentence Embeddings with Pre-trained Language Models}, 9 | author = {Huang, James Y. and Huang, Kuan-Hao and Chang, Kai-Wei}, 10 | booktitle = {NAACL}, 11 | year = {2021} 12 | } 13 | ``` 14 | 15 | ### Dependencies 16 | 17 | - Python==3.7.6 18 | - PyTorch==1.6.0 19 | - Transformers==3.0.2 20 | 21 | ### Pre-trained Models 22 | 23 | Our pre-trained ParaBART model is available [here](https://drive.google.com/file/d/1Ev9iB2bIekEp1yYTCJPkngzZSRWOS-cz/view?usp=sharing) 24 | 25 | ### Training 26 | 27 | - Download the [dataset](https://drive.google.com/file/d/1Pv_RB47BD_zLhmQUhFpiEdI6UHDbb-wX/view?usp=sharing) and put it under `./data/` 28 | - Run the following command to train ParaBART 29 | ``` 30 | python train_parabart.py --data_dir ./data/ 31 | ``` 32 | 33 | ### Evaluation 34 | 35 | - Download the [SentEval](https://github.com/facebookresearch/SentEval) toolkit and datasets 36 | - Name your trained model `model.pt` and put it under `./model/` 37 | - Run the following command to evaluate ParaBART on semantic textual similarity and syntactic probing tasks 38 | ``` 39 | python parabart_senteval.py --senteval_dir ../SentEval --model_dir ./model/ 40 | ``` 41 | - Download QQP-Easy and QQP-Hard datasets [here](https://drive.google.com/file/d/1am502GkMU-9h-5chZ7RVt-7l0FAGvfH2/view?usp=sharing) 42 | - Run the following command to evaluate ParaBART on QQP datasets 43 | ``` 44 | python parabart_qqpeval.py 45 | ``` 46 | 47 | ### Author 48 | 49 | James Yipeng Huang / [@jyhuang36](https://github.com/jyhuang36) 50 | -------------------------------------------------------------------------------- /parabart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformers.modeling_bart import ( 6 | PretrainedBartModel, 7 | LayerNorm, 8 | EncoderLayer, 9 | DecoderLayer, 10 | LearnedPositionalEmbedding, 11 | _prepare_bart_decoder_inputs, 12 | _make_linear_from_emb 13 | ) 14 | 15 | class ParaBart(PretrainedBartModel): 16 | def __init__(self, config): 17 | super().__init__(config) 18 | 19 | self.shared = nn.Embedding(config.vocab_size, config.d_model, config.pad_token_id) 20 | 21 | self.encoder = ParaBartEncoder(config, self.shared) 22 | self.decoder = ParaBartDecoder(config, self.shared) 23 | 24 | self.linear = nn.Linear(config.d_model, config.vocab_size) 25 | 26 | self.adversary = Discriminator(config) 27 | 28 | self.init_weights() 29 | 30 | def forward( 31 | self, 32 | input_ids, 33 | decoder_input_ids, 34 | attention_mask=None, 35 | decoder_padding_mask=None, 36 | encoder_outputs=None, 37 | return_encoder_outputs=False, 38 | ): 39 | if attention_mask is None: 40 | attention_mask = input_ids == self.config.pad_token_id 41 | 42 | if encoder_outputs is None: 43 | encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask) 44 | 45 | if return_encoder_outputs: 46 | return encoder_outputs 47 | 48 | assert encoder_outputs is not None 49 | assert decoder_input_ids is not None 50 | 51 | decoder_input_ids = decoder_input_ids[:, :-1] 52 | 53 | _, decoder_padding_mask, decoder_causal_mask = _prepare_bart_decoder_inputs( 54 | self.config, 55 | input_ids=None, 56 | decoder_input_ids=decoder_input_ids, 57 | decoder_padding_mask=decoder_padding_mask, 58 | causal_mask_dtype=self.shared.weight.dtype, 59 | ) 60 | 61 | attention_mask2 = torch.cat((torch.zeros(input_ids.shape[0], 1).bool().cuda(), attention_mask[:, self.config.max_sent_len+2:]), dim=1) 62 | 63 | # decoder 64 | decoder_outputs = self.decoder( 65 | decoder_input_ids, 66 | torch.cat((encoder_outputs[1], encoder_outputs[0][:, self.config.max_sent_len+2:]), dim=1), 67 | decoder_padding_mask=decoder_padding_mask, 68 | decoder_causal_mask=decoder_causal_mask, 69 | encoder_attention_mask=attention_mask2, 70 | )[0] 71 | 72 | 73 | batch_size = decoder_outputs.shape[0] 74 | outputs = self.linear(decoder_outputs.contiguous().view(-1, self.config.d_model)) 75 | outputs = outputs.view(batch_size, -1, self.config.vocab_size) 76 | 77 | # discriminator 78 | for p in self.adversary.parameters(): 79 | p.required_grad=False 80 | adv_outputs = self.adversary(encoder_outputs[1]) 81 | 82 | return outputs, adv_outputs 83 | 84 | def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs): 85 | assert past is not None, "past has to be defined for encoder_outputs" 86 | 87 | encoder_outputs = past[0] 88 | return { 89 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 90 | "encoder_outputs": encoder_outputs, 91 | "decoder_input_ids": torch.cat((decoder_input_ids, torch.zeros((decoder_input_ids.shape[0], 1), dtype=torch.long).cuda()), 1), 92 | "attention_mask": attention_mask, 93 | } 94 | 95 | def get_encoder(self): 96 | return self.encoder 97 | 98 | def get_output_embeddings(self): 99 | return _make_linear_from_emb(self.shared) 100 | 101 | def get_input_embeddings(self): 102 | return self.shared 103 | 104 | @staticmethod 105 | def _reorder_cache(past, beam_idx): 106 | enc_out = past[0][0] 107 | 108 | new_enc_out = enc_out.index_select(0, beam_idx) 109 | 110 | past = ((new_enc_out, ), ) 111 | return past 112 | 113 | def forward_adv( 114 | self, 115 | input_token_ids, 116 | attention_mask=None, 117 | decoder_padding_mask=None 118 | ): 119 | for p in self.adversary.parameters(): 120 | p.required_grad=True 121 | sent_embeds = self.encoder.embed(input_token_ids, attention_mask=attention_mask).detach() 122 | adv_outputs = self.adversary(sent_embeds) 123 | 124 | return adv_outputs 125 | 126 | 127 | class ParaBartEncoder(nn.Module): 128 | def __init__(self, config, embed_tokens): 129 | super().__init__() 130 | self.config = config 131 | 132 | self.dropout = config.dropout 133 | self.embed_tokens = embed_tokens 134 | 135 | self.embed_synt = nn.Embedding(77, config.d_model, config.pad_token_id) 136 | self.embed_synt.weight.data.normal_(mean=0.0, std=config.init_std) 137 | self.embed_synt.weight.data[config.pad_token_id].zero_() 138 | 139 | self.embed_positions = LearnedPositionalEmbedding( 140 | config.max_position_embeddings, config.d_model, config.pad_token_id, config.extra_pos_embeddings 141 | ) 142 | 143 | self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) 144 | self.synt_layers = nn.ModuleList([EncoderLayer(config) for _ in range(1)]) 145 | 146 | self.layernorm_embedding = LayerNorm(config.d_model) 147 | 148 | self.synt_layernorm_embedding = LayerNorm(config.d_model) 149 | 150 | self.pooling = MeanPooling(config) 151 | 152 | 153 | def forward(self, input_ids, attention_mask): 154 | 155 | input_token_ids, input_synt_ids = torch.split(input_ids, [self.config.max_sent_len+2, self.config.max_synt_len+2], dim=1) 156 | input_token_mask, input_synt_mask = torch.split(attention_mask, [self.config.max_sent_len+2, self.config.max_synt_len+2], dim=1) 157 | 158 | x = self.forward_token(input_token_ids, input_token_mask) 159 | y = self.forward_synt(input_synt_ids, input_synt_mask) 160 | 161 | encoder_outputs = torch.cat((x,y), dim=1) 162 | 163 | sent_embeds = self.pooling(x, input_token_ids) 164 | 165 | return encoder_outputs, sent_embeds 166 | 167 | def forward_token(self, input_token_ids, attention_mask): 168 | if self.training: 169 | drop_mask = torch.bernoulli(self.config.word_dropout*torch.ones(input_token_ids.shape)).bool().cuda() 170 | input_token_ids = input_token_ids.masked_fill(drop_mask, 50264) 171 | 172 | input_token_embeds = self.embed_tokens(input_token_ids) + self.embed_positions(input_token_ids) 173 | x = self.layernorm_embedding(input_token_embeds) 174 | x = F.dropout(x, p=self.dropout, training=self.training) 175 | 176 | x = x.transpose(0, 1) 177 | 178 | for encoder_layer in self.layers: 179 | x, _ = encoder_layer(x, encoder_padding_mask=attention_mask) 180 | 181 | x = x.transpose(0, 1) 182 | return x 183 | 184 | def forward_synt(self, input_synt_ids, attention_mask): 185 | input_synt_embeds = self.embed_synt(input_synt_ids) + self.embed_positions(input_synt_ids) 186 | y = self.synt_layernorm_embedding(input_synt_embeds) 187 | y = F.dropout(y, p=self.dropout, training=self.training) 188 | 189 | # B x T x C -> T x B x C 190 | y = y.transpose(0, 1) 191 | 192 | for encoder_synt_layer in self.synt_layers: 193 | y, _ = encoder_synt_layer(y, encoder_padding_mask=attention_mask) 194 | 195 | # T x B x C -> B x T x C 196 | y = y.transpose(0, 1) 197 | return y 198 | 199 | 200 | def embed(self, input_token_ids, attention_mask=None, pool='mean'): 201 | if attention_mask is None: 202 | attention_mask = input_token_ids == self.config.pad_token_id 203 | 204 | x = self.forward_token(input_token_ids, attention_mask) 205 | 206 | sent_embeds = self.pooling(x, input_token_ids) 207 | return sent_embeds 208 | 209 | class MeanPooling(nn.Module): 210 | def __init__(self, config): 211 | super().__init__() 212 | self.config = config 213 | 214 | def forward(self, x, input_token_ids): 215 | mask = input_token_ids != self.config.pad_token_id 216 | mean_mask = mask.float()/mask.float().sum(1, keepdim=True) 217 | x = (x*mean_mask.unsqueeze(2)).sum(1, keepdim=True) 218 | return x 219 | 220 | 221 | class ParaBartDecoder(nn.Module): 222 | def __init__(self, config, embed_tokens): 223 | super().__init__() 224 | 225 | self.dropout = config.dropout 226 | 227 | self.embed_tokens = embed_tokens 228 | 229 | self.embed_positions = LearnedPositionalEmbedding( 230 | config.max_position_embeddings, config.d_model, config.pad_token_id, config.extra_pos_embeddings 231 | ) 232 | 233 | self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(1)]) 234 | self.layernorm_embedding = LayerNorm(config.d_model) 235 | 236 | def forward( 237 | self, 238 | decoder_input_ids, 239 | encoder_hidden_states, 240 | decoder_padding_mask, 241 | decoder_causal_mask, 242 | encoder_attention_mask 243 | ): 244 | 245 | x = self.embed_tokens(decoder_input_ids) + self.embed_positions(decoder_input_ids) 246 | x = self.layernorm_embedding(x) 247 | x = F.dropout(x, p=self.dropout, training=self.training) 248 | 249 | x = x.transpose(0, 1) 250 | encoder_hidden_states = encoder_hidden_states.transpose(0, 1) 251 | 252 | for idx, decoder_layer in enumerate(self.layers): 253 | x, _, _ = decoder_layer( 254 | x, 255 | encoder_hidden_states, 256 | encoder_attn_mask=encoder_attention_mask, 257 | decoder_padding_mask=decoder_padding_mask, 258 | causal_mask=decoder_causal_mask) 259 | 260 | x = x.transpose(0, 1) 261 | 262 | return x, 263 | 264 | 265 | class Discriminator(nn.Module): 266 | def __init__(self, config): 267 | super().__init__() 268 | self.sent_layernorm_embedding = LayerNorm(config.d_model, elementwise_affine=False) 269 | self.adv = nn.Linear(config.d_model, 74) 270 | 271 | def forward(self, sent_embeds): 272 | x = self.sent_layernorm_embedding(sent_embeds).squeeze(1) 273 | x = self.adv(x) 274 | return x 275 | -------------------------------------------------------------------------------- /parabart_qqpeval.py: -------------------------------------------------------------------------------- 1 | import sys, io 2 | import numpy as np 3 | import torch 4 | from transformers import BartTokenizer, BartConfig, BartModel 5 | from tqdm import tqdm 6 | from sklearn.metrics import f1_score, roc_auc_score 7 | import pickle, random 8 | from parabart import ParaBart 9 | 10 | 11 | 12 | print("==== loading model ====") 13 | config = BartConfig.from_pretrained('facebook/bart-base', cache_dir='../para-data/bart-base') 14 | 15 | model = ParaBart(config) 16 | 17 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir='../para-data/bart-base') 18 | 19 | model.load_state_dict(torch.load("./model/model.pt", map_location='cpu')) 20 | 21 | model = model.cuda() 22 | 23 | def build_embeddings(model, tokenizer, sents): 24 | model.eval() 25 | embeddings = torch.ones((len(sents), model.config.d_model)) 26 | with torch.no_grad(): 27 | for i, sent in enumerate(sents): 28 | sent_inputs = tokenizer(sent, return_tensors="pt") 29 | sent_token_ids = sent_inputs['input_ids'] 30 | 31 | sent_embed = model.encoder.embed(sent_token_ids.cuda()) 32 | embeddings[i] = sent_embed.detach().cpu().clone() 33 | return embeddings 34 | 35 | def cosine(u, v): 36 | return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v)) 37 | 38 | 39 | 40 | scores = [] 41 | labels = [] 42 | with open("qqp.pkl", "rb") as f: 43 | para_split = pickle.load(f) 44 | 45 | 46 | pos_hard = para_split['pos_hard'] 47 | pos = para_split['pos'] 48 | neg = para_split['neg'] 49 | 50 | easy = pos + neg 51 | hard = pos_hard + neg 52 | 53 | scores = [] 54 | for i in tqdm(range(len(easy))): 55 | embeds = build_embeddings(model, tokenizer, [easy[i][0], easy[i][1]]) 56 | score = cosine(embeds[0], embeds[1]) 57 | scores.append(score) 58 | 59 | scores_hard = [] 60 | for i in tqdm(range(len(hard))): 61 | embeds = build_embeddings(model, tokenizer, [hard[i][0], hard[i][1]]) 62 | score = cosine(embeds[0], embeds[1]) 63 | scores_hard.append(score) 64 | 65 | 66 | 67 | 68 | best_acc = 0.0 69 | best_thres = 0.0 70 | scores = np.asarray(scores) 71 | labels = [1]*len(pos) + [0]*len(neg) 72 | labels = np.asarray(labels) 73 | for thres in range(-100, 100, 1): 74 | thres = thres / 100.0 75 | preds = scores > thres 76 | acc = sum(labels == preds)/len(labels) 77 | if acc > best_acc: 78 | best_acc = acc 79 | best_thres = thres 80 | print('easy acc:', best_acc) 81 | 82 | 83 | best_acc = 0.0 84 | best_thres = 0.0 85 | scores_hard = np.asarray(scores_hard) 86 | labels_hard = [1]*len(pos_hard) + [0]*len(neg) 87 | labels_hard = np.asarray(labels_hard) 88 | for thres in range(-100, 100, 1): 89 | thres = thres / 100.0 90 | preds = scores_hard > thres 91 | acc = sum(labels_hard == preds)/len(labels_hard) 92 | if acc > best_acc: 93 | best_acc = acc 94 | best_thres = thres 95 | print('hard acc:', best_acc) 96 | 97 | -------------------------------------------------------------------------------- /parabart_senteval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, unicode_literals 2 | import json 3 | import os 4 | import sys 5 | import numpy as np 6 | import logging 7 | import pickle 8 | import torch 9 | import argparse 10 | from transformers import BartTokenizer, BartConfig, BartModel 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_dir', type=str, default="./model/") 15 | parser.add_argument('--cache_dir', type=str, default="./bart-base/") 16 | parser.add_argument('--senteval_dir', type=str, default="../SentEval/") 17 | args = parser.parse_args() 18 | 19 | 20 | # import SentEval 21 | sys.path.insert(0, args.senteval_dir) 22 | import senteval 23 | 24 | sys.path.insert(0, args.model_dir) 25 | from parabart import ParaBart 26 | 27 | 28 | # SentEval prepare and batcher 29 | def prepare(params, samples): 30 | pass 31 | 32 | def batcher(params, batch): 33 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch] 34 | embeddings = build_embeddings(embed_model, tokenizer, batch) 35 | return embeddings 36 | 37 | def build_embeddings(model, tokenizer, sents): 38 | model.eval() 39 | embeddings = torch.ones((len(sents), model.config.d_model)) 40 | with torch.no_grad(): 41 | for i, sent in enumerate(sents): 42 | sent_inputs = tokenizer(sent, return_tensors="pt") 43 | sent_token_ids = sent_inputs['input_ids'] 44 | 45 | sent_embed = model.encoder.embed(sent_token_ids.cuda()) 46 | embeddings[i] = sent_embed.detach().cpu().clone() 47 | return embeddings 48 | 49 | 50 | print("==== loading model ====") 51 | config = BartConfig.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir) 52 | 53 | embed_model = ParaBart(config) 54 | 55 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir) 56 | 57 | embed_model.load_state_dict(torch.load(os.path.join(args.model_dir, "model.pt"), map_location='cpu')) 58 | 59 | embed_model = embed_model.cuda() 60 | 61 | # Set params for SentEval 62 | params = {'task_path': os.path.join(args.senteval_dir, 'data'), 'usepytorch': True, 'kfold': 10} 63 | params['classifier'] = {'nhid': 50, 'optim': 'adam', 'batch_size': 64, 64 | 'tenacity': 5, 'epoch_size': 4} 65 | 66 | # Set up logger 67 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 68 | 69 | if __name__ == "__main__": 70 | se = senteval.engine.SE(params, batcher, prepare) 71 | 72 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 73 | 'BigramShift', 'Depth', 'TopConstituents'] 74 | 75 | results = se.eval(transfer_tasks) 76 | print(results) 77 | 78 | 79 | -------------------------------------------------------------------------------- /synt_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uclanlp/ParaBART/09afbc09e565fb72f5c9f98653002e626e2b150b/synt_vocab.pkl -------------------------------------------------------------------------------- /train_parabart.py: -------------------------------------------------------------------------------- 1 | import os, argparse, pickle, h5py 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader, random_split 8 | 9 | from utils import Timer, make_path, deleaf 10 | from pprint import pprint 11 | from tqdm import tqdm 12 | from transformers import BartTokenizer, BartConfig, BartModel 13 | from parabart import ParaBart 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model_dir', type=str, default="./model/") 17 | parser.add_argument('--cache_dir', type=str, default="./bart-base/") 18 | parser.add_argument('--data_dir', type=str, default="./data/") 19 | parser.add_argument('--max_sent_len', type=int, default=40) 20 | parser.add_argument('--max_synt_len', type=int, default=160) 21 | parser.add_argument('--word_dropout', type=float, default=0.2) 22 | parser.add_argument('--n_epoch', type=int, default=10) 23 | parser.add_argument('--train_batch_size', type=int, default=64) 24 | parser.add_argument('--accumulation_steps', type=int, default=1) 25 | parser.add_argument('--valid_batch_size', type=int, default=16) 26 | parser.add_argument('--lr', type=float, default=2e-5) 27 | parser.add_argument('--fast_lr', type=float, default=1e-4) 28 | parser.add_argument('--weight_decay', type=float, default=1e-2) 29 | parser.add_argument('--log_interval', type=int, default=1000) 30 | parser.add_argument('--temp', type=float, default=0.5) 31 | parser.add_argument('--seed', type=int, default=0) 32 | args = parser.parse_args() 33 | pprint(vars(args)) 34 | print() 35 | 36 | # fix random seed 37 | np.random.seed(args.seed) 38 | torch.manual_seed(args.seed) 39 | torch.cuda.manual_seed(args.seed) 40 | torch.backends.cudnn.deterministic = True 41 | 42 | def train(epoch, dataset, model, tokenizer, optimizer, args): 43 | timer = Timer() 44 | n_it = len(train_loader) 45 | optimizer.zero_grad() 46 | 47 | for it, idxs in enumerate(train_loader): 48 | total_loss = 0.0 49 | adv_total_loss = 0.0 50 | model.train() 51 | 52 | sent1_token_ids = dataset['sent1'][idxs].cuda() 53 | synt1_token_ids = dataset['synt1'][idxs].cuda() 54 | sent2_token_ids = dataset['sent2'][idxs].cuda() 55 | synt2_token_ids = dataset['synt2'][idxs].cuda() 56 | synt1_bow = dataset['synt1bow'][idxs].cuda() 57 | synt2_bow = dataset['synt2bow'][idxs].cuda() 58 | 59 | # optimize adv 60 | # sent1 adv 61 | outputs = model.forward_adv(sent1_token_ids) 62 | targs = synt1_bow 63 | loss = adv_criterion(outputs, targs) 64 | loss.backward() 65 | adv_total_loss += loss.item() 66 | 67 | 68 | # sent2 adv 69 | outputs = model.forward_adv(sent2_token_ids) 70 | targs = synt2_bow 71 | loss = adv_criterion(outputs, targs) 72 | loss.backward() 73 | adv_total_loss += loss.item() 74 | 75 | if (it+1) % args.accumulation_steps == 0: 76 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 77 | if epoch > 1: 78 | adv_optimizer.step() 79 | adv_optimizer.zero_grad() 80 | 81 | # optimize model 82 | # sent1->sent2 para & sent1 adv 83 | outputs, adv_outputs = model(torch.cat((sent1_token_ids, synt2_token_ids), 1), sent2_token_ids) 84 | targs = sent2_token_ids[:, 1:].contiguous().view(-1) 85 | outputs = outputs.contiguous().view(-1, outputs.size(-1)) 86 | adv_targs = synt1_bow 87 | loss = para_criterion(outputs, targs) 88 | if epoch > 1: 89 | loss -= 0.1 * adv_criterion(adv_outputs, adv_targs) 90 | loss.backward() 91 | total_loss += loss.item() 92 | 93 | # sent2->sent1 para & sent2 adv 94 | outputs, adv_outputs = model(torch.cat((sent2_token_ids, synt1_token_ids), 1), sent1_token_ids) 95 | targs = sent1_token_ids[:, 1:].contiguous().view(-1) 96 | outputs = outputs.contiguous().view(-1, outputs.size(-1)) 97 | adv_targs = synt2_bow 98 | loss = para_criterion(outputs, targs) 99 | if epoch > 1: 100 | loss -= 0.1 * adv_criterion(adv_outputs, adv_targs) 101 | loss.backward() 102 | total_loss += loss.item() 103 | 104 | 105 | if (it+1) % args.accumulation_steps == 0: 106 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 107 | optimizer.step() 108 | optimizer.zero_grad() 109 | 110 | if (it+1) % args.log_interval == 0 or it == 0: 111 | para_1_2_loss, para_2_1_loss, adv_1_loss, adv_2_loss = evaluate(model, tokenizer, args) 112 | valid_loss = para_1_2_loss + para_2_1_loss - 0.1 * adv_1_loss - 0.1 * adv_2_loss 113 | print("| ep {:2d}/{} | it {:3d}/{} | {:5.2f} s | adv loss {:.4f} | loss {:.4f} | para 1-2 loss {:.4f} | para 2-1 loss {:.4f} | adv 1 loss {:.4f} | adv 2 loss {:.4f} | valid loss {:.4f} |".format( 114 | epoch, args.n_epoch, it+1, n_it, timer.get_time_from_last(), adv_total_loss, total_loss, para_1_2_loss, para_2_1_loss, adv_1_loss, adv_2_loss, valid_loss)) 115 | 116 | 117 | 118 | def evaluate(model, tokenizer, args): 119 | model.eval() 120 | para_1_2_loss = 0.0 121 | para_2_1_loss = 0.0 122 | adv_1_loss = 0.0 123 | adv_2_loss = 0.0 124 | with torch.no_grad(): 125 | for idxs in valid_loader: 126 | 127 | sent1_token_ids = dataset['sent1'][idxs].cuda() 128 | synt1_token_ids = dataset['synt1'][idxs].cuda() 129 | sent2_token_ids = dataset['sent2'][idxs].cuda() 130 | synt2_token_ids = dataset['synt2'][idxs].cuda() 131 | synt1_bow = dataset['synt1bow'][idxs].cuda() 132 | synt2_bow = dataset['synt2bow'][idxs].cuda() 133 | 134 | outputs, adv_outputs = model(torch.cat((sent1_token_ids, synt2_token_ids), 1), sent2_token_ids) 135 | targs = sent2_token_ids[:, 1:].contiguous().view(-1) 136 | outputs = outputs.contiguous().view(-1, outputs.size(-1)) 137 | adv_targs = synt1_bow 138 | para_1_2_loss += para_criterion(outputs, targs) 139 | adv_1_loss += adv_criterion(adv_outputs, adv_targs) 140 | 141 | outputs, adv_outputs = model(torch.cat((sent2_token_ids, synt1_token_ids), 1), sent1_token_ids) 142 | targs = sent1_token_ids[:, 1:].contiguous().view(-1) 143 | outputs = outputs.contiguous().view(-1, outputs.size(-1)) 144 | adv_targs = synt2_bow 145 | para_2_1_loss += para_criterion(outputs, targs) 146 | adv_2_loss += adv_criterion(adv_outputs, adv_targs) 147 | 148 | return para_1_2_loss / len(valid_loader), para_2_1_loss / len(valid_loader), adv_1_loss / len(valid_loader), adv_2_loss / len(valid_loader) 149 | 150 | 151 | def prepare_dataset(para_data, tokenizer, num): 152 | sents1 = list(para_data['train_sents1'][:num]) 153 | synts1 = list(para_data['train_synts1'][:num]) 154 | sents2 = list(para_data['train_sents2'][:num]) 155 | synts2 = list(para_data['train_synts2'][:num]) 156 | 157 | sent1_token_ids = torch.ones((num, args.max_sent_len+2), dtype=torch.long) 158 | sent2_token_ids = torch.ones((num, args.max_sent_len+2), dtype=torch.long) 159 | synt1_token_ids = torch.ones((num, args.max_synt_len+2), dtype=torch.long) 160 | synt2_token_ids = torch.ones((num, args.max_synt_len+2), dtype=torch.long) 161 | synt1_bow = torch.ones((num, 74)) 162 | synt2_bow = torch.ones((num, 74)) 163 | 164 | bsz = 64 165 | 166 | for i in tqdm(range(0, num, bsz)): 167 | sent1_inputs = tokenizer(sents1[i:i+bsz], padding='max_length', truncation=True, max_length=args.max_sent_len+2, return_tensors="pt") 168 | sent2_inputs = tokenizer(sents2[i:i+bsz], padding='max_length', truncation=True, max_length=args.max_sent_len+2, return_tensors="pt") 169 | sent1_token_ids[i:i+bsz] = sent1_inputs['input_ids'] 170 | sent2_token_ids[i:i+bsz] = sent2_inputs['input_ids'] 171 | 172 | for i in tqdm(range(num)): 173 | synt1 = [''] + deleaf(synts1[i]) + [''] 174 | synt1_token_ids[i, :len(synt1)] = torch.tensor([synt_vocab[tag] for tag in synt1])[:args.max_synt_len+2] 175 | synt2 = [''] + deleaf(synts2[i]) + [''] 176 | synt2_token_ids[i, :len(synt2)] = torch.tensor([synt_vocab[tag] for tag in synt2])[:args.max_synt_len+2] 177 | 178 | for tag in synt1: 179 | if tag != '' and tag != '': 180 | synt1_bow[i][synt_vocab[tag]-3] += 1 181 | for tag in synt2: 182 | if tag != '' and tag != '': 183 | synt2_bow[i][synt_vocab[tag]-3] += 1 184 | 185 | synt1_bow /= synt1_bow.sum(1, keepdim=True) 186 | synt2_bow /= synt2_bow.sum(1, keepdim=True) 187 | 188 | sum = 0 189 | for i in range(num): 190 | if torch.equal(synt1_bow[i], synt2_bow[i]): 191 | sum += 1 192 | 193 | return {'sent1':sent1_token_ids, 'sent2':sent2_token_ids, 'synt1': synt1_token_ids, 'synt2': synt2_token_ids, 194 | 'synt1bow': synt1_bow, 'synt2bow': synt2_bow} 195 | 196 | print("==== loading data ====") 197 | num = 1000000 198 | para_data = h5py.File(os.path.join(args.data_dir, 'data.h5'), 'r') 199 | 200 | train_idxs, valid_idxs = random_split(range(num), [num-5000, 5000], generator=torch.Generator().manual_seed(args.seed)) 201 | 202 | print(f"number of train examples: {len(train_idxs)}") 203 | print(f"number of valid examples: {len(valid_idxs)}") 204 | 205 | train_loader = DataLoader(train_idxs, batch_size=args.train_batch_size, shuffle=True) 206 | valid_loader = DataLoader(valid_idxs, batch_size=args.valid_batch_size, shuffle=False) 207 | 208 | print("==== preparing data ====") 209 | make_path(args.cache_dir) 210 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir) 211 | 212 | with open('synt_vocab.pkl', 'rb') as f: 213 | synt_vocab = pickle.load(f) 214 | 215 | dataset = prepare_dataset(para_data, tokenizer, num) 216 | 217 | print("==== loading model ====") 218 | config = BartConfig.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir) 219 | config.word_dropout = args.word_dropout 220 | config.max_sent_len = args.max_sent_len 221 | config.max_synt_len = args.max_synt_len 222 | 223 | bart = BartModel.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir) 224 | model = ParaBart(config) 225 | model.load_state_dict(bart.state_dict(), strict=False) 226 | model.zero_grad() 227 | del bart 228 | 229 | 230 | no_decay_params = [] 231 | no_decay_fast_params = [] 232 | fast_params = [] 233 | all_other_params = [] 234 | adv_no_decay_params = [] 235 | adv_all_other_params = [] 236 | 237 | for n, p in model.named_parameters(): 238 | if 'adv' in n: 239 | if 'norm' in n or 'bias' in n: 240 | adv_no_decay_params.append(p) 241 | else: 242 | adv_all_other_params.append(p) 243 | elif 'linear' in n or 'synt' in n or 'decoder' in n: 244 | if 'bias' in n: 245 | no_decay_fast_params.append(p) 246 | else: 247 | fast_params.append(p) 248 | elif 'norm' in n or 'bias' in n: 249 | no_decay_params.append(p) 250 | else: 251 | all_other_params.append(p) 252 | 253 | optimizer = optim.AdamW([ 254 | {'params': fast_params, 'lr': args.fast_lr}, 255 | {'params': no_decay_fast_params, 'lr': args.fast_lr, 'weight_decay': 0.0}, 256 | {'params': no_decay_params, 'weight_decay': 0.0}, 257 | {'params': all_other_params} 258 | ], lr=args.lr, weight_decay=args.weight_decay) 259 | 260 | adv_optimizer = optim.AdamW([ 261 | {'params': adv_no_decay_params, 'weight_decay': 0.0}, 262 | {'params': adv_all_other_params} 263 | ], lr=args.lr, weight_decay=args.weight_decay) 264 | 265 | para_criterion = nn.CrossEntropyLoss(ignore_index=model.config.pad_token_id).cuda() 266 | adv_criterion = nn.BCEWithLogitsLoss().cuda() 267 | 268 | model = model.cuda() 269 | 270 | make_path(args.model_dir) 271 | 272 | print("==== start training ====") 273 | 274 | for epoch in range(1, args.n_epoch+1): 275 | train(epoch, dataset, model, tokenizer, optimizer, args) 276 | torch.save(model.state_dict(), os.path.join(args.model_dir, "model_epoch{:02d}.pt".format(epoch))) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, errno 2 | import numpy as np 3 | from datetime import datetime 4 | 5 | 6 | def make_path(path): 7 | try: 8 | os.makedirs(path) 9 | except OSError as exc: 10 | if exc.errno == errno.EEXIST and os.path.isdir(path): 11 | pass 12 | else: raise 13 | 14 | class Timer: 15 | def __init__(self): 16 | self.start_time = datetime.now() 17 | self.last_time = self.start_time 18 | 19 | def get_time_from_last(self, update=True): 20 | now_time = datetime.now() 21 | diff_time = now_time - self.last_time 22 | if update: 23 | self.last_time = now_time 24 | return diff_time.total_seconds() 25 | 26 | def get_time_from_start(self, update=True): 27 | now_time = datetime.now() 28 | diff_time = now_time - self.start_time 29 | if update: 30 | self.last_time = now_time 31 | return diff_time.total_seconds() 32 | 33 | 34 | def is_paren(tok): 35 | return tok == ")" or tok == "(" 36 | 37 | def deleaf(tree): 38 | nonleaves = '' 39 | for w in tree.replace('\n', '').split(): 40 | w = w.replace('(', '( ').replace(')', ' )') 41 | nonleaves += w + ' ' 42 | 43 | arr = nonleaves.split() 44 | for n, i in enumerate(arr): 45 | if n + 1 < len(arr): 46 | tok1 = arr[n] 47 | tok2 = arr[n + 1] 48 | if not is_paren(tok1) and not is_paren(tok2): 49 | arr[n + 1] = "" 50 | 51 | nonleaves = " ".join(arr) 52 | return nonleaves.split() --------------------------------------------------------------------------------