├── README.md
├── parabart.py
├── parabart_qqpeval.py
├── parabart_senteval.py
├── synt_vocab.pkl
├── train_parabart.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | ## ParaBART
2 |
3 | Code for our NAACL-2021 paper ["Disentangling Semantics and Syntax in Sentence Embeddings with Pre-trained Language Models"](https://arxiv.org/abs/2104.05115).
4 |
5 | If you find this repository useful, please consider citing our paper.
6 | ```
7 | @inproceedings{huang2021disentangling,
8 | title = {Disentangling Semantics and Syntax in Sentence Embeddings with Pre-trained Language Models},
9 | author = {Huang, James Y. and Huang, Kuan-Hao and Chang, Kai-Wei},
10 | booktitle = {NAACL},
11 | year = {2021}
12 | }
13 | ```
14 |
15 | ### Dependencies
16 |
17 | - Python==3.7.6
18 | - PyTorch==1.6.0
19 | - Transformers==3.0.2
20 |
21 | ### Pre-trained Models
22 |
23 | Our pre-trained ParaBART model is available [here](https://drive.google.com/file/d/1Ev9iB2bIekEp1yYTCJPkngzZSRWOS-cz/view?usp=sharing)
24 |
25 | ### Training
26 |
27 | - Download the [dataset](https://drive.google.com/file/d/1Pv_RB47BD_zLhmQUhFpiEdI6UHDbb-wX/view?usp=sharing) and put it under `./data/`
28 | - Run the following command to train ParaBART
29 | ```
30 | python train_parabart.py --data_dir ./data/
31 | ```
32 |
33 | ### Evaluation
34 |
35 | - Download the [SentEval](https://github.com/facebookresearch/SentEval) toolkit and datasets
36 | - Name your trained model `model.pt` and put it under `./model/`
37 | - Run the following command to evaluate ParaBART on semantic textual similarity and syntactic probing tasks
38 | ```
39 | python parabart_senteval.py --senteval_dir ../SentEval --model_dir ./model/
40 | ```
41 | - Download QQP-Easy and QQP-Hard datasets [here](https://drive.google.com/file/d/1am502GkMU-9h-5chZ7RVt-7l0FAGvfH2/view?usp=sharing)
42 | - Run the following command to evaluate ParaBART on QQP datasets
43 | ```
44 | python parabart_qqpeval.py
45 | ```
46 |
47 | ### Author
48 |
49 | James Yipeng Huang / [@jyhuang36](https://github.com/jyhuang36)
50 |
--------------------------------------------------------------------------------
/parabart.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from transformers.modeling_bart import (
6 | PretrainedBartModel,
7 | LayerNorm,
8 | EncoderLayer,
9 | DecoderLayer,
10 | LearnedPositionalEmbedding,
11 | _prepare_bart_decoder_inputs,
12 | _make_linear_from_emb
13 | )
14 |
15 | class ParaBart(PretrainedBartModel):
16 | def __init__(self, config):
17 | super().__init__(config)
18 |
19 | self.shared = nn.Embedding(config.vocab_size, config.d_model, config.pad_token_id)
20 |
21 | self.encoder = ParaBartEncoder(config, self.shared)
22 | self.decoder = ParaBartDecoder(config, self.shared)
23 |
24 | self.linear = nn.Linear(config.d_model, config.vocab_size)
25 |
26 | self.adversary = Discriminator(config)
27 |
28 | self.init_weights()
29 |
30 | def forward(
31 | self,
32 | input_ids,
33 | decoder_input_ids,
34 | attention_mask=None,
35 | decoder_padding_mask=None,
36 | encoder_outputs=None,
37 | return_encoder_outputs=False,
38 | ):
39 | if attention_mask is None:
40 | attention_mask = input_ids == self.config.pad_token_id
41 |
42 | if encoder_outputs is None:
43 | encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
44 |
45 | if return_encoder_outputs:
46 | return encoder_outputs
47 |
48 | assert encoder_outputs is not None
49 | assert decoder_input_ids is not None
50 |
51 | decoder_input_ids = decoder_input_ids[:, :-1]
52 |
53 | _, decoder_padding_mask, decoder_causal_mask = _prepare_bart_decoder_inputs(
54 | self.config,
55 | input_ids=None,
56 | decoder_input_ids=decoder_input_ids,
57 | decoder_padding_mask=decoder_padding_mask,
58 | causal_mask_dtype=self.shared.weight.dtype,
59 | )
60 |
61 | attention_mask2 = torch.cat((torch.zeros(input_ids.shape[0], 1).bool().cuda(), attention_mask[:, self.config.max_sent_len+2:]), dim=1)
62 |
63 | # decoder
64 | decoder_outputs = self.decoder(
65 | decoder_input_ids,
66 | torch.cat((encoder_outputs[1], encoder_outputs[0][:, self.config.max_sent_len+2:]), dim=1),
67 | decoder_padding_mask=decoder_padding_mask,
68 | decoder_causal_mask=decoder_causal_mask,
69 | encoder_attention_mask=attention_mask2,
70 | )[0]
71 |
72 |
73 | batch_size = decoder_outputs.shape[0]
74 | outputs = self.linear(decoder_outputs.contiguous().view(-1, self.config.d_model))
75 | outputs = outputs.view(batch_size, -1, self.config.vocab_size)
76 |
77 | # discriminator
78 | for p in self.adversary.parameters():
79 | p.required_grad=False
80 | adv_outputs = self.adversary(encoder_outputs[1])
81 |
82 | return outputs, adv_outputs
83 |
84 | def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
85 | assert past is not None, "past has to be defined for encoder_outputs"
86 |
87 | encoder_outputs = past[0]
88 | return {
89 | "input_ids": None, # encoder_outputs is defined. input_ids not needed
90 | "encoder_outputs": encoder_outputs,
91 | "decoder_input_ids": torch.cat((decoder_input_ids, torch.zeros((decoder_input_ids.shape[0], 1), dtype=torch.long).cuda()), 1),
92 | "attention_mask": attention_mask,
93 | }
94 |
95 | def get_encoder(self):
96 | return self.encoder
97 |
98 | def get_output_embeddings(self):
99 | return _make_linear_from_emb(self.shared)
100 |
101 | def get_input_embeddings(self):
102 | return self.shared
103 |
104 | @staticmethod
105 | def _reorder_cache(past, beam_idx):
106 | enc_out = past[0][0]
107 |
108 | new_enc_out = enc_out.index_select(0, beam_idx)
109 |
110 | past = ((new_enc_out, ), )
111 | return past
112 |
113 | def forward_adv(
114 | self,
115 | input_token_ids,
116 | attention_mask=None,
117 | decoder_padding_mask=None
118 | ):
119 | for p in self.adversary.parameters():
120 | p.required_grad=True
121 | sent_embeds = self.encoder.embed(input_token_ids, attention_mask=attention_mask).detach()
122 | adv_outputs = self.adversary(sent_embeds)
123 |
124 | return adv_outputs
125 |
126 |
127 | class ParaBartEncoder(nn.Module):
128 | def __init__(self, config, embed_tokens):
129 | super().__init__()
130 | self.config = config
131 |
132 | self.dropout = config.dropout
133 | self.embed_tokens = embed_tokens
134 |
135 | self.embed_synt = nn.Embedding(77, config.d_model, config.pad_token_id)
136 | self.embed_synt.weight.data.normal_(mean=0.0, std=config.init_std)
137 | self.embed_synt.weight.data[config.pad_token_id].zero_()
138 |
139 | self.embed_positions = LearnedPositionalEmbedding(
140 | config.max_position_embeddings, config.d_model, config.pad_token_id, config.extra_pos_embeddings
141 | )
142 |
143 | self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
144 | self.synt_layers = nn.ModuleList([EncoderLayer(config) for _ in range(1)])
145 |
146 | self.layernorm_embedding = LayerNorm(config.d_model)
147 |
148 | self.synt_layernorm_embedding = LayerNorm(config.d_model)
149 |
150 | self.pooling = MeanPooling(config)
151 |
152 |
153 | def forward(self, input_ids, attention_mask):
154 |
155 | input_token_ids, input_synt_ids = torch.split(input_ids, [self.config.max_sent_len+2, self.config.max_synt_len+2], dim=1)
156 | input_token_mask, input_synt_mask = torch.split(attention_mask, [self.config.max_sent_len+2, self.config.max_synt_len+2], dim=1)
157 |
158 | x = self.forward_token(input_token_ids, input_token_mask)
159 | y = self.forward_synt(input_synt_ids, input_synt_mask)
160 |
161 | encoder_outputs = torch.cat((x,y), dim=1)
162 |
163 | sent_embeds = self.pooling(x, input_token_ids)
164 |
165 | return encoder_outputs, sent_embeds
166 |
167 | def forward_token(self, input_token_ids, attention_mask):
168 | if self.training:
169 | drop_mask = torch.bernoulli(self.config.word_dropout*torch.ones(input_token_ids.shape)).bool().cuda()
170 | input_token_ids = input_token_ids.masked_fill(drop_mask, 50264)
171 |
172 | input_token_embeds = self.embed_tokens(input_token_ids) + self.embed_positions(input_token_ids)
173 | x = self.layernorm_embedding(input_token_embeds)
174 | x = F.dropout(x, p=self.dropout, training=self.training)
175 |
176 | x = x.transpose(0, 1)
177 |
178 | for encoder_layer in self.layers:
179 | x, _ = encoder_layer(x, encoder_padding_mask=attention_mask)
180 |
181 | x = x.transpose(0, 1)
182 | return x
183 |
184 | def forward_synt(self, input_synt_ids, attention_mask):
185 | input_synt_embeds = self.embed_synt(input_synt_ids) + self.embed_positions(input_synt_ids)
186 | y = self.synt_layernorm_embedding(input_synt_embeds)
187 | y = F.dropout(y, p=self.dropout, training=self.training)
188 |
189 | # B x T x C -> T x B x C
190 | y = y.transpose(0, 1)
191 |
192 | for encoder_synt_layer in self.synt_layers:
193 | y, _ = encoder_synt_layer(y, encoder_padding_mask=attention_mask)
194 |
195 | # T x B x C -> B x T x C
196 | y = y.transpose(0, 1)
197 | return y
198 |
199 |
200 | def embed(self, input_token_ids, attention_mask=None, pool='mean'):
201 | if attention_mask is None:
202 | attention_mask = input_token_ids == self.config.pad_token_id
203 |
204 | x = self.forward_token(input_token_ids, attention_mask)
205 |
206 | sent_embeds = self.pooling(x, input_token_ids)
207 | return sent_embeds
208 |
209 | class MeanPooling(nn.Module):
210 | def __init__(self, config):
211 | super().__init__()
212 | self.config = config
213 |
214 | def forward(self, x, input_token_ids):
215 | mask = input_token_ids != self.config.pad_token_id
216 | mean_mask = mask.float()/mask.float().sum(1, keepdim=True)
217 | x = (x*mean_mask.unsqueeze(2)).sum(1, keepdim=True)
218 | return x
219 |
220 |
221 | class ParaBartDecoder(nn.Module):
222 | def __init__(self, config, embed_tokens):
223 | super().__init__()
224 |
225 | self.dropout = config.dropout
226 |
227 | self.embed_tokens = embed_tokens
228 |
229 | self.embed_positions = LearnedPositionalEmbedding(
230 | config.max_position_embeddings, config.d_model, config.pad_token_id, config.extra_pos_embeddings
231 | )
232 |
233 | self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(1)])
234 | self.layernorm_embedding = LayerNorm(config.d_model)
235 |
236 | def forward(
237 | self,
238 | decoder_input_ids,
239 | encoder_hidden_states,
240 | decoder_padding_mask,
241 | decoder_causal_mask,
242 | encoder_attention_mask
243 | ):
244 |
245 | x = self.embed_tokens(decoder_input_ids) + self.embed_positions(decoder_input_ids)
246 | x = self.layernorm_embedding(x)
247 | x = F.dropout(x, p=self.dropout, training=self.training)
248 |
249 | x = x.transpose(0, 1)
250 | encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
251 |
252 | for idx, decoder_layer in enumerate(self.layers):
253 | x, _, _ = decoder_layer(
254 | x,
255 | encoder_hidden_states,
256 | encoder_attn_mask=encoder_attention_mask,
257 | decoder_padding_mask=decoder_padding_mask,
258 | causal_mask=decoder_causal_mask)
259 |
260 | x = x.transpose(0, 1)
261 |
262 | return x,
263 |
264 |
265 | class Discriminator(nn.Module):
266 | def __init__(self, config):
267 | super().__init__()
268 | self.sent_layernorm_embedding = LayerNorm(config.d_model, elementwise_affine=False)
269 | self.adv = nn.Linear(config.d_model, 74)
270 |
271 | def forward(self, sent_embeds):
272 | x = self.sent_layernorm_embedding(sent_embeds).squeeze(1)
273 | x = self.adv(x)
274 | return x
275 |
--------------------------------------------------------------------------------
/parabart_qqpeval.py:
--------------------------------------------------------------------------------
1 | import sys, io
2 | import numpy as np
3 | import torch
4 | from transformers import BartTokenizer, BartConfig, BartModel
5 | from tqdm import tqdm
6 | from sklearn.metrics import f1_score, roc_auc_score
7 | import pickle, random
8 | from parabart import ParaBart
9 |
10 |
11 |
12 | print("==== loading model ====")
13 | config = BartConfig.from_pretrained('facebook/bart-base', cache_dir='../para-data/bart-base')
14 |
15 | model = ParaBart(config)
16 |
17 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir='../para-data/bart-base')
18 |
19 | model.load_state_dict(torch.load("./model/model.pt", map_location='cpu'))
20 |
21 | model = model.cuda()
22 |
23 | def build_embeddings(model, tokenizer, sents):
24 | model.eval()
25 | embeddings = torch.ones((len(sents), model.config.d_model))
26 | with torch.no_grad():
27 | for i, sent in enumerate(sents):
28 | sent_inputs = tokenizer(sent, return_tensors="pt")
29 | sent_token_ids = sent_inputs['input_ids']
30 |
31 | sent_embed = model.encoder.embed(sent_token_ids.cuda())
32 | embeddings[i] = sent_embed.detach().cpu().clone()
33 | return embeddings
34 |
35 | def cosine(u, v):
36 | return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))
37 |
38 |
39 |
40 | scores = []
41 | labels = []
42 | with open("qqp.pkl", "rb") as f:
43 | para_split = pickle.load(f)
44 |
45 |
46 | pos_hard = para_split['pos_hard']
47 | pos = para_split['pos']
48 | neg = para_split['neg']
49 |
50 | easy = pos + neg
51 | hard = pos_hard + neg
52 |
53 | scores = []
54 | for i in tqdm(range(len(easy))):
55 | embeds = build_embeddings(model, tokenizer, [easy[i][0], easy[i][1]])
56 | score = cosine(embeds[0], embeds[1])
57 | scores.append(score)
58 |
59 | scores_hard = []
60 | for i in tqdm(range(len(hard))):
61 | embeds = build_embeddings(model, tokenizer, [hard[i][0], hard[i][1]])
62 | score = cosine(embeds[0], embeds[1])
63 | scores_hard.append(score)
64 |
65 |
66 |
67 |
68 | best_acc = 0.0
69 | best_thres = 0.0
70 | scores = np.asarray(scores)
71 | labels = [1]*len(pos) + [0]*len(neg)
72 | labels = np.asarray(labels)
73 | for thres in range(-100, 100, 1):
74 | thres = thres / 100.0
75 | preds = scores > thres
76 | acc = sum(labels == preds)/len(labels)
77 | if acc > best_acc:
78 | best_acc = acc
79 | best_thres = thres
80 | print('easy acc:', best_acc)
81 |
82 |
83 | best_acc = 0.0
84 | best_thres = 0.0
85 | scores_hard = np.asarray(scores_hard)
86 | labels_hard = [1]*len(pos_hard) + [0]*len(neg)
87 | labels_hard = np.asarray(labels_hard)
88 | for thres in range(-100, 100, 1):
89 | thres = thres / 100.0
90 | preds = scores_hard > thres
91 | acc = sum(labels_hard == preds)/len(labels_hard)
92 | if acc > best_acc:
93 | best_acc = acc
94 | best_thres = thres
95 | print('hard acc:', best_acc)
96 |
97 |
--------------------------------------------------------------------------------
/parabart_senteval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, unicode_literals
2 | import json
3 | import os
4 | import sys
5 | import numpy as np
6 | import logging
7 | import pickle
8 | import torch
9 | import argparse
10 | from transformers import BartTokenizer, BartConfig, BartModel
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--model_dir', type=str, default="./model/")
15 | parser.add_argument('--cache_dir', type=str, default="./bart-base/")
16 | parser.add_argument('--senteval_dir', type=str, default="../SentEval/")
17 | args = parser.parse_args()
18 |
19 |
20 | # import SentEval
21 | sys.path.insert(0, args.senteval_dir)
22 | import senteval
23 |
24 | sys.path.insert(0, args.model_dir)
25 | from parabart import ParaBart
26 |
27 |
28 | # SentEval prepare and batcher
29 | def prepare(params, samples):
30 | pass
31 |
32 | def batcher(params, batch):
33 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
34 | embeddings = build_embeddings(embed_model, tokenizer, batch)
35 | return embeddings
36 |
37 | def build_embeddings(model, tokenizer, sents):
38 | model.eval()
39 | embeddings = torch.ones((len(sents), model.config.d_model))
40 | with torch.no_grad():
41 | for i, sent in enumerate(sents):
42 | sent_inputs = tokenizer(sent, return_tensors="pt")
43 | sent_token_ids = sent_inputs['input_ids']
44 |
45 | sent_embed = model.encoder.embed(sent_token_ids.cuda())
46 | embeddings[i] = sent_embed.detach().cpu().clone()
47 | return embeddings
48 |
49 |
50 | print("==== loading model ====")
51 | config = BartConfig.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir)
52 |
53 | embed_model = ParaBart(config)
54 |
55 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir)
56 |
57 | embed_model.load_state_dict(torch.load(os.path.join(args.model_dir, "model.pt"), map_location='cpu'))
58 |
59 | embed_model = embed_model.cuda()
60 |
61 | # Set params for SentEval
62 | params = {'task_path': os.path.join(args.senteval_dir, 'data'), 'usepytorch': True, 'kfold': 10}
63 | params['classifier'] = {'nhid': 50, 'optim': 'adam', 'batch_size': 64,
64 | 'tenacity': 5, 'epoch_size': 4}
65 |
66 | # Set up logger
67 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
68 |
69 | if __name__ == "__main__":
70 | se = senteval.engine.SE(params, batcher, prepare)
71 |
72 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark',
73 | 'BigramShift', 'Depth', 'TopConstituents']
74 |
75 | results = se.eval(transfer_tasks)
76 | print(results)
77 |
78 |
79 |
--------------------------------------------------------------------------------
/synt_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uclanlp/ParaBART/09afbc09e565fb72f5c9f98653002e626e2b150b/synt_vocab.pkl
--------------------------------------------------------------------------------
/train_parabart.py:
--------------------------------------------------------------------------------
1 | import os, argparse, pickle, h5py
2 | import pandas as pd
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torch.utils.data import DataLoader, random_split
8 |
9 | from utils import Timer, make_path, deleaf
10 | from pprint import pprint
11 | from tqdm import tqdm
12 | from transformers import BartTokenizer, BartConfig, BartModel
13 | from parabart import ParaBart
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--model_dir', type=str, default="./model/")
17 | parser.add_argument('--cache_dir', type=str, default="./bart-base/")
18 | parser.add_argument('--data_dir', type=str, default="./data/")
19 | parser.add_argument('--max_sent_len', type=int, default=40)
20 | parser.add_argument('--max_synt_len', type=int, default=160)
21 | parser.add_argument('--word_dropout', type=float, default=0.2)
22 | parser.add_argument('--n_epoch', type=int, default=10)
23 | parser.add_argument('--train_batch_size', type=int, default=64)
24 | parser.add_argument('--accumulation_steps', type=int, default=1)
25 | parser.add_argument('--valid_batch_size', type=int, default=16)
26 | parser.add_argument('--lr', type=float, default=2e-5)
27 | parser.add_argument('--fast_lr', type=float, default=1e-4)
28 | parser.add_argument('--weight_decay', type=float, default=1e-2)
29 | parser.add_argument('--log_interval', type=int, default=1000)
30 | parser.add_argument('--temp', type=float, default=0.5)
31 | parser.add_argument('--seed', type=int, default=0)
32 | args = parser.parse_args()
33 | pprint(vars(args))
34 | print()
35 |
36 | # fix random seed
37 | np.random.seed(args.seed)
38 | torch.manual_seed(args.seed)
39 | torch.cuda.manual_seed(args.seed)
40 | torch.backends.cudnn.deterministic = True
41 |
42 | def train(epoch, dataset, model, tokenizer, optimizer, args):
43 | timer = Timer()
44 | n_it = len(train_loader)
45 | optimizer.zero_grad()
46 |
47 | for it, idxs in enumerate(train_loader):
48 | total_loss = 0.0
49 | adv_total_loss = 0.0
50 | model.train()
51 |
52 | sent1_token_ids = dataset['sent1'][idxs].cuda()
53 | synt1_token_ids = dataset['synt1'][idxs].cuda()
54 | sent2_token_ids = dataset['sent2'][idxs].cuda()
55 | synt2_token_ids = dataset['synt2'][idxs].cuda()
56 | synt1_bow = dataset['synt1bow'][idxs].cuda()
57 | synt2_bow = dataset['synt2bow'][idxs].cuda()
58 |
59 | # optimize adv
60 | # sent1 adv
61 | outputs = model.forward_adv(sent1_token_ids)
62 | targs = synt1_bow
63 | loss = adv_criterion(outputs, targs)
64 | loss.backward()
65 | adv_total_loss += loss.item()
66 |
67 |
68 | # sent2 adv
69 | outputs = model.forward_adv(sent2_token_ids)
70 | targs = synt2_bow
71 | loss = adv_criterion(outputs, targs)
72 | loss.backward()
73 | adv_total_loss += loss.item()
74 |
75 | if (it+1) % args.accumulation_steps == 0:
76 | nn.utils.clip_grad_norm_(model.parameters(), 1.0)
77 | if epoch > 1:
78 | adv_optimizer.step()
79 | adv_optimizer.zero_grad()
80 |
81 | # optimize model
82 | # sent1->sent2 para & sent1 adv
83 | outputs, adv_outputs = model(torch.cat((sent1_token_ids, synt2_token_ids), 1), sent2_token_ids)
84 | targs = sent2_token_ids[:, 1:].contiguous().view(-1)
85 | outputs = outputs.contiguous().view(-1, outputs.size(-1))
86 | adv_targs = synt1_bow
87 | loss = para_criterion(outputs, targs)
88 | if epoch > 1:
89 | loss -= 0.1 * adv_criterion(adv_outputs, adv_targs)
90 | loss.backward()
91 | total_loss += loss.item()
92 |
93 | # sent2->sent1 para & sent2 adv
94 | outputs, adv_outputs = model(torch.cat((sent2_token_ids, synt1_token_ids), 1), sent1_token_ids)
95 | targs = sent1_token_ids[:, 1:].contiguous().view(-1)
96 | outputs = outputs.contiguous().view(-1, outputs.size(-1))
97 | adv_targs = synt2_bow
98 | loss = para_criterion(outputs, targs)
99 | if epoch > 1:
100 | loss -= 0.1 * adv_criterion(adv_outputs, adv_targs)
101 | loss.backward()
102 | total_loss += loss.item()
103 |
104 |
105 | if (it+1) % args.accumulation_steps == 0:
106 | nn.utils.clip_grad_norm_(model.parameters(), 1.0)
107 | optimizer.step()
108 | optimizer.zero_grad()
109 |
110 | if (it+1) % args.log_interval == 0 or it == 0:
111 | para_1_2_loss, para_2_1_loss, adv_1_loss, adv_2_loss = evaluate(model, tokenizer, args)
112 | valid_loss = para_1_2_loss + para_2_1_loss - 0.1 * adv_1_loss - 0.1 * adv_2_loss
113 | print("| ep {:2d}/{} | it {:3d}/{} | {:5.2f} s | adv loss {:.4f} | loss {:.4f} | para 1-2 loss {:.4f} | para 2-1 loss {:.4f} | adv 1 loss {:.4f} | adv 2 loss {:.4f} | valid loss {:.4f} |".format(
114 | epoch, args.n_epoch, it+1, n_it, timer.get_time_from_last(), adv_total_loss, total_loss, para_1_2_loss, para_2_1_loss, adv_1_loss, adv_2_loss, valid_loss))
115 |
116 |
117 |
118 | def evaluate(model, tokenizer, args):
119 | model.eval()
120 | para_1_2_loss = 0.0
121 | para_2_1_loss = 0.0
122 | adv_1_loss = 0.0
123 | adv_2_loss = 0.0
124 | with torch.no_grad():
125 | for idxs in valid_loader:
126 |
127 | sent1_token_ids = dataset['sent1'][idxs].cuda()
128 | synt1_token_ids = dataset['synt1'][idxs].cuda()
129 | sent2_token_ids = dataset['sent2'][idxs].cuda()
130 | synt2_token_ids = dataset['synt2'][idxs].cuda()
131 | synt1_bow = dataset['synt1bow'][idxs].cuda()
132 | synt2_bow = dataset['synt2bow'][idxs].cuda()
133 |
134 | outputs, adv_outputs = model(torch.cat((sent1_token_ids, synt2_token_ids), 1), sent2_token_ids)
135 | targs = sent2_token_ids[:, 1:].contiguous().view(-1)
136 | outputs = outputs.contiguous().view(-1, outputs.size(-1))
137 | adv_targs = synt1_bow
138 | para_1_2_loss += para_criterion(outputs, targs)
139 | adv_1_loss += adv_criterion(adv_outputs, adv_targs)
140 |
141 | outputs, adv_outputs = model(torch.cat((sent2_token_ids, synt1_token_ids), 1), sent1_token_ids)
142 | targs = sent1_token_ids[:, 1:].contiguous().view(-1)
143 | outputs = outputs.contiguous().view(-1, outputs.size(-1))
144 | adv_targs = synt2_bow
145 | para_2_1_loss += para_criterion(outputs, targs)
146 | adv_2_loss += adv_criterion(adv_outputs, adv_targs)
147 |
148 | return para_1_2_loss / len(valid_loader), para_2_1_loss / len(valid_loader), adv_1_loss / len(valid_loader), adv_2_loss / len(valid_loader)
149 |
150 |
151 | def prepare_dataset(para_data, tokenizer, num):
152 | sents1 = list(para_data['train_sents1'][:num])
153 | synts1 = list(para_data['train_synts1'][:num])
154 | sents2 = list(para_data['train_sents2'][:num])
155 | synts2 = list(para_data['train_synts2'][:num])
156 |
157 | sent1_token_ids = torch.ones((num, args.max_sent_len+2), dtype=torch.long)
158 | sent2_token_ids = torch.ones((num, args.max_sent_len+2), dtype=torch.long)
159 | synt1_token_ids = torch.ones((num, args.max_synt_len+2), dtype=torch.long)
160 | synt2_token_ids = torch.ones((num, args.max_synt_len+2), dtype=torch.long)
161 | synt1_bow = torch.ones((num, 74))
162 | synt2_bow = torch.ones((num, 74))
163 |
164 | bsz = 64
165 |
166 | for i in tqdm(range(0, num, bsz)):
167 | sent1_inputs = tokenizer(sents1[i:i+bsz], padding='max_length', truncation=True, max_length=args.max_sent_len+2, return_tensors="pt")
168 | sent2_inputs = tokenizer(sents2[i:i+bsz], padding='max_length', truncation=True, max_length=args.max_sent_len+2, return_tensors="pt")
169 | sent1_token_ids[i:i+bsz] = sent1_inputs['input_ids']
170 | sent2_token_ids[i:i+bsz] = sent2_inputs['input_ids']
171 |
172 | for i in tqdm(range(num)):
173 | synt1 = [''] + deleaf(synts1[i]) + ['']
174 | synt1_token_ids[i, :len(synt1)] = torch.tensor([synt_vocab[tag] for tag in synt1])[:args.max_synt_len+2]
175 | synt2 = [''] + deleaf(synts2[i]) + ['']
176 | synt2_token_ids[i, :len(synt2)] = torch.tensor([synt_vocab[tag] for tag in synt2])[:args.max_synt_len+2]
177 |
178 | for tag in synt1:
179 | if tag != '' and tag != '':
180 | synt1_bow[i][synt_vocab[tag]-3] += 1
181 | for tag in synt2:
182 | if tag != '' and tag != '':
183 | synt2_bow[i][synt_vocab[tag]-3] += 1
184 |
185 | synt1_bow /= synt1_bow.sum(1, keepdim=True)
186 | synt2_bow /= synt2_bow.sum(1, keepdim=True)
187 |
188 | sum = 0
189 | for i in range(num):
190 | if torch.equal(synt1_bow[i], synt2_bow[i]):
191 | sum += 1
192 |
193 | return {'sent1':sent1_token_ids, 'sent2':sent2_token_ids, 'synt1': synt1_token_ids, 'synt2': synt2_token_ids,
194 | 'synt1bow': synt1_bow, 'synt2bow': synt2_bow}
195 |
196 | print("==== loading data ====")
197 | num = 1000000
198 | para_data = h5py.File(os.path.join(args.data_dir, 'data.h5'), 'r')
199 |
200 | train_idxs, valid_idxs = random_split(range(num), [num-5000, 5000], generator=torch.Generator().manual_seed(args.seed))
201 |
202 | print(f"number of train examples: {len(train_idxs)}")
203 | print(f"number of valid examples: {len(valid_idxs)}")
204 |
205 | train_loader = DataLoader(train_idxs, batch_size=args.train_batch_size, shuffle=True)
206 | valid_loader = DataLoader(valid_idxs, batch_size=args.valid_batch_size, shuffle=False)
207 |
208 | print("==== preparing data ====")
209 | make_path(args.cache_dir)
210 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir)
211 |
212 | with open('synt_vocab.pkl', 'rb') as f:
213 | synt_vocab = pickle.load(f)
214 |
215 | dataset = prepare_dataset(para_data, tokenizer, num)
216 |
217 | print("==== loading model ====")
218 | config = BartConfig.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir)
219 | config.word_dropout = args.word_dropout
220 | config.max_sent_len = args.max_sent_len
221 | config.max_synt_len = args.max_synt_len
222 |
223 | bart = BartModel.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir)
224 | model = ParaBart(config)
225 | model.load_state_dict(bart.state_dict(), strict=False)
226 | model.zero_grad()
227 | del bart
228 |
229 |
230 | no_decay_params = []
231 | no_decay_fast_params = []
232 | fast_params = []
233 | all_other_params = []
234 | adv_no_decay_params = []
235 | adv_all_other_params = []
236 |
237 | for n, p in model.named_parameters():
238 | if 'adv' in n:
239 | if 'norm' in n or 'bias' in n:
240 | adv_no_decay_params.append(p)
241 | else:
242 | adv_all_other_params.append(p)
243 | elif 'linear' in n or 'synt' in n or 'decoder' in n:
244 | if 'bias' in n:
245 | no_decay_fast_params.append(p)
246 | else:
247 | fast_params.append(p)
248 | elif 'norm' in n or 'bias' in n:
249 | no_decay_params.append(p)
250 | else:
251 | all_other_params.append(p)
252 |
253 | optimizer = optim.AdamW([
254 | {'params': fast_params, 'lr': args.fast_lr},
255 | {'params': no_decay_fast_params, 'lr': args.fast_lr, 'weight_decay': 0.0},
256 | {'params': no_decay_params, 'weight_decay': 0.0},
257 | {'params': all_other_params}
258 | ], lr=args.lr, weight_decay=args.weight_decay)
259 |
260 | adv_optimizer = optim.AdamW([
261 | {'params': adv_no_decay_params, 'weight_decay': 0.0},
262 | {'params': adv_all_other_params}
263 | ], lr=args.lr, weight_decay=args.weight_decay)
264 |
265 | para_criterion = nn.CrossEntropyLoss(ignore_index=model.config.pad_token_id).cuda()
266 | adv_criterion = nn.BCEWithLogitsLoss().cuda()
267 |
268 | model = model.cuda()
269 |
270 | make_path(args.model_dir)
271 |
272 | print("==== start training ====")
273 |
274 | for epoch in range(1, args.n_epoch+1):
275 | train(epoch, dataset, model, tokenizer, optimizer, args)
276 | torch.save(model.state_dict(), os.path.join(args.model_dir, "model_epoch{:02d}.pt".format(epoch)))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os, errno
2 | import numpy as np
3 | from datetime import datetime
4 |
5 |
6 | def make_path(path):
7 | try:
8 | os.makedirs(path)
9 | except OSError as exc:
10 | if exc.errno == errno.EEXIST and os.path.isdir(path):
11 | pass
12 | else: raise
13 |
14 | class Timer:
15 | def __init__(self):
16 | self.start_time = datetime.now()
17 | self.last_time = self.start_time
18 |
19 | def get_time_from_last(self, update=True):
20 | now_time = datetime.now()
21 | diff_time = now_time - self.last_time
22 | if update:
23 | self.last_time = now_time
24 | return diff_time.total_seconds()
25 |
26 | def get_time_from_start(self, update=True):
27 | now_time = datetime.now()
28 | diff_time = now_time - self.start_time
29 | if update:
30 | self.last_time = now_time
31 | return diff_time.total_seconds()
32 |
33 |
34 | def is_paren(tok):
35 | return tok == ")" or tok == "("
36 |
37 | def deleaf(tree):
38 | nonleaves = ''
39 | for w in tree.replace('\n', '').split():
40 | w = w.replace('(', '( ').replace(')', ' )')
41 | nonleaves += w + ' '
42 |
43 | arr = nonleaves.split()
44 | for n, i in enumerate(arr):
45 | if n + 1 < len(arr):
46 | tok1 = arr[n]
47 | tok2 = arr[n + 1]
48 | if not is_paren(tok1) and not is_paren(tok2):
49 | arr[n + 1] = ""
50 |
51 | nonleaves = " ".join(arr)
52 | return nonleaves.split()
--------------------------------------------------------------------------------