├── Code ├── run.sh ├── train.py ├── util.py ├── info.py ├── model.py └── test.py ├── LICENSE └── README.md /Code/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cuda=0 4 | model_name="bert_base-ce" 5 | 6 | for task_id in "Task1" 7 | do 8 | for seed in 0 1 2 3 4 5 9 | do 10 | CUDA_VISIBLE_DEVICES=${cuda} python3 train.py --task_id=${task_id} --model_name=${model_name} --version="det" --seed=${seed} 11 | done 12 | for seed in 0 1 2 3 4 5 13 | do 14 | CUDA_VISIBLE_DEVICES=${cuda} python3 train.py --task_id=${task_id} --model_name=${model_name} --version="sto" --seed=${seed} 15 | done 16 | CUDA_VISIBLE_DEVICES=${cuda} python3 test.py --task_id=${task_id} --model_name=${model_name} 17 | done -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 xiaoyuxin1002 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Code/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import dill as pk 3 | 4 | import torch 5 | from torch.nn.utils import clip_grad_norm_ 6 | 7 | from info import Info 8 | from util import parse_args_train, set_seed, myprint, load, prepare, iter_batch 9 | 10 | 11 | def test(info, idx_epoch, inputs_dev, model): 12 | 13 | model.eval() 14 | with torch.no_grad(): 15 | 16 | all_preds, all_labels = [], [] 17 | for idx_batch, (batch_inputs, batch_labels) in enumerate(iter_batch(info, inputs_dev, if_shuffle=False)): 18 | 19 | batch_probs = model.infer(batch_inputs) 20 | all_preds.append(batch_probs.argmax(dim=1)) 21 | all_labels.append(batch_labels) 22 | 23 | all_accuracy = (torch.cat(all_preds) == torch.cat(all_labels)).float().mean().item() 24 | myprint(f'Finish Testing Epoch {idx_epoch} | Accuracy {all_accuracy:.4f}', info.FILE_STDOUT) 25 | 26 | return all_accuracy 27 | 28 | 29 | def train(info, idx_epoch, inputs_train, model, optimizer, scheduler): 30 | 31 | model.train() 32 | num_batch = math.ceil(inputs_train[1].shape[0] / info.HP_BATCH_SIZE) 33 | report_batch = num_batch // 5 34 | for idx_batch, (batch_inputs, batch_labels) in enumerate(iter_batch(info, inputs_train, if_shuffle=True)): 35 | 36 | batch_loss = model.learn(batch_inputs, batch_labels) 37 | optimizer.zero_grad() 38 | batch_loss.backward() 39 | clip_grad_norm_(model.parameters(), info.HP_MAX_GRAD_NORM) 40 | optimizer.step() 41 | scheduler.step() 42 | 43 | if idx_batch % report_batch == 0: 44 | myprint(f'Finish Training Epoch {idx_epoch} | Batch {idx_batch} | Loss {batch_loss.item():.4f}', info.FILE_STDOUT) 45 | myprint('-'*20, info.FILE_STDOUT) 46 | 47 | 48 | def main(): 49 | 50 | args = parse_args_train() 51 | info = Info(args) 52 | 53 | myprint('='*20, info.FILE_STDOUT) 54 | myprint(f'Start {args.stage}ing {args.version}-{args.model_name} for {args.task_id}', info.FILE_STDOUT) 55 | myprint('-'*20, info.FILE_STDOUT) 56 | 57 | set_seed(args.seed) 58 | inputs_train, inputs_dev = load(args, info, args.stage) 59 | model, optimizer, scheduler = prepare(info, inputs_train) 60 | 61 | best_accuracy, best_epoch = 0, 0 62 | for idx_epoch in range(info.HP_NUM_EPOCH): 63 | 64 | train(info, idx_epoch, inputs_train, model, optimizer, scheduler) 65 | epoch_accuracy = test(info, idx_epoch, inputs_dev, model) 66 | 67 | if epoch_accuracy >= best_accuracy: 68 | best_accuracy, best_epoch = epoch_accuracy, idx_epoch 69 | pk.dump(model, open(info.FILE_MODEL, 'wb'), -1) 70 | myprint(f'This is the Best Performing Epoch by far - Epoch {idx_epoch} Accuracy {epoch_accuracy:.4f}', info.FILE_STDOUT) 71 | else: 72 | myprint(f'Not the Best Performing Epoch by far - Epoch {idx_epoch} Accuracy {epoch_accuracy:.4f} vs Best Accuracy {best_accuracy:.4f}', info.FILE_STDOUT) 73 | myprint('-'*20, info.FILE_STDOUT) 74 | 75 | myprint(f'Finish {args.stage}ing {args.version}-{args.model_name} for {args.task_id}', info.FILE_STDOUT) 76 | myprint('='*20, info.FILE_STDOUT) 77 | 78 | 79 | if __name__=='__main__': 80 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UQ-PLM 2 | 3 | Code for Uncertainty Quantification with Pre-trained Language Models: An Empirical Analysis (EMNLP 2022 Findings). 4 | 5 | ## Requirements 6 | 7 | ``` 8 | PyTorch = 1.10.1 9 | Bayesian-Torch = 0.1 10 | HuggingFace Transformers = 4.11.1 11 | ``` 12 | 13 | ## Data 14 | 15 | Our empirical analysis consists of the following three NLP (natural language processing) classification tasks: 16 | 17 | **task_id** | Task | In-Domain Dataset | Out-Of-Domain Dataset 18 | --- | --- | --- | --- 19 | **Task1** | Sentiment Analysis | IMDb | Yelp 20 | **Task2** | Natural Language Inference | MNLI | SNLI 21 | **Task3** | Commonsense Reasoning | SWAG | HellaSWAG 22 | 23 | You can download our input data here and unzip it to the current directory. 24 | 25 | Then the corresponding data splits of each task are stored in **Data/{task_id}/Original**: 26 | - **train.pkl**, **dev.pkl**, and **test_in.pkl** come from the in-domain dataset. 27 | - **test_out.pkl** comes from the out-of-domain dataset. 28 | 29 | ## Run 30 | 31 | Specify the targeting `model_name` and `task_id` in **Code/run.sh**: 32 | - `model_name` is specified in the format of `{PLM}_{size}-{loss}`. 33 | - `{PLM}` (Pre-trained Language Model) can be chosen from `bert`, `xlnet`, `electra`, `roberta`, and `deberta`. 34 | - `{size}` can be chosen from `base` and `large`. 35 | - `{loss}` can be chosen from `be` (Brier loss), `fl` (focal loss), `ce` (cross-entropy), `ls` (label smoothing), and `mm` (max mean calibration error). 36 | - `task_id` can be chosen from `Task1` (Sentiment Analysis), `Task2` (Natural Language Inference), and `Task3` (Commonsense Reasoning). 37 | 38 | Other hyperparameters are defined in **Code/info.py** (e.g., learning rate, batch size, and training epoch). 39 | 40 | Use the command `bash Code/run.sh` to run one sweep of experiments: 41 | 1. Transform the original data input in **Data/{task_id}/Original** to the model-specific data input in **Data/{task_id}/{model_name}**. 42 | 1. Train six deterministic (version=`det`) PLM-based pipelines (used for `Vanilla`, `Temp Scaling` (temperature scaling), `MC Dropout` (monte-carlo dropout), and `Ensemble`) stored in **Result/{task_id}/{model_name}**. 43 | 1. Train six stochastic (version=`sto`) PLM-based pipelines (used for `LL SVI` (last-layer stochastic variational inference)) stored in **Result/{task_id}/{model_name}**. 44 | 1. Test the above pipelines with five kinds of uncertainty quantifiers (`Vanilla`, `Temp Scaling`, `MC Dropout`, `Ensemble`, and `LL SVI`) under two domain settings (`test_in` and `test_out`) based on four metrics (`ERR` (prediction error), `ECE` (expected calibration error), `RPP` (reversed pair proportion), and `FAR95` (false alarm rate at 95% recall)). 45 | 1. The evaluation of each (uncertainty quantifier, domain setting, metric) combination consists of six trials, and the results are stored in **Result/{task_id}/{model_name}/result_score.pkl**. 46 | 1. The ground truth labels and raw probability outputs are stored in **Result/{task_id}/{model_name}/result_prob.pkl**. 47 | 1. All the training and testing stdouts are stored in **Result/{task_id}/{model_name}/**. 48 | 49 | ## Result 50 | 51 | We store our empirical observations in **results.pkl**. You can download this dictionary here. 52 | - The key is in the format of `({task}, {model}, {quantifier}, {domain}, {metric})`. 53 | - `{task}` can be chosen from `Sentiment Analysis`, `Natural Language Inference`, and `Commonsense Reasoning`. 54 | - `{model}` can be chosen from `bert_base-br`, `bert_base-ce`, `bert_base-fl`, `bert_base-ls`, `bert_base-mm`, `bert_large-ce`, `deberta_base-ce`, `deberta_large-ce`, `electra_base-ce`, `electra_large-ce`, `roberta_base-ce`, `roberta_large-ce`, `xlnet_base-ce`, and `xlnet_large-ce`. 55 | - `{quantifier}` can be chosen from `Vanilla`, `Temp Scaling`, `MC Dropout`, `Ensemble`, and `LL SVI`. 56 | - `{domain}` can be chosen from `test_in` and `test_out`. 57 | - `{metric}` can be chosen from `ERR`, `ECE`, `RPP`, and `FAR95`. Note that `FAR95` only works with the domain setting of `test_out`. 58 | - The value is in the format of `(mean, standard error)`, which are calculated based on six trials with different seeds. 59 | 60 | ## Citation 61 | 62 | ``` 63 | @inproceedings{xiao2022uncertainty, 64 | title={Uncertainty Quantification with Pre-trained Language Models: An Empirical Analysis}, 65 | author={Xiao, Yuxin and Liang, Paul Pu and Bhatt, Umang and Neiswanger, Willie and Salakhutdinov, Ruslan and Morency, Louis-Philippe}, 66 | booktitle={Findings of EMNLP}, 67 | year={2022} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /Code/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import random 5 | import argparse 6 | import dill as pk 7 | import numpy as np 8 | 9 | import torch 10 | from transformers import AutoTokenizer, AutoConfig, AutoModel 11 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup 12 | 13 | from model import Model 14 | 15 | 16 | def parse_args_train(): 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument('--task_id', type=str, default='Task1') 21 | parser.add_argument('--model_name', type=str, default='bert_base-model_ce') 22 | parser.add_argument('--stage', type=str, default='train') 23 | parser.add_argument('--version', type=str, default='det') 24 | parser.add_argument('--seed', type=int, default=0) 25 | 26 | args = parser.parse_args() 27 | 28 | return args 29 | 30 | 31 | def parse_args_test(): 32 | 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument('--task_id', type=str, default='Task1') 36 | parser.add_argument('--model_name', type=str, default='bert_base-model_ce') 37 | parser.add_argument('--stage', type=str, default='test') 38 | parser.add_argument('--seed', type=int, default=0) 39 | 40 | args = parser.parse_args() 41 | 42 | return args 43 | 44 | 45 | def set_seed(seed): 46 | 47 | random.seed(seed) 48 | np.random.seed(seed) 49 | torch.manual_seed(seed) 50 | torch.cuda.manual_seed(seed) 51 | 52 | 53 | def myprint(text, file): 54 | 55 | file = open(file, 'a') 56 | print(time.strftime("%Y %b %d %a, %H:%M:%S: ", time.localtime()) + text, file=file, flush=True) 57 | file.close() 58 | 59 | 60 | def load(args, info, stage): 61 | 62 | if stage == info.STAGE_TRAIN: 63 | name_list = [info.TYPE_TRAIN, info.TYPE_DEV] 64 | ori_list = [info.FILE_ORI_TRAIN, info.FILE_ORI_DEV] 65 | input_list = [info.FILE_INPUT_TRAIN, info.FILE_INPUT_DEV] 66 | elif stage == info.STAGE_TEST: 67 | name_list = [info.TYPE_DEV, info.TYPE_TEST_IN, info.TYPE_TEST_OUT] 68 | ori_list = [info.FILE_ORI_DEV, info.FILE_ORI_TEST_IN, info.FILE_ORI_TEST_OUT] 69 | input_list = [info.FILE_INPUT_DEV, info.FILE_INPUT_TEST_IN, info.FILE_INPUT_TEST_OUT] 70 | 71 | data_list = [] 72 | tokenizer = AutoTokenizer.from_pretrained(info.PLM_NAME) 73 | for name, ori_file, input_file in zip(name_list, ori_list, input_list): 74 | myprint(f'Load Data from {name}', info.FILE_STDOUT) 75 | 76 | if os.path.isfile(input_file): 77 | all_inputs, all_labels = pk.load(open(input_file, 'rb')) 78 | 79 | else: 80 | ori_data = pk.load(open(ori_file, 'rb')) 81 | text1s, text2s, labels = [], [], [] 82 | for row in ori_data: 83 | 84 | if args.task_id in ['Task1']: 85 | text1s.append(row[0]) 86 | labels.append(row[1]) 87 | 88 | elif args.task_id in ['Task2']: 89 | text1s.append(row[0]) 90 | text2s.append(row[1]) 91 | labels.append(row[2]) 92 | 93 | elif args.task_id in ['Task3']: 94 | text1s += [row[0]] * info.NUM_CLASS[1] 95 | text2s += row[1:1+info.NUM_CLASS[1]] 96 | labels.append(row[-1]) 97 | 98 | if len(text2s) == 0: text2s = None 99 | all_inputs = tokenizer(text1s, text2s, return_tensors='pt', padding=True, truncation=True, max_length=512) 100 | all_labels = torch.Tensor(labels).long() 101 | pk.dump((all_inputs, all_labels), open(input_file, 'wb'), -1) 102 | 103 | data_list.append((all_inputs.to(info.DEVICE_GPU), all_labels.to(info.DEVICE_GPU))) 104 | myprint('-'*20, info.FILE_STDOUT) 105 | 106 | return data_list 107 | 108 | 109 | def prepare(info, inputs_train): 110 | 111 | config = AutoConfig.from_pretrained(info.PLM_NAME, num_labels=info.NUM_CLASS[0]) 112 | transformer = AutoModel.from_pretrained(info.PLM_NAME) 113 | 114 | model = Model(info, config, transformer).to(info.DEVICE_GPU) 115 | parameters = list(model.parameters()) 116 | optimizer = AdamW(parameters, lr=info.HP_LR) 117 | 118 | num_updates = math.ceil(inputs_train[1].shape[0] / info.HP_BATCH_SIZE) * info.HP_NUM_EPOCH 119 | num_warmups = int(num_updates * info.HP_WARMUP_RATIO) 120 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmups, num_training_steps=num_updates) 121 | 122 | return model, optimizer, scheduler 123 | 124 | 125 | def iter_batch(info, inputs, if_shuffle=False): 126 | 127 | all_inputs, all_labels = inputs 128 | batch_seq = np.arange(all_labels.shape[0]) 129 | if if_shuffle: np.random.shuffle(batch_seq) 130 | num_batch = math.ceil(batch_seq.shape[0] / info.HP_BATCH_SIZE) 131 | 132 | for idx_batch in range(num_batch): 133 | batch_indices = batch_seq[idx_batch*info.HP_BATCH_SIZE : (idx_batch+1)*info.HP_BATCH_SIZE] 134 | 135 | batch_labels = all_labels[batch_indices] 136 | if info.NUM_CLASS[1] != 1: batch_indices = np.repeat(batch_indices*info.NUM_CLASS[1], info.NUM_CLASS[1]) + np.tile(np.arange(info.NUM_CLASS[1]), batch_indices.shape[0]) 137 | batch_inputs = {k:v[batch_indices] for k,v in all_inputs.items()} 138 | 139 | yield batch_inputs, batch_labels -------------------------------------------------------------------------------- /Code/info.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from collections import defaultdict 4 | 5 | 6 | class Info: 7 | 8 | def __init__(self, args): 9 | 10 | self.metadata() 11 | self.individual(args) 12 | 13 | 14 | def metadata(self): 15 | 16 | self.DEVICE_CPU = 'cpu' 17 | self.DEVICE_GPU = 'cuda' 18 | 19 | self.STAGE_TRAIN = 'train' 20 | self.STAGE_TEST = 'test' 21 | 22 | self.VERSION_DET = 'det' 23 | self.VERSION_STO = 'sto' 24 | 25 | self.LOSS_BR = 'br' # Brier Loss 26 | self.LOSS_CE = 'ce' # Cross Entropy 27 | self.LOSS_FL = 'fl' # Focal Loss 28 | self.LOSS_LS = 'ls' # Label Smoothing 29 | self.LOSS_MM = 'mm' # Max Mean Calibration Error 30 | 31 | self.METHOD_VANILLA = 'Vanilla' 32 | self.METHOD_TEMP_SCALING = 'Temp Scaling' 33 | self.METHOD_MC_DROPOUT = 'MC Dropout' 34 | self.METHOD_ENSEMBLE = 'Ensemble' 35 | self.METHOD_LL_SVI = 'LL SVI' 36 | self.METHODS = [self.METHOD_VANILLA, self.METHOD_TEMP_SCALING, self.METHOD_MC_DROPOUT, self.METHOD_ENSEMBLE, self.METHOD_LL_SVI] 37 | 38 | self.METRIC_ERR = 'ERR' 39 | self.METRIC_ECE = 'ECE' 40 | self.METRICS_EXPLICIT = [self.METRIC_ERR, self.METRIC_ECE] 41 | self.METRIC_RPP = 'RPP' 42 | self.METRIC_FAR = 'FAR95' 43 | self.METRICS_IMPLICIT = [self.METRIC_RPP, self.METRIC_FAR] 44 | self.METRICS = [self.METRICS_EXPLICIT, self.METRICS_IMPLICIT] 45 | 46 | self.TYPE_TRAIN = 'train' 47 | self.TYPE_DEV = 'dev' 48 | self.TYPE_TEST_IN = 'test_in' 49 | self.TYPE_TEST_OUT = 'test_out' 50 | self.TYPE_TESTS = [self.TYPE_TEST_IN, self.TYPE_TEST_OUT] 51 | 52 | self.TASK2NCLASS = {'Task1':(2,1), 'Task2':(3,1), 'Task3':(1,4)} 53 | 54 | self.MODEL2NAME = {'bert_base':'bert-base-cased', 'bert_large':'bert-large-cased', 55 | 'xlnet_base':'xlnet-base-cased', 'xlnet_large':'xlnet-large-cased', 56 | 'electra_base':'google/electra-base-discriminator', 'electra_large':'google/electra-large-discriminator', 57 | 'roberta_base':'roberta-base', 'roberta_large':'roberta-large', 58 | 'deberta_base':'microsoft/deberta-base', 'deberta_large':'microsoft/deberta-large'} 59 | 60 | self.HP_LOSS_FL = (0.2, 5, 3) 61 | self.HP_LOSS_LS = 0.1 62 | self.HP_LOSS_MM = 1 63 | 64 | self.HP_NUM_DROPOUT_MC = 10 65 | self.HP_NUM_ENSEMBLE = 5 66 | self.HP_NUM_SVI_MC = 50 67 | 68 | self.HP_WARMUP_RATIO = 0.1 69 | self.HP_MAX_GRAD_NORM = 1.0 70 | self.HP_BATCH_SIZE = 16 71 | self.HP_NUM_EPOCH = 5 72 | self.HP_MODEL2LR = {'bert_base':2e-5, 'xlnet_base':2e-5, 'electra_base':2e-5, 'roberta_base':2e-5, 'deberta_base':2e-5, 73 | 'bert_large':5e-6, 'xlnet_large':5e-6, 'electra_large':5e-6, 'roberta_large':5e-6, 'deberta_large':5e-6} 74 | 75 | self.DIR_CURR = os.getcwd() 76 | self.DIR_DATA = os.path.join(self.DIR_CURR, '../Data') 77 | self.DIR_RESULT = os.path.join(self.DIR_CURR, '../Result') 78 | 79 | 80 | def individual(self, args): 81 | 82 | self.DIR_ORI = os.path.join(self.DIR_DATA, args.task_id, 'Original') 83 | self.FILE_ORI_TRAIN = os.path.join(self.DIR_ORI, f'{self.TYPE_TRAIN}.pkl') 84 | self.FILE_ORI_DEV = os.path.join(self.DIR_ORI, f'{self.TYPE_DEV}.pkl') 85 | self.FILE_ORI_TEST_IN = os.path.join(self.DIR_ORI, f'{self.TYPE_TEST_IN}.pkl') 86 | self.FILE_ORI_TEST_OUT = os.path.join(self.DIR_ORI, f'{self.TYPE_TEST_OUT}.pkl') 87 | 88 | self.DIR_INPUT = os.path.join(self.DIR_DATA, args.task_id, args.model_name) 89 | Path(self.DIR_INPUT).mkdir(parents=True, exist_ok=True) 90 | self.FILE_INPUT_TRAIN = os.path.join(self.DIR_INPUT, f'{self.TYPE_TRAIN}.pkl') 91 | self.FILE_INPUT_DEV = os.path.join(self.DIR_INPUT, f'{self.TYPE_DEV}.pkl') 92 | self.FILE_INPUT_TEST_IN = os.path.join(self.DIR_INPUT, f'{self.TYPE_TEST_IN}.pkl') 93 | self.FILE_INPUT_TEST_OUT = os.path.join(self.DIR_INPUT, f'{self.TYPE_TEST_OUT}.pkl') 94 | 95 | self.DIR_OUTPUT = os.path.join(self.DIR_RESULT, args.task_id, args.model_name) 96 | Path(self.DIR_OUTPUT).mkdir(parents=True, exist_ok=True) 97 | 98 | if args.stage == self.STAGE_TRAIN: 99 | self.FILE_STDOUT = os.path.join(self.DIR_OUTPUT, f'stdout_{args.stage}_{args.version}_{args.seed}.txt') 100 | self.FILE_MODEL = os.path.join(self.DIR_OUTPUT, f'model_{args.version}_{args.seed}.pkl') 101 | self.VERSION_MODE = args.version 102 | elif args.stage == self.STAGE_TEST: 103 | self.FILE_STDOUT = os.path.join(self.DIR_OUTPUT, f'stdout_{args.stage}.txt') 104 | self.FILE_SCORE = os.path.join(self.DIR_OUTPUT, f'result_score.pkl') 105 | self.FILE_PROB = os.path.join(self.DIR_OUTPUT, f'result_prob.pkl') 106 | 107 | self.FILE_MODELS = defaultdict(list) 108 | for file in os.listdir(self.DIR_OUTPUT): 109 | if file.startswith('model'): self.FILE_MODELS[file.split('_')[1]].append(os.path.join(self.DIR_OUTPUT, file)) 110 | 111 | model_name, self.LOSS_MODE = args.model_name.split('-') 112 | self.PLM_NAME = self.MODEL2NAME[model_name] 113 | self.HP_LR = self.HP_MODEL2LR[model_name] 114 | self.NUM_CLASS = self.TASK2NCLASS[args.task_id] -------------------------------------------------------------------------------- /Code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from bayesian_torch.layers import LinearReparameterization 5 | 6 | 7 | class Classifier(nn.Module): 8 | 9 | def __init__(self, info, config): 10 | super(Classifier, self).__init__() 11 | 12 | self.info = info 13 | 14 | self.activation = nn.Tanh() 15 | self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else config.hidden_dropout_prob) 16 | 17 | if info.VERSION_MODE == info.VERSION_DET: 18 | self.linear1 = nn.Linear(config.hidden_size, config.hidden_size) 19 | self.linear2 = nn.Linear(config.hidden_size, config.num_labels) 20 | elif info.VERSION_MODE == info.VERSION_STO: 21 | self.linear1 = LinearReparameterization(config.hidden_size, config.hidden_size) 22 | self.linear2 = LinearReparameterization(config.hidden_size, config.num_labels) 23 | 24 | def forward(self, batch_reps): 25 | 26 | if self.info.VERSION_MODE == self.info.VERSION_DET: 27 | batch_kls = 0 28 | 29 | batch_logits = self.dropout(batch_reps) 30 | batch_logits = self.linear1(batch_logits) 31 | batch_logits = self.activation(batch_logits) 32 | batch_logits = self.dropout(batch_reps) 33 | batch_logits = self.linear2(batch_logits) 34 | 35 | if self.info.NUM_CLASS[1] != 1: batch_logits = batch_logits.view(-1, self.info.NUM_CLASS[1]) 36 | batch_logits = batch_logits.unsqueeze(0) 37 | 38 | elif self.info.VERSION_MODE == self.info.VERSION_STO: 39 | batch_kls, batch_logits = 0, [] 40 | for _ in range(self.info.HP_NUM_SVI_MC): 41 | 42 | round_logits = self.dropout(batch_reps) 43 | round_logits, round_kls = self.linear1(round_logits) 44 | batch_kls += round_kls 45 | round_logits = self.activation(round_logits) 46 | round_logits = self.dropout(round_logits) 47 | round_logits, round_kls = self.linear2(round_logits) 48 | batch_kls += round_kls 49 | 50 | if self.info.NUM_CLASS[1] != 1: round_logits = round_logits.view(-1, self.info.NUM_CLASS[1]) 51 | batch_logits.append(round_logits) 52 | batch_kls /= (self.info.HP_NUM_SVI_MC * batch_reps.shape[0]) 53 | batch_logits = torch.stack(batch_logits) 54 | 55 | return batch_kls, batch_logits 56 | 57 | 58 | class Loss(nn.Module): 59 | 60 | def __init__(self, info): 61 | super(Loss, self).__init__() 62 | 63 | self.info = info 64 | 65 | def brier_loss(self, batch_logits, batch_labels): 66 | 67 | batch_labels = F.one_hot(batch_labels, self.info.NUM_CLASS[0] * self.info.NUM_CLASS[1]) 68 | batch_loss = (batch_logits.flatten(0,1) - batch_labels.tile((batch_logits.shape[0],1))).square().sum(-1).mean(0) 69 | return batch_loss 70 | 71 | def cross_entropy(self, batch_logits, batch_labels): 72 | 73 | batch_loss = F.cross_entropy(batch_logits.flatten(0,1), batch_labels.tile((batch_logits.shape[0],))) 74 | return batch_loss 75 | 76 | def focal_loss(self, batch_logits, batch_labels): 77 | 78 | batch_labels = batch_labels.tile((batch_logits.shape[0],)).unsqueeze(-1) 79 | batch_probs = F.softmax(batch_logits, dim=-1).flatten(0, 1).gather(1, batch_labels).flatten() 80 | batch_gammas = torch.where(batch_probs 0: 63 | bin_accs = accuracy[bin_in].mean() 64 | bin_confs = confidence[bin_in].mean() 65 | ece += np.abs(bin_accs - bin_confs) * bin_ratio 66 | 67 | return ece 68 | 69 | 70 | def selective_prediction(all_accs, all_confs): 71 | 72 | conf_accs = all_accs[np.argsort(all_confs)] 73 | conf_rpp = np.cumsum(conf_accs)[~conf_accs].sum() / (all_accs.shape[0]**2) 74 | 75 | return conf_rpp 76 | 77 | 78 | def out_detection(in_confs, out_confs): 79 | 80 | conf_far95 = (in_confs < np.percentile(out_confs, 95)).mean() 81 | 82 | return conf_far95 83 | 84 | 85 | def evaluate(args, info, type, all_probs, all_labels, in_confs=None): 86 | 87 | all_confs, all_preds = all_probs.max(-1) 88 | all_confs = all_confs.to(info.DEVICE_CPU).numpy() 89 | all_accs = (all_preds == all_labels).to(info.DEVICE_CPU).numpy() 90 | 91 | all_scores = {} 92 | all_scores[info.METRIC_ERR] = 1 - all_accs.mean() 93 | all_scores[info.METRIC_ECE] = get_ece(all_accs, all_confs) 94 | all_scores[info.METRIC_RPP] = selective_prediction(all_accs, all_confs) 95 | if type != info.TYPE_TEST_IN: all_scores[info.METRIC_FAR] = out_detection(in_confs, all_confs) 96 | 97 | return all_confs, all_scores 98 | 99 | 100 | def feed(info, type, method, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores): 101 | 102 | all_labels[type] = each_labels.to(info.DEVICE_CPU).numpy() 103 | all_probs[(type, method)].append(each_probs.to(info.DEVICE_CPU).numpy()) 104 | for metric, value in each_scores.items(): 105 | all_scores[(type, method, metric)].append(value) 106 | 107 | 108 | def main(): 109 | 110 | args = parse_args_test() 111 | info = Info(args) 112 | 113 | myprint('='*20, info.FILE_STDOUT) 114 | myprint(f'Start {args.stage}ing {args.model_name} for {args.task_id}', info.FILE_STDOUT) 115 | myprint('-'*20, info.FILE_STDOUT) 116 | 117 | set_seed(args.seed) 118 | inputs_list = load(args, info, args.stage) 119 | all_labels, all_probs, all_scores = {}, defaultdict(list), defaultdict(list) 120 | 121 | for model_id, model_file in enumerate(info.FILE_MODELS[info.VERSION_DET]): 122 | myprint(f'Load {info.VERSION_DET} Model {model_id}', info.FILE_STDOUT) 123 | model = pk.load(open(model_file, 'rb')).to(info.DEVICE_GPU) 124 | myprint('-'*20, info.FILE_STDOUT) 125 | 126 | myprint(f'Calculate Temperature for {info.VERSION_DET} Model {model_id}', info.FILE_STDOUT) 127 | dev_logits, dev_labels = process(info, info.METHOD_VANILLA, inputs_list[0], model, if_eval=True, num_mc=1) 128 | temperature = recalibrate(info, dev_logits, dev_labels) 129 | myprint(f'Temperature = {temperature:.4f}', info.FILE_STDOUT) 130 | myprint('-'*20, info.FILE_STDOUT) 131 | 132 | in_confs = None; vanilla_logits = {} 133 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]): 134 | myprint(f'Uncertainty for {type} Data via {info.VERSION_DET} Model {model_id} and {info.METHOD_VANILLA}', info.FILE_STDOUT) 135 | each_logits, each_labels = process(info, info.METHOD_VANILLA, inputs, model, if_eval=True, num_mc=1) 136 | each_probs = F.softmax(each_logits, dim=-1); vanilla_logits[type] = each_logits 137 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs) 138 | feed(info, type, info.METHOD_VANILLA, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores) 139 | if type == info.TYPE_TEST_IN: in_confs = each_confs 140 | myprint('-'*20, info.FILE_STDOUT) 141 | 142 | in_confs = None 143 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]): 144 | myprint(f'Uncertainty for {type} Data via {info.VERSION_DET} Model {model_id} and {info.METHOD_TEMP_SCALING}', info.FILE_STDOUT) 145 | each_probs = F.softmax(vanilla_logits[type] / temperature, dim=-1) 146 | each_labels = torch.from_numpy(all_labels[type]).long().to(info.DEVICE_GPU) 147 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs) 148 | feed(info, type, info.METHOD_TEMP_SCALING, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores) 149 | if type == info.TYPE_TEST_IN: in_confs = each_confs 150 | myprint('-'*20, info.FILE_STDOUT) 151 | 152 | in_confs = None 153 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]): 154 | myprint(f'Uncertainty for {type} Data via {info.VERSION_DET} Model {model_id} and {info.METHOD_MC_DROPOUT}', info.FILE_STDOUT) 155 | each_probs, each_labels = process(info, info.METHOD_MC_DROPOUT, inputs, model, if_eval=False, num_mc=info.HP_NUM_DROPOUT_MC) 156 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs) 157 | feed(info, type, info.METHOD_MC_DROPOUT, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores) 158 | if type == info.TYPE_TEST_IN: in_confs = each_confs 159 | myprint('-'*20, info.FILE_STDOUT) 160 | 161 | for ensemble_id, model_ids in enumerate(combinations(np.arange(len(info.FILE_MODELS[info.VERSION_DET])), info.HP_NUM_ENSEMBLE)): 162 | in_confs = None 163 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]): 164 | myprint(f'Uncertainty for {type} Data via Ensemble {ensemble_id}', info.FILE_STDOUT) 165 | each_probs = torch.from_numpy(np.mean([all_probs[(type, info.METHOD_VANILLA)][model_id] for model_id in model_ids], axis=0)).float().to(info.DEVICE_GPU) 166 | each_labels = torch.from_numpy(all_labels[type]).long().to(info.DEVICE_GPU) 167 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs) 168 | feed(info, type, info.METHOD_ENSEMBLE, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores) 169 | if type == info.TYPE_TEST_IN: in_confs = each_confs 170 | myprint('-'*20, info.FILE_STDOUT) 171 | 172 | for model_id, model_file in enumerate(info.FILE_MODELS[info.VERSION_STO]): 173 | myprint(f'Load {info.VERSION_STO} Model {model_id}', info.FILE_STDOUT) 174 | model = pk.load(open(model_file, 'rb')).to(info.DEVICE_GPU) 175 | myprint('-'*20, info.FILE_STDOUT) 176 | 177 | in_confs = None 178 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]): 179 | myprint(f'Uncertainty for {type} Data via {info.VERSION_STO} Model {model_id} and {info.METHOD_LL_SVI}', info.FILE_STDOUT) 180 | each_probs, each_labels = process(info, info.METHOD_LL_SVI, inputs, model, if_eval=True, num_mc=1) 181 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs) 182 | feed(info, type, info.METHOD_LL_SVI, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores) 183 | if type == info.TYPE_TEST_IN: in_confs = each_confs 184 | myprint('-'*20, info.FILE_STDOUT) 185 | 186 | for type in info.TYPE_TESTS: 187 | for method in info.METHODS: 188 | myprint(f'Data Type: {type} & Uncertainty Method: {method}', info.FILE_STDOUT) 189 | for metrics in info.METRICS: 190 | result = [] 191 | for metric in metrics: 192 | if (type, method, metric) not in all_scores: continue 193 | scores = all_scores[(type, method, metric)] 194 | mean, sem = np.mean(scores), stats.sem(scores) 195 | result.append(f'{metric}: {mean:.4f}±{sem:.4f}') 196 | if len(result) != 0: myprint(' | '.join(result), info.FILE_STDOUT) 197 | myprint('-'*20, info.FILE_STDOUT) 198 | 199 | pk.dump(all_scores, open(info.FILE_SCORE, 'wb'), -1) 200 | pk.dump((all_labels, all_probs), open(info.FILE_PROB, 'wb'), -1) 201 | 202 | myprint(f'Finish {args.stage}ing {args.model_name} for {args.task_id}', info.FILE_STDOUT) 203 | myprint('='*20, info.FILE_STDOUT) 204 | 205 | 206 | if __name__=='__main__': 207 | main() --------------------------------------------------------------------------------