├── .gitignore ├── ANLI ├── MI_estimators.py ├── README.md ├── adv_trainer.py ├── advtraining_args.py ├── datasets │ └── anli.py ├── download_glue_data.py ├── eval_anli_local.py ├── jutils.py ├── local_robust_trainer.py ├── models │ ├── bert.py │ ├── modeling_auto.py │ └── roberta.py ├── print_table.py ├── processors │ └── anli.py ├── run_anli.py ├── run_glue.py ├── setup.sh └── trainer.py ├── README.md ├── SQuAD ├── MI_estimators.py ├── README.md ├── eval_adv_squad.py ├── jutils.py ├── models │ ├── bert.py │ ├── modeling_auto.py │ └── roberta.py ├── processors │ └── squad.py ├── run_squad.py ├── run_squad_standard.py └── setup.sh └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ -------------------------------------------------------------------------------- /ANLI/MI_estimators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from jutils import * 6 | 7 | ## cubic 8 | # lowersize = 40 9 | # hiddensize = 6 10 | 11 | ## Gaussian 12 | # lowersize = 20 13 | # hiddensize = 8 14 | 15 | ## club vs l1out 16 | lowersize = 40 17 | hiddensize = 8 18 | 19 | 20 | class CLUB(nn.Module): # CLUB: Mutual Information Contrastive Learning Upper Bound 21 | def __init__(self, x_dim, y_dim, lr=1e-3, beta=0): 22 | super(CLUB, self).__init__() 23 | self.hiddensize = y_dim 24 | self.version = 0 25 | self.p_mu = nn.Sequential(nn.Linear(x_dim, self.hiddensize), 26 | nn.ReLU(), 27 | nn.Linear(self.hiddensize, y_dim)) 28 | 29 | self.p_logvar = nn.Sequential(nn.Linear(x_dim, self.hiddensize), 30 | nn.ReLU(), 31 | nn.Linear(self.hiddensize, y_dim), 32 | nn.Tanh()) 33 | 34 | self.optimizer = torch.optim.Adam(self.parameters(), lr) 35 | self.beta = beta 36 | 37 | def get_mu_logvar(self, x_samples): 38 | mu = self.p_mu(x_samples) 39 | logvar = self.p_logvar(x_samples) 40 | return mu, logvar 41 | 42 | def mi_est_sample(self, x_samples, y_samples): 43 | mu, logvar = self.get_mu_logvar(x_samples) 44 | 45 | sample_size = x_samples.shape[0] 46 | random_index = torch.randint(sample_size, (sample_size,)).long() 47 | 48 | positive = - (mu - y_samples) ** 2 / 2. / logvar.exp() 49 | negative = - (mu - y_samples[random_index]) ** 2 / 2. / logvar.exp() 50 | upper_bound = (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 51 | # return upper_bound/2. 52 | return upper_bound 53 | 54 | def mi_est(self, x_samples, y_samples): # [nsample, 1] 55 | mu, logvar = self.get_mu_logvar(x_samples) 56 | 57 | positive = - (mu - y_samples) ** 2 / 2. / logvar.exp() 58 | 59 | prediction_1 = mu.unsqueeze(1) # [nsample,1,dim] 60 | y_samples_1 = y_samples.unsqueeze(0) # [1,nsample,dim] 61 | negative = - ((y_samples_1 - prediction_1) ** 2).mean(dim=1) / 2. / logvar.exp() # [nsample, dim] 62 | return (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 63 | # return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean(), positive.sum(dim = -1).mean(), negative.sum(dim = -1).mean() 64 | 65 | def loglikeli(self, x_samples, y_samples): 66 | mu, logvar = self.get_mu_logvar(x_samples) 67 | 68 | # return -1./2. * ((mu - y_samples)**2 /logvar.exp()-logvar ).sum(dim=1).mean(dim=0) 69 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) 70 | 71 | def update(self, x_samples, y_samples): 72 | if self.version == 0: 73 | self.train() 74 | loss = - self.loglikeli(x_samples, y_samples) 75 | 76 | self.optimizer.zero_grad() 77 | loss.backward(retain_graph=True) 78 | self.optimizer.step() 79 | 80 | # self.eval() 81 | return self.mi_est_sample(x_samples, y_samples) * self.beta 82 | 83 | elif self.version == 1: 84 | self.train() 85 | x_samples = torch.reshape(x_samples, (-1, x_samples.shape[-1])) 86 | y_samples = torch.reshape(y_samples, (-1, y_samples.shape[-1])) 87 | 88 | loss = -self.loglikeli(x_samples, y_samples) 89 | 90 | self.optimizer.zero_grad() 91 | loss.backward(retain_graph=True) 92 | self.optimizer.step() 93 | upper_bound = self.mi_est_sample(x_samples, y_samples) * self.beta 94 | # self.eval() 95 | return upper_bound 96 | 97 | 98 | class CLUBv2(nn.Module): # CLUB: Mutual Information Contrastive Learning Upper Bound 99 | def __init__(self, x_dim, y_dim, lr=1e-3, beta=0): 100 | super(CLUBv2, self).__init__() 101 | self.hiddensize = y_dim 102 | self.version = 2 103 | self.beta = beta 104 | 105 | def mi_est_sample(self, x_samples, y_samples): 106 | sample_size = y_samples.shape[0] 107 | random_index = torch.randint(sample_size, (sample_size,)).long() 108 | 109 | positive = torch.zeros_like(y_samples) 110 | negative = - (y_samples - y_samples[random_index]) ** 2 / 2. 111 | upper_bound = (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 112 | # return upper_bound/2. 113 | return upper_bound 114 | 115 | def mi_est(self, x_samples, y_samples): # [nsample, 1] 116 | positive = torch.zeros_like(y_samples) 117 | 118 | prediction_1 = y_samples.unsqueeze(1) # [nsample,1,dim] 119 | y_samples_1 = y_samples.unsqueeze(0) # [1,nsample,dim] 120 | negative = - ((y_samples_1 - prediction_1) ** 2).mean(dim=1) / 2. # [nsample, dim] 121 | return (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 122 | # return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean(), positive.sum(dim = -1).mean(), negative.sum(dim = -1).mean() 123 | 124 | def loglikeli(self, x_samples, y_samples): 125 | return 0 126 | 127 | def update(self, x_samples, y_samples, steps=None): 128 | # no performance improvement, not enabled 129 | if steps: 130 | beta = self.beta if steps > 1000 else self.beta * steps / 1000 # beta anealing 131 | else: 132 | beta = self.beta 133 | 134 | return self.mi_est_sample(x_samples, y_samples) * self.beta 135 | 136 | 137 | 138 | class MINE(nn.Module): 139 | def __init__(self, x_dim, y_dim): 140 | super(MINE, self).__init__() 141 | self.T_func = nn.Sequential(nn.Linear(x_dim + y_dim, lowersize), 142 | nn.ReLU(), 143 | nn.Linear(lowersize, 1)) 144 | 145 | def mi_est(self, x_samples, y_samples): # samples have shape [sample_size, dim] 146 | # shuffle and concatenate 147 | sample_size = y_samples.shape[0] 148 | random_index = torch.randint(sample_size, (sample_size,)).long() 149 | 150 | y_shuffle = y_samples[random_index] 151 | 152 | T0 = self.T_func(torch.cat([x_samples, y_samples], dim=-1)) 153 | T1 = self.T_func(torch.cat([x_samples, y_shuffle], dim=-1)) 154 | 155 | # lower_bound = T0.mean() - torch.log(T1.exp().mean()) 156 | lower_bound = T0.mean() - (torch.logsumexp(T1, dim=0).mean() - np.log(sample_size)) 157 | 158 | # compute the negative loss (maximise loss == minimise -loss) 159 | return lower_bound 160 | 161 | 162 | class NWJ(nn.Module): 163 | def __init__(self, x_dim, y_dim): 164 | super(NWJ, self).__init__() 165 | self.F_func = nn.Sequential(nn.Linear(x_dim + y_dim, lowersize), 166 | nn.ReLU(), 167 | nn.Linear(lowersize, 1)) 168 | 169 | def mi_est(self, x_samples, y_samples): # samples have shape [sample_size, dim] 170 | # shuffle and concatenate 171 | sample_size = y_samples.shape[0] 172 | # random_index = torch.randint(sample_size, (sample_size,)).long() 173 | 174 | x_tile = x_samples.unsqueeze(0).repeat((sample_size, 1, 1)) 175 | y_tile = y_samples.unsqueeze(1).repeat((1, sample_size, 1)) 176 | 177 | T0 = self.F_func(torch.cat([x_samples, y_samples], dim=-1)) 178 | T1 = self.F_func(torch.cat([x_tile, y_tile], dim=-1)) - 1. # [s_size, s_size, 1] 179 | 180 | lower_bound = T0.mean() - (T1.logsumexp(dim=1) - np.log(sample_size)).exp().mean() 181 | return lower_bound 182 | 183 | 184 | class InfoNCE(nn.Module): 185 | def __init__(self, x_dim, y_dim): 186 | super(InfoNCE, self).__init__() 187 | self.lower_size = 300 188 | self.F_func = nn.Sequential(nn.Linear(x_dim + y_dim, self.lower_size), 189 | nn.ReLU(), 190 | nn.Linear(self.lower_size, 1), 191 | nn.Softplus()) 192 | 193 | def forward(self, x_samples, y_samples): # samples have shape [sample_size, dim] 194 | # shuffle and concatenate 195 | sample_size = y_samples.shape[0] 196 | random_index = torch.randint(sample_size, (sample_size,)).long() 197 | 198 | x_tile = x_samples.unsqueeze(0).repeat((sample_size, 1, 1)) 199 | y_tile = y_samples.unsqueeze(1).repeat((1, sample_size, 1)) 200 | 201 | T0 = self.F_func(torch.cat([x_samples, y_samples], dim=-1)) 202 | T1 = self.F_func(torch.cat([x_tile, y_tile], dim=-1)) # [s_size, s_size, 1] 203 | 204 | lower_bound = T0.mean() - ( 205 | T1.logsumexp(dim=1).mean() - np.log(sample_size)) # torch.log(T1.exp().mean(dim = 1)).mean() 206 | 207 | # compute the negative loss (maximise loss == minimise -loss) 208 | return lower_bound 209 | 210 | 211 | class L1OutUB(nn.Module): # naive upper bound 212 | def __init__(self, x_dim, y_dim): 213 | super(L1OutUB, self).__init__() 214 | self.p_mu = nn.Sequential(nn.Linear(x_dim, hiddensize), 215 | nn.ReLU(), 216 | nn.Linear(hiddensize, y_dim)) 217 | 218 | self.p_logvar = nn.Sequential(nn.Linear(x_dim, hiddensize), 219 | nn.ReLU(), 220 | nn.Linear(hiddensize, y_dim), 221 | nn.Tanh()) 222 | 223 | def get_mu_logvar(self, x_samples): 224 | mu = self.p_mu(x_samples) 225 | logvar = self.p_logvar(x_samples) 226 | return mu, logvar 227 | 228 | def mi_est(self, x_samples, y_samples): # [nsample, 1] 229 | batch_size = y_samples.shape[0] 230 | mu, logvar = self.get_mu_logvar(x_samples) 231 | 232 | positive = (- (mu - y_samples) ** 2 / 2. / logvar.exp() - logvar / 2.).sum(dim=-1) # [nsample] 233 | 234 | mu_1 = mu.unsqueeze(1) # [nsample,1,dim] 235 | logvar_1 = logvar.unsqueeze(1) 236 | y_samples_1 = y_samples.unsqueeze(0) # [1,nsample,dim] 237 | all_probs = (- (y_samples_1 - mu_1) ** 2 / 2. / logvar_1.exp() - logvar_1 / 2.).sum( 238 | dim=-1) # [nsample, nsample] 239 | 240 | # diag_mask = torch.ones([batch_size, batch_size,1]).cuda() - torch.ones([batch_size]).diag().unsqueeze(-1).cuda() 241 | diag_mask = torch.ones([batch_size]).diag().unsqueeze(-1).cuda() * (-20.) 242 | 243 | # negative = (all_probs + diag_mask).logsumexp(dim = 0) - np.log(y_samples.shape[0]-1.) #[nsample] 244 | inpt = all_probs + diag_mask 245 | negative = log_sum_exp(all_probs + diag_mask, dim=0) - np.log(y_samples.shape[0] - 1.) # [nsample] 246 | return (positive - negative).mean() 247 | 248 | def loglikeli(self, x_samples, y_samples): 249 | mu, logvar = self.get_mu_logvar(x_samples) 250 | # return -1./2. * ((mu - y_samples)**2 /logvar.exp()-logvar ).sum(dim=1).mean(dim=0) 251 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) 252 | 253 | 254 | class VarUB(nn.Module): # variational upper bound 255 | def __init__(self, x_dim, y_dim): 256 | super(VarUB, self).__init__() 257 | self.p_mu = nn.Sequential(nn.Linear(x_dim, hiddensize), 258 | nn.ReLU(), 259 | nn.Linear(hiddensize, y_dim)) 260 | 261 | self.p_logvar = nn.Sequential(nn.Linear(x_dim, hiddensize), 262 | nn.ReLU(), 263 | nn.Linear(hiddensize, y_dim), 264 | nn.Tanh()) 265 | 266 | def get_mu_logvar(self, x_samples): 267 | mu = self.p_mu(x_samples) 268 | logvar = self.p_logvar(x_samples) 269 | return mu, logvar 270 | 271 | def mi_est(self, x_samples, y_samples): # [nsample, 1] 272 | mu, logvar = self.get_mu_logvar(x_samples) 273 | return 1. / 2. * (mu ** 2 + logvar.exp() - 1. - logvar).mean() 274 | 275 | def loglikeli(self, x_samples, y_samples): 276 | mu, logvar = self.get_mu_logvar(x_samples) 277 | # return -1./2. * ((mu - y_samples)**2 /logvar.exp()-logvar ).sum(dim=1).mean(dim=0) 278 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) 279 | -------------------------------------------------------------------------------- /ANLI/README.md: -------------------------------------------------------------------------------- 1 | # InfoBERT on ANLI 2 | 3 | ## Prepare your data 4 | Download the [training data and test data](https://drive.google.com/file/d/1xWwABFkzJ6fEnR1f3xr-vkMesxdO7IZm/view?usp=sharing) into `anli_data`. 5 | The training data includes: 6 | 7 | - ANLI(train) 8 | - MNLI 9 | - SNLI 10 | - FeverNLI 11 | 12 | The testing data includes: 13 | - ANLI(dev/test) 14 | - MNLI 15 | - SNLI 16 | - adversarial MNLI dataset generated by TextFooler against BERT and RoBERTa (`infobert_mnli_matched`, `infobert_mnli_mismatched`, `roberta_mnli_matched`, `roberta_mnli_mismatched`) 17 | - adversarial SNLI dataset generated by TextFooler against BERT and RoBERTa (`infobert_snli`, `roberta_snli`) 18 | 19 | After unzip the data, the directory strcucture should look like: 20 | ``` 21 | anli_data 22 | ├── MNLI 23 | ├── SNLI 24 | ├── infobert_mnli_matched 25 | ├── infobert_mnli_mismatched 26 | ├── infobert_snli 27 | ├── nli_fever 28 | ├── roberta_mnli_matched 29 | ├── roberta_mnli_mismatched 30 | └── roberta_snli 31 | 32 | ``` 33 | 34 | Note that the ANLI dataset will be automatically downloaded by `nlp` package during training/testing. 35 | 36 | ## Train 37 | 38 | Running example: 39 | ```bash 40 | # [task-name] [custom-name] [base-model] [lr] [batch-size] [max-seq-len] [training-steps] [warmup steps] [seed] [weight-decay] [beta] [mi_version] [hdp] [adp] [adv lr] [adv mag] [anorm] [asteps] [alpha] [cl] [ch] 41 | source setup.sh && runexp anli-full infobert roberta-large 2e-5 32 128 -1 1000 42 1e-5 5e-3 6 0.1 0 4e-2 8e-2 0 3 5e-3 0.5 0.9 42 | ``` 43 | 44 | Detailed explanation of each parameter can be found in `setup.sh`. 45 | Note: 46 | - If task name is `anli-full`, it will be trained with both adversarial and benign dataset (ANLI(train) + SNLI + MNLI + FeverNLI). 47 | - If task name is `anli-part`, it will be trained with only benign dataset (SNLI + MNLI). 48 | - The default MI estimator version is 6, which means both regularizers will be enabled. 49 | - By default, `setup.sh` will use 1 GPU for training. If equipped with multiple GPUs, please modify the `setup.sh` and set `--nproc_per_node=[ngpus]` accordingly. 50 | - Our script can stop at any time and automatically resume from exisitng checkpoints. 51 | 52 | ## Evaluate 53 | We provide our [InfoBERT checkpoint](https://drive.google.com/file/d/1RljMLLuIUD4lFwjciX7gxBC7RxvlGMDy/view?usp=sharing) that achieves the state-of-the-art performance on the ANLI test dataset. 54 | 55 | To evaluate your model from a checkpoint 56 | ```bash 57 | source setup.sh && evalexp [results-dir] [checkpoint-dir] 58 | ``` 59 | 60 | For example, to evaluate our checkpoint performance 61 | ```bash 62 | source setup.sh && evalexp best-infobert infobert-checkpoint 63 | ``` 64 | 65 | The output should be like 66 | 67 | ``` 68 | 01/13/2021 12:37:05 - INFO - local_robust_trainer - ***** Running Evaluation ***** 69 | 01/13/2021 12:37:05 - INFO - local_robust_trainer - Num examples = 3200 70 | 01/13/2021 12:37:05 - INFO - local_robust_trainer - Batch size = 8 71 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:17<00:00, 22.72it/s] 72 | {"eval_loss": 1.4944478544429876, "eval_acc": 0.5825, "eval_mi_info": 40.253395080566406, "step": null} 73 | 01/13/2021 12:37:23 - INFO - **main** - ***** Eval results anli-full-dev ***** 74 | 01/13/2021 12:37:23 - INFO - **main** - eval_loss = 1.4944478544429876 75 | 01/13/2021 12:37:23 - INFO - **main** - eval_acc = 0.5825 76 | 01/13/2021 12:37:23 - INFO - **main** - eval_mi_info = 40.253395080566406 77 | 01/13/2021 12:37:23 - INFO - local_robust_trainer - ***** Running Evaluation ***** 78 | 01/13/2021 12:37:23 - INFO - local_robust_trainer - Num examples = 3200 79 | 01/13/2021 12:37:23 - INFO - local_robust_trainer - Batch size = 8 80 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:17<00:00, 23.19it/s] 81 | {"eval_loss": 1.5127741410769522, "eval_acc": 0.583125, "eval_mi_info": 40.0756950378418, "step": null} 82 | 01/13/2021 12:37:40 - INFO - **main** - ***** Eval results anli-full-test ***** 83 | 01/13/2021 12:37:40 - INFO - **main** - eval_loss = 1.5127741410769522 84 | 01/13/2021 12:37:40 - INFO - **main** - eval_acc = 0.583125 85 | 01/13/2021 12:37:40 - INFO - **main** - eval_mi_info = 40.0756950378418 86 | 01/13/2021 12:37:40 - INFO - local_robust_trainer - ***** Running Evaluation ***** 87 | 01/13/2021 12:37:40 - INFO - local_robust_trainer - Num examples = 1000 88 | 01/13/2021 12:37:40 - INFO - local_robust_trainer - Batch size = 8 89 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:05<00:00, 22.57it/s] 90 | {"eval_loss": 0.8243400501683354, "eval_acc": 0.764, "eval_mi_info": 41.68259048461914, "step": null} 91 | 01/13/2021 12:37:46 - INFO - **main** - ***** Eval results anli-r1-dev ***** 92 | 01/13/2021 12:37:46 - INFO - **main** - eval_loss = 0.8243400501683354 93 | 01/13/2021 12:37:46 - INFO - **main** - eval_acc = 0.764 94 | 01/13/2021 12:37:46 - INFO - **main** - eval_mi_info = 41.68259048461914 95 | 01/13/2021 12:37:46 - INFO - local_robust_trainer - ***** Running Evaluation ***** 96 | 01/13/2021 12:37:46 - INFO - local_robust_trainer - Num examples = 1000 97 | 01/13/2021 12:37:46 - INFO - local_robust_trainer - Batch size = 8 98 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:05<00:00, 22.62it/s] 99 | {"eval_loss": 0.8654105467200279, "eval_acc": 0.755, "eval_mi_info": 41.359413146972656, "step": null} 100 | 01/13/2021 12:37:51 - INFO - **main** - ***** Eval results anli-r1-test ***** 101 | 01/13/2021 12:37:51 - INFO - **main** - eval_loss = 0.8654105467200279 102 | 01/13/2021 12:37:51 - INFO - **main** - eval_acc = 0.755 103 | 01/13/2021 12:37:51 - INFO - **main** - eval_mi_info = 41.359413146972656 104 | 01/13/2021 12:37:51 - INFO - local_robust_trainer - ***** Running Evaluation ***** 105 | 01/13/2021 12:37:51 - INFO - local_robust_trainer - Num examples = 1000 106 | 01/13/2021 12:37:51 - INFO - local_robust_trainer - Batch size = 8 107 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:05<00:00, 22.82it/s] 108 | {"eval_loss": 1.6665315675735473, "eval_acc": 0.517, "eval_mi_info": 41.185367584228516, "step": null} 109 | 01/13/2021 12:37:57 - INFO - **main** - ***** Eval results anli-r2-dev ***** 110 | 01/13/2021 12:37:57 - INFO - **main** - eval_loss = 1.6665315675735473 111 | 01/13/2021 12:37:57 - INFO - **main** - eval_acc = 0.517 112 | 01/13/2021 12:37:57 - INFO - **main** - eval_mi_info = 41.185367584228516 113 | 01/13/2021 12:37:57 - INFO - local_robust_trainer - ***** Running Evaluation ***** 114 | 01/13/2021 12:37:57 - INFO - local_robust_trainer - Num examples = 1000 115 | 01/13/2021 12:37:57 - INFO - local_robust_trainer - Batch size = 8 116 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:05<00:00, 22.74it/s] 117 | {"eval_loss": 1.682003709077835, "eval_acc": 0.514, "eval_mi_info": 41.24909210205078, "step": null} 118 | 01/13/2021 12:38:02 - INFO - **main** - ***** Eval results anli-r2-test ***** 119 | 01/13/2021 12:38:02 - INFO - **main** - eval_loss = 1.682003709077835 120 | 01/13/2021 12:38:02 - INFO - **main** - eval_acc = 0.514 121 | 01/13/2021 12:38:02 - INFO - **main** - eval_mi_info = 41.24909210205078 122 | 01/13/2021 12:38:02 - INFO - local_robust_trainer - ***** Running Evaluation ***** 123 | 01/13/2021 12:38:02 - INFO - local_robust_trainer - Num examples = 1200 124 | 01/13/2021 12:38:02 - INFO - local_robust_trainer - Batch size = 8 125 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:06<00:00, 23.77it/s] 126 | {"eval_loss": 1.909467930396398, "eval_acc": 0.48583333333333334, "eval_mi_info": 38.28575897216797, "step": null} 127 | 01/13/2021 12:38:08 - INFO - **main** - ***** Eval results anli-r3-dev ***** 128 | 01/13/2021 12:38:08 - INFO - **main** - eval_loss = 1.909467930396398 129 | 01/13/2021 12:38:08 - INFO - **main** - eval_acc = 0.48583333333333334 130 | 01/13/2021 12:38:08 - INFO - **main** - eval_mi_info = 38.28575897216797 131 | 01/13/2021 12:38:08 - INFO - local_robust_trainer - ***** Running Evaluation ***** 132 | 01/13/2021 12:38:08 - INFO - local_robust_trainer - Num examples = 1200 133 | 01/13/2021 12:38:08 - INFO - local_robust_trainer - Batch size = 8 134 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:06<00:00, 23.85it/s] 135 | {"eval_loss": 1.9112191630403201, "eval_acc": 0.4975, "eval_mi_info": 38.02809524536133, "step": null} 136 | 01/13/2021 12:38:15 - INFO - **main** - ***** Eval results anli-r3-test ***** 137 | 01/13/2021 12:38:15 - INFO - **main** - eval_loss = 1.9112191630403201 138 | 01/13/2021 12:38:15 - INFO - **main** - eval_acc = 0.4975 139 | 01/13/2021 12:38:15 - INFO - **main** - eval_mi_info = 38.02809524536133 140 | 01/13/2021 12:38:15 - INFO - local_robust_trainer - ***** Running Evaluation ***** 141 | 01/13/2021 12:38:15 - INFO - local_robust_trainer - Num examples = 9815 142 | 01/13/2021 12:38:15 - INFO - local_robust_trainer - Batch size = 8 143 | Evaluation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1227/1227 [00:41<00:00, 29.85it/s] 144 | {"eval_loss": 0.3075757020476829, "eval_acc": 0.9108507386653082, "eval_mi_info": 33.53987121582031, "step": null} 145 | 01/13/2021 12:38:56 - INFO - **main** - ***** Eval results mnli-dev ***** 146 | 01/13/2021 12:38:56 - INFO - **main** - eval_loss = 0.3075757020476829 147 | 01/13/2021 12:38:56 - INFO - **main** - eval_acc = 0.9108507386653082 148 | 01/13/2021 12:38:56 - INFO - **main** - eval_mi_info = 33.53987121582031 149 | 01/13/2021 12:38:56 - INFO - local_robust_trainer - ***** Running Evaluation ***** 150 | 01/13/2021 12:38:56 - INFO - local_robust_trainer - Num examples = 9832 151 | 01/13/2021 12:38:56 - INFO - local_robust_trainer - Batch size = 8 152 | Evaluation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1229/1229 [00:41<00:00, 29.60it/s] 153 | {"eval_loss": 0.325766117048493, "eval_acc": 0.9070382424735557, "eval_mi_info": 33.497161865234375, "step": null} 154 | 01/13/2021 12:39:37 - INFO - **main** - ***** Eval results mnli-mm-dev ***** 155 | 01/13/2021 12:39:37 - INFO - **main** - eval_loss = 0.325766117048493 156 | 01/13/2021 12:39:37 - INFO - **main** - eval_acc = 0.9070382424735557 157 | 01/13/2021 12:39:37 - INFO - **main** - eval_mi_info = 33.497161865234375 158 | 01/13/2021 12:39:37 - INFO - local_robust_trainer - ***** Running Evaluation ***** 159 | 01/13/2021 12:39:37 - INFO - local_robust_trainer - Num examples = 9842 160 | 01/13/2021 12:39:37 - INFO - local_robust_trainer - Batch size = 8 161 | Evaluation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1231/1231 [00:40<00:00, 30.69it/s] 162 | {"eval_loss": 0.22175195821680835, "eval_acc": 0.9312131680552733, "eval_mi_info": 26.011133193969727, "step": null} 163 | 01/13/2021 12:40:17 - INFO - **main** - ***** Eval results snli-dev ***** 164 | 01/13/2021 12:40:17 - INFO - **main** - eval_loss = 0.22175195821680835 165 | 01/13/2021 12:40:17 - INFO - **main** - eval_acc = 0.9312131680552733 166 | 01/13/2021 12:40:17 - INFO - **main** - eval_mi_info = 26.011133193969727 167 | 01/13/2021 12:40:17 - INFO - local_robust_trainer - ***** Running Evaluation ***** 168 | 01/13/2021 12:40:17 - INFO - local_robust_trainer - Num examples = 772 169 | 01/13/2021 12:40:17 - INFO - local_robust_trainer - Batch size = 8 170 | Evaluation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 29.60it/s] 171 | {"eval_loss": 0.9281529402840383, "eval_acc": 0.7202072538860104, "eval_mi_info": 35.50751495361328, "step": null} 172 | 01/13/2021 12:40:21 - INFO - **main** - ***** Eval results mnli-bert-adv-dev ***** 173 | 01/13/2021 12:40:21 - INFO - **main** - eval_loss = 0.9281529402840383 174 | 01/13/2021 12:40:21 - INFO - **main** - eval_acc = 0.7202072538860104 175 | 01/13/2021 12:40:21 - INFO - **main** - eval_mi_info = 35.50751495361328 176 | 01/13/2021 12:40:21 - INFO - local_robust_trainer - ***** Running Evaluation ***** 177 | 01/13/2021 12:40:21 - INFO - local_robust_trainer - Num examples = 746 178 | 01/13/2021 12:40:21 - INFO - local_robust_trainer - Batch size = 8 179 | Evaluation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:03<00:00, 29.64it/s] 180 | {"eval_loss": 1.023387150561556, "eval_acc": 0.6983914209115282, "eval_mi_info": 35.46954345703125, "step": null} 181 | 01/13/2021 12:40:24 - INFO - **main** - ***** Eval results mnli-mm-bert-adv-dev ***** 182 | 01/13/2021 12:40:24 - INFO - **main** - eval_loss = 1.023387150561556 183 | 01/13/2021 12:40:24 - INFO - **main** - eval_acc = 0.6983914209115282 184 | 01/13/2021 12:40:24 - INFO - **main** - eval_mi_info = 35.46954345703125 185 | 01/13/2021 12:40:24 - INFO - local_robust_trainer - ***** Running Evaluation ***** 186 | 01/13/2021 12:40:24 - INFO - local_robust_trainer - Num examples = 848 187 | 01/13/2021 12:40:24 - INFO - local_robust_trainer - Batch size = 8 188 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 106/106 [00:03<00:00, 30.65it/s] 189 | {"eval_loss": 1.2708790726819128, "eval_acc": 0.5955188679245284, "eval_mi_info": 28.26164436340332, "step": null} 190 | 01/13/2021 12:40:27 - INFO - **main** - ***** Eval results snli-bert-adv-dev ***** 191 | 01/13/2021 12:40:27 - INFO - **main** - eval_loss = 1.2708790726819128 192 | 01/13/2021 12:40:27 - INFO - **main** - eval_acc = 0.5955188679245284 193 | 01/13/2021 12:40:27 - INFO - **main** - eval_mi_info = 28.26164436340332 194 | 01/13/2021 12:40:27 - INFO - local_robust_trainer - ***** Running Evaluation ***** 195 | 01/13/2021 12:40:27 - INFO - local_robust_trainer - Num examples = 775 196 | 01/13/2021 12:40:27 - INFO - local_robust_trainer - Batch size = 8 197 | Evaluation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 29.61it/s] 198 | {"eval_loss": 1.4205392347168677, "eval_acc": 0.5367741935483871, "eval_mi_info": 35.69055938720703, "step": null} 199 | 01/13/2021 12:40:31 - INFO - **main** - ***** Eval results mnli-roberta-adv-dev ***** 200 | 01/13/2021 12:40:31 - INFO - **main** - eval_loss = 1.4205392347168677 201 | 01/13/2021 12:40:31 - INFO - **main** - eval_acc = 0.5367741935483871 202 | 01/13/2021 12:40:31 - INFO - **main** - eval_mi_info = 35.69055938720703 203 | 01/13/2021 12:40:31 - INFO - local_robust_trainer - ***** Running Evaluation ***** 204 | 01/13/2021 12:40:31 - INFO - local_robust_trainer - Num examples = 775 205 | 01/13/2021 12:40:31 - INFO - local_robust_trainer - Batch size = 8 206 | Evaluation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [00:03<00:00, 29.47it/s] 207 | {"eval_loss": 1.367013628949824, "eval_acc": 0.5703225806451613, "eval_mi_info": 35.5481071472168, "step": null} 208 | 01/13/2021 12:40:34 - INFO - **main** - ***** Eval results mnli-mm-roberta-adv-dev ***** 209 | 01/13/2021 12:40:34 - INFO - **main** - eval_loss = 1.367013628949824 210 | 01/13/2021 12:40:34 - INFO - **main** - eval_acc = 0.5703225806451613 211 | 01/13/2021 12:40:34 - INFO - **main** - eval_mi_info = 35.5481071472168 212 | 01/13/2021 12:40:34 - INFO - local_robust_trainer - ***** Running Evaluation ***** 213 | 01/13/2021 12:40:34 - INFO - local_robust_trainer - Num examples = 836 214 | 01/13/2021 12:40:34 - INFO - local_robust_trainer - Batch size = 8 215 | Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:03<00:00, 30.71it/s] 216 | {"eval_loss": 1.7491175064018794, "eval_acc": 0.42105263157894735, "eval_mi_info": 28.207780838012695, "step": null} 217 | 01/13/2021 12:40:37 - INFO - **main** - ***** Eval results snli-roberta-adv-dev ***** 218 | 01/13/2021 12:40:37 - INFO - **main** - eval_loss = 1.7491175064018794 219 | 01/13/2021 12:40:37 - INFO - **main** - eval_acc = 0.42105263157894735 220 | 01/13/2021 12:40:37 - INFO - **main** - eval_mi_info = 28.207780838012695 221 | 222 | ``` -------------------------------------------------------------------------------- /ANLI/advtraining_args.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import logging 4 | from dataclasses import dataclass, field 5 | from typing import Any, Dict, Optional, Tuple 6 | 7 | from transformers.file_utils import cached_property, is_torch_available, torch_required 8 | 9 | 10 | if is_torch_available(): 11 | import torch 12 | 13 | 14 | try: 15 | import torch_xla.core.xla_model as xm 16 | 17 | _has_tpu = True 18 | except ImportError: 19 | _has_tpu = False 20 | 21 | 22 | @torch_required 23 | def is_tpu_available(): 24 | return _has_tpu 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | @dataclass 31 | class TrainingArguments: 32 | """ 33 | TrainingArguments is the subset of the arguments we use in our example scripts 34 | **which relate to the training loop itself**. 35 | 36 | Using `HfArgumentParser` we can turn this class 37 | into argparse arguments to be able to specify them on 38 | the command line. 39 | """ 40 | 41 | output_dir: str = field( 42 | metadata={"help": "The output directory where the model predictions and checkpoints will be written."} 43 | ) 44 | overwrite_output_dir: bool = field( 45 | default=False, 46 | metadata={ 47 | "help": ( 48 | "Overwrite the content of the output directory." 49 | "Use this to continue training if output_dir points to a checkpoint directory." 50 | ) 51 | }, 52 | ) 53 | 54 | do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) 55 | do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) 56 | do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) 57 | evaluate_during_training: bool = field( 58 | default=False, metadata={"help": "Run evaluation during training at each logging step."}, 59 | ) 60 | 61 | per_device_train_batch_size: int = field( 62 | default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} 63 | ) 64 | per_device_eval_batch_size: int = field( 65 | default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} 66 | ) 67 | 68 | per_gpu_train_batch_size: Optional[int] = field( 69 | default=None, 70 | metadata={ 71 | "help": "Deprecated, the use of `--per_device_train_batch_size` is preferred. " 72 | "Batch size per GPU/TPU core/CPU for training." 73 | }, 74 | ) 75 | per_gpu_eval_batch_size: Optional[int] = field( 76 | default=None, 77 | metadata={ 78 | "help": "Deprecated, the use of `--per_device_eval_batch_size` is preferred." 79 | "Batch size per GPU/TPU core/CPU for evaluation." 80 | }, 81 | ) 82 | 83 | gradient_accumulation_steps: int = field( 84 | default=1, 85 | metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, 86 | ) 87 | 88 | learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) 89 | weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."}) 90 | adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."}) 91 | max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) 92 | 93 | num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) 94 | max_steps: int = field( 95 | default=-1, 96 | metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, 97 | ) 98 | warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) 99 | 100 | logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) 101 | logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"}) 102 | logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) 103 | save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) 104 | save_total_limit: Optional[int] = field( 105 | default=None, 106 | metadata={ 107 | "help": ( 108 | "Limit the total amount of checkpoints." 109 | "Deletes the older checkpoints in the output_dir. Default is unlimited checkpoints" 110 | ) 111 | }, 112 | ) 113 | no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) 114 | seed: int = field(default=42, metadata={"help": "random seed for initialization"}) 115 | 116 | fp16: bool = field( 117 | default=False, 118 | metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"}, 119 | ) 120 | fp16_opt_level: str = field( 121 | default="O1", 122 | metadata={ 123 | "help": ( 124 | "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 125 | "See details at https://nvidia.github.io/apex/amp.html" 126 | ) 127 | }, 128 | ) 129 | local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) 130 | 131 | tpu_num_cores: Optional[int] = field( 132 | default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"} 133 | ) 134 | tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"}) 135 | 136 | # adv args 137 | alpha: Optional[float] = field(default=0) 138 | adv_lr: Optional[float] = field(default=0) 139 | adv_steps: Optional[int] = field(default=1, metadata={"help": "should be at least 1"}) 140 | adv_init_mag: Optional[float] = field(default=0) 141 | norm_type: Optional[str] = field(default="l2", metadata={"help": 'choices "l2", or "linf"'}) 142 | adv_max_norm: Optional[float] = field(default=0, metadata={"help": "set to 0 to be unlimited"}) 143 | hidden_dropout_prob: Optional[float] = field(default=0.1) 144 | attention_probs_dropout_prob: Optional[float] = field(default=0) 145 | cl: Optional[float] = field(default=0.5) 146 | ch: Optional[float] = field(default=0.9) 147 | 148 | @property 149 | def train_batch_size(self) -> int: 150 | if self.per_gpu_train_batch_size: 151 | logger.warning( 152 | "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " 153 | "version. Using `--per_device_train_batch_size` is preferred." 154 | ) 155 | per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size 156 | return per_device_batch_size * max(1, self.n_gpu) 157 | 158 | @property 159 | def eval_batch_size(self) -> int: 160 | if self.per_gpu_eval_batch_size: 161 | logger.warning( 162 | "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future " 163 | "version. Using `--per_device_eval_batch_size` is preferred." 164 | ) 165 | per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size 166 | return per_device_batch_size * max(1, self.n_gpu) 167 | 168 | @cached_property 169 | @torch_required 170 | def _setup_devices(self) -> Tuple["torch.device", int]: 171 | logger.info("PyTorch: setting up devices") 172 | if self.no_cuda: 173 | device = torch.device("cpu") 174 | n_gpu = 0 175 | elif is_tpu_available(): 176 | device = xm.xla_device() 177 | n_gpu = 0 178 | elif self.local_rank == -1: 179 | # if n_gpu is > 1 we'll use nn.DataParallel. 180 | # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` 181 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 182 | n_gpu = torch.cuda.device_count() 183 | else: 184 | # Here, we'll use torch.distributed. 185 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 186 | torch.distributed.init_process_group(backend="nccl") 187 | device = torch.device("cuda", self.local_rank) 188 | n_gpu = 1 189 | return device, n_gpu 190 | 191 | @property 192 | @torch_required 193 | def device(self) -> "torch.device": 194 | return self._setup_devices[0] 195 | 196 | @property 197 | @torch_required 198 | def n_gpu(self): 199 | return self._setup_devices[1] 200 | 201 | def to_json_string(self): 202 | """ 203 | Serializes this instance to a JSON string. 204 | """ 205 | return json.dumps(dataclasses.asdict(self), indent=2) 206 | 207 | def to_sanitized_dict(self) -> Dict[str, Any]: 208 | """ 209 | Sanitized serialization to use with TensorBoard’s hparams 210 | """ 211 | d = dataclasses.asdict(self) 212 | valid_types = [bool, int, float, str] 213 | if is_torch_available(): 214 | valid_types.append(torch.Tensor) 215 | return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} 216 | -------------------------------------------------------------------------------- /ANLI/datasets/anli.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from dataclasses import dataclass, field 5 | from enum import Enum 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from filelock import FileLock 10 | from torch.utils.data.dataset import Dataset 11 | 12 | from transformers import RobertaTokenizer, RobertaTokenizerFast 13 | from transformers import PreTrainedTokenizer 14 | from transformers import XLMRobertaTokenizer 15 | from processors.anli import glue_convert_examples_to_features, glue_output_modes, glue_processors 16 | from transformers import InputFeatures 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | @dataclass 23 | class GlueDataTrainingArguments: 24 | """ 25 | Arguments pertaining to what data we are going to input our model for training and eval. 26 | 27 | Using `HfArgumentParser` we can turn this class 28 | into argparse arguments to be able to specify them on 29 | the command line. 30 | """ 31 | 32 | task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())}) 33 | data_dir: str = field( 34 | metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} 35 | ) 36 | max_seq_length: int = field( 37 | default=128, 38 | metadata={ 39 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 40 | "than this will be truncated, sequences shorter will be padded." 41 | }, 42 | ) 43 | overwrite_cache: bool = field( 44 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 45 | ) 46 | 47 | def __post_init__(self): 48 | self.task_name = self.task_name.lower() 49 | 50 | 51 | class Split(Enum): 52 | train = "train" 53 | dev = "dev" 54 | test = "test" 55 | 56 | 57 | class GlueDataset(Dataset): 58 | """ 59 | This will be superseded by a framework-agnostic approach 60 | soon. 61 | """ 62 | 63 | args: GlueDataTrainingArguments 64 | output_mode: str 65 | features: List[InputFeatures] 66 | 67 | def __init__( 68 | self, 69 | args: GlueDataTrainingArguments, 70 | tokenizer: PreTrainedTokenizer, 71 | limit_length: Optional[int] = None, 72 | mode: Union[str, Split] = Split.train, 73 | ): 74 | self.args = args 75 | self.mode = mode 76 | self.processor = glue_processors[args.task_name]() 77 | self.output_mode = glue_output_modes[args.task_name] 78 | if isinstance(mode, str): 79 | try: 80 | mode = Split[mode] 81 | except KeyError: 82 | raise KeyError("mode is not a valid split name") 83 | # Load data features from cache or dataset file 84 | cached_features_file = os.path.join( 85 | args.data_dir, 86 | "cached_{}_{}_{}_{}".format( 87 | mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, 88 | ), 89 | ) 90 | label_list = self.processor.get_labels() 91 | 92 | anli_datasets = ["mnli", "mnli-mm", "anli-r1", "anli-r2", "anli-r3", "anli-all", "anli-full", "anli-part", "snli", 93 | "mnli-bert-adv", "mnli-mm-bert-adv", "snli-bert-adv", 94 | "mnli-roberta-adv", "mnli-mm-roberta-adv", "snli-roberta-adv"] 95 | if args.task_name in anli_datasets and tokenizer.__class__ in ( 96 | RobertaTokenizer, 97 | RobertaTokenizerFast, 98 | XLMRobertaTokenizer, 99 | ): 100 | # HACK(label indices are swapped in RoBERTa pretrained model) 101 | label_list[1], label_list[2] = label_list[2], label_list[1] 102 | self.label_list = label_list 103 | 104 | # Make sure only the first process in distributed training processes the dataset, 105 | # and the others will use the cache. 106 | lock_path = cached_features_file + ".lock" 107 | with FileLock(lock_path): 108 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 109 | start = time.time() 110 | self.features = torch.load(cached_features_file) 111 | logger.info( 112 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start 113 | ) 114 | else: 115 | logger.info(f"Creating features from dataset file at {args.data_dir}") 116 | 117 | if mode == Split.dev: 118 | examples = self.processor.get_dev_examples(args.data_dir) 119 | elif mode == Split.test: 120 | examples = self.processor.get_test_examples(args.data_dir) 121 | else: 122 | examples = self.processor.get_train_examples(args.data_dir) 123 | if limit_length is not None: 124 | examples = examples[:limit_length] 125 | self.features = glue_convert_examples_to_features( 126 | examples, 127 | tokenizer, 128 | max_length=args.max_seq_length, 129 | label_list=label_list, 130 | output_mode=self.output_mode, 131 | ) 132 | start = time.time() 133 | torch.save(self.features, cached_features_file) 134 | # ^ This seems to take a lot of time so I want to investigate why and how we can improve. 135 | logger.info( 136 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start 137 | ) 138 | 139 | def __len__(self): 140 | return len(self.features) 141 | 142 | def __getitem__(self, i) -> InputFeatures: 143 | return self.features[i] 144 | 145 | def get_labels(self): 146 | return self.label_list 147 | -------------------------------------------------------------------------------- /ANLI/download_glue_data.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | Note: for legal reasons, we are unable to host MRPC. 3 | You can either use the version hosted by the SentEval team, which is already tokenized, 4 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 5 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 6 | You should then rename and place specific files in a folder (see below for an example). 7 | mkdir MRPC 8 | cabextract MSRParaphraseCorpus.msi -d MRPC 9 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 10 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 11 | rm MRPC/_* 12 | rm MSRParaphraseCorpus.msi 13 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 14 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 15 | ''' 16 | 17 | import os 18 | import sys 19 | import shutil 20 | import argparse 21 | import tempfile 22 | import urllib.request 23 | import zipfile 24 | 25 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 26 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 27 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 28 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 29 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 30 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 31 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 32 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 33 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 34 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 35 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 36 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 37 | 38 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 39 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 40 | 41 | def download_and_extract(task, data_dir): 42 | print("Downloading and extracting %s..." % task) 43 | data_file = "%s.zip" % task 44 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 45 | with zipfile.ZipFile(data_file) as zip_ref: 46 | zip_ref.extractall(data_dir) 47 | os.remove(data_file) 48 | print("\tCompleted!") 49 | 50 | def format_mrpc(data_dir, path_to_data): 51 | print("Processing MRPC...") 52 | mrpc_dir = os.path.join(data_dir, "MRPC") 53 | if not os.path.isdir(mrpc_dir): 54 | os.mkdir(mrpc_dir) 55 | if path_to_data: 56 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 57 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 58 | else: 59 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 60 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 61 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 62 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 63 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 64 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 65 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 66 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 67 | 68 | dev_ids = [] 69 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 70 | for row in ids_fh: 71 | dev_ids.append(row.strip().split('\t')) 72 | 73 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 74 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 75 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 76 | header = data_fh.readline() 77 | train_fh.write(header) 78 | dev_fh.write(header) 79 | for row in data_fh: 80 | label, id1, id2, s1, s2 = row.strip().split('\t') 81 | if [id1, id2] in dev_ids: 82 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 83 | else: 84 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 85 | 86 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 87 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 88 | header = data_fh.readline() 89 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 90 | for idx, row in enumerate(data_fh): 91 | label, id1, id2, s1, s2 = row.strip().split('\t') 92 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 93 | print("\tCompleted!") 94 | 95 | def download_diagnostic(data_dir): 96 | print("Downloading and extracting diagnostic...") 97 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 98 | os.mkdir(os.path.join(data_dir, "diagnostic")) 99 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 100 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 101 | print("\tCompleted!") 102 | return 103 | 104 | def get_tasks(task_names): 105 | task_names = task_names.split(',') 106 | if "all" in task_names: 107 | tasks = TASKS 108 | else: 109 | tasks = [] 110 | for task_name in task_names: 111 | assert task_name in TASKS, "Task %s not found!" % task_name 112 | tasks.append(task_name) 113 | return tasks 114 | 115 | def main(arguments): 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 118 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 119 | type=str, default='all') 120 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 121 | type=str, default='') 122 | args = parser.parse_args(arguments) 123 | 124 | if not os.path.isdir(args.data_dir): 125 | os.mkdir(args.data_dir) 126 | tasks = get_tasks(args.tasks) 127 | 128 | for task in tasks: 129 | if task == 'MRPC': 130 | format_mrpc(args.data_dir, args.path_to_mrpc) 131 | elif task == 'diagnostic': 132 | download_diagnostic(args.data_dir) 133 | else: 134 | download_and_extract(task, args.data_dir) 135 | 136 | 137 | if __name__ == '__main__': 138 | sys.exit(main(sys.argv[1:])) -------------------------------------------------------------------------------- /ANLI/eval_anli_local.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" 17 | 18 | import dataclasses 19 | import logging 20 | import os 21 | import sys 22 | from dataclasses import dataclass, field 23 | from typing import Callable, Dict, Optional 24 | 25 | import numpy as np 26 | import torch 27 | 28 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction 29 | 30 | from MI_estimators import CLUB, CLUBv2, InfoNCE 31 | from datasets.anli import GlueDataset, GlueDataTrainingArguments as DataTrainingArguments 32 | from processors.anli import glue_output_modes, glue_tasks_num_labels, glue_compute_metrics 33 | from local_robust_trainer import Trainer 34 | from transformers import ( 35 | HfArgumentParser, 36 | set_seed, 37 | ) 38 | from advtraining_args import TrainingArguments 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | @dataclass 44 | class ModelArguments: 45 | """ 46 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 47 | """ 48 | 49 | model_name_or_path: str = field( 50 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 51 | ) 52 | config_name: Optional[str] = field( 53 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 54 | ) 55 | tokenizer_name: Optional[str] = field( 56 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 57 | ) 58 | cache_dir: Optional[str] = field( 59 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 60 | ) 61 | load: Optional[str] = field( 62 | default=None, metadata={"help": "the path to load pretrained models"} 63 | ) 64 | beta: float = field( 65 | default=0, metadata={"help": "the regularization term"} 66 | ) 67 | version: float = field( 68 | default=-1, metadata={"help": "version of MI estimator"} 69 | ) 70 | 71 | 72 | def main(): 73 | # See all possible arguments in src/transformers/advtraining_args.py 74 | # or by passing the --help flag to this script. 75 | # We now keep distinct sets of args, for a cleaner separation of concerns. 76 | 77 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 78 | 79 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 80 | # If we pass only one argument to the script and it's the path to a json file, 81 | # let's parse it to get our arguments. 82 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 83 | else: 84 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 85 | 86 | if ( 87 | os.path.exists(training_args.output_dir) 88 | and os.listdir(training_args.output_dir) 89 | and training_args.do_train 90 | and not training_args.overwrite_output_dir 91 | and training_args.local_rank in [-1, 0] 92 | ): 93 | raise ValueError( 94 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 95 | ) 96 | 97 | # Setup logging 98 | root_dir = training_args.output_dir 99 | if not os.path.exists(root_dir) and training_args.local_rank in [-1, 0]: 100 | os.mkdir(root_dir) 101 | 102 | logging.basicConfig( 103 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 104 | datefmt="%m/%d/%Y %H:%M:%S", 105 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 106 | handlers=[ 107 | logging.FileHandler(os.path.join(training_args.output_dir, "log.txt")), 108 | logging.StreamHandler() 109 | ] if training_args.local_rank in [-1, 0] else [logging.StreamHandler()] 110 | ) 111 | logger.warning( 112 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 113 | training_args.local_rank, 114 | training_args.device, 115 | training_args.n_gpu, 116 | bool(training_args.local_rank != -1), 117 | training_args.fp16, 118 | ) 119 | logger.info("Training/evaluation parameters %s", training_args) 120 | 121 | # Set seed 122 | set_seed(training_args.seed) 123 | 124 | try: 125 | num_labels = glue_tasks_num_labels[data_args.task_name] 126 | output_mode = glue_output_modes[data_args.task_name] 127 | except KeyError: 128 | raise ValueError("Task not found: %s" % (data_args.task_name)) 129 | 130 | # Load pretrained model and tokenizer 131 | # 132 | # Distributed training: 133 | # The .from_pretrained methods guarantee that only one local process can concurrently 134 | # download model & vocab. 135 | 136 | config = AutoConfig.from_pretrained( 137 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 138 | num_labels=num_labels, 139 | finetuning_task=data_args.task_name, 140 | cache_dir=model_args.cache_dir, 141 | output_hidden_states=True 142 | ) 143 | tokenizer = AutoTokenizer.from_pretrained( 144 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 145 | cache_dir=model_args.cache_dir, 146 | ) 147 | model = AutoModelForSequenceClassification.from_pretrained( 148 | model_args.model_name_or_path, 149 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 150 | config=config, 151 | cache_dir=model_args.cache_dir, 152 | ) 153 | # take the embedding of the whole sentence as varaible y 154 | if model_args.version >= 0: 155 | if model_args.version == 0: 156 | mi_upper_estimator = CLUB(config.hidden_size * data_args.max_seq_length, config.hidden_size, 157 | beta=model_args.beta).to(training_args.device) 158 | mi_upper_estimator.version = 0 159 | mi_estimator = None 160 | elif model_args.version == 1: 161 | mi_upper_estimator = CLUB(config.hidden_size, config.hidden_size, beta=model_args.beta).to(training_args.device) 162 | # mi_estimator = CLUB(config.hidden_size, config.hidden_size, beta=model_args.beta) 163 | mi_upper_estimator.version = 1 164 | mi_estimator = None 165 | elif model_args.version == 2 or model_args.version == 3: 166 | mi_upper_estimator = CLUBv2(config.hidden_size, config.hidden_size, beta=model_args.beta).to(training_args.device) 167 | mi_upper_estimator.version = model_args.version 168 | mi_estimator = None 169 | elif model_args.version == 4: 170 | mi_estimator = InfoNCE(config.hidden_size, config.hidden_size).to(training_args.device) 171 | mi_upper_estimator = None 172 | elif model_args.version == 5: 173 | mi_estimator = InfoNCE(config.hidden_size, config.hidden_size).to(training_args.device) 174 | mi_upper_estimator = CLUBv2(config.hidden_size, config.hidden_size, beta=model_args.beta).to(training_args.device) 175 | mi_upper_estimator.version = 2 176 | elif model_args.version == 6: 177 | mi_estimator = InfoNCE(config.hidden_size, config.hidden_size).to(training_args.device) 178 | mi_upper_estimator = CLUBv2(config.hidden_size, config.hidden_size, beta=model_args.beta).to(training_args.device) 179 | mi_upper_estimator.version = 3 180 | else: 181 | mi_estimator = None 182 | mi_upper_estimator = None 183 | 184 | # Get datasets 185 | train_dataset = ( 186 | GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None 187 | ) 188 | eval_dataset = ( 189 | GlueDataset(data_args, tokenizer=tokenizer, mode="dev") 190 | if training_args.do_eval 191 | else None 192 | ) 193 | test_dataset = ( 194 | GlueDataset(data_args, tokenizer=tokenizer, mode="test") 195 | if training_args.do_predict 196 | else None 197 | ) 198 | 199 | def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: 200 | def compute_metrics_fn(p: EvalPrediction): 201 | if output_mode == "classification": 202 | preds = np.argmax(p.predictions, axis=1) 203 | elif output_mode == "regression": 204 | preds = np.squeeze(p.predictions) 205 | return glue_compute_metrics(task_name, preds, p.label_ids) 206 | 207 | return compute_metrics_fn 208 | 209 | if model_args.load is not None: 210 | print(model_args.load) 211 | model.load_state_dict(torch.load(os.path.join(model_args.load, "pytorch_model.bin"), map_location=training_args.device)) 212 | if mi_estimator: # and os.path.exists(os.path.join(model_args.load, "mi_estimator.bin")) 213 | mi_estimator.load_state_dict(torch.load(os.path.join(model_args.load, "mi_estimator.bin"), map_location=training_args.device)) 214 | logger.info(f"Load successful from {model_args.load}") 215 | 216 | # Initialize our Trainer 217 | trainer = Trainer( 218 | model=model, 219 | args=training_args, 220 | train_dataset=train_dataset, 221 | eval_dataset=eval_dataset, 222 | compute_metrics=build_compute_metrics_fn(data_args.task_name), 223 | mi_estimator=mi_estimator, 224 | mi_upper_estimator=mi_upper_estimator 225 | ) 226 | trainer.tokenizer = tokenizer 227 | 228 | # Training 229 | if training_args.do_train: 230 | trainer.train_mi_only( 231 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 232 | ) 233 | trainer.save_model() 234 | if mi_estimator: 235 | torch.save(mi_estimator.state_dict(), os.path.join(training_args.output_dir, "mi_estimator.bin")) 236 | # For convenience, we also re-save the tokenizer to the same directory, 237 | # so that you can share your model easily on huggingface.co/models =) 238 | if trainer.is_world_master(): 239 | tokenizer.save_pretrained(training_args.output_dir) 240 | 241 | # Evaluation 242 | eval_results = {} 243 | if training_args.do_eval: 244 | logger.info("*** Evaluate ***") 245 | 246 | # Loop to handle MNLI double evaluation (matched, mis-matched) 247 | eval_datasets = [eval_dataset] 248 | if data_args.task_name == "mnli": 249 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 250 | eval_datasets.append( 251 | GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev") 252 | ) 253 | 254 | if data_args.task_name == 'anli-full' or data_args.task_name == 'anli-part': 255 | eval_tasks = ["anli-r1", "anli-r2", "anli-r3", "mnli", "mnli-mm", "snli"] 256 | for task in eval_tasks: 257 | if "mnli" in task: 258 | task_data_dir = os.path.join(data_args.data_dir, "MNLI") 259 | elif "snli" == task: 260 | task_data_dir = os.path.join(data_args.data_dir, "SNLI") 261 | else: 262 | task_data_dir = data_args.data_dir 263 | task_data_args = dataclasses.replace(data_args, task_name=task, data_dir=task_data_dir) 264 | eval_datasets.append( 265 | GlueDataset(task_data_args, tokenizer=tokenizer, mode="dev") 266 | ) 267 | 268 | for eval_dataset in eval_datasets: 269 | trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name) 270 | eval_result = trainer.evaluate_mi(eval_dataset=eval_dataset) 271 | 272 | output_eval_file = os.path.join( 273 | training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" 274 | ) 275 | if trainer.is_world_master(): 276 | with open(output_eval_file, "w") as writer: 277 | logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) 278 | for key, value in eval_result.items(): 279 | logger.info(" %s = %s", key, value) 280 | writer.write("%s = %s\n" % (key, value)) 281 | 282 | eval_results.update(eval_result) 283 | 284 | if training_args.do_predict: 285 | logging.info("*** Test ***") 286 | test_datasets = [test_dataset] 287 | if data_args.task_name == "mnli": 288 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 289 | test_datasets.append( 290 | GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 291 | ) 292 | 293 | for test_dataset in test_datasets: 294 | predictions = trainer.predict(test_dataset=test_dataset).predictions 295 | if output_mode == "classification": 296 | predictions = np.argmax(predictions, axis=1) 297 | 298 | output_test_file = os.path.join( 299 | training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" 300 | ) 301 | if trainer.is_world_master(): 302 | with open(output_test_file, "w") as writer: 303 | logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) 304 | writer.write("index\tprediction\n") 305 | for index, item in enumerate(predictions): 306 | if output_mode == "regression": 307 | writer.write("%d\t%3.3f\n" % (index, item)) 308 | else: 309 | item = test_dataset.get_labels()[item] 310 | writer.write("%d\t%s\n" % (index, item)) 311 | return eval_results 312 | 313 | 314 | def _mp_fn(index): 315 | # For xla_spawn (TPUs) 316 | main() 317 | 318 | 319 | if __name__ == "__main__": 320 | main() -------------------------------------------------------------------------------- /ANLI/jutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | 8 | def sample_correlated_gaussian(rho=0.5, dim=20, batch_size=128, cubic=None): 9 | """Generate samples from a correlated Gaussian distribution.""" 10 | x, eps = torch.chunk(torch.randn(batch_size, 2 * dim), 2, dim=1) 11 | y = rho * x + torch.sqrt(torch.tensor(1. - rho**2).float()) * eps 12 | 13 | if cubic is not None: 14 | y = y ** 3 15 | 16 | return x, y 17 | 18 | 19 | def rho_to_mi(dim, rho): 20 | return -0.5 * np.log(1-rho**2) * dim 21 | 22 | 23 | def mi_to_rho(dim, mi): 24 | return np.sqrt(1-np.exp(-2.0 / dim * mi)) 25 | 26 | 27 | def mi_schedule(n_iter): 28 | """Generate schedule for increasing correlation over time.""" 29 | mis = np.round(np.linspace(0.5, 5.5-1e-9, n_iter)) * 2.0 30 | return mis.astype(np.float32) 31 | 32 | 33 | def logmeanexp_diag(x): 34 | batch_size = x.size(0) 35 | 36 | logsumexp = torch.logsumexp(x.diag(), dim=(0,)) 37 | num_elem = batch_size 38 | 39 | return logsumexp - torch.log(torch.tensor(num_elem).float()).cuda() 40 | 41 | 42 | def logmeanexp_nodiag(x, dim=None, device='cuda'): 43 | batch_size = x.size(0) 44 | if dim is None: 45 | dim = (0, 1) 46 | 47 | logsumexp = torch.logsumexp( 48 | x - torch.diag(np.inf * torch.ones(batch_size).to(device)), dim=dim) 49 | 50 | try: 51 | if len(dim) == 1: 52 | num_elem = batch_size - 1. 53 | else: 54 | num_elem = batch_size * (batch_size - 1.) 55 | except: 56 | num_elem = batch_size - 1 57 | return logsumexp - torch.log(torch.tensor(num_elem)).to(device) 58 | 59 | 60 | def tuba_lower_bound(scores, log_baseline=None): 61 | if log_baseline is not None: 62 | scores -= log_baseline[:, None] 63 | batch_size = scores.size(0) 64 | 65 | # First term is an expectation over samples from the joint, 66 | # which are the diagonal elmements of the scores matrix. 67 | joint_term = scores.diag().mean() 68 | 69 | # Second term is an expectation over samples from the marginal, 70 | # which are the off-diagonal elements of the scores matrix. 71 | marg_term = logmeanexp_nodiag(scores).exp() 72 | return 1. + joint_term - marg_term 73 | 74 | 75 | def nwj_lower_bound(scores): 76 | return tuba_lower_bound(scores - 1.) 77 | 78 | 79 | def infonce_lower_bound(scores): 80 | nll = scores.diag().mean() - scores.logsumexp(dim=1) 81 | # Alternative implementation: 82 | # nll = -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=tf.range(batch_size)) 83 | mi = torch.tensor(scores.size(0)).float().log() + nll 84 | mi = mi.mean() 85 | return mi 86 | 87 | 88 | def js_fgan_lower_bound(f): 89 | """Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016).""" 90 | f_diag = f.diag() 91 | first_term = -F.softplus(-f_diag).mean() 92 | n = f.size(0) 93 | second_term = (torch.sum(F.softplus(f)) - 94 | torch.sum(F.softplus(f_diag))) / (n * (n - 1.)) 95 | return first_term - second_term 96 | 97 | 98 | def js_lower_bound(f): 99 | nwj = nwj_lower_bound(f) 100 | js = js_fgan_lower_bound(f) 101 | 102 | with torch.no_grad(): 103 | nwj_js = nwj - js 104 | 105 | return js + nwj_js 106 | 107 | 108 | def dv_upper_lower_bound(f): 109 | """DV lower bound, but upper bounded by using log outside.""" 110 | first_term = f.diag().mean() 111 | second_term = logmeanexp_nodiag(f) 112 | 113 | return first_term - second_term 114 | 115 | 116 | def mine_lower_bound(f, buffer=None, momentum=0.9): 117 | if buffer is None: 118 | buffer = torch.tensor(1.0).cuda() 119 | first_term = f.diag().mean() 120 | 121 | buffer_update = logmeanexp_nodiag(f).exp() 122 | with torch.no_grad(): 123 | second_term = logmeanexp_nodiag(f) 124 | buffer_new = buffer * momentum + buffer_update * (1 - momentum) 125 | buffer_new = torch.clamp(buffer_new, min=1e-4) 126 | third_term_no_grad = buffer_update / buffer_new 127 | 128 | third_term_grad = buffer_update / buffer_new 129 | 130 | return first_term - second_term - third_term_grad + third_term_no_grad, buffer_update 131 | 132 | 133 | def regularized_dv_bound(f, l=0.0): 134 | first_term = f.diag().mean() 135 | second_term = logmeanexp_nodiag(f) 136 | 137 | reg_term = l * (second_term.exp() - 1) ** 2 138 | 139 | with torch.no_grad(): 140 | reg_term_no_grad = reg_term 141 | 142 | return first_term - second_term + reg_term - reg_term_no_grad 143 | 144 | 145 | def renorm_q(f, alpha=1.0, clip=None): 146 | if clip is not None: 147 | f = torch.clamp(f * alpha, -clip, clip) 148 | z = logmeanexp_nodiag(f * alpha, dim=(0, 1)) 149 | return z 150 | 151 | 152 | def disc_renorm_q(f): 153 | batch_size = f.size(0) 154 | z = torch.zeros(1, requires_grad=True, device='cuda') 155 | 156 | opt = optim.SGD([z], lr=0.001) 157 | for i in range(10): 158 | opt.zero_grad() 159 | 160 | first_term = -F.softplus(z - f).diag().mean() 161 | st = -F.softplus(f - z) 162 | second_term = (st - st.diag().diag()).sum() / \ 163 | (batch_size * (batch_size - 1.)) 164 | total = first_term + second_term 165 | 166 | total.backward(retain_graph=True) 167 | opt.step() 168 | 169 | if total.item() <= -2 * np.log(2): 170 | break 171 | 172 | return z 173 | 174 | 175 | def renorm_p(f, alpha=1.0): 176 | z = logmeanexp_diag(-f * alpha) 177 | return z 178 | 179 | 180 | def smile_lower_bound(f, alpha=1.0, clip=None): 181 | z = renorm_q(f, alpha, clip) 182 | dv = f.diag().mean() - z 183 | 184 | js = js_fgan_lower_bound(f) 185 | 186 | with torch.no_grad(): 187 | dv_js = dv - js 188 | 189 | return js + dv_js 190 | 191 | 192 | def js_dv_disc_renorm_lower_bound(f): 193 | z = disc_renorm_q(f) 194 | dv = f.diag().mean() - z.mean() 195 | 196 | js = js_fgan_lower_bound(f) 197 | 198 | with torch.no_grad(): 199 | dv_js = dv - js 200 | 201 | return js + dv_js 202 | 203 | 204 | def vae_lower_bound(f): 205 | f1, f2 = f 206 | n = f1.size(0) 207 | logp = f1.mean() 208 | logq = (f2.sum() - f2.diag().sum()) / (n * (n-1.)) 209 | 210 | with torch.no_grad(): 211 | logq_nograd = logq * 1.0 212 | logqd = f2.diag().mean() 213 | logp_nograd = logp 214 | 215 | return logp - logqd + logq - logq_nograd # logp - logqd + logq - logq_nograd 216 | 217 | 218 | def js_nwj_renorm_lower_bound(f, alpha=1.0): 219 | z = renorm_q(f - 1.0, alpha) 220 | 221 | nwj = nwj_lower_bound(f - z) 222 | js = js_fgan_lower_bound(f) 223 | 224 | with torch.no_grad(): 225 | nwj_js = nwj - js 226 | 227 | return js + nwj_js 228 | 229 | 230 | def estimate_p_norm(f, alpha=1.0): 231 | z = renorm_q(f, alpha) 232 | # f = renorm_p(f, alpha) 233 | # f = renorm_q(f, alpha) 234 | f = f - z 235 | f = -f 236 | 237 | return f.diag().exp().mean() 238 | 239 | 240 | def estimate_mutual_information(estimator, x, y, critic_fn, 241 | baseline_fn=None, alpha_logit=None, **kwargs): 242 | """Estimate variational lower bounds on mutual information. 243 | 244 | Args: 245 | estimator: string specifying estimator, one of: 246 | 'nwj', 'infonce', 'tuba', 'js', 'interpolated' 247 | x: [batch_size, dim_x] Tensor 248 | y: [batch_size, dim_y] Tensor 249 | critic_fn: callable that takes x and y as input and outputs critic scores 250 | output shape is a [batch_size, batch_size] matrix 251 | baseline_fn (optional): callable that takes y as input 252 | outputs a [batch_size] or [batch_size, 1] vector 253 | alpha_logit (optional): logit(alpha) for interpolated bound 254 | 255 | Returns: 256 | scalar estimate of mutual information 257 | """ 258 | x, y = x.cuda(), y.cuda() 259 | scores = critic_fn(x, y) 260 | if baseline_fn is not None: 261 | # Some baselines' output is (batch_size, 1) which we remove here. 262 | log_baseline = torch.squeeze(baseline_fn(y)) 263 | if estimator == 'infonce': 264 | mi = infonce_lower_bound(scores) 265 | elif estimator == 'nwj': 266 | mi = nwj_lower_bound(scores) 267 | elif estimator == 'tuba': 268 | mi = tuba_lower_bound(scores, log_baseline) 269 | elif estimator == 'js': 270 | mi = js_lower_bound(scores) 271 | elif estimator == 'smile': 272 | mi = smile_lower_bound(scores, **kwargs) 273 | elif estimator == 'dv_disc_normalized': 274 | mi = js_dv_disc_renorm_lower_bound(scores, **kwargs) 275 | elif estimator == 'nwj_normalized': 276 | mi = js_nwj_renorm_lower_bound(scores, **kwargs) 277 | elif estimator == 'dv': 278 | mi = dv_upper_lower_bound(scores) 279 | # p_norm = estimate_p_norm(scores * kwargs.get('alpha', 1.0)) 280 | if estimator is not 'smile': 281 | p_norm = renorm_q(scores) 282 | else: 283 | p_norm = renorm_q(scores, alpha=kwargs.get( 284 | 'alpha', 1.0), clip=kwargs.get('clip', None)) 285 | return mi, p_norm 286 | 287 | 288 | def mlp(dim, hidden_dim, output_dim, layers, activation): 289 | activation = { 290 | 'relu': nn.ReLU 291 | }[activation] 292 | 293 | seq = [nn.Linear(dim, hidden_dim), activation()] 294 | for _ in range(layers): 295 | seq += [nn.Linear(hidden_dim, hidden_dim), activation()] 296 | seq += [nn.Linear(hidden_dim, output_dim)] 297 | 298 | return nn.Sequential(*seq) 299 | 300 | 301 | class SeparableCritic(nn.Module): 302 | def __init__(self, dim, hidden_dim, embed_dim, layers, activation, **extra_kwargs): 303 | super(SeparableCritic, self).__init__() 304 | self._g = mlp(dim, hidden_dim, embed_dim, layers, activation) 305 | self._h = mlp(dim, hidden_dim, embed_dim, layers, activation) 306 | 307 | def forward(self, x, y): 308 | scores = torch.matmul(self._h(y), self._g(x).t()) 309 | return scores 310 | 311 | 312 | class ConcatCritic(nn.Module): 313 | def __init__(self, dim, hidden_dim, layers, activation, **extra_kwargs): 314 | super(ConcatCritic, self).__init__() 315 | # output is scalar score 316 | self._f = mlp(dim * 2, hidden_dim, 1, layers, activation) 317 | 318 | def forward(self, x, y): 319 | batch_size = x.size(0) 320 | # Tile all possible combinations of x and y 321 | x_tiled = torch.stack([x] * batch_size, dim=0) 322 | y_tiled = torch.stack([y] * batch_size, dim=1) 323 | # xy is [batch_size * batch_size, x_dim + y_dim] 324 | xy_pairs = torch.reshape(torch.cat((x_tiled, y_tiled), dim=2), [ 325 | batch_size * batch_size, -1]) 326 | # Compute scores for each x_i, y_j pair. 327 | scores = self._f(xy_pairs) 328 | return torch.reshape(scores, [batch_size, batch_size]).t() 329 | 330 | 331 | def log_prob_gaussian(x): 332 | return torch.sum(torch.distributions.Normal(0., 1.).log_prob(x), -1) 333 | 334 | 335 | dim = 20 336 | 337 | 338 | CRITICS = { 339 | 'separable': SeparableCritic, 340 | 'concat': ConcatCritic, 341 | } 342 | 343 | BASELINES = { 344 | 'constant': lambda: None, 345 | 'unnormalized': lambda: mlp(dim=dim, hidden_dim=512, output_dim=1, layers=2, activation='relu').cuda(), 346 | 'gaussian': lambda: log_prob_gaussian, 347 | } 348 | 349 | 350 | def train_estimator(critic_params, data_params, mi_params, opt_params, **kwargs): 351 | """Main training loop that estimates time-varying MI.""" 352 | # Ground truth rho is only used by conditional critic 353 | critic = CRITICS[mi_params.get('critic', 'separable')]( 354 | rho=None, **critic_params).cuda() 355 | baseline = BASELINES[mi_params.get('baseline', 'constant')]() 356 | 357 | opt_crit = optim.Adam(critic.parameters(), lr=opt_params['learning_rate']) 358 | if isinstance(baseline, nn.Module): 359 | opt_base = optim.Adam(baseline.parameters(), 360 | lr=opt_params['learning_rate']) 361 | else: 362 | opt_base = None 363 | 364 | def train_step(rho, data_params, mi_params): 365 | # Annoying special case: 366 | # For the true conditional, the critic depends on the true correlation rho, 367 | # so we rebuild the critic at each iteration. 368 | opt_crit.zero_grad() 369 | if isinstance(baseline, nn.Module): 370 | opt_base.zero_grad() 371 | 372 | if mi_params['critic'] == 'conditional': 373 | critic_ = CRITICS['conditional'](rho=rho).cuda() 374 | else: 375 | critic_ = critic 376 | 377 | x, y = sample_correlated_gaussian( 378 | dim=data_params['dim'], rho=rho, batch_size=data_params['batch_size'], cubic=data_params['cubic']) 379 | mi, p_norm = estimate_mutual_information( 380 | mi_params['estimator'], x, y, critic_, baseline, mi_params.get('alpha_logit', None), **kwargs) 381 | loss = -mi 382 | 383 | loss.backward() 384 | opt_crit.step() 385 | if isinstance(baseline, nn.Module): 386 | opt_base.step() 387 | 388 | return mi, p_norm 389 | 390 | # Schedule of correlation over iterations 391 | mis = mi_schedule(opt_params['iterations']) 392 | rhos = mi_to_rho(data_params['dim'], mis) 393 | 394 | estimates = [] 395 | p_norms = [] 396 | for i in range(opt_params['iterations']): 397 | mi, p_norm = train_step( 398 | rhos[i], data_params, mi_params) 399 | mi = mi.detach().cpu().numpy() 400 | p_norm = p_norm.detach().cpu().numpy() 401 | estimates.append(mi) 402 | p_norms.append(p_norm) 403 | 404 | return np.array(estimates), np.array(p_norms) 405 | -------------------------------------------------------------------------------- /ANLI/print_table.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | root_dir = "results/eval-all/" 5 | eval_list = [ 6 | "anli-full-dev", "anli-full-test", "anli-r1-dev", "anli-r1-test", "anli-r2-dev", "anli-r2-test", 7 | "anli-r3-dev", "anli-r3-test", "mnli-dev", "mnli-mm-dev", "mnli-bert-adv", "mnli-mm-bert-adv", 8 | "mnli-roberta-adv", "mnli-mm-roberta-adv", "snli-bert-adv", "snli-dev", "snli-roberta-adv" 9 | ] 10 | 11 | def extract_accuracy(lines): 12 | return {eval_item: round(float(lines[i+2].split()[-1]) * 100, 1) 13 | for i in range(0, len(lines), 5) 14 | for eval_item in eval_list if eval_item in lines[i]} 15 | 16 | def print_table(results, label): 17 | dev_items = [results.get(f"{eval_item}-dev", '-') for eval_item in eval_list] 18 | test_items = [results.get(f"{eval_item}-test", '-') for eval_item in eval_list] 19 | 20 | print(f"=====For {label} table=============") 21 | print(" & ".join(map(str, dev_items + test_items))) 22 | 23 | for file in os.listdir(root_dir): 24 | cur_path = os.path.join(root_dir, file) 25 | print(file) 26 | 27 | with open(os.path.join(cur_path, "eval_results.txt")) as f: 28 | lines = f.readlines() 29 | results = extract_accuracy(lines) 30 | 31 | print_table(results, "ANLI") 32 | print_table(results, "TextFooler") 33 | 34 | -------------------------------------------------------------------------------- /ANLI/run_anli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" 17 | 18 | import dataclasses 19 | import logging 20 | import os 21 | import sys 22 | from dataclasses import dataclass, field 23 | from typing import Callable, Dict, Optional 24 | 25 | import numpy as np 26 | import torch 27 | 28 | from transformers import AutoConfig, AutoTokenizer, EvalPrediction 29 | from models.modeling_auto import AutoModelForSequenceClassification 30 | from MI_estimators import CLUB, CLUBv2, InfoNCE 31 | from datasets.anli import GlueDataset, GlueDataTrainingArguments as DataTrainingArguments 32 | from processors.anli import glue_output_modes, glue_tasks_num_labels, glue_compute_metrics 33 | from local_robust_trainer import Trainer 34 | from transformers import ( 35 | HfArgumentParser, 36 | set_seed, 37 | ) 38 | from advtraining_args import TrainingArguments 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | @dataclass 44 | class ModelArguments: 45 | """ 46 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 47 | """ 48 | 49 | model_name_or_path: str = field( 50 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 51 | ) 52 | config_name: Optional[str] = field( 53 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 54 | ) 55 | tokenizer_name: Optional[str] = field( 56 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 57 | ) 58 | cache_dir: Optional[str] = field( 59 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 60 | ) 61 | load: Optional[str] = field( 62 | default=None, metadata={"help": "the path to load pretrained models"} 63 | ) 64 | beta: float = field( 65 | default=0, metadata={"help": "the regularization term"} 66 | ) 67 | version: float = field( 68 | default=-1, metadata={"help": "version of MI estimator"} 69 | ) 70 | 71 | 72 | def main(): 73 | # See all possible arguments in src/transformers/advtraining_args.py 74 | # or by passing the --help flag to this script. 75 | # We now keep distinct sets of args, for a cleaner separation of concerns. 76 | 77 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 78 | 79 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 80 | # If we pass only one argument to the script and it's the path to a json file, 81 | # let's parse it to get our arguments. 82 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 83 | else: 84 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 85 | 86 | if ( 87 | os.path.exists(training_args.output_dir) 88 | and os.listdir(training_args.output_dir) 89 | and training_args.do_train 90 | and not training_args.overwrite_output_dir 91 | and training_args.local_rank in [-1, 0] 92 | ): 93 | raise ValueError( 94 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 95 | ) 96 | 97 | # Setup logging 98 | root_dir = training_args.output_dir 99 | if not os.path.exists(root_dir) and training_args.local_rank in [-1, 0]: 100 | os.mkdir(root_dir) 101 | 102 | logging.basicConfig( 103 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 104 | datefmt="%m/%d/%Y %H:%M:%S", 105 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 106 | handlers=[ 107 | logging.FileHandler(os.path.join(training_args.output_dir, "log.txt")), 108 | logging.StreamHandler() 109 | ] if training_args.local_rank in [-1, 0] else [logging.StreamHandler()] 110 | ) 111 | logger.warning( 112 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 113 | training_args.local_rank, 114 | training_args.device, 115 | training_args.n_gpu, 116 | bool(training_args.local_rank != -1), 117 | training_args.fp16, 118 | ) 119 | logger.info("Training/evaluation parameters %s", training_args) 120 | 121 | # Set seed 122 | set_seed(training_args.seed) 123 | 124 | try: 125 | num_labels = glue_tasks_num_labels[data_args.task_name] 126 | output_mode = glue_output_modes[data_args.task_name] 127 | except KeyError: 128 | raise ValueError("Task not found: %s" % (data_args.task_name)) 129 | 130 | # Load pretrained model and tokenizer 131 | # 132 | # Distributed training: 133 | # The .from_pretrained methods guarantee that only one local process can concurrently 134 | # download model & vocab. 135 | 136 | config = AutoConfig.from_pretrained( 137 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 138 | num_labels=num_labels, 139 | finetuning_task=data_args.task_name, 140 | cache_dir=model_args.cache_dir, 141 | output_hidden_states=True, 142 | attention_probs_dropout_prob=training_args.attention_probs_dropout_prob, 143 | hidden_dropout_prob=training_args.hidden_dropout_prob 144 | ) 145 | tokenizer = AutoTokenizer.from_pretrained( 146 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 147 | cache_dir=model_args.cache_dir, 148 | ) 149 | model = AutoModelForSequenceClassification.from_pretrained( 150 | model_args.model_name_or_path, 151 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 152 | config=config, 153 | cache_dir=model_args.cache_dir, 154 | ) 155 | # take the embedding of the whole sentence as varaible y 156 | if model_args.version >= 0: 157 | if model_args.version == 0: 158 | mi_upper_estimator = CLUB(config.hidden_size * data_args.max_seq_length, config.hidden_size, 159 | beta=model_args.beta).to(training_args.device) 160 | mi_upper_estimator.version = 0 161 | mi_estimator = None 162 | elif model_args.version == 1: 163 | mi_upper_estimator = CLUB(config.hidden_size, config.hidden_size, beta=model_args.beta).to(training_args.device) 164 | # mi_estimator = CLUB(config.hidden_size, config.hidden_size, beta=model_args.beta) 165 | mi_upper_estimator.version = 1 166 | mi_estimator = None 167 | elif model_args.version == 2 or model_args.version == 3: 168 | mi_upper_estimator = CLUBv2(config.hidden_size, config.hidden_size, beta=model_args.beta).to(training_args.device) 169 | mi_upper_estimator.version = model_args.version 170 | mi_estimator = None 171 | elif model_args.version == 4: 172 | mi_estimator = InfoNCE(config.hidden_size, config.hidden_size).to(training_args.device) 173 | mi_upper_estimator = None 174 | elif model_args.version == 5: 175 | mi_estimator = InfoNCE(config.hidden_size, config.hidden_size).to(training_args.device) 176 | mi_upper_estimator = CLUBv2(config.hidden_size, config.hidden_size, beta=model_args.beta).to(training_args.device) 177 | mi_upper_estimator.version = 2 178 | elif model_args.version == 6: 179 | mi_estimator = InfoNCE(config.hidden_size, config.hidden_size).to(training_args.device) 180 | mi_upper_estimator = CLUBv2(config.hidden_size, config.hidden_size, beta=model_args.beta).to(training_args.device) 181 | mi_upper_estimator.version = 3 182 | else: 183 | mi_estimator = None 184 | mi_upper_estimator = None 185 | 186 | # Get datasets 187 | train_dataset = ( 188 | GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None 189 | ) 190 | eval_dataset = ( 191 | GlueDataset(data_args, tokenizer=tokenizer, mode="dev") 192 | if training_args.do_eval 193 | else None 194 | ) 195 | test_dataset = ( 196 | GlueDataset(data_args, tokenizer=tokenizer, mode="test") 197 | if training_args.do_predict 198 | else None 199 | ) 200 | 201 | def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: 202 | def compute_metrics_fn(p: EvalPrediction): 203 | if output_mode == "classification": 204 | preds = np.argmax(p.predictions, axis=1) 205 | elif output_mode == "regression": 206 | preds = np.squeeze(p.predictions) 207 | return glue_compute_metrics(task_name, preds, p.label_ids) 208 | 209 | return compute_metrics_fn 210 | 211 | if model_args.load is not None: 212 | print(model_args.load) 213 | model.load_state_dict(torch.load(os.path.join(model_args.load, "pytorch_model.bin"))) 214 | if mi_estimator: 215 | mi_estimator.load_state_dict(torch.load(os.path.join(model_args.load, "mi_estimator.bin"))) 216 | logger.info(f"Load successful from {model_args.load}") 217 | 218 | if os.path.isdir(model_args.model_name_or_path): 219 | if mi_estimator: 220 | mi_estimator.load_state_dict(torch.load(os.path.join(model_args.model_name_or_path, "mi_estimator.bin"))) 221 | logger.info(f"Load mi estimator successful from {model_args.model_name_or_path}") 222 | 223 | 224 | # Initialize our Trainer 225 | trainer = Trainer( 226 | model=model, 227 | args=training_args, 228 | train_dataset=train_dataset, 229 | eval_dataset=eval_dataset, 230 | compute_metrics=build_compute_metrics_fn(data_args.task_name), 231 | mi_estimator=mi_estimator, 232 | mi_upper_estimator=mi_upper_estimator 233 | ) 234 | trainer.tokenizer = tokenizer 235 | 236 | # Training 237 | if training_args.do_train: 238 | trainer.train( 239 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 240 | ) 241 | trainer.save_model() 242 | # For convenience, we also re-save the tokenizer to the same directory, 243 | # so that you can share your model easily on huggingface.co/models =) 244 | if trainer.is_world_master(): 245 | tokenizer.save_pretrained(training_args.output_dir) 246 | 247 | if mi_estimator: 248 | torch.save(mi_estimator.state_dict(), os.path.join(training_args.output_dir, "mi_estimator.bin")) 249 | 250 | torch.save(trainer.eval_hist, os.path.join(training_args.output_dir, 'eval_hist.bin')) 251 | 252 | # Evaluation 253 | eval_results = {} 254 | if training_args.do_eval: 255 | logger.info("*** Evaluate ***") 256 | 257 | # Loop to handle MNLI double evaluation (matched, mis-matched) 258 | eval_datasets = [eval_dataset] 259 | if data_args.task_name == "mnli": 260 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 261 | eval_datasets.append( 262 | GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev") 263 | ) 264 | if 'anli' in data_args.task_name: 265 | eval_datasets.append( 266 | GlueDataset(data_args, tokenizer=tokenizer, mode="test") 267 | ) 268 | 269 | if data_args.task_name == 'anli-full' or data_args.task_name == 'anli-part': 270 | eval_tasks = ["anli-r1", "anli-r2", "anli-r3", "mnli", "mnli-mm", "snli", 271 | "mnli-bert-adv", "mnli-mm-bert-adv", "snli-bert-adv", 272 | "mnli-roberta-adv", "mnli-mm-roberta-adv", "snli-roberta-adv"] 273 | for task in eval_tasks: 274 | if "mnli" in task and 'adv' not in task: 275 | task_data_dir = os.path.join(data_args.data_dir, "MNLI") 276 | elif "snli" == task and 'adv' not in task: 277 | task_data_dir = os.path.join(data_args.data_dir, "SNLI") 278 | else: 279 | task_data_dir = data_args.data_dir 280 | task_data_args = dataclasses.replace(data_args, task_name=task, data_dir=task_data_dir) 281 | eval_datasets.append( 282 | GlueDataset(task_data_args, tokenizer=tokenizer, mode="dev") 283 | ) 284 | if 'anli' in task: 285 | eval_datasets.append( 286 | GlueDataset(task_data_args, tokenizer=tokenizer, mode="test") 287 | ) 288 | 289 | for eval_dataset in eval_datasets: 290 | trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name) 291 | eval_result = trainer.evaluate(eval_dataset=eval_dataset) 292 | # eval_result = trainer.evaluate_mi(eval_dataset=eval_dataset) 293 | 294 | output_eval_file = os.path.join( 295 | training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}-{eval_dataset.mode}.txt" 296 | ) 297 | if trainer.is_world_master(): 298 | with open(output_eval_file, "w") as writer: 299 | logger.info(f"***** Eval results {eval_dataset.args.task_name}-{eval_dataset.mode} *****") 300 | for key, value in eval_result.items(): 301 | logger.info(" %s = %s", key, value) 302 | writer.write("%s = %s\n" % (key, value)) 303 | 304 | eval_results.update(eval_result) 305 | 306 | if trainer.eval_hist: 307 | best_eval = trainer.eval_hist[0] 308 | for eval in trainer.eval_hist: 309 | if eval['eval_acc'] > best_eval['eval_acc']: 310 | best_eval = eval 311 | output_eval_file = os.path.join( 312 | training_args.output_dir, f"best_eval_results_{data_args.task_name}_.txt" 313 | ) 314 | with open(output_eval_file, "w") as writer: 315 | logger.info("***** Best Eval results {} *****".format(data_args.task_name)) 316 | for key, value in best_eval.items(): 317 | logger.info(" %s = %s", key, value) 318 | writer.write("%s = %s\n" % (key, value)) 319 | 320 | del trainer.model 321 | torch.cuda.empty_cache() 322 | # re-evaluate the best parameters 323 | checkpoint = os.path.join(training_args.output_dir, f"checkpoint-{best_eval['step']}") 324 | # trainer.model.load_state_dict(torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))) 325 | trainer.model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(training_args.device) 326 | logger.info(f"successfully load from {checkpoint}") 327 | 328 | for eval_dataset in eval_datasets: 329 | trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name) 330 | # eval_result = trainer.evaluate(eval_dataset=eval_dataset) 331 | eval_result = trainer.evaluate(eval_dataset=eval_dataset) 332 | 333 | output_eval_file = os.path.join( 334 | training_args.output_dir, f"best_eval_results_{eval_dataset.args.task_name}-{eval_dataset.mode}.txt" 335 | ) 336 | if trainer.is_world_master(): 337 | with open(output_eval_file, "w") as writer: 338 | logger.info(f"***** Best Eval results {eval_dataset.args.task_name}--{eval_dataset.mode} *****") 339 | for key, value in eval_result.items(): 340 | logger.info(" %s = %s", key, value) 341 | writer.write("%s = %s\n" % (key, value)) 342 | 343 | # # double eval to test whether there is stocahsticasty during evaluation 344 | # for eval_dataset in eval_datasets: 345 | # trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name) 346 | # # eval_result = trainer.evaluate(eval_dataset=eval_dataset) 347 | # eval_result = trainer.evaluate(eval_dataset=eval_dataset) 348 | # 349 | # output_eval_file = os.path.join( 350 | # training_args.output_dir, f"best_eval_results_{eval_dataset.args.task_name}.txt" 351 | # ) 352 | # if trainer.is_world_master(): 353 | # with open(output_eval_file, "w") as writer: 354 | # logger.info("***** Best Eval results {} *****".format(eval_dataset.args.task_name)) 355 | # for key, value in eval_result.items(): 356 | # logger.info(" %s = %s", key, value) 357 | # writer.write("%s = %s\n" % (key, value)) 358 | 359 | if training_args.do_predict: 360 | logging.info("*** Test ***") 361 | test_datasets = [test_dataset] 362 | if data_args.task_name == "mnli": 363 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 364 | test_datasets.append( 365 | GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 366 | ) 367 | 368 | for test_dataset in test_datasets: 369 | predictions = trainer.predict(test_dataset=test_dataset).predictions 370 | if output_mode == "classification": 371 | predictions = np.argmax(predictions, axis=1) 372 | 373 | output_test_file = os.path.join( 374 | training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" 375 | ) 376 | if trainer.is_world_master(): 377 | with open(output_test_file, "w") as writer: 378 | logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) 379 | writer.write("index\tprediction\n") 380 | for index, item in enumerate(predictions): 381 | if output_mode == "regression": 382 | writer.write("%d\t%3.3f\n" % (index, item)) 383 | else: 384 | item = test_dataset.get_labels()[item] 385 | writer.write("%d\t%s\n" % (index, item)) 386 | return eval_results 387 | 388 | 389 | def _mp_fn(index): 390 | # For xla_spawn (TPUs) 391 | main() 392 | 393 | 394 | if __name__ == "__main__": 395 | main() -------------------------------------------------------------------------------- /ANLI/run_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" 17 | 18 | 19 | import dataclasses 20 | import logging 21 | import os 22 | import sys 23 | from dataclasses import dataclass, field 24 | from typing import Callable, Dict, Optional 25 | 26 | import numpy as np 27 | import torch 28 | 29 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset 30 | from transformers import GlueDataTrainingArguments as DataTrainingArguments 31 | from transformers import ( 32 | HfArgumentParser, 33 | Trainer, 34 | TrainingArguments, 35 | glue_compute_metrics, 36 | glue_output_modes, 37 | glue_tasks_num_labels, 38 | set_seed, 39 | ) 40 | 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | @dataclass 46 | class ModelArguments: 47 | """ 48 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 49 | """ 50 | 51 | model_name_or_path: str = field( 52 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 53 | ) 54 | config_name: Optional[str] = field( 55 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 56 | ) 57 | tokenizer_name: Optional[str] = field( 58 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 59 | ) 60 | cache_dir: Optional[str] = field( 61 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 62 | ) 63 | load: Optional[str] = field( 64 | default=None, metadata={"help": "the path to load pretrained models"} 65 | ) 66 | 67 | 68 | def main(): 69 | # See all possible arguments in src/transformers/advtraining_args.py 70 | # or by passing the --help flag to this script. 71 | # We now keep distinct sets of args, for a cleaner separation of concerns. 72 | 73 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 74 | 75 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 76 | # If we pass only one argument to the script and it's the path to a json file, 77 | # let's parse it to get our arguments. 78 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 79 | else: 80 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 81 | 82 | 83 | if ( 84 | os.path.exists(training_args.output_dir) 85 | and os.listdir(training_args.output_dir) 86 | and training_args.do_train 87 | and not training_args.overwrite_output_dir 88 | ): 89 | raise ValueError( 90 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 91 | ) 92 | 93 | # Setup logging 94 | root_dir = training_args.output_dir 95 | if not os.path.exists(root_dir): 96 | os.mkdir(root_dir) 97 | logging.basicConfig( 98 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 99 | datefmt="%m/%d/%Y %H:%M:%S", 100 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 101 | handlers=[ 102 | logging.FileHandler(os.path.join(training_args.output_dir, "log.txt")), 103 | logging.StreamHandler() 104 | ] 105 | ) 106 | logger.warning( 107 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 108 | training_args.local_rank, 109 | training_args.device, 110 | training_args.n_gpu, 111 | bool(training_args.local_rank != -1), 112 | training_args.fp16, 113 | ) 114 | logger.info("Training/evaluation parameters %s", training_args) 115 | 116 | # Set seed 117 | set_seed(training_args.seed) 118 | 119 | try: 120 | num_labels = glue_tasks_num_labels[data_args.task_name] 121 | output_mode = glue_output_modes[data_args.task_name] 122 | except KeyError: 123 | raise ValueError("Task not found: %s" % (data_args.task_name)) 124 | 125 | # Load pretrained model and tokenizer 126 | # 127 | # Distributed training: 128 | # The .from_pretrained methods guarantee that only one local process can concurrently 129 | # download model & vocab. 130 | 131 | config = AutoConfig.from_pretrained( 132 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 133 | num_labels=num_labels, 134 | finetuning_task=data_args.task_name, 135 | cache_dir=model_args.cache_dir, 136 | ) 137 | tokenizer = AutoTokenizer.from_pretrained( 138 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 139 | cache_dir=model_args.cache_dir, 140 | ) 141 | model = AutoModelForSequenceClassification.from_pretrained( 142 | model_args.model_name_or_path, 143 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 144 | config=config, 145 | cache_dir=model_args.cache_dir, 146 | ) 147 | 148 | # Get datasets 149 | train_dataset = ( 150 | GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None 151 | ) 152 | eval_dataset = ( 153 | GlueDataset(data_args, tokenizer=tokenizer, mode="dev") 154 | if training_args.do_eval 155 | else None 156 | ) 157 | test_dataset = ( 158 | GlueDataset(data_args, tokenizer=tokenizer, mode="test") 159 | if training_args.do_predict 160 | else None 161 | ) 162 | 163 | def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: 164 | def compute_metrics_fn(p: EvalPrediction): 165 | if output_mode == "classification": 166 | preds = np.argmax(p.predictions, axis=1) 167 | elif output_mode == "regression": 168 | preds = np.squeeze(p.predictions) 169 | return glue_compute_metrics(task_name, preds, p.label_ids) 170 | 171 | return compute_metrics_fn 172 | 173 | if model_args.load is not None: 174 | model.load_state_dict(torch.load(model_args.load)) 175 | logger.info(f"Load successful from {model_args.load}") 176 | 177 | # Initialize our Trainer 178 | trainer = Trainer( 179 | model=model, 180 | args=training_args, 181 | train_dataset=train_dataset, 182 | eval_dataset=eval_dataset, 183 | compute_metrics=build_compute_metrics_fn(data_args.task_name), 184 | ) 185 | 186 | # Training 187 | if training_args.do_train: 188 | trainer.train( 189 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 190 | ) 191 | trainer.save_model() 192 | # For convenience, we also re-save the tokenizer to the same directory, 193 | # so that you can share your model easily on huggingface.co/models =) 194 | if trainer.is_world_master(): 195 | tokenizer.save_pretrained(training_args.output_dir) 196 | 197 | # Evaluation 198 | eval_results = {} 199 | if training_args.do_eval: 200 | logger.info("*** Evaluate ***") 201 | 202 | # Loop to handle MNLI double evaluation (matched, mis-matched) 203 | eval_datasets = [eval_dataset] 204 | if data_args.task_name == "mnli": 205 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 206 | eval_datasets.append( 207 | GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev") 208 | ) 209 | 210 | for eval_dataset in eval_datasets: 211 | trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name) 212 | eval_result = trainer.evaluate(eval_dataset=eval_dataset) 213 | 214 | output_eval_file = os.path.join( 215 | training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" 216 | ) 217 | if trainer.is_world_master(): 218 | with open(output_eval_file, "w") as writer: 219 | logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) 220 | for key, value in eval_result.items(): 221 | logger.info(" %s = %s", key, value) 222 | writer.write("%s = %s\n" % (key, value)) 223 | 224 | eval_results.update(eval_result) 225 | 226 | if training_args.do_predict: 227 | logging.info("*** Test ***") 228 | test_datasets = [test_dataset] 229 | if data_args.task_name == "mnli": 230 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 231 | test_datasets.append( 232 | GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 233 | ) 234 | 235 | for test_dataset in test_datasets: 236 | predictions = trainer.predict(test_dataset=test_dataset).predictions 237 | if output_mode == "classification": 238 | predictions = np.argmax(predictions, axis=1) 239 | 240 | output_test_file = os.path.join( 241 | training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" 242 | ) 243 | if trainer.is_world_master(): 244 | with open(output_test_file, "w") as writer: 245 | logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) 246 | writer.write("index\tprediction\n") 247 | for index, item in enumerate(predictions): 248 | if output_mode == "regression": 249 | writer.write("%d\t%3.3f\n" % (index, item)) 250 | else: 251 | item = test_dataset.get_labels()[item] 252 | writer.write("%d\t%s\n" % (index, item)) 253 | return eval_results 254 | 255 | 256 | def _mp_fn(index): 257 | # For xla_spawn (TPUs) 258 | main() 259 | 260 | 261 | if __name__ == "__main__": 262 | main() -------------------------------------------------------------------------------- /ANLI/setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function runexp { 4 | 5 | export GLUE_DIR=anli_data 6 | export TASK_NAME=${1} 7 | 8 | custom=${2} # Custom name 9 | mname=${3} # Model name 10 | lr=${4} # Learning rate for model parameters 11 | bsize=${5} # Batch size 12 | seqlen=${6} # Maximum sequence length 13 | ts=${7} # Number of training steps (counted as parameter updates) 14 | ws=${8} # Learning rate warm-up steps 15 | seed=${9} # Seed for randomness 16 | wd=${10} # Weight decay 17 | beta=${11} # regularizer coefficient 18 | version=${12} # mi estimator version 19 | hdp=${13} # Hidden layer dropouts for ALBERT 20 | adp=${14} # Attention dropouts for ALBERT 21 | alr=${15} # Step size of gradient ascent 22 | amag=${16} # Magnitude of initial (adversarial?) perturbation 23 | anorm=${17} # Maximum norm of adversarial perturbation 24 | asteps=${18} # Number of gradient ascent steps for the adversary 25 | alpha=${19} # alpha for controlling local robust regularizer 26 | cl=${20} # lower threshold 27 | ch=${21} # higher threshold 28 | 29 | if [ -z "${20}" ] ;then 30 | cl=0.5 31 | fi 32 | 33 | if [ -z "${21}" ] ;then 34 | ch=0.9 35 | fi 36 | 37 | expname=${custom}-${mname}-${TASK_NAME}-sl${seqlen}-lr${lr}-bs${bsize}-ts${ts}-ws${ws}-wd${wd}-seed${seed}-beta${beta}-alpha${alpha}--cl${cl}-ch${ch}-alr${alr}-amag${amag}-anm${anorm}-as${asteps}-hdp${hdp}-adp${adp}-version${version} 38 | 39 | max=-1 40 | for file in ${expname}/checkpoint-* 41 | do 42 | fname=$(basename ${file}) 43 | num=${fname:11} 44 | [[ $num -gt $max ]] && max=$num 45 | done 46 | 47 | if [ $max -eq -1 ] ; then 48 | echo "Train from stratch" 49 | else 50 | FILE="${expname}/checkpoint-$max/eval_hist.bin" 51 | if [[ -f "$FILE" ]]; then 52 | echo "$FILE exists." 53 | mname="${expname}/checkpoint-$max" && echo "Resume Training from checkpoint $mname" 54 | else 55 | echo "$FILE does not exists." 56 | if [ $max -eq 100 ]; then 57 | max=-1 58 | echo "Train from stratch" 59 | else 60 | max=$(($max-100)) 61 | FILE="${expname}/checkpoint-$max/eval_hist.bin" 62 | echo "use $FILE instead." 63 | mname="${expname}/checkpoint-$max" && echo "Resume Training from checkpoint $mname" 64 | fi 65 | fi 66 | fi 67 | 68 | 69 | port=$(($RANDOM + 1024)) 70 | echo "Master port: ${port}" 71 | python -m torch.distributed.launch --nproc_per_node=1 --master_port ${port} ./run_anli.py \ 72 | --model_name_or_path ${mname} \ 73 | --task_name $TASK_NAME \ 74 | --do_train \ 75 | --do_eval \ 76 | --data_dir $GLUE_DIR \ 77 | --max_seq_length ${seqlen} \ 78 | --per_device_train_batch_size ${bsize} \ 79 | --learning_rate ${lr} \ 80 | --max_steps ${ts} \ 81 | --warmup_steps ${ws} \ 82 | --weight_decay ${wd} \ 83 | --seed ${seed} \ 84 | --beta ${beta} \ 85 | --logging_dir ${expname} \ 86 | --output_dir ${expname} \ 87 | --version ${version} --evaluate_during_training\ 88 | --logging_steps 500 --save_steps 500 \ 89 | --hidden_dropout_prob ${hdp} --attention_probs_dropout_prob ${adp} --overwrite_output_dir \ 90 | --adv_lr ${alr} --adv_init_mag ${amag} --adv_max_norm ${anorm} --adv_steps ${asteps} --alpha ${alpha} \ 91 | --cl ${cl} --ch ${ch} 92 | } 93 | 94 | 95 | 96 | 97 | function evalexp { 98 | #export NCCL_DEBUG=INFO 99 | #export NCCL_IB_CUDA_SUPPORT=0 100 | #export NCCL_P2P_DISABLE=0 101 | #export NCCL_IB_DISABLE=1 102 | #export NCCL_NET_GDR_LEVEL=3 103 | #export NCCL_NET_GDR_READ=0 104 | #export NCCL_SHM_DISABLE=0 105 | 106 | export GLUE_DIR=anli_data 107 | export TASK_NAME="anli-full" 108 | 109 | mname=${2} # Model name 110 | custom=${1} # Custom name 111 | lr=5e-3 # Learning rate for model parameters 112 | bsize=32 # Batch size 113 | seqlen=128 # Maximum sequence length 114 | ts=0 # Number of training steps (counted as parameter updates) 115 | ws=0 # Learning rate warm-up steps 116 | seed=42 # Seed for randomness 117 | wd=1e-5 # Weight decay 118 | beta=0 # regularizer coefficient 119 | version=3 # mi estimator version 120 | hdp=0 # Hidden layer dropouts for ALBERT 121 | adp=0 # Attention dropouts for ALBERT 122 | alr=0 # Step size of gradient ascent 123 | amag=0 # Magnitude of initial (adversarial?) perturbation 124 | anorm=0 # Maximum norm of adversarial perturbation 125 | asteps=1 # Number of gradient ascent steps for the adversary 126 | alpha=0 # alpha for controlling local robust regularizer 127 | cl=0 # lower threshold 128 | ch=0 # higher threshold 129 | 130 | expname=${custom}-load 131 | port=$(($RANDOM + 1024)) 132 | echo "Master port: ${port}" 133 | 134 | python -m torch.distributed.launch --nproc_per_node=1 --master_port ${port} ./run_anli.py \ 135 | --model_name_or_path ${mname} \ 136 | --task_name $TASK_NAME \ 137 | --do_eval \ 138 | --data_dir $GLUE_DIR \ 139 | --max_seq_length ${seqlen} \ 140 | --per_device_train_batch_size ${bsize} \ 141 | --learning_rate 3e-5 \ 142 | --max_steps ${ts} \ 143 | --warmup_steps ${ws} \ 144 | --weight_decay ${wd} \ 145 | --seed ${seed} \ 146 | --beta ${beta} \ 147 | --logging_dir ${expname} \ 148 | --output_dir ${expname} \ 149 | --version ${version} \ 150 | --logging_steps 100 --save_steps 100 \ 151 | --hidden_dropout_prob ${hdp} --attention_probs_dropout_prob ${adp} --evaluate_during_training \ 152 | --adv_lr ${alr} --adv_init_mag ${amag} --adv_max_norm ${anorm} --adv_steps ${asteps} --overwrite_output_dir \ 153 | --alpha ${alpha} --cl ${cl} --ch ${ch} 154 | 155 | tail -n +1 "${expname}"/eval_results*.txt > "${expname}"/eval_results.txt 156 | cat "${expname}"/eval_results.txt 157 | } 158 | 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InfoBERT: Improving Robustness of Language Models from An Information Theoretic Perspective 2 | 3 | This is the official code base for our ICLR 2021 paper: 4 | 5 | ["InfoBERT: Improving Robustness of Language Models from An Information Theoretic Perspective".](https://openreview.net/forum?id=hpH98mK5Puk) 6 | 7 | Boxin Wang, Shuohang Wang, Yu Cheng, Zhe Gan, Ruoxi Jia, Bo Li, Jingjing Liu 8 | 9 | ## Usage 10 | ### Prepare your environment 11 | 12 | Download required packages 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | ### ANLI and TextFooler 17 | To run ANLI and TextFooler experiments, refer to [README](https://github.com/AI-secure/InfoBERT/tree/master/ANLI) in the `ANLI` directory. 18 | 19 | ### SQuAD 20 | To run SQuAD experiments, refer to [README](https://github.com/AI-secure/InfoBERT/tree/master/SQuAD) in the `SQuAD` directory. 21 | 22 | ## Citation 23 | ``` 24 | @inproceedings{ 25 | wang2021infobert, 26 | title={InfoBERT: Improving Robustness of Language Models from An Information Theoretic Perspective}, 27 | author={Wang, Boxin and Wang, Shuohang and Cheng, Yu and Gan, Zhe and Jia, Ruoxi and Li, Bo and Liu, Jingjing}, 28 | booktitle={International Conference on Learning Representations}, 29 | year={2021}} 30 | ``` -------------------------------------------------------------------------------- /SQuAD/MI_estimators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from jutils import * 6 | 7 | ## cubic 8 | # lowersize = 40 9 | # hiddensize = 6 10 | 11 | ## Gaussian 12 | # lowersize = 20 13 | # hiddensize = 8 14 | 15 | ## club vs l1out 16 | lowersize = 40 17 | hiddensize = 8 18 | 19 | 20 | class CLUB(nn.Module): # CLUB: Mutual Information Contrastive Learning Upper Bound 21 | def __init__(self, x_dim, y_dim, lr=1e-3, beta=0): 22 | super(CLUB, self).__init__() 23 | self.hiddensize = y_dim 24 | self.version = 0 25 | self.p_mu = nn.Sequential(nn.Linear(x_dim, self.hiddensize), 26 | nn.ReLU(), 27 | nn.Linear(self.hiddensize, y_dim)) 28 | 29 | self.p_logvar = nn.Sequential(nn.Linear(x_dim, self.hiddensize), 30 | nn.ReLU(), 31 | nn.Linear(self.hiddensize, y_dim), 32 | nn.Tanh()) 33 | 34 | self.optimizer = torch.optim.Adam(self.parameters(), lr) 35 | self.beta = beta 36 | 37 | def get_mu_logvar(self, x_samples): 38 | mu = self.p_mu(x_samples) 39 | logvar = self.p_logvar(x_samples) 40 | return mu, logvar 41 | 42 | def mi_est_sample(self, x_samples, y_samples): 43 | mu, logvar = self.get_mu_logvar(x_samples) 44 | 45 | sample_size = x_samples.shape[0] 46 | random_index = torch.randint(sample_size, (sample_size,)).long() 47 | 48 | positive = - (mu - y_samples) ** 2 / 2. / logvar.exp() 49 | negative = - (mu - y_samples[random_index]) ** 2 / 2. / logvar.exp() 50 | upper_bound = (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 51 | # return upper_bound/2. 52 | return upper_bound 53 | 54 | def mi_est(self, x_samples, y_samples): # [nsample, 1] 55 | mu, logvar = self.get_mu_logvar(x_samples) 56 | 57 | positive = - (mu - y_samples) ** 2 / 2. / logvar.exp() 58 | 59 | prediction_1 = mu.unsqueeze(1) # [nsample,1,dim] 60 | y_samples_1 = y_samples.unsqueeze(0) # [1,nsample,dim] 61 | negative = - ((y_samples_1 - prediction_1) ** 2).mean(dim=1) / 2. / logvar.exp() # [nsample, dim] 62 | return (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 63 | # return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean(), positive.sum(dim = -1).mean(), negative.sum(dim = -1).mean() 64 | 65 | def loglikeli(self, x_samples, y_samples): 66 | mu, logvar = self.get_mu_logvar(x_samples) 67 | 68 | # return -1./2. * ((mu - y_samples)**2 /logvar.exp()-logvar ).sum(dim=1).mean(dim=0) 69 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) 70 | 71 | def update(self, x_samples, y_samples): 72 | if self.version == 0: 73 | self.train() 74 | loss = - self.loglikeli(x_samples, y_samples) 75 | 76 | self.optimizer.zero_grad() 77 | loss.backward(retain_graph=True) 78 | self.optimizer.step() 79 | 80 | # self.eval() 81 | return self.mi_est_sample(x_samples, y_samples) * self.beta 82 | 83 | elif self.version == 1: 84 | self.train() 85 | x_samples = torch.reshape(x_samples, (-1, x_samples.shape[-1])) 86 | y_samples = torch.reshape(y_samples, (-1, y_samples.shape[-1])) 87 | 88 | loss = -self.loglikeli(x_samples, y_samples) 89 | 90 | self.optimizer.zero_grad() 91 | loss.backward(retain_graph=True) 92 | self.optimizer.step() 93 | upper_bound = self.mi_est_sample(x_samples, y_samples) * self.beta 94 | # self.eval() 95 | return upper_bound 96 | 97 | 98 | class CLUBv2(nn.Module): # CLUB: Mutual Information Contrastive Learning Upper Bound 99 | def __init__(self, x_dim, y_dim, lr=1e-3, beta=0): 100 | super(CLUBv2, self).__init__() 101 | self.hiddensize = y_dim 102 | self.version = 2 103 | self.beta = beta 104 | 105 | def mi_est_sample(self, x_samples, y_samples): 106 | sample_size = y_samples.shape[0] 107 | random_index = torch.randint(sample_size, (sample_size,)).long() 108 | 109 | positive = torch.zeros_like(y_samples) 110 | negative = - (y_samples - y_samples[random_index]) ** 2 / 2. 111 | upper_bound = (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 112 | # return upper_bound/2. 113 | return upper_bound 114 | 115 | def mi_est(self, x_samples, y_samples): # [nsample, 1] 116 | positive = torch.zeros_like(y_samples) 117 | 118 | prediction_1 = y_samples.unsqueeze(1) # [nsample,1,dim] 119 | y_samples_1 = y_samples.unsqueeze(0) # [1,nsample,dim] 120 | negative = - ((y_samples_1 - prediction_1) ** 2).mean(dim=1) / 2. # [nsample, dim] 121 | return (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() 122 | # return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean(), positive.sum(dim = -1).mean(), negative.sum(dim = -1).mean() 123 | 124 | def loglikeli(self, x_samples, y_samples): 125 | return 0 126 | 127 | def update(self, x_samples, y_samples, steps=None): 128 | # no performance improvement, not enabled 129 | if steps: 130 | beta = self.beta if steps > 1000 else self.beta * steps / 1000 # beta anealing 131 | else: 132 | beta = self.beta 133 | 134 | return self.mi_est_sample(x_samples, y_samples) * self.beta 135 | 136 | 137 | 138 | class InfoNCE(nn.Module): 139 | def __init__(self, x_dim, y_dim): 140 | super(InfoNCE, self).__init__() 141 | self.lower_size = 300 142 | self.F_func = nn.Sequential(nn.Linear(x_dim + y_dim, self.lower_size), 143 | nn.ReLU(), 144 | nn.Linear(self.lower_size, 1), 145 | nn.Softplus()) 146 | 147 | def forward(self, x_samples, y_samples): # samples have shape [sample_size, dim] 148 | # shuffle and concatenate 149 | sample_size = y_samples.shape[0] 150 | random_index = torch.randint(sample_size, (sample_size,)).long() 151 | 152 | x_tile = x_samples.unsqueeze(0).repeat((sample_size, 1, 1)) 153 | y_tile = y_samples.unsqueeze(1).repeat((1, sample_size, 1)) 154 | 155 | T0 = self.F_func(torch.cat([x_samples, y_samples], dim=-1)) 156 | T1 = self.F_func(torch.cat([x_tile, y_tile], dim=-1)) # [s_size, s_size, 1] 157 | 158 | lower_bound = T0.mean() - ( 159 | T1.logsumexp(dim=1).mean() - np.log(sample_size)) # torch.log(T1.exp().mean(dim = 1)).mean() 160 | 161 | # compute the negative loss (maximise loss == minimise -loss) 162 | return lower_bound 163 | 164 | 165 | class NWJ(nn.Module): 166 | def __init__(self, x_dim, y_dim): 167 | super(NWJ, self).__init__() 168 | self.F_func = nn.Sequential(nn.Linear(x_dim + y_dim, lowersize), 169 | nn.ReLU(), 170 | nn.Linear(lowersize, 1)) 171 | 172 | def mi_est(self, x_samples, y_samples): # samples have shape [sample_size, dim] 173 | # shuffle and concatenate 174 | sample_size = y_samples.shape[0] 175 | # random_index = torch.randint(sample_size, (sample_size,)).long() 176 | 177 | x_tile = x_samples.unsqueeze(0).repeat((sample_size, 1, 1)) 178 | y_tile = y_samples.unsqueeze(1).repeat((1, sample_size, 1)) 179 | 180 | T0 = self.F_func(torch.cat([x_samples, y_samples], dim=-1)) 181 | T1 = self.F_func(torch.cat([x_tile, y_tile], dim=-1)) - 1. # [s_size, s_size, 1] 182 | 183 | lower_bound = T0.mean() - (T1.logsumexp(dim=1) - np.log(sample_size)).exp().mean() 184 | return lower_bound 185 | 186 | 187 | 188 | class L1OutUB(nn.Module): # naive upper bound 189 | def __init__(self, x_dim, y_dim): 190 | super(L1OutUB, self).__init__() 191 | self.p_mu = nn.Sequential(nn.Linear(x_dim, hiddensize), 192 | nn.ReLU(), 193 | nn.Linear(hiddensize, y_dim)) 194 | 195 | self.p_logvar = nn.Sequential(nn.Linear(x_dim, hiddensize), 196 | nn.ReLU(), 197 | nn.Linear(hiddensize, y_dim), 198 | nn.Tanh()) 199 | 200 | def get_mu_logvar(self, x_samples): 201 | mu = self.p_mu(x_samples) 202 | logvar = self.p_logvar(x_samples) 203 | return mu, logvar 204 | 205 | def mi_est(self, x_samples, y_samples): # [nsample, 1] 206 | batch_size = y_samples.shape[0] 207 | mu, logvar = self.get_mu_logvar(x_samples) 208 | 209 | positive = (- (mu - y_samples) ** 2 / 2. / logvar.exp() - logvar / 2.).sum(dim=-1) # [nsample] 210 | 211 | mu_1 = mu.unsqueeze(1) # [nsample,1,dim] 212 | logvar_1 = logvar.unsqueeze(1) 213 | y_samples_1 = y_samples.unsqueeze(0) # [1,nsample,dim] 214 | all_probs = (- (y_samples_1 - mu_1) ** 2 / 2. / logvar_1.exp() - logvar_1 / 2.).sum( 215 | dim=-1) # [nsample, nsample] 216 | 217 | # diag_mask = torch.ones([batch_size, batch_size,1]).cuda() - torch.ones([batch_size]).diag().unsqueeze(-1).cuda() 218 | diag_mask = torch.ones([batch_size]).diag().unsqueeze(-1).cuda() * (-20.) 219 | 220 | # negative = (all_probs + diag_mask).logsumexp(dim = 0) - np.log(y_samples.shape[0]-1.) #[nsample] 221 | inpt = all_probs + diag_mask 222 | negative = log_sum_exp(all_probs + diag_mask, dim=0) - np.log(y_samples.shape[0] - 1.) # [nsample] 223 | return (positive - negative).mean() 224 | 225 | def loglikeli(self, x_samples, y_samples): 226 | mu, logvar = self.get_mu_logvar(x_samples) 227 | # return -1./2. * ((mu - y_samples)**2 /logvar.exp()-logvar ).sum(dim=1).mean(dim=0) 228 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) 229 | 230 | 231 | class VarUB(nn.Module): # variational upper bound 232 | def __init__(self, x_dim, y_dim): 233 | super(VarUB, self).__init__() 234 | self.p_mu = nn.Sequential(nn.Linear(x_dim, hiddensize), 235 | nn.ReLU(), 236 | nn.Linear(hiddensize, y_dim)) 237 | 238 | self.p_logvar = nn.Sequential(nn.Linear(x_dim, hiddensize), 239 | nn.ReLU(), 240 | nn.Linear(hiddensize, y_dim), 241 | nn.Tanh()) 242 | 243 | def get_mu_logvar(self, x_samples): 244 | mu = self.p_mu(x_samples) 245 | logvar = self.p_logvar(x_samples) 246 | return mu, logvar 247 | 248 | def mi_est(self, x_samples, y_samples): # [nsample, 1] 249 | mu, logvar = self.get_mu_logvar(x_samples) 250 | return 1. / 2. * (mu ** 2 + logvar.exp() - 1. - logvar).mean() 251 | 252 | def loglikeli(self, x_samples, y_samples): 253 | mu, logvar = self.get_mu_logvar(x_samples) 254 | # return -1./2. * ((mu - y_samples)**2 /logvar.exp()-logvar ).sum(dim=1).mean(dim=0) 255 | return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) 256 | -------------------------------------------------------------------------------- /SQuAD/README.md: -------------------------------------------------------------------------------- 1 | # InfoBERT on SQuAD 2 | 3 | ## Train 4 | 5 | Running standard SQuAD training: 6 | ```bash 7 | # [runstandardsquad] [custom] [gpu] [bszie] [beta] [version] [mname] 8 | source setup.sh && runstandardsquad squad-s 4 8 0 3 roberta-large 9 | ``` 10 | 11 | Running InfoBERT for SQuAD: 12 | 13 | ```bash 14 | ## runsquad [custom] [gpu] [bszie] [beta] [version] [hdp] [adp] [alr] [amag] [anorm] [asteps] [mname] [alpha] [cl] [ch] 15 | source setup.sh && runsquad roberta 4 8 5e-5 6 0.1 0.1 2e-2 2e-2 4e-2 2 roberta-large 5e-3 0.75 0.95 16 | ``` 17 | 18 | ## Evaluate 19 | 20 | Run standard evaluation for SQuAD models 21 | 22 | ```bash 23 | ## [evalsquad] [custom] [gpu] [bszie] [beta] [version] [hdp] [adp] [alr] [amag] [anorm] [asteps] [mname] 24 | #source setup.sh && evalsquad squad 4 10 0 3 0.1 0 4e-2 8e-2 0 3 squad-bert-large-uncased-whole-word-masking-sl384-lr3e-5-bs6-beta0-alr4e-2-amag8e-2-anm0-as3-hdp0.1-adp0-version3 25 | ``` -------------------------------------------------------------------------------- /SQuAD/eval_adv_squad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | """Do an adversarial evaluation of SQuAD.""" 4 | import argparse 5 | from collections import Counter, OrderedDict, defaultdict 6 | import json 7 | import re 8 | import string 9 | import sys 10 | 11 | OPTS = None 12 | 13 | ### BEGIN: official SQuAD code version 1.1 14 | ### See https://rajpurkar.github.io/SQuAD-explorer/ 15 | def normalize_answer(s): 16 | """Lower text and remove punctuation, articles and extra whitespace.""" 17 | def remove_articles(text): 18 | return re.sub(r'\b(a|an|the)\b', ' ', text) 19 | 20 | def white_space_fix(text): 21 | return ' '.join(text.split()) 22 | 23 | def remove_punc(text): 24 | exclude = set(string.punctuation) 25 | return ''.join(ch for ch in text if ch not in exclude) 26 | 27 | def lower(text): 28 | return text.lower() 29 | 30 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 31 | 32 | 33 | def f1_score(prediction, ground_truth): 34 | prediction_tokens = normalize_answer(prediction).split() 35 | ground_truth_tokens = normalize_answer(ground_truth).split() 36 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 37 | num_same = sum(common.values()) 38 | if num_same == 0: 39 | return 0 40 | precision = 1.0 * num_same / len(prediction_tokens) 41 | recall = 1.0 * num_same / len(ground_truth_tokens) 42 | f1 = (2 * precision * recall) / (precision + recall) 43 | return f1 44 | 45 | 46 | def exact_match_score(prediction, ground_truth): 47 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 48 | 49 | 50 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 51 | scores_for_ground_truths = [] 52 | for ground_truth in ground_truths: 53 | score = metric_fn(prediction, ground_truth) 54 | scores_for_ground_truths.append(score) 55 | return max(scores_for_ground_truths) 56 | ### END: official SQuAD code 57 | 58 | def strip_id(id_str): 59 | return id_str.split('-')[0] 60 | 61 | def highlight_after(s, n): 62 | return s[:n] + colored(s[n:], 'cyan') 63 | 64 | def get_answer_color(pred, answers): 65 | ans_texts = [a['text'] for a in answers] 66 | exact = metric_max_over_ground_truths(exact_match_score, pred, ans_texts) 67 | if exact: return 'green' 68 | f1 = metric_max_over_ground_truths(f1_score, pred, ans_texts) 69 | if f1: return 'yellow' 70 | return 'red' 71 | 72 | 73 | def print_details(dataset, predictions, adv_ids): 74 | id_to_paragraph = {} 75 | for article in dataset: 76 | for paragraph in article['paragraphs']: 77 | for qa in paragraph['qas']: 78 | id_to_paragraph[qa['id']] = paragraph['context'] 79 | for article in dataset: 80 | for paragraph in article['paragraphs']: 81 | for qa in paragraph['qas']: 82 | orig_id = strip_id(qa['id']) 83 | if orig_id != qa['id']: continue # Skip the mutated ones 84 | adv_id = adv_ids[orig_id] 85 | print 'Title: %s' % article['title'].encode('utf-8') 86 | print 'Paragraph: %s' % paragraph['context'].encode('utf-8') 87 | print 'Question: %s' % qa['question'].encode('utf-8') 88 | print 'Answers: [%s]' % ', '.join(a['text'].encode('utf-8') 89 | for a in qa['answers']) 90 | orig_color = get_answer_color(predictions[orig_id], qa['answers']) 91 | print 'Predicted: %s' % colored( 92 | predictions[orig_id], orig_color).encode('utf-8') 93 | print 'Adversary succeeded?: %s' % (adv_id != orig_id) 94 | if adv_id != orig_id: 95 | print 'Adversarial Paragraph: %s' % highlight_after( 96 | id_to_paragraph[adv_id], len(paragraph['context'])).encode('utf-8') 97 | # highlight_after is a hack that only works when mutations append stuff. 98 | adv_color = get_answer_color(predictions[adv_id], qa['answers']) 99 | print 'Prediction under Adversary: %s' % colored( 100 | predictions[adv_id], adv_color).encode('utf-8') 101 | print 102 | 103 | 104 | 105 | def evaluate_adversarial(dataset, predictions, verbose=False, id_set=None): 106 | orig_f1_score = 0.0 107 | orig_exact_match_score = 0.0 108 | adv_f1_scores = {} # Map from original ID to F1 score 109 | adv_exact_match_scores = {} # Map from original ID to exact match score 110 | adv_ids = {} 111 | all_ids = set() # Set of all original IDs 112 | f1 = exact_match = 0 113 | for article in dataset: 114 | for paragraph in article['paragraphs']: 115 | for qa in paragraph['qas']: 116 | orig_id = qa['id'].split('-')[0] 117 | if id_set and orig_id not in id_set: continue 118 | all_ids.add(orig_id) 119 | if qa['id'] not in predictions: 120 | message = 'Unanswered question ' + qa['id'] + ' will receive score 0.' 121 | print >> sys.stderr, message 122 | continue 123 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 124 | prediction = predictions[qa['id']] 125 | cur_exact_match = metric_max_over_ground_truths(exact_match_score, 126 | prediction, ground_truths) 127 | cur_f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths) 128 | if orig_id == qa['id']: 129 | # This is an original example 130 | orig_f1_score += cur_f1 131 | orig_exact_match_score += cur_exact_match 132 | if orig_id not in adv_f1_scores: 133 | # Haven't seen adversarial example yet, so use original for adversary 134 | adv_ids[orig_id] = orig_id 135 | adv_f1_scores[orig_id] = cur_f1 136 | adv_exact_match_scores[orig_id] = cur_exact_match 137 | else: 138 | # This is an adversarial example 139 | if (orig_id not in adv_f1_scores or adv_ids[orig_id] == orig_id 140 | or adv_f1_scores[orig_id] > cur_f1): 141 | # Always override if currently adversary currently using orig_id 142 | adv_ids[orig_id] = qa['id'] 143 | adv_f1_scores[orig_id] = cur_f1 144 | adv_exact_match_scores[orig_id] = cur_exact_match 145 | if verbose: 146 | print_details(dataset, predictions, adv_ids) 147 | orig_f1 = 100.0 * orig_f1_score / len(all_ids) 148 | orig_exact_match = 100.0 * orig_exact_match_score / len(all_ids) 149 | adv_exact_match = 100.0 * sum(adv_exact_match_scores.values()) / len(all_ids) 150 | adv_f1 = 100.0 * sum(adv_f1_scores.values()) / len(all_ids) 151 | return OrderedDict([ 152 | ('orig_exact_match', orig_exact_match), 153 | ('orig_f1', orig_f1), 154 | ('adv_exact_match', adv_exact_match), 155 | ('adv_f1', adv_f1), 156 | ]) 157 | 158 | def split_by_attempted(dataset): 159 | all_ids = set() 160 | attempted_ids = set() 161 | for article in dataset: 162 | for paragraph in article['paragraphs']: 163 | for qa in paragraph['qas']: 164 | orig_id = qa['id'].split('-')[0] 165 | all_ids.add(orig_id) 166 | if orig_id != qa['id']: 167 | attempted_ids.add(orig_id) 168 | not_attempted_ids = all_ids - attempted_ids 169 | return attempted_ids, not_attempted_ids 170 | 171 | def evaluate_by_attempted(dataset, predictions): 172 | attempted, not_attempted = split_by_attempted(dataset) 173 | total_num = len(attempted) + len(not_attempted) 174 | results_attempted = evaluate_adversarial(dataset, predictions, 175 | id_set=attempted) 176 | print 'Attempted %d/%d = %.2f%%' % ( 177 | len(attempted), total_num, 100.0 * len(attempted) / total_num) 178 | print json.dumps(results_attempted) 179 | results_not_attempted = evaluate_adversarial(dataset, predictions, 180 | id_set=not_attempted) 181 | print 'Did not attempt %d/%d = %.2f%%' % ( 182 | len(not_attempted), total_num, 100.0 * len(not_attempted) / total_num) 183 | print json.dumps(results_not_attempted) 184 | 185 | 186 | if __name__ == '__main__': 187 | expected_version = '1.1' 188 | parser = argparse.ArgumentParser( 189 | description='Adverarial evaluation for SQuAD ' + expected_version) 190 | parser.add_argument('dataset_file', help='Dataset file') 191 | parser.add_argument('prediction_file', help='Prediction File') 192 | parser.add_argument('--out-file', '-o', default=None, 193 | help='Write JSON output to this file (default is stdout).') 194 | parser.add_argument('--verbose', '-v', default=False, action='store_true', 195 | help='Enable verbose logging.') 196 | parser.add_argument('--split-by-attempted', default=False, action='store_true', 197 | help='Split by whether adversary attempted the example.') 198 | args = parser.parse_args() 199 | with open(args.dataset_file) as dataset_file: 200 | dataset_json = json.load(dataset_file) 201 | # if (dataset_json['version'] != expected_version): 202 | # print >> sys.stderr, ( 203 | # 'Evaluation expects v-' + expected_version + 204 | # ', but got dataset with v-' + dataset_json['version']) 205 | dataset = dataset_json['data'] 206 | with open(args.prediction_file) as prediction_file: 207 | predictions = json.load(prediction_file) 208 | if args.verbose: 209 | from termcolor import colored 210 | results = evaluate_adversarial(dataset, predictions, verbose=args.verbose) 211 | if args.out_file: 212 | with open(args.out_file, 'wb') as f: 213 | json.dump(results, f) 214 | else: 215 | print json.dumps(results) 216 | if args.split_by_attempted: 217 | evaluate_by_attempted(dataset, predictions) -------------------------------------------------------------------------------- /SQuAD/jutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | 8 | def sample_correlated_gaussian(rho=0.5, dim=20, batch_size=128, cubic=None): 9 | """Generate samples from a correlated Gaussian distribution.""" 10 | x, eps = torch.chunk(torch.randn(batch_size, 2 * dim), 2, dim=1) 11 | y = rho * x + torch.sqrt(torch.tensor(1. - rho**2).float()) * eps 12 | 13 | if cubic is not None: 14 | y = y ** 3 15 | 16 | return x, y 17 | 18 | 19 | def rho_to_mi(dim, rho): 20 | return -0.5 * np.log(1-rho**2) * dim 21 | 22 | 23 | def mi_to_rho(dim, mi): 24 | return np.sqrt(1-np.exp(-2.0 / dim * mi)) 25 | 26 | 27 | def mi_schedule(n_iter): 28 | """Generate schedule for increasing correlation over time.""" 29 | mis = np.round(np.linspace(0.5, 5.5-1e-9, n_iter)) * 2.0 30 | return mis.astype(np.float32) 31 | 32 | 33 | def logmeanexp_diag(x): 34 | batch_size = x.size(0) 35 | 36 | logsumexp = torch.logsumexp(x.diag(), dim=(0,)) 37 | num_elem = batch_size 38 | 39 | return logsumexp - torch.log(torch.tensor(num_elem).float()).cuda() 40 | 41 | 42 | def logmeanexp_nodiag(x, dim=None, device='cuda'): 43 | batch_size = x.size(0) 44 | if dim is None: 45 | dim = (0, 1) 46 | 47 | logsumexp = torch.logsumexp( 48 | x - torch.diag(np.inf * torch.ones(batch_size).to(device)), dim=dim) 49 | 50 | try: 51 | if len(dim) == 1: 52 | num_elem = batch_size - 1. 53 | else: 54 | num_elem = batch_size * (batch_size - 1.) 55 | except: 56 | num_elem = batch_size - 1 57 | return logsumexp - torch.log(torch.tensor(num_elem)).to(device) 58 | 59 | 60 | def tuba_lower_bound(scores, log_baseline=None): 61 | if log_baseline is not None: 62 | scores -= log_baseline[:, None] 63 | batch_size = scores.size(0) 64 | 65 | # First term is an expectation over samples from the joint, 66 | # which are the diagonal elmements of the scores matrix. 67 | joint_term = scores.diag().mean() 68 | 69 | # Second term is an expectation over samples from the marginal, 70 | # which are the off-diagonal elements of the scores matrix. 71 | marg_term = logmeanexp_nodiag(scores).exp() 72 | return 1. + joint_term - marg_term 73 | 74 | 75 | def nwj_lower_bound(scores): 76 | return tuba_lower_bound(scores - 1.) 77 | 78 | 79 | def infonce_lower_bound(scores): 80 | nll = scores.diag().mean() - scores.logsumexp(dim=1) 81 | # Alternative implementation: 82 | # nll = -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=tf.range(batch_size)) 83 | mi = torch.tensor(scores.size(0)).float().log() + nll 84 | mi = mi.mean() 85 | return mi 86 | 87 | 88 | def js_fgan_lower_bound(f): 89 | """Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016).""" 90 | f_diag = f.diag() 91 | first_term = -F.softplus(-f_diag).mean() 92 | n = f.size(0) 93 | second_term = (torch.sum(F.softplus(f)) - 94 | torch.sum(F.softplus(f_diag))) / (n * (n - 1.)) 95 | return first_term - second_term 96 | 97 | 98 | def js_lower_bound(f): 99 | nwj = nwj_lower_bound(f) 100 | js = js_fgan_lower_bound(f) 101 | 102 | with torch.no_grad(): 103 | nwj_js = nwj - js 104 | 105 | return js + nwj_js 106 | 107 | 108 | def dv_upper_lower_bound(f): 109 | """DV lower bound, but upper bounded by using log outside.""" 110 | first_term = f.diag().mean() 111 | second_term = logmeanexp_nodiag(f) 112 | 113 | return first_term - second_term 114 | 115 | 116 | def mine_lower_bound(f, buffer=None, momentum=0.9): 117 | if buffer is None: 118 | buffer = torch.tensor(1.0).cuda() 119 | first_term = f.diag().mean() 120 | 121 | buffer_update = logmeanexp_nodiag(f).exp() 122 | with torch.no_grad(): 123 | second_term = logmeanexp_nodiag(f) 124 | buffer_new = buffer * momentum + buffer_update * (1 - momentum) 125 | buffer_new = torch.clamp(buffer_new, min=1e-4) 126 | third_term_no_grad = buffer_update / buffer_new 127 | 128 | third_term_grad = buffer_update / buffer_new 129 | 130 | return first_term - second_term - third_term_grad + third_term_no_grad, buffer_update 131 | 132 | 133 | def regularized_dv_bound(f, l=0.0): 134 | first_term = f.diag().mean() 135 | second_term = logmeanexp_nodiag(f) 136 | 137 | reg_term = l * (second_term.exp() - 1) ** 2 138 | 139 | with torch.no_grad(): 140 | reg_term_no_grad = reg_term 141 | 142 | return first_term - second_term + reg_term - reg_term_no_grad 143 | 144 | 145 | def renorm_q(f, alpha=1.0, clip=None): 146 | if clip is not None: 147 | f = torch.clamp(f * alpha, -clip, clip) 148 | z = logmeanexp_nodiag(f * alpha, dim=(0, 1)) 149 | return z 150 | 151 | 152 | def disc_renorm_q(f): 153 | batch_size = f.size(0) 154 | z = torch.zeros(1, requires_grad=True, device='cuda') 155 | 156 | opt = optim.SGD([z], lr=0.001) 157 | for i in range(10): 158 | opt.zero_grad() 159 | 160 | first_term = -F.softplus(z - f).diag().mean() 161 | st = -F.softplus(f - z) 162 | second_term = (st - st.diag().diag()).sum() / \ 163 | (batch_size * (batch_size - 1.)) 164 | total = first_term + second_term 165 | 166 | total.backward(retain_graph=True) 167 | opt.step() 168 | 169 | if total.item() <= -2 * np.log(2): 170 | break 171 | 172 | return z 173 | 174 | 175 | def renorm_p(f, alpha=1.0): 176 | z = logmeanexp_diag(-f * alpha) 177 | return z 178 | 179 | 180 | def smile_lower_bound(f, alpha=1.0, clip=None): 181 | z = renorm_q(f, alpha, clip) 182 | dv = f.diag().mean() - z 183 | 184 | js = js_fgan_lower_bound(f) 185 | 186 | with torch.no_grad(): 187 | dv_js = dv - js 188 | 189 | return js + dv_js 190 | 191 | 192 | def js_dv_disc_renorm_lower_bound(f): 193 | z = disc_renorm_q(f) 194 | dv = f.diag().mean() - z.mean() 195 | 196 | js = js_fgan_lower_bound(f) 197 | 198 | with torch.no_grad(): 199 | dv_js = dv - js 200 | 201 | return js + dv_js 202 | 203 | 204 | def vae_lower_bound(f): 205 | f1, f2 = f 206 | n = f1.size(0) 207 | logp = f1.mean() 208 | logq = (f2.sum() - f2.diag().sum()) / (n * (n-1.)) 209 | 210 | with torch.no_grad(): 211 | logq_nograd = logq * 1.0 212 | logqd = f2.diag().mean() 213 | logp_nograd = logp 214 | 215 | return logp - logqd + logq - logq_nograd # logp - logqd + logq - logq_nograd 216 | 217 | 218 | def js_nwj_renorm_lower_bound(f, alpha=1.0): 219 | z = renorm_q(f - 1.0, alpha) 220 | 221 | nwj = nwj_lower_bound(f - z) 222 | js = js_fgan_lower_bound(f) 223 | 224 | with torch.no_grad(): 225 | nwj_js = nwj - js 226 | 227 | return js + nwj_js 228 | 229 | 230 | def estimate_p_norm(f, alpha=1.0): 231 | z = renorm_q(f, alpha) 232 | # f = renorm_p(f, alpha) 233 | # f = renorm_q(f, alpha) 234 | f = f - z 235 | f = -f 236 | 237 | return f.diag().exp().mean() 238 | 239 | 240 | def estimate_mutual_information(estimator, x, y, critic_fn, 241 | baseline_fn=None, alpha_logit=None, **kwargs): 242 | """Estimate variational lower bounds on mutual information. 243 | 244 | Args: 245 | estimator: string specifying estimator, one of: 246 | 'nwj', 'infonce', 'tuba', 'js', 'interpolated' 247 | x: [batch_size, dim_x] Tensor 248 | y: [batch_size, dim_y] Tensor 249 | critic_fn: callable that takes x and y as input and outputs critic scores 250 | output shape is a [batch_size, batch_size] matrix 251 | baseline_fn (optional): callable that takes y as input 252 | outputs a [batch_size] or [batch_size, 1] vector 253 | alpha_logit (optional): logit(alpha) for interpolated bound 254 | 255 | Returns: 256 | scalar estimate of mutual information 257 | """ 258 | x, y = x.cuda(), y.cuda() 259 | scores = critic_fn(x, y) 260 | if baseline_fn is not None: 261 | # Some baselines' output is (batch_size, 1) which we remove here. 262 | log_baseline = torch.squeeze(baseline_fn(y)) 263 | if estimator == 'infonce': 264 | mi = infonce_lower_bound(scores) 265 | elif estimator == 'nwj': 266 | mi = nwj_lower_bound(scores) 267 | elif estimator == 'tuba': 268 | mi = tuba_lower_bound(scores, log_baseline) 269 | elif estimator == 'js': 270 | mi = js_lower_bound(scores) 271 | elif estimator == 'smile': 272 | mi = smile_lower_bound(scores, **kwargs) 273 | elif estimator == 'dv_disc_normalized': 274 | mi = js_dv_disc_renorm_lower_bound(scores, **kwargs) 275 | elif estimator == 'nwj_normalized': 276 | mi = js_nwj_renorm_lower_bound(scores, **kwargs) 277 | elif estimator == 'dv': 278 | mi = dv_upper_lower_bound(scores) 279 | # p_norm = estimate_p_norm(scores * kwargs.get('alpha', 1.0)) 280 | if estimator is not 'smile': 281 | p_norm = renorm_q(scores) 282 | else: 283 | p_norm = renorm_q(scores, alpha=kwargs.get( 284 | 'alpha', 1.0), clip=kwargs.get('clip', None)) 285 | return mi, p_norm 286 | 287 | 288 | def mlp(dim, hidden_dim, output_dim, layers, activation): 289 | activation = { 290 | 'relu': nn.ReLU 291 | }[activation] 292 | 293 | seq = [nn.Linear(dim, hidden_dim), activation()] 294 | for _ in range(layers): 295 | seq += [nn.Linear(hidden_dim, hidden_dim), activation()] 296 | seq += [nn.Linear(hidden_dim, output_dim)] 297 | 298 | return nn.Sequential(*seq) 299 | 300 | 301 | class SeparableCritic(nn.Module): 302 | def __init__(self, dim, hidden_dim, embed_dim, layers, activation, **extra_kwargs): 303 | super(SeparableCritic, self).__init__() 304 | self._g = mlp(dim, hidden_dim, embed_dim, layers, activation) 305 | self._h = mlp(dim, hidden_dim, embed_dim, layers, activation) 306 | 307 | def forward(self, x, y): 308 | scores = torch.matmul(self._h(y), self._g(x).t()) 309 | return scores 310 | 311 | 312 | class ConcatCritic(nn.Module): 313 | def __init__(self, dim, hidden_dim, layers, activation, **extra_kwargs): 314 | super(ConcatCritic, self).__init__() 315 | # output is scalar score 316 | self._f = mlp(dim * 2, hidden_dim, 1, layers, activation) 317 | 318 | def forward(self, x, y): 319 | batch_size = x.size(0) 320 | # Tile all possible combinations of x and y 321 | x_tiled = torch.stack([x] * batch_size, dim=0) 322 | y_tiled = torch.stack([y] * batch_size, dim=1) 323 | # xy is [batch_size * batch_size, x_dim + y_dim] 324 | xy_pairs = torch.reshape(torch.cat((x_tiled, y_tiled), dim=2), [ 325 | batch_size * batch_size, -1]) 326 | # Compute scores for each x_i, y_j pair. 327 | scores = self._f(xy_pairs) 328 | return torch.reshape(scores, [batch_size, batch_size]).t() 329 | 330 | 331 | def log_prob_gaussian(x): 332 | return torch.sum(torch.distributions.Normal(0., 1.).log_prob(x), -1) 333 | 334 | 335 | dim = 20 336 | 337 | 338 | CRITICS = { 339 | 'separable': SeparableCritic, 340 | 'concat': ConcatCritic, 341 | } 342 | 343 | BASELINES = { 344 | 'constant': lambda: None, 345 | 'unnormalized': lambda: mlp(dim=dim, hidden_dim=512, output_dim=1, layers=2, activation='relu').cuda(), 346 | 'gaussian': lambda: log_prob_gaussian, 347 | } 348 | 349 | 350 | def train_estimator(critic_params, data_params, mi_params, opt_params, **kwargs): 351 | """Main training loop that estimates time-varying MI.""" 352 | # Ground truth rho is only used by conditional critic 353 | critic = CRITICS[mi_params.get('critic', 'separable')]( 354 | rho=None, **critic_params).cuda() 355 | baseline = BASELINES[mi_params.get('baseline', 'constant')]() 356 | 357 | opt_crit = optim.Adam(critic.parameters(), lr=opt_params['learning_rate']) 358 | if isinstance(baseline, nn.Module): 359 | opt_base = optim.Adam(baseline.parameters(), 360 | lr=opt_params['learning_rate']) 361 | else: 362 | opt_base = None 363 | 364 | def train_step(rho, data_params, mi_params): 365 | # Annoying special case: 366 | # For the true conditional, the critic depends on the true correlation rho, 367 | # so we rebuild the critic at each iteration. 368 | opt_crit.zero_grad() 369 | if isinstance(baseline, nn.Module): 370 | opt_base.zero_grad() 371 | 372 | if mi_params['critic'] == 'conditional': 373 | critic_ = CRITICS['conditional'](rho=rho).cuda() 374 | else: 375 | critic_ = critic 376 | 377 | x, y = sample_correlated_gaussian( 378 | dim=data_params['dim'], rho=rho, batch_size=data_params['batch_size'], cubic=data_params['cubic']) 379 | mi, p_norm = estimate_mutual_information( 380 | mi_params['estimator'], x, y, critic_, baseline, mi_params.get('alpha_logit', None), **kwargs) 381 | loss = -mi 382 | 383 | loss.backward() 384 | opt_crit.step() 385 | if isinstance(baseline, nn.Module): 386 | opt_base.step() 387 | 388 | return mi, p_norm 389 | 390 | # Schedule of correlation over iterations 391 | mis = mi_schedule(opt_params['iterations']) 392 | rhos = mi_to_rho(data_params['dim'], mis) 393 | 394 | estimates = [] 395 | p_norms = [] 396 | for i in range(opt_params['iterations']): 397 | mi, p_norm = train_step( 398 | rhos[i], data_params, mi_params) 399 | mi = mi.detach().cpu().numpy() 400 | p_norm = p_norm.detach().cpu().numpy() 401 | estimates.append(mi) 402 | p_norms.append(p_norm) 403 | 404 | return np.array(estimates), np.array(p_norms) 405 | -------------------------------------------------------------------------------- /SQuAD/processors/squad.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from multiprocessing import Pool, cpu_count 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from transformers.file_utils import is_tf_available, is_torch_available 11 | from transformers.tokenization_bert import whitespace_tokenize 12 | from transformers.data.processors.utils import DataProcessor 13 | 14 | 15 | if is_torch_available(): 16 | import torch 17 | from torch.utils.data import TensorDataset 18 | 19 | if is_tf_available(): 20 | import tensorflow as tf 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): 26 | """Returns tokenized answer spans that better match the annotated answer.""" 27 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 28 | 29 | for new_start in range(input_start, input_end + 1): 30 | for new_end in range(input_end, new_start - 1, -1): 31 | text_span = " ".join(doc_tokens[new_start : (new_end + 1)]) 32 | if text_span == tok_answer_text: 33 | return (new_start, new_end) 34 | 35 | return (input_start, input_end) 36 | 37 | 38 | def _check_is_max_context(doc_spans, cur_span_index, position): 39 | """Check if this is the 'max context' doc span for the token.""" 40 | best_score = None 41 | best_span_index = None 42 | for (span_index, doc_span) in enumerate(doc_spans): 43 | end = doc_span.start + doc_span.length - 1 44 | if position < doc_span.start: 45 | continue 46 | if position > end: 47 | continue 48 | num_left_context = position - doc_span.start 49 | num_right_context = end - position 50 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 51 | if best_score is None or score > best_score: 52 | best_score = score 53 | best_span_index = span_index 54 | 55 | return cur_span_index == best_span_index 56 | 57 | 58 | def _new_check_is_max_context(doc_spans, cur_span_index, position): 59 | """Check if this is the 'max context' doc span for the token.""" 60 | # if len(doc_spans) == 1: 61 | # return True 62 | best_score = None 63 | best_span_index = None 64 | for (span_index, doc_span) in enumerate(doc_spans): 65 | end = doc_span["start"] + doc_span["length"] - 1 66 | if position < doc_span["start"]: 67 | continue 68 | if position > end: 69 | continue 70 | num_left_context = position - doc_span["start"] 71 | num_right_context = end - position 72 | score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"] 73 | if best_score is None or score > best_score: 74 | best_score = score 75 | best_span_index = span_index 76 | 77 | return cur_span_index == best_span_index 78 | 79 | 80 | def _is_whitespace(c): 81 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 82 | return True 83 | return False 84 | 85 | 86 | def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training): 87 | features = [] 88 | if is_training and not example.is_impossible: 89 | # Get start and end position 90 | start_position = example.start_position 91 | end_position = example.end_position 92 | 93 | # If the answer cannot be found in the text, then skip this example. 94 | actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)]) 95 | cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) 96 | if actual_text.find(cleaned_answer_text) == -1: 97 | logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) 98 | return [] 99 | 100 | tok_to_orig_index = [] 101 | orig_to_tok_index = [] 102 | all_doc_tokens = [] 103 | for (i, token) in enumerate(example.doc_tokens): 104 | orig_to_tok_index.append(len(all_doc_tokens)) 105 | sub_tokens = tokenizer.tokenize(token) 106 | for sub_token in sub_tokens: 107 | tok_to_orig_index.append(i) 108 | all_doc_tokens.append(sub_token) 109 | 110 | if is_training and not example.is_impossible: 111 | tok_start_position = orig_to_tok_index[example.start_position] 112 | if example.end_position < len(example.doc_tokens) - 1: 113 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 114 | else: 115 | tok_end_position = len(all_doc_tokens) - 1 116 | 117 | (tok_start_position, tok_end_position) = _improve_answer_span( 118 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text 119 | ) 120 | 121 | spans = [] 122 | 123 | truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) 124 | sequence_added_tokens = ( 125 | tokenizer.max_len - tokenizer.max_len_single_sentence + 1 126 | if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer)) 127 | else tokenizer.max_len - tokenizer.max_len_single_sentence 128 | ) 129 | sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair 130 | 131 | span_doc_tokens = all_doc_tokens 132 | while len(spans) * doc_stride < len(all_doc_tokens): 133 | 134 | encoded_dict = tokenizer.encode_plus( 135 | truncated_query if tokenizer.padding_side == "right" else span_doc_tokens, 136 | span_doc_tokens if tokenizer.padding_side == "right" else truncated_query, 137 | max_length=max_seq_length, 138 | return_overflowing_tokens=True, 139 | pad_to_max_length=True, 140 | stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, 141 | truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first", 142 | return_token_type_ids=True, 143 | ) 144 | 145 | paragraph_len = min( 146 | len(all_doc_tokens) - len(spans) * doc_stride, 147 | max_seq_length - len(truncated_query) - sequence_pair_added_tokens, 148 | ) 149 | 150 | if tokenizer.pad_token_id in encoded_dict["input_ids"]: 151 | if tokenizer.padding_side == "right": 152 | non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)] 153 | else: 154 | last_padding_id_position = ( 155 | len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id) 156 | ) 157 | non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :] 158 | 159 | else: 160 | non_padded_ids = encoded_dict["input_ids"] 161 | 162 | tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) 163 | 164 | token_to_orig_map = {} 165 | for i in range(paragraph_len): 166 | index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i 167 | token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i] 168 | 169 | encoded_dict["paragraph_len"] = paragraph_len 170 | encoded_dict["tokens"] = tokens 171 | encoded_dict["token_to_orig_map"] = token_to_orig_map 172 | encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens 173 | encoded_dict["token_is_max_context"] = {} 174 | encoded_dict["start"] = len(spans) * doc_stride 175 | encoded_dict["length"] = paragraph_len 176 | 177 | spans.append(encoded_dict) 178 | 179 | if "overflowing_tokens" not in encoded_dict: 180 | break 181 | span_doc_tokens = encoded_dict["overflowing_tokens"] 182 | 183 | for doc_span_index in range(len(spans)): 184 | for j in range(spans[doc_span_index]["paragraph_len"]): 185 | is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) 186 | index = ( 187 | j 188 | if tokenizer.padding_side == "left" 189 | else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j 190 | ) 191 | spans[doc_span_index]["token_is_max_context"][index] = is_max_context 192 | 193 | for span in spans: 194 | # Identify the position of the CLS token 195 | cls_index = span["input_ids"].index(tokenizer.cls_token_id) 196 | 197 | # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) 198 | # Original TF implem also keep the classification token (set to 0) 199 | p_mask = np.ones_like(span["token_type_ids"]) 200 | if tokenizer.padding_side == "right": 201 | p_mask[len(truncated_query) + sequence_added_tokens :] = 0 202 | else: 203 | p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0 204 | 205 | pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id) 206 | special_token_indices = np.asarray( 207 | tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True) 208 | ).nonzero() 209 | 210 | p_mask[pad_token_indices] = 1 211 | p_mask[special_token_indices] = 1 212 | 213 | # Set the cls index to 0: the CLS index can be used for impossible answers 214 | p_mask[cls_index] = 0 215 | 216 | span_is_impossible = example.is_impossible 217 | start_position = 0 218 | end_position = 0 219 | if is_training and not span_is_impossible: 220 | # For training, if our document chunk does not contain an annotation 221 | # we throw it out, since there is nothing to predict. 222 | doc_start = span["start"] 223 | doc_end = span["start"] + span["length"] - 1 224 | out_of_span = False 225 | 226 | if not (tok_start_position >= doc_start and tok_end_position <= doc_end): 227 | out_of_span = True 228 | 229 | if out_of_span: 230 | start_position = cls_index 231 | end_position = cls_index 232 | span_is_impossible = True 233 | else: 234 | if tokenizer.padding_side == "left": 235 | doc_offset = 0 236 | else: 237 | doc_offset = len(truncated_query) + sequence_added_tokens 238 | 239 | start_position = tok_start_position - doc_start + doc_offset 240 | end_position = tok_end_position - doc_start + doc_offset 241 | 242 | features.append( 243 | SquadFeatures( 244 | span["input_ids"], 245 | span["attention_mask"], 246 | span["token_type_ids"], 247 | cls_index, 248 | p_mask.tolist(), 249 | example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing. 250 | unique_id=0, 251 | paragraph_len=span["paragraph_len"], 252 | token_is_max_context=span["token_is_max_context"], 253 | tokens=span["tokens"], 254 | token_to_orig_map=span["token_to_orig_map"], 255 | start_position=start_position, 256 | end_position=end_position, 257 | is_impossible=span_is_impossible, 258 | qas_id=example.qas_id, 259 | ) 260 | ) 261 | return features 262 | 263 | 264 | def squad_convert_example_to_features_init(tokenizer_for_convert): 265 | global tokenizer 266 | tokenizer = tokenizer_for_convert 267 | 268 | 269 | def squad_convert_examples_to_features( 270 | examples, 271 | tokenizer, 272 | max_seq_length, 273 | doc_stride, 274 | max_query_length, 275 | is_training, 276 | return_dataset=False, 277 | threads=1, 278 | tqdm_enabled=True, 279 | ): 280 | """ 281 | Converts a list of examples into a list of features that can be directly given as input to a model. 282 | It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. 283 | 284 | Args: 285 | examples: list of :class:`~transformers.data.processors.squad.SquadExample` 286 | tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer` 287 | max_seq_length: The maximum sequence length of the inputs. 288 | doc_stride: The stride used when the context is too large and is split across several features. 289 | max_query_length: The maximum length of the query. 290 | is_training: whether to create features for model evaluation or model training. 291 | return_dataset: Default False. Either 'pt' or 'tf'. 292 | if 'pt': returns a torch.data.TensorDataset, 293 | if 'tf': returns a tf.data.Dataset 294 | threads: multiple processing threadsa-smi 295 | 296 | 297 | Returns: 298 | list of :class:`~transformers.data.processors.squad.SquadFeatures` 299 | 300 | Example:: 301 | 302 | processor = SquadV2Processor() 303 | examples = processor.get_dev_examples(data_dir) 304 | 305 | features = squad_convert_examples_to_features( 306 | examples=examples, 307 | tokenizer=tokenizer, 308 | max_seq_length=args.max_seq_length, 309 | doc_stride=args.doc_stride, 310 | max_query_length=args.max_query_length, 311 | is_training=not evaluate, 312 | ) 313 | """ 314 | 315 | # Defining helper methods 316 | features = [] 317 | threads = min(threads, cpu_count()) 318 | with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p: 319 | annotate_ = partial( 320 | squad_convert_example_to_features, 321 | max_seq_length=max_seq_length, 322 | doc_stride=doc_stride, 323 | max_query_length=max_query_length, 324 | is_training=is_training, 325 | ) 326 | features = list( 327 | tqdm( 328 | p.imap(annotate_, examples, chunksize=32), 329 | total=len(examples), 330 | desc="convert squad examples to features", 331 | disable=not tqdm_enabled, 332 | ) 333 | ) 334 | new_features = [] 335 | unique_id = 1000000000 336 | example_index = 0 337 | for example_features in tqdm( 338 | features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled 339 | ): 340 | if not example_features: 341 | continue 342 | for example_feature in example_features: 343 | example_feature.example_index = example_index 344 | example_feature.unique_id = unique_id 345 | new_features.append(example_feature) 346 | unique_id += 1 347 | example_index += 1 348 | features = new_features 349 | del new_features 350 | if return_dataset == "pt": 351 | if not is_torch_available(): 352 | raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.") 353 | 354 | # Convert to Tensors and build dataset 355 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 356 | all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 357 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 358 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) 359 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) 360 | all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float) 361 | 362 | if not is_training: 363 | all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 364 | dataset = TensorDataset( 365 | all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask 366 | ) 367 | else: 368 | all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) 369 | all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) 370 | dataset = TensorDataset( 371 | all_input_ids, 372 | all_attention_masks, 373 | all_token_type_ids, 374 | all_start_positions, 375 | all_end_positions, 376 | all_cls_index, 377 | all_p_mask, 378 | all_is_impossible, 379 | ) 380 | 381 | return features, dataset 382 | elif return_dataset == "tf": 383 | if not is_tf_available(): 384 | raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.") 385 | 386 | def gen(): 387 | for i, ex in enumerate(features): 388 | yield ( 389 | { 390 | "input_ids": ex.input_ids, 391 | "attention_mask": ex.attention_mask, 392 | "token_type_ids": ex.token_type_ids, 393 | "feature_index": i, 394 | "qas_id": ex.qas_id, 395 | }, 396 | { 397 | "start_position": ex.start_position, 398 | "end_position": ex.end_position, 399 | "cls_index": ex.cls_index, 400 | "p_mask": ex.p_mask, 401 | "is_impossible": ex.is_impossible, 402 | }, 403 | ) 404 | 405 | # Why have we split the batch into a tuple? PyTorch just has a list of tensors. 406 | train_types = ( 407 | { 408 | "input_ids": tf.int32, 409 | "attention_mask": tf.int32, 410 | "token_type_ids": tf.int32, 411 | "feature_index": tf.int64, 412 | "qas_id": tf.string, 413 | }, 414 | { 415 | "start_position": tf.int64, 416 | "end_position": tf.int64, 417 | "cls_index": tf.int64, 418 | "p_mask": tf.int32, 419 | "is_impossible": tf.int32, 420 | }, 421 | ) 422 | 423 | train_shapes = ( 424 | { 425 | "input_ids": tf.TensorShape([None]), 426 | "attention_mask": tf.TensorShape([None]), 427 | "token_type_ids": tf.TensorShape([None]), 428 | "feature_index": tf.TensorShape([]), 429 | "qas_id": tf.TensorShape([]), 430 | }, 431 | { 432 | "start_position": tf.TensorShape([]), 433 | "end_position": tf.TensorShape([]), 434 | "cls_index": tf.TensorShape([]), 435 | "p_mask": tf.TensorShape([None]), 436 | "is_impossible": tf.TensorShape([]), 437 | }, 438 | ) 439 | 440 | return tf.data.Dataset.from_generator(gen, train_types, train_shapes) 441 | else: 442 | return features 443 | 444 | 445 | class SquadProcessor(DataProcessor): 446 | """ 447 | Processor for the SQuAD data set. 448 | Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively. 449 | """ 450 | 451 | train_file = None 452 | dev_file = None 453 | 454 | def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False): 455 | if not evaluate: 456 | answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8") 457 | answer_start = tensor_dict["answers"]["answer_start"][0].numpy() 458 | answers = [] 459 | else: 460 | answers = [ 461 | {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")} 462 | for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"]) 463 | ] 464 | 465 | answer = None 466 | answer_start = None 467 | 468 | return SquadExample( 469 | qas_id=tensor_dict["id"].numpy().decode("utf-8"), 470 | question_text=tensor_dict["question"].numpy().decode("utf-8"), 471 | context_text=tensor_dict["context"].numpy().decode("utf-8"), 472 | answer_text=answer, 473 | start_position_character=answer_start, 474 | title=tensor_dict["title"].numpy().decode("utf-8"), 475 | answers=answers, 476 | ) 477 | 478 | def get_examples_from_dataset(self, dataset, evaluate=False): 479 | """ 480 | Creates a list of :class:`~transformers.data.processors.squad.SquadExample` using a TFDS dataset. 481 | 482 | Args: 483 | dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")` 484 | evaluate: boolean specifying if in evaluation mode or in training mode 485 | 486 | Returns: 487 | List of SquadExample 488 | 489 | Examples:: 490 | 491 | import tensorflow_datasets as tfds 492 | dataset = tfds.load("squad") 493 | 494 | training_examples = get_examples_from_dataset(dataset, evaluate=False) 495 | evaluation_examples = get_examples_from_dataset(dataset, evaluate=True) 496 | """ 497 | 498 | if evaluate: 499 | dataset = dataset["validation"] 500 | else: 501 | dataset = dataset["train"] 502 | 503 | examples = [] 504 | for tensor_dict in tqdm(dataset): 505 | examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate)) 506 | 507 | return examples 508 | 509 | def get_train_examples(self, data_dir, filename=None): 510 | """ 511 | Returns the training examples from the data directory. 512 | 513 | Args: 514 | data_dir: Directory containing the data files used for training and evaluating. 515 | filename: None by default, specify this if the training file has a different name than the original one 516 | which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. 517 | 518 | """ 519 | if data_dir is None: 520 | data_dir = "" 521 | 522 | if self.train_file is None: 523 | raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") 524 | 525 | with open( 526 | os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8" 527 | ) as reader: 528 | input_data = json.load(reader)["data"] 529 | return self._create_examples(input_data, "train") 530 | 531 | def get_dev_examples(self, data_dir, filename=None): 532 | """ 533 | Returns the evaluation example from the data directory. 534 | 535 | Args: 536 | data_dir: Directory containing the data files used for training and evaluating. 537 | filename: None by default, specify this if the evaluation file has a different name than the original one 538 | which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. 539 | """ 540 | if data_dir is None: 541 | data_dir = "" 542 | 543 | if self.dev_file is None: 544 | raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") 545 | 546 | with open( 547 | os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8" 548 | ) as reader: 549 | input_data = json.load(reader)["data"] 550 | return self._create_examples(input_data, "dev") 551 | 552 | def _create_examples(self, input_data, set_type): 553 | is_training = set_type == "train" 554 | examples = [] 555 | for entry in tqdm(input_data): 556 | title = entry["title"] 557 | for paragraph in entry["paragraphs"]: 558 | context_text = paragraph["context"] 559 | for qa in paragraph["qas"]: 560 | qas_id = qa["id"] 561 | question_text = qa["question"] 562 | start_position_character = None 563 | answer_text = None 564 | answers = [] 565 | 566 | if "is_impossible" in qa: 567 | is_impossible = qa["is_impossible"] 568 | else: 569 | is_impossible = False 570 | 571 | if not is_impossible: 572 | if is_training: 573 | answer = qa["answers"][0] 574 | answer_text = answer["text"] 575 | start_position_character = answer["answer_start"] 576 | else: 577 | if 'answers' in qa: 578 | answers = qa["answers"] 579 | else: 580 | answers = '' 581 | 582 | example = SquadExample( 583 | qas_id=qas_id, 584 | question_text=question_text, 585 | context_text=context_text, 586 | answer_text=answer_text, 587 | start_position_character=start_position_character, 588 | title=title, 589 | is_impossible=is_impossible, 590 | answers=answers, 591 | ) 592 | 593 | examples.append(example) 594 | return examples 595 | 596 | 597 | class SquadV1Processor(SquadProcessor): 598 | train_file = "train-v1.1.json" 599 | dev_file = "dev-v1.1.json" 600 | 601 | 602 | class AdvSquadV1Processor(SquadV1Processor): 603 | adv_sent_file = "sample1k-HCVerifyAll.json" 604 | adv_one_sent_file = "sample1k-HCVerifySample.json" 605 | 606 | 607 | class ChecklistV1Processor(AdvSquadV1Processor): 608 | checklist_file = "squad.json" 609 | # regex for processing the prediction file ^.+"[0-9]+": " 610 | 611 | 612 | class SquadV2Processor(SquadProcessor): 613 | train_file = "train-v2.0.json" 614 | dev_file = "dev-v2.0.json" 615 | 616 | 617 | class SquadExample(object): 618 | """ 619 | A single training/test example for the Squad dataset, as loaded from disk. 620 | 621 | Args: 622 | qas_id: The example's unique identifier 623 | question_text: The question string 624 | context_text: The context string 625 | answer_text: The answer string 626 | start_position_character: The character position of the start of the answer 627 | title: The title of the example 628 | answers: None by default, this is used during evaluation. Holds answers as well as their start positions. 629 | is_impossible: False by default, set to True if the example has no possible answer. 630 | """ 631 | 632 | def __init__( 633 | self, 634 | qas_id, 635 | question_text, 636 | context_text, 637 | answer_text, 638 | start_position_character, 639 | title, 640 | answers=[], 641 | is_impossible=False, 642 | ): 643 | self.qas_id = qas_id 644 | self.question_text = question_text 645 | self.context_text = context_text 646 | self.answer_text = answer_text 647 | self.title = title 648 | self.is_impossible = is_impossible 649 | self.answers = answers 650 | 651 | self.start_position, self.end_position = 0, 0 652 | 653 | doc_tokens = [] 654 | char_to_word_offset = [] 655 | prev_is_whitespace = True 656 | 657 | # Split on whitespace so that different tokens may be attributed to their original position. 658 | for c in self.context_text: 659 | if _is_whitespace(c): 660 | prev_is_whitespace = True 661 | else: 662 | if prev_is_whitespace: 663 | doc_tokens.append(c) 664 | else: 665 | doc_tokens[-1] += c 666 | prev_is_whitespace = False 667 | char_to_word_offset.append(len(doc_tokens) - 1) 668 | 669 | self.doc_tokens = doc_tokens 670 | self.char_to_word_offset = char_to_word_offset 671 | 672 | # Start and end positions only has a value during evaluation. 673 | if start_position_character is not None and not is_impossible: 674 | self.start_position = char_to_word_offset[start_position_character] 675 | self.end_position = char_to_word_offset[ 676 | min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1) 677 | ] 678 | 679 | 680 | class SquadFeatures(object): 681 | """ 682 | Single squad example features to be fed to a model. 683 | Those features are model-specific and can be crafted from :class:`~transformers.data.processors.squad.SquadExample` 684 | using the :method:`~transformers.data.processors.squad.squad_convert_examples_to_features` method. 685 | 686 | Args: 687 | input_ids: Indices of input sequence tokens in the vocabulary. 688 | attention_mask: Mask to avoid performing attention on padding token indices. 689 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 690 | cls_index: the index of the CLS token. 691 | p_mask: Mask identifying tokens that can be answers vs. tokens that cannot. 692 | Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer 693 | example_index: the index of the example 694 | unique_id: The unique Feature identifier 695 | paragraph_len: The length of the context 696 | token_is_max_context: List of booleans identifying which tokens have their maximum context in this feature object. 697 | If a token does not have their maximum context in this feature object, it means that another feature object 698 | has more information related to that token and should be prioritized over this feature for that token. 699 | tokens: list of tokens corresponding to the input ids 700 | token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer. 701 | start_position: start of the answer token index 702 | end_position: end of the answer token index 703 | """ 704 | 705 | def __init__( 706 | self, 707 | input_ids, 708 | attention_mask, 709 | token_type_ids, 710 | cls_index, 711 | p_mask, 712 | example_index, 713 | unique_id, 714 | paragraph_len, 715 | token_is_max_context, 716 | tokens, 717 | token_to_orig_map, 718 | start_position, 719 | end_position, 720 | is_impossible, 721 | qas_id: str = None, 722 | ): 723 | self.input_ids = input_ids 724 | self.attention_mask = attention_mask 725 | self.token_type_ids = token_type_ids 726 | self.cls_index = cls_index 727 | self.p_mask = p_mask 728 | 729 | self.example_index = example_index 730 | self.unique_id = unique_id 731 | self.paragraph_len = paragraph_len 732 | self.token_is_max_context = token_is_max_context 733 | self.tokens = tokens 734 | self.token_to_orig_map = token_to_orig_map 735 | 736 | self.start_position = start_position 737 | self.end_position = end_position 738 | self.is_impossible = is_impossible 739 | self.qas_id = qas_id 740 | 741 | 742 | class SquadResult(object): 743 | """ 744 | Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset. 745 | 746 | Args: 747 | unique_id: The unique identifier corresponding to that example. 748 | start_logits: The logits corresponding to the start of the answer 749 | end_logits: The logits corresponding to the end of the answer 750 | """ 751 | 752 | def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None): 753 | self.start_logits = start_logits 754 | self.end_logits = end_logits 755 | self.unique_id = unique_id 756 | 757 | if start_top_index: 758 | self.start_top_index = start_top_index 759 | self.end_top_index = end_top_index 760 | self.cls_logits = cls_logits 761 | -------------------------------------------------------------------------------- /SQuAD/setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | function runstandardsquad { 3 | 4 | export SQUAD_DIR=squad_data 5 | 6 | custom=${1} # Custom name 7 | gpu=${2} # number of GPU 8 | bsize=${3} # Batch size 9 | beta=${4} # regularizer coefficient 10 | version=${5} # mi estimator version 11 | mname=${6} 12 | #beta=${11} # regularizer coefficient 13 | #version=${12} # mi estimator version 14 | #hdp=${13} # Hidden layer dropouts for ALBERT 15 | #adp=${14} # Attention dropouts for ALBERT 16 | #alr=${15} # Step size of gradient ascent 17 | #amag=${16} # Magnitude of initial (adversarial?) perturbation 18 | #anorm=${17} # Maximum norm of adversarial perturbation 19 | #asteps=${18} # Number of gradient ascent steps for the adversary 20 | export seqlen=384 21 | export lr=3e-5 22 | 23 | if [[ ${mname} == *"roberta"* ]]; then 24 | model_type=roberta 25 | else 26 | model_type=bert 27 | fi 28 | 29 | expname=${custom}-${mname}-sl${seqlen}-lr${lr}-bs${bsize}-beta${beta}-version${version} 30 | 31 | max=-1 32 | for file in ${expname}/checkpoint-* 33 | do 34 | fname=$(basename ${file}) 35 | num=${fname:11} 36 | [[ $num -gt $max ]] && max=$num 37 | done 38 | 39 | if [ $max -eq -1 ] 40 | then echo "Train from stratch" 41 | else mname="${expname}/checkpoint-$max" && echo "Resume Training from checkpoint $mname" 42 | fi 43 | 44 | python -m torch.distributed.launch --nproc_per_node=${gpu} ./run_squad_standard.py \ 45 | --model_type ${model_type} --evaluate_during_training --overwrite_output_dir \ 46 | --model_name_or_path ${mname} \ 47 | --do_train \ 48 | --do_eval \ 49 | --do_lower_case \ 50 | --learning_rate 3e-5 \ 51 | --num_train_epochs 2 \ 52 | --max_seq_length 384 \ 53 | --doc_stride 128 \ 54 | --logging_steps 100 --save_steps 5000 \ 55 | --output_dir ${expname} --data_dir ${SQUAD_DIR}\ 56 | --per_gpu_eval_batch_size=${bsize} \ 57 | --per_gpu_train_batch_size=${bsize} \ 58 | --beta ${beta} \ 59 | --version ${version} \ 60 | && echo "add sent" && python2 eval_adv_squad.py squad_data/sample1k-HCVerifyAll.json \ 61 | ${expname}/predictions_add_sent.json \ 62 | && echo "add one sent" && python2 eval_adv_squad.py squad_data/sample1k-HCVerifySample.json \ 63 | ${expname}/predictions_add_one_sent.json 64 | } 65 | 66 | 67 | 68 | 69 | function runsquad { 70 | 71 | export SQUAD_DIR=squad_data 72 | 73 | custom=${1} # Custom name 74 | gpu=${2} # number of GPU 75 | bsize=${3} # Batch size 76 | beta=${4} # regularizer coefficient 77 | version=${5} # mi estimator version 78 | hdp=${6} # Hidden layer dropouts for ALBERT 79 | adp=${7} # Attention dropouts for ALBERT 80 | alr=${8} # Step size of gradient ascent 81 | amag=${9} # Magnitude of initial (adversarial?) perturbation 82 | anorm=${10} # Maximum norm of adversarial perturbation 83 | asteps=${11} # Number of gradient ascent steps for the adversary 84 | mname=${12} 85 | alpha=${13} 86 | cl=${14} 87 | ch=${15} 88 | if [ -z "${14}" ] ;then 89 | cl=0.5 90 | fi 91 | 92 | if [ -z "${15}" ] ;then 93 | ch=0.9 94 | fi 95 | #beta=${11} # regularizer coefficient 96 | #version=${12} # mi estimator version 97 | #hdp=${13} # Hidden layer dropouts for ALBERT 98 | #adp=${14} # Attention dropouts for ALBERT 99 | #alr=${15} # Step size of gradient ascent 100 | #amag=${16} # Magnitude of initial (adversarial?) perturbation 101 | #anorm=${17} # Maximum norm of adversarial perturbation 102 | #asteps=${18} # Number of gradient ascent steps for the adversary 103 | export seqlen=384 104 | export lr=3e-5 105 | #export mname=bert-large-uncased-whole-word-masking 106 | 107 | #mname=${3} # Model name 108 | #lr=${4} # Learning rate for model parameters 109 | #seqlen=${6} # Maximum sequence length 110 | #ts=${7} # Number of training steps (counted as parameter updates) 111 | #ws=${8} # Learning rate warm-up steps 112 | #seed=${9} # Seed for randomness 113 | #wd=${10} # Weight decay 114 | if [[ ${mname} == *"roberta"* ]]; then 115 | model_type=roberta 116 | else 117 | model_type=bert 118 | fi 119 | 120 | #expname=${custom}-${mname}-${TASK_NAME}-sl${seqlen}-lr${lr}-bs${bsize}-ts${ts}-ws${ws}-wd${wd}-seed${seed}-beta${beta}-alr${alr}-amag${amag}-anm${anorm}-as${asteps}-hdp${hdp}-adp${adp}-version${version} 121 | expname=${custom}-${mname}-sl${seqlen}-lr${lr}-bs${bsize}-beta${beta}-alr${alr}-amag${amag}-anm${anorm}-as${asteps}-hdp${hdp}-adp${adp}-alpha${alpha}-cl${cl}-ch${ch}-version${version} 122 | 123 | max=-1 124 | for file in ${expname}/checkpoint-* 125 | do 126 | fname=$(basename ${file}) 127 | num=${fname:11} 128 | [[ $num -gt $max ]] && max=$num 129 | done 130 | 131 | if [ $max -eq -1 ] 132 | then echo "Train from stratch" 133 | else mname="${expname}/checkpoint-$max" && echo "Resume Training from checkpoint $mname" 134 | fi 135 | 136 | python -m torch.distributed.launch --nproc_per_node=${gpu} ./run_squad.py \ 137 | --model_type ${model_type} --evaluate_during_training --overwrite_output_dir \ 138 | --model_name_or_path ${mname} \ 139 | --do_train \ 140 | --do_eval \ 141 | --do_lower_case \ 142 | --learning_rate 3e-5 \ 143 | --num_train_epochs 2 \ 144 | --max_seq_length 384 \ 145 | --doc_stride 128 \ 146 | --logging_steps 100 --save_steps 5000 \ 147 | --output_dir ${expname} --data_dir ${SQUAD_DIR}\ 148 | --per_gpu_eval_batch_size=${bsize} \ 149 | --per_gpu_train_batch_size=${bsize} \ 150 | --hidden_dropout_prob ${hdp} --attention_probs_dropout_prob ${adp} \ 151 | --adv-lr ${alr} --adv-init-mag ${amag} --adv-max-norm ${anorm} --adv-steps ${asteps} \ 152 | --beta ${beta} --alpha ${alpha} --cl ${cl} --ch ${ch} \ 153 | --version ${version} \ 154 | && echo "add sent" && python2 eval_adv_squad.py squad_data/sample1k-HCVerifyAll.json \ 155 | ${expname}/predictions_add_sent.json \ 156 | && echo "add one sent" && python2 eval_adv_squad.py squad_data/sample1k-HCVerifySample.json \ 157 | ${expname}/predictions_add_one_sent.json 158 | } 159 | 160 | function runptsquad { 161 | 162 | export SQUAD_DIR=${PT_DATA_DIR} 163 | 164 | custom=${1} # Custom name 165 | gpu=${2} # number of GPU 166 | bsize=${3} # Batch size 167 | beta=${4} # regularizer coefficient 168 | version=${5} # mi estimator version 169 | hdp=${6} # Hidden layer dropouts for ALBERT 170 | adp=${7} # Attention dropouts for ALBERT 171 | alr=${8} # Step size of gradient ascent 172 | amag=${9} # Magnitude of initial (adversarial?) perturbation 173 | anorm=${10} # Maximum norm of adversarial perturbation 174 | asteps=${11} # Number of gradient ascent steps for the adversary 175 | mname=${12} 176 | alpha=${13} 177 | cl=${14} 178 | ch=${15} 179 | if [ -z "${14}" ] ;then 180 | cl=0.5 181 | fi 182 | 183 | if [ -z "${15}" ] ;then 184 | ch=0.9 185 | fi 186 | #beta=${11} # regularizer coefficient 187 | #version=${12} # mi estimator version 188 | #hdp=${13} # Hidden layer dropouts for ALBERT 189 | #adp=${14} # Attention dropouts for ALBERT 190 | #alr=${15} # Step size of gradient ascent 191 | #amag=${16} # Magnitude of initial (adversarial?) perturbation 192 | #anorm=${17} # Maximum norm of adversarial perturbation 193 | #asteps=${18} # Number of gradient ascent steps for the adversary 194 | export seqlen=384 195 | export lr=3e-5 196 | export mname=bert-large-uncased-whole-word-masking 197 | 198 | #mname=${3} # Model name 199 | #lr=${4} # Learning rate for model parameters 200 | #seqlen=${6} # Maximum sequence length 201 | #ts=${7} # Number of training steps (counted as parameter updates) 202 | #ws=${8} # Learning rate warm-up steps 203 | #seed=${9} # Seed for randomness 204 | #wd=${10} # Weight decay 205 | 206 | 207 | #expname=${custom}-${mname}-${TASK_NAME}-sl${seqlen}-lr${lr}-bs${bsize}-ts${ts}-ws${ws}-wd${wd}-seed${seed}-beta${beta}-alr${alr}-amag${amag}-anm${anorm}-as${asteps}-hdp${hdp}-adp${adp}-version${version} 208 | expname=${PT_OUTPUT_DIR}/${custom}-${mname}-sl${seqlen}-lr${lr}-bs${bsize}-beta${beta}-alr${alr}-amag${amag}-anm${anorm}-as${asteps}-hdp${hdp}-adp${adp}-alpha${alpha}-cl${cl}-ch${ch}-version${version} 209 | 210 | max=-1 211 | for file in ${expname}/checkpoint-* 212 | do 213 | fname=$(basename ${file}) 214 | num=${fname:11} 215 | [[ $num -gt $max ]] && max=$num 216 | done 217 | 218 | if [ $max -eq -1 ] 219 | then echo "Train from stratch" 220 | else mname="${expname}/checkpoint-$max" && echo "Resume Training from checkpoint $mname" 221 | fi 222 | 223 | python -m torch.distributed.launch --nproc_per_node=${gpu} ./run_squad.py \ 224 | --model_type bert \ 225 | --model_name_or_path bert-large-uncased-whole-word-masking --overwrite_output_dir \ 226 | --do_train \ 227 | --do_eval \ 228 | --do_lower_case \ 229 | --learning_rate 3e-5 \ 230 | --num_train_epochs 2 \ 231 | --max_seq_length 384 \ 232 | --doc_stride 128 \ 233 | --logging_steps 100 --save_steps 100 \ 234 | --output_dir ${expname} --data_dir ${SQUAD_DIR}\ 235 | --per_gpu_eval_batch_size=${bsize} \ 236 | --per_gpu_train_batch_size=${bsize} \ 237 | --hidden_dropout_prob ${hdp} --attention_probs_dropout_prob ${adp} \ 238 | --adv-lr ${alr} --adv-init-mag ${amag} --adv-max-norm ${anorm} --adv-steps ${asteps} \ 239 | --beta ${beta} --alpha ${alpha} --cl ${cl} --ch ${ch} \ 240 | --version ${version} \ 241 | && echo "add sent" && python2 eval_adv_squad.py squad_data/sample1k-HCVerifyAll.json \ 242 | ${expname}/predictions_add_sent.json \ 243 | && echo "add one sent" && python2 eval_adv_squad.py squad_data/sample1k-HCVerifySample.json \ 244 | ${expname}/predictions_add_one_sent.json 245 | } 246 | 247 | function evalsquad { 248 | 249 | 250 | export SQUAD_DIR=squad_data 251 | 252 | custom=${1} # Custom name 253 | gpu=${2} # number of GPU 254 | bsize=${3} # Batch size 255 | beta=${4} # regularizer coefficient 256 | version=${5} # mi estimator version 257 | hdp=${6} # Hidden layer dropouts for ALBERT 258 | adp=${7} # Attention dropouts for ALBERT 259 | alr=${8} # Step size of gradient ascent 260 | amag=${9} # Magnitude of initial (adversarial?) perturbation 261 | anorm=${10} # Maximum norm of adversarial perturbation 262 | asteps=${11} # Number of gradient ascent steps for the adversary 263 | mname=${12} 264 | alpha=${13} 265 | #beta=${11} # regularizer coefficient 266 | #version=${12} # mi estimator version 267 | #hdp=${13} # Hidden layer dropouts for ALBERT 268 | #adp=${14} # Attention dropouts for ALBERT 269 | #alr=${15} # Step size of gradient ascent 270 | #amag=${16} # Magnitude of initial (adversarial?) perturbation 271 | #anorm=${17} # Maximum norm of adversarial perturbation 272 | #asteps=${18} # Number of gradient ascent steps for the adversary 273 | export seqlen=384 274 | export lr=3e-5 275 | #export mname=bert-large-uncased-whole-word-masking 276 | 277 | #mname=${3} # Model name 278 | #lr=${4} # Learning rate for model parameters 279 | #seqlen=${6} # Maximum sequence length 280 | #ts=${7} # Number of training steps (counted as parameter updates) 281 | #ws=${8} # Learning rate warm-up steps 282 | #seed=${9} # Seed for randomness 283 | #wd=${10} # Weight decay 284 | if [[ ${mname} == *"roberta"* ]]; then 285 | model_type=roberta 286 | else 287 | model_type=bert 288 | fi 289 | 290 | #expname=${custom}-${mname}-${TASK_NAME}-sl${seqlen}-lr${lr}-bs${bsize}-ts${ts}-ws${ws}-wd${wd}-seed${seed}-beta${beta}-alr${alr}-amag${amag}-anm${anorm}-as${asteps}-hdp${hdp}-adp${adp}-version${version} 291 | expname=${custom}-${mname}-load 292 | 293 | #python -m torch.distributed.launch --nproc_per_node=${gpu} ./run_squad.py \ 294 | python run_squad.py \ 295 | --model_type ${model_type} \ 296 | --model_name_or_path ${mname} \ 297 | --do_eval \ 298 | --do_lower_case \ 299 | --learning_rate 3e-5 \ 300 | --num_train_epochs 2 \ 301 | --max_seq_length 384 \ 302 | --doc_stride 128 \ 303 | --logging_steps 100 --save_steps 5000 \ 304 | --output_dir ${expname} --data_dir ${SQUAD_DIR}\ 305 | --per_gpu_eval_batch_size=${bsize} \ 306 | --per_gpu_train_batch_size=${bsize} \ 307 | --hidden_dropout_prob ${hdp} --attention_probs_dropout_prob ${adp} \ 308 | --adv-lr ${alr} --adv-init-mag ${amag} --adv-max-norm ${anorm} --adv-steps ${asteps} \ 309 | --beta ${beta} --alpha ${alpha} \ 310 | --version ${version} \ 311 | && echo "add sent" && python2 eval_adv_squad.py squad_data/sample1k-HCVerifyAll.json \ 312 | ${expname}/predictions_add_sent.json \ 313 | && echo "add one sent" && python2 eval_adv_squad.py squad_data/sample1k-HCVerifySample.json \ 314 | ${expname}/predictions_add_one_sent.json 315 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.4.0 2 | filelock~=3.0.12 3 | nlp~=0.4.0 4 | transformers~=2.11.0 5 | numpy~=1.18.4 6 | scikit-learn~=0.21.3 7 | packaging~=20.3 8 | tqdm~=4.46.0 9 | tensorboardX~=1.9 10 | matplotlib~=3.1.3 --------------------------------------------------------------------------------