├── Code
├── run.sh
├── train.py
├── util.py
├── info.py
├── model.py
└── test.py
├── LICENSE
└── README.md
/Code/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cuda=0
4 | model_name="bert_base-ce"
5 |
6 | for task_id in "Task1"
7 | do
8 | for seed in 0 1 2 3 4 5
9 | do
10 | CUDA_VISIBLE_DEVICES=${cuda} python3 train.py --task_id=${task_id} --model_name=${model_name} --version="det" --seed=${seed}
11 | done
12 | for seed in 0 1 2 3 4 5
13 | do
14 | CUDA_VISIBLE_DEVICES=${cuda} python3 train.py --task_id=${task_id} --model_name=${model_name} --version="sto" --seed=${seed}
15 | done
16 | CUDA_VISIBLE_DEVICES=${cuda} python3 test.py --task_id=${task_id} --model_name=${model_name}
17 | done
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 xiaoyuxin1002
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Code/train.py:
--------------------------------------------------------------------------------
1 | import math
2 | import dill as pk
3 |
4 | import torch
5 | from torch.nn.utils import clip_grad_norm_
6 |
7 | from info import Info
8 | from util import parse_args_train, set_seed, myprint, load, prepare, iter_batch
9 |
10 |
11 | def test(info, idx_epoch, inputs_dev, model):
12 |
13 | model.eval()
14 | with torch.no_grad():
15 |
16 | all_preds, all_labels = [], []
17 | for idx_batch, (batch_inputs, batch_labels) in enumerate(iter_batch(info, inputs_dev, if_shuffle=False)):
18 |
19 | batch_probs = model.infer(batch_inputs)
20 | all_preds.append(batch_probs.argmax(dim=1))
21 | all_labels.append(batch_labels)
22 |
23 | all_accuracy = (torch.cat(all_preds) == torch.cat(all_labels)).float().mean().item()
24 | myprint(f'Finish Testing Epoch {idx_epoch} | Accuracy {all_accuracy:.4f}', info.FILE_STDOUT)
25 |
26 | return all_accuracy
27 |
28 |
29 | def train(info, idx_epoch, inputs_train, model, optimizer, scheduler):
30 |
31 | model.train()
32 | num_batch = math.ceil(inputs_train[1].shape[0] / info.HP_BATCH_SIZE)
33 | report_batch = num_batch // 5
34 | for idx_batch, (batch_inputs, batch_labels) in enumerate(iter_batch(info, inputs_train, if_shuffle=True)):
35 |
36 | batch_loss = model.learn(batch_inputs, batch_labels)
37 | optimizer.zero_grad()
38 | batch_loss.backward()
39 | clip_grad_norm_(model.parameters(), info.HP_MAX_GRAD_NORM)
40 | optimizer.step()
41 | scheduler.step()
42 |
43 | if idx_batch % report_batch == 0:
44 | myprint(f'Finish Training Epoch {idx_epoch} | Batch {idx_batch} | Loss {batch_loss.item():.4f}', info.FILE_STDOUT)
45 | myprint('-'*20, info.FILE_STDOUT)
46 |
47 |
48 | def main():
49 |
50 | args = parse_args_train()
51 | info = Info(args)
52 |
53 | myprint('='*20, info.FILE_STDOUT)
54 | myprint(f'Start {args.stage}ing {args.version}-{args.model_name} for {args.task_id}', info.FILE_STDOUT)
55 | myprint('-'*20, info.FILE_STDOUT)
56 |
57 | set_seed(args.seed)
58 | inputs_train, inputs_dev = load(args, info, args.stage)
59 | model, optimizer, scheduler = prepare(info, inputs_train)
60 |
61 | best_accuracy, best_epoch = 0, 0
62 | for idx_epoch in range(info.HP_NUM_EPOCH):
63 |
64 | train(info, idx_epoch, inputs_train, model, optimizer, scheduler)
65 | epoch_accuracy = test(info, idx_epoch, inputs_dev, model)
66 |
67 | if epoch_accuracy >= best_accuracy:
68 | best_accuracy, best_epoch = epoch_accuracy, idx_epoch
69 | pk.dump(model, open(info.FILE_MODEL, 'wb'), -1)
70 | myprint(f'This is the Best Performing Epoch by far - Epoch {idx_epoch} Accuracy {epoch_accuracy:.4f}', info.FILE_STDOUT)
71 | else:
72 | myprint(f'Not the Best Performing Epoch by far - Epoch {idx_epoch} Accuracy {epoch_accuracy:.4f} vs Best Accuracy {best_accuracy:.4f}', info.FILE_STDOUT)
73 | myprint('-'*20, info.FILE_STDOUT)
74 |
75 | myprint(f'Finish {args.stage}ing {args.version}-{args.model_name} for {args.task_id}', info.FILE_STDOUT)
76 | myprint('='*20, info.FILE_STDOUT)
77 |
78 |
79 | if __name__=='__main__':
80 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UQ-PLM
2 |
3 | Code for Uncertainty Quantification with Pre-trained Language Models: An Empirical Analysis (EMNLP 2022 Findings).
4 |
5 | ## Requirements
6 |
7 | ```
8 | PyTorch = 1.10.1
9 | Bayesian-Torch = 0.1
10 | HuggingFace Transformers = 4.11.1
11 | ```
12 |
13 | ## Data
14 |
15 | Our empirical analysis consists of the following three NLP (natural language processing) classification tasks:
16 |
17 | **task_id** | Task | In-Domain Dataset | Out-Of-Domain Dataset
18 | --- | --- | --- | ---
19 | **Task1** | Sentiment Analysis | IMDb | Yelp
20 | **Task2** | Natural Language Inference | MNLI | SNLI
21 | **Task3** | Commonsense Reasoning | SWAG | HellaSWAG
22 |
23 | You can download our input data here and unzip it to the current directory.
24 |
25 | Then the corresponding data splits of each task are stored in **Data/{task_id}/Original**:
26 | - **train.pkl**, **dev.pkl**, and **test_in.pkl** come from the in-domain dataset.
27 | - **test_out.pkl** comes from the out-of-domain dataset.
28 |
29 | ## Run
30 |
31 | Specify the targeting `model_name` and `task_id` in **Code/run.sh**:
32 | - `model_name` is specified in the format of `{PLM}_{size}-{loss}`.
33 | - `{PLM}` (Pre-trained Language Model) can be chosen from `bert`, `xlnet`, `electra`, `roberta`, and `deberta`.
34 | - `{size}` can be chosen from `base` and `large`.
35 | - `{loss}` can be chosen from `be` (Brier loss), `fl` (focal loss), `ce` (cross-entropy), `ls` (label smoothing), and `mm` (max mean calibration error).
36 | - `task_id` can be chosen from `Task1` (Sentiment Analysis), `Task2` (Natural Language Inference), and `Task3` (Commonsense Reasoning).
37 |
38 | Other hyperparameters are defined in **Code/info.py** (e.g., learning rate, batch size, and training epoch).
39 |
40 | Use the command `bash Code/run.sh` to run one sweep of experiments:
41 | 1. Transform the original data input in **Data/{task_id}/Original** to the model-specific data input in **Data/{task_id}/{model_name}**.
42 | 1. Train six deterministic (version=`det`) PLM-based pipelines (used for `Vanilla`, `Temp Scaling` (temperature scaling), `MC Dropout` (monte-carlo dropout), and `Ensemble`) stored in **Result/{task_id}/{model_name}**.
43 | 1. Train six stochastic (version=`sto`) PLM-based pipelines (used for `LL SVI` (last-layer stochastic variational inference)) stored in **Result/{task_id}/{model_name}**.
44 | 1. Test the above pipelines with five kinds of uncertainty quantifiers (`Vanilla`, `Temp Scaling`, `MC Dropout`, `Ensemble`, and `LL SVI`) under two domain settings (`test_in` and `test_out`) based on four metrics (`ERR` (prediction error), `ECE` (expected calibration error), `RPP` (reversed pair proportion), and `FAR95` (false alarm rate at 95% recall)).
45 | 1. The evaluation of each (uncertainty quantifier, domain setting, metric) combination consists of six trials, and the results are stored in **Result/{task_id}/{model_name}/result_score.pkl**.
46 | 1. The ground truth labels and raw probability outputs are stored in **Result/{task_id}/{model_name}/result_prob.pkl**.
47 | 1. All the training and testing stdouts are stored in **Result/{task_id}/{model_name}/**.
48 |
49 | ## Result
50 |
51 | We store our empirical observations in **results.pkl**. You can download this dictionary here.
52 | - The key is in the format of `({task}, {model}, {quantifier}, {domain}, {metric})`.
53 | - `{task}` can be chosen from `Sentiment Analysis`, `Natural Language Inference`, and `Commonsense Reasoning`.
54 | - `{model}` can be chosen from `bert_base-br`, `bert_base-ce`, `bert_base-fl`, `bert_base-ls`, `bert_base-mm`, `bert_large-ce`, `deberta_base-ce`, `deberta_large-ce`, `electra_base-ce`, `electra_large-ce`, `roberta_base-ce`, `roberta_large-ce`, `xlnet_base-ce`, and `xlnet_large-ce`.
55 | - `{quantifier}` can be chosen from `Vanilla`, `Temp Scaling`, `MC Dropout`, `Ensemble`, and `LL SVI`.
56 | - `{domain}` can be chosen from `test_in` and `test_out`.
57 | - `{metric}` can be chosen from `ERR`, `ECE`, `RPP`, and `FAR95`. Note that `FAR95` only works with the domain setting of `test_out`.
58 | - The value is in the format of `(mean, standard error)`, which are calculated based on six trials with different seeds.
59 |
60 | ## Citation
61 |
62 | ```
63 | @inproceedings{xiao2022uncertainty,
64 | title={Uncertainty Quantification with Pre-trained Language Models: An Empirical Analysis},
65 | author={Xiao, Yuxin and Liang, Paul Pu and Bhatt, Umang and Neiswanger, Willie and Salakhutdinov, Ruslan and Morency, Louis-Philippe},
66 | booktitle={Findings of EMNLP},
67 | year={2022}
68 | }
69 | ```
70 |
--------------------------------------------------------------------------------
/Code/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import random
5 | import argparse
6 | import dill as pk
7 | import numpy as np
8 |
9 | import torch
10 | from transformers import AutoTokenizer, AutoConfig, AutoModel
11 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup
12 |
13 | from model import Model
14 |
15 |
16 | def parse_args_train():
17 |
18 | parser = argparse.ArgumentParser()
19 |
20 | parser.add_argument('--task_id', type=str, default='Task1')
21 | parser.add_argument('--model_name', type=str, default='bert_base-model_ce')
22 | parser.add_argument('--stage', type=str, default='train')
23 | parser.add_argument('--version', type=str, default='det')
24 | parser.add_argument('--seed', type=int, default=0)
25 |
26 | args = parser.parse_args()
27 |
28 | return args
29 |
30 |
31 | def parse_args_test():
32 |
33 | parser = argparse.ArgumentParser()
34 |
35 | parser.add_argument('--task_id', type=str, default='Task1')
36 | parser.add_argument('--model_name', type=str, default='bert_base-model_ce')
37 | parser.add_argument('--stage', type=str, default='test')
38 | parser.add_argument('--seed', type=int, default=0)
39 |
40 | args = parser.parse_args()
41 |
42 | return args
43 |
44 |
45 | def set_seed(seed):
46 |
47 | random.seed(seed)
48 | np.random.seed(seed)
49 | torch.manual_seed(seed)
50 | torch.cuda.manual_seed(seed)
51 |
52 |
53 | def myprint(text, file):
54 |
55 | file = open(file, 'a')
56 | print(time.strftime("%Y %b %d %a, %H:%M:%S: ", time.localtime()) + text, file=file, flush=True)
57 | file.close()
58 |
59 |
60 | def load(args, info, stage):
61 |
62 | if stage == info.STAGE_TRAIN:
63 | name_list = [info.TYPE_TRAIN, info.TYPE_DEV]
64 | ori_list = [info.FILE_ORI_TRAIN, info.FILE_ORI_DEV]
65 | input_list = [info.FILE_INPUT_TRAIN, info.FILE_INPUT_DEV]
66 | elif stage == info.STAGE_TEST:
67 | name_list = [info.TYPE_DEV, info.TYPE_TEST_IN, info.TYPE_TEST_OUT]
68 | ori_list = [info.FILE_ORI_DEV, info.FILE_ORI_TEST_IN, info.FILE_ORI_TEST_OUT]
69 | input_list = [info.FILE_INPUT_DEV, info.FILE_INPUT_TEST_IN, info.FILE_INPUT_TEST_OUT]
70 |
71 | data_list = []
72 | tokenizer = AutoTokenizer.from_pretrained(info.PLM_NAME)
73 | for name, ori_file, input_file in zip(name_list, ori_list, input_list):
74 | myprint(f'Load Data from {name}', info.FILE_STDOUT)
75 |
76 | if os.path.isfile(input_file):
77 | all_inputs, all_labels = pk.load(open(input_file, 'rb'))
78 |
79 | else:
80 | ori_data = pk.load(open(ori_file, 'rb'))
81 | text1s, text2s, labels = [], [], []
82 | for row in ori_data:
83 |
84 | if args.task_id in ['Task1']:
85 | text1s.append(row[0])
86 | labels.append(row[1])
87 |
88 | elif args.task_id in ['Task2']:
89 | text1s.append(row[0])
90 | text2s.append(row[1])
91 | labels.append(row[2])
92 |
93 | elif args.task_id in ['Task3']:
94 | text1s += [row[0]] * info.NUM_CLASS[1]
95 | text2s += row[1:1+info.NUM_CLASS[1]]
96 | labels.append(row[-1])
97 |
98 | if len(text2s) == 0: text2s = None
99 | all_inputs = tokenizer(text1s, text2s, return_tensors='pt', padding=True, truncation=True, max_length=512)
100 | all_labels = torch.Tensor(labels).long()
101 | pk.dump((all_inputs, all_labels), open(input_file, 'wb'), -1)
102 |
103 | data_list.append((all_inputs.to(info.DEVICE_GPU), all_labels.to(info.DEVICE_GPU)))
104 | myprint('-'*20, info.FILE_STDOUT)
105 |
106 | return data_list
107 |
108 |
109 | def prepare(info, inputs_train):
110 |
111 | config = AutoConfig.from_pretrained(info.PLM_NAME, num_labels=info.NUM_CLASS[0])
112 | transformer = AutoModel.from_pretrained(info.PLM_NAME)
113 |
114 | model = Model(info, config, transformer).to(info.DEVICE_GPU)
115 | parameters = list(model.parameters())
116 | optimizer = AdamW(parameters, lr=info.HP_LR)
117 |
118 | num_updates = math.ceil(inputs_train[1].shape[0] / info.HP_BATCH_SIZE) * info.HP_NUM_EPOCH
119 | num_warmups = int(num_updates * info.HP_WARMUP_RATIO)
120 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmups, num_training_steps=num_updates)
121 |
122 | return model, optimizer, scheduler
123 |
124 |
125 | def iter_batch(info, inputs, if_shuffle=False):
126 |
127 | all_inputs, all_labels = inputs
128 | batch_seq = np.arange(all_labels.shape[0])
129 | if if_shuffle: np.random.shuffle(batch_seq)
130 | num_batch = math.ceil(batch_seq.shape[0] / info.HP_BATCH_SIZE)
131 |
132 | for idx_batch in range(num_batch):
133 | batch_indices = batch_seq[idx_batch*info.HP_BATCH_SIZE : (idx_batch+1)*info.HP_BATCH_SIZE]
134 |
135 | batch_labels = all_labels[batch_indices]
136 | if info.NUM_CLASS[1] != 1: batch_indices = np.repeat(batch_indices*info.NUM_CLASS[1], info.NUM_CLASS[1]) + np.tile(np.arange(info.NUM_CLASS[1]), batch_indices.shape[0])
137 | batch_inputs = {k:v[batch_indices] for k,v in all_inputs.items()}
138 |
139 | yield batch_inputs, batch_labels
--------------------------------------------------------------------------------
/Code/info.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from collections import defaultdict
4 |
5 |
6 | class Info:
7 |
8 | def __init__(self, args):
9 |
10 | self.metadata()
11 | self.individual(args)
12 |
13 |
14 | def metadata(self):
15 |
16 | self.DEVICE_CPU = 'cpu'
17 | self.DEVICE_GPU = 'cuda'
18 |
19 | self.STAGE_TRAIN = 'train'
20 | self.STAGE_TEST = 'test'
21 |
22 | self.VERSION_DET = 'det'
23 | self.VERSION_STO = 'sto'
24 |
25 | self.LOSS_BR = 'br' # Brier Loss
26 | self.LOSS_CE = 'ce' # Cross Entropy
27 | self.LOSS_FL = 'fl' # Focal Loss
28 | self.LOSS_LS = 'ls' # Label Smoothing
29 | self.LOSS_MM = 'mm' # Max Mean Calibration Error
30 |
31 | self.METHOD_VANILLA = 'Vanilla'
32 | self.METHOD_TEMP_SCALING = 'Temp Scaling'
33 | self.METHOD_MC_DROPOUT = 'MC Dropout'
34 | self.METHOD_ENSEMBLE = 'Ensemble'
35 | self.METHOD_LL_SVI = 'LL SVI'
36 | self.METHODS = [self.METHOD_VANILLA, self.METHOD_TEMP_SCALING, self.METHOD_MC_DROPOUT, self.METHOD_ENSEMBLE, self.METHOD_LL_SVI]
37 |
38 | self.METRIC_ERR = 'ERR'
39 | self.METRIC_ECE = 'ECE'
40 | self.METRICS_EXPLICIT = [self.METRIC_ERR, self.METRIC_ECE]
41 | self.METRIC_RPP = 'RPP'
42 | self.METRIC_FAR = 'FAR95'
43 | self.METRICS_IMPLICIT = [self.METRIC_RPP, self.METRIC_FAR]
44 | self.METRICS = [self.METRICS_EXPLICIT, self.METRICS_IMPLICIT]
45 |
46 | self.TYPE_TRAIN = 'train'
47 | self.TYPE_DEV = 'dev'
48 | self.TYPE_TEST_IN = 'test_in'
49 | self.TYPE_TEST_OUT = 'test_out'
50 | self.TYPE_TESTS = [self.TYPE_TEST_IN, self.TYPE_TEST_OUT]
51 |
52 | self.TASK2NCLASS = {'Task1':(2,1), 'Task2':(3,1), 'Task3':(1,4)}
53 |
54 | self.MODEL2NAME = {'bert_base':'bert-base-cased', 'bert_large':'bert-large-cased',
55 | 'xlnet_base':'xlnet-base-cased', 'xlnet_large':'xlnet-large-cased',
56 | 'electra_base':'google/electra-base-discriminator', 'electra_large':'google/electra-large-discriminator',
57 | 'roberta_base':'roberta-base', 'roberta_large':'roberta-large',
58 | 'deberta_base':'microsoft/deberta-base', 'deberta_large':'microsoft/deberta-large'}
59 |
60 | self.HP_LOSS_FL = (0.2, 5, 3)
61 | self.HP_LOSS_LS = 0.1
62 | self.HP_LOSS_MM = 1
63 |
64 | self.HP_NUM_DROPOUT_MC = 10
65 | self.HP_NUM_ENSEMBLE = 5
66 | self.HP_NUM_SVI_MC = 50
67 |
68 | self.HP_WARMUP_RATIO = 0.1
69 | self.HP_MAX_GRAD_NORM = 1.0
70 | self.HP_BATCH_SIZE = 16
71 | self.HP_NUM_EPOCH = 5
72 | self.HP_MODEL2LR = {'bert_base':2e-5, 'xlnet_base':2e-5, 'electra_base':2e-5, 'roberta_base':2e-5, 'deberta_base':2e-5,
73 | 'bert_large':5e-6, 'xlnet_large':5e-6, 'electra_large':5e-6, 'roberta_large':5e-6, 'deberta_large':5e-6}
74 |
75 | self.DIR_CURR = os.getcwd()
76 | self.DIR_DATA = os.path.join(self.DIR_CURR, '../Data')
77 | self.DIR_RESULT = os.path.join(self.DIR_CURR, '../Result')
78 |
79 |
80 | def individual(self, args):
81 |
82 | self.DIR_ORI = os.path.join(self.DIR_DATA, args.task_id, 'Original')
83 | self.FILE_ORI_TRAIN = os.path.join(self.DIR_ORI, f'{self.TYPE_TRAIN}.pkl')
84 | self.FILE_ORI_DEV = os.path.join(self.DIR_ORI, f'{self.TYPE_DEV}.pkl')
85 | self.FILE_ORI_TEST_IN = os.path.join(self.DIR_ORI, f'{self.TYPE_TEST_IN}.pkl')
86 | self.FILE_ORI_TEST_OUT = os.path.join(self.DIR_ORI, f'{self.TYPE_TEST_OUT}.pkl')
87 |
88 | self.DIR_INPUT = os.path.join(self.DIR_DATA, args.task_id, args.model_name)
89 | Path(self.DIR_INPUT).mkdir(parents=True, exist_ok=True)
90 | self.FILE_INPUT_TRAIN = os.path.join(self.DIR_INPUT, f'{self.TYPE_TRAIN}.pkl')
91 | self.FILE_INPUT_DEV = os.path.join(self.DIR_INPUT, f'{self.TYPE_DEV}.pkl')
92 | self.FILE_INPUT_TEST_IN = os.path.join(self.DIR_INPUT, f'{self.TYPE_TEST_IN}.pkl')
93 | self.FILE_INPUT_TEST_OUT = os.path.join(self.DIR_INPUT, f'{self.TYPE_TEST_OUT}.pkl')
94 |
95 | self.DIR_OUTPUT = os.path.join(self.DIR_RESULT, args.task_id, args.model_name)
96 | Path(self.DIR_OUTPUT).mkdir(parents=True, exist_ok=True)
97 |
98 | if args.stage == self.STAGE_TRAIN:
99 | self.FILE_STDOUT = os.path.join(self.DIR_OUTPUT, f'stdout_{args.stage}_{args.version}_{args.seed}.txt')
100 | self.FILE_MODEL = os.path.join(self.DIR_OUTPUT, f'model_{args.version}_{args.seed}.pkl')
101 | self.VERSION_MODE = args.version
102 | elif args.stage == self.STAGE_TEST:
103 | self.FILE_STDOUT = os.path.join(self.DIR_OUTPUT, f'stdout_{args.stage}.txt')
104 | self.FILE_SCORE = os.path.join(self.DIR_OUTPUT, f'result_score.pkl')
105 | self.FILE_PROB = os.path.join(self.DIR_OUTPUT, f'result_prob.pkl')
106 |
107 | self.FILE_MODELS = defaultdict(list)
108 | for file in os.listdir(self.DIR_OUTPUT):
109 | if file.startswith('model'): self.FILE_MODELS[file.split('_')[1]].append(os.path.join(self.DIR_OUTPUT, file))
110 |
111 | model_name, self.LOSS_MODE = args.model_name.split('-')
112 | self.PLM_NAME = self.MODEL2NAME[model_name]
113 | self.HP_LR = self.HP_MODEL2LR[model_name]
114 | self.NUM_CLASS = self.TASK2NCLASS[args.task_id]
--------------------------------------------------------------------------------
/Code/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from bayesian_torch.layers import LinearReparameterization
5 |
6 |
7 | class Classifier(nn.Module):
8 |
9 | def __init__(self, info, config):
10 | super(Classifier, self).__init__()
11 |
12 | self.info = info
13 |
14 | self.activation = nn.Tanh()
15 | self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else config.hidden_dropout_prob)
16 |
17 | if info.VERSION_MODE == info.VERSION_DET:
18 | self.linear1 = nn.Linear(config.hidden_size, config.hidden_size)
19 | self.linear2 = nn.Linear(config.hidden_size, config.num_labels)
20 | elif info.VERSION_MODE == info.VERSION_STO:
21 | self.linear1 = LinearReparameterization(config.hidden_size, config.hidden_size)
22 | self.linear2 = LinearReparameterization(config.hidden_size, config.num_labels)
23 |
24 | def forward(self, batch_reps):
25 |
26 | if self.info.VERSION_MODE == self.info.VERSION_DET:
27 | batch_kls = 0
28 |
29 | batch_logits = self.dropout(batch_reps)
30 | batch_logits = self.linear1(batch_logits)
31 | batch_logits = self.activation(batch_logits)
32 | batch_logits = self.dropout(batch_reps)
33 | batch_logits = self.linear2(batch_logits)
34 |
35 | if self.info.NUM_CLASS[1] != 1: batch_logits = batch_logits.view(-1, self.info.NUM_CLASS[1])
36 | batch_logits = batch_logits.unsqueeze(0)
37 |
38 | elif self.info.VERSION_MODE == self.info.VERSION_STO:
39 | batch_kls, batch_logits = 0, []
40 | for _ in range(self.info.HP_NUM_SVI_MC):
41 |
42 | round_logits = self.dropout(batch_reps)
43 | round_logits, round_kls = self.linear1(round_logits)
44 | batch_kls += round_kls
45 | round_logits = self.activation(round_logits)
46 | round_logits = self.dropout(round_logits)
47 | round_logits, round_kls = self.linear2(round_logits)
48 | batch_kls += round_kls
49 |
50 | if self.info.NUM_CLASS[1] != 1: round_logits = round_logits.view(-1, self.info.NUM_CLASS[1])
51 | batch_logits.append(round_logits)
52 | batch_kls /= (self.info.HP_NUM_SVI_MC * batch_reps.shape[0])
53 | batch_logits = torch.stack(batch_logits)
54 |
55 | return batch_kls, batch_logits
56 |
57 |
58 | class Loss(nn.Module):
59 |
60 | def __init__(self, info):
61 | super(Loss, self).__init__()
62 |
63 | self.info = info
64 |
65 | def brier_loss(self, batch_logits, batch_labels):
66 |
67 | batch_labels = F.one_hot(batch_labels, self.info.NUM_CLASS[0] * self.info.NUM_CLASS[1])
68 | batch_loss = (batch_logits.flatten(0,1) - batch_labels.tile((batch_logits.shape[0],1))).square().sum(-1).mean(0)
69 | return batch_loss
70 |
71 | def cross_entropy(self, batch_logits, batch_labels):
72 |
73 | batch_loss = F.cross_entropy(batch_logits.flatten(0,1), batch_labels.tile((batch_logits.shape[0],)))
74 | return batch_loss
75 |
76 | def focal_loss(self, batch_logits, batch_labels):
77 |
78 | batch_labels = batch_labels.tile((batch_logits.shape[0],)).unsqueeze(-1)
79 | batch_probs = F.softmax(batch_logits, dim=-1).flatten(0, 1).gather(1, batch_labels).flatten()
80 | batch_gammas = torch.where(batch_probs 0:
63 | bin_accs = accuracy[bin_in].mean()
64 | bin_confs = confidence[bin_in].mean()
65 | ece += np.abs(bin_accs - bin_confs) * bin_ratio
66 |
67 | return ece
68 |
69 |
70 | def selective_prediction(all_accs, all_confs):
71 |
72 | conf_accs = all_accs[np.argsort(all_confs)]
73 | conf_rpp = np.cumsum(conf_accs)[~conf_accs].sum() / (all_accs.shape[0]**2)
74 |
75 | return conf_rpp
76 |
77 |
78 | def out_detection(in_confs, out_confs):
79 |
80 | conf_far95 = (in_confs < np.percentile(out_confs, 95)).mean()
81 |
82 | return conf_far95
83 |
84 |
85 | def evaluate(args, info, type, all_probs, all_labels, in_confs=None):
86 |
87 | all_confs, all_preds = all_probs.max(-1)
88 | all_confs = all_confs.to(info.DEVICE_CPU).numpy()
89 | all_accs = (all_preds == all_labels).to(info.DEVICE_CPU).numpy()
90 |
91 | all_scores = {}
92 | all_scores[info.METRIC_ERR] = 1 - all_accs.mean()
93 | all_scores[info.METRIC_ECE] = get_ece(all_accs, all_confs)
94 | all_scores[info.METRIC_RPP] = selective_prediction(all_accs, all_confs)
95 | if type != info.TYPE_TEST_IN: all_scores[info.METRIC_FAR] = out_detection(in_confs, all_confs)
96 |
97 | return all_confs, all_scores
98 |
99 |
100 | def feed(info, type, method, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores):
101 |
102 | all_labels[type] = each_labels.to(info.DEVICE_CPU).numpy()
103 | all_probs[(type, method)].append(each_probs.to(info.DEVICE_CPU).numpy())
104 | for metric, value in each_scores.items():
105 | all_scores[(type, method, metric)].append(value)
106 |
107 |
108 | def main():
109 |
110 | args = parse_args_test()
111 | info = Info(args)
112 |
113 | myprint('='*20, info.FILE_STDOUT)
114 | myprint(f'Start {args.stage}ing {args.model_name} for {args.task_id}', info.FILE_STDOUT)
115 | myprint('-'*20, info.FILE_STDOUT)
116 |
117 | set_seed(args.seed)
118 | inputs_list = load(args, info, args.stage)
119 | all_labels, all_probs, all_scores = {}, defaultdict(list), defaultdict(list)
120 |
121 | for model_id, model_file in enumerate(info.FILE_MODELS[info.VERSION_DET]):
122 | myprint(f'Load {info.VERSION_DET} Model {model_id}', info.FILE_STDOUT)
123 | model = pk.load(open(model_file, 'rb')).to(info.DEVICE_GPU)
124 | myprint('-'*20, info.FILE_STDOUT)
125 |
126 | myprint(f'Calculate Temperature for {info.VERSION_DET} Model {model_id}', info.FILE_STDOUT)
127 | dev_logits, dev_labels = process(info, info.METHOD_VANILLA, inputs_list[0], model, if_eval=True, num_mc=1)
128 | temperature = recalibrate(info, dev_logits, dev_labels)
129 | myprint(f'Temperature = {temperature:.4f}', info.FILE_STDOUT)
130 | myprint('-'*20, info.FILE_STDOUT)
131 |
132 | in_confs = None; vanilla_logits = {}
133 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]):
134 | myprint(f'Uncertainty for {type} Data via {info.VERSION_DET} Model {model_id} and {info.METHOD_VANILLA}', info.FILE_STDOUT)
135 | each_logits, each_labels = process(info, info.METHOD_VANILLA, inputs, model, if_eval=True, num_mc=1)
136 | each_probs = F.softmax(each_logits, dim=-1); vanilla_logits[type] = each_logits
137 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs)
138 | feed(info, type, info.METHOD_VANILLA, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores)
139 | if type == info.TYPE_TEST_IN: in_confs = each_confs
140 | myprint('-'*20, info.FILE_STDOUT)
141 |
142 | in_confs = None
143 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]):
144 | myprint(f'Uncertainty for {type} Data via {info.VERSION_DET} Model {model_id} and {info.METHOD_TEMP_SCALING}', info.FILE_STDOUT)
145 | each_probs = F.softmax(vanilla_logits[type] / temperature, dim=-1)
146 | each_labels = torch.from_numpy(all_labels[type]).long().to(info.DEVICE_GPU)
147 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs)
148 | feed(info, type, info.METHOD_TEMP_SCALING, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores)
149 | if type == info.TYPE_TEST_IN: in_confs = each_confs
150 | myprint('-'*20, info.FILE_STDOUT)
151 |
152 | in_confs = None
153 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]):
154 | myprint(f'Uncertainty for {type} Data via {info.VERSION_DET} Model {model_id} and {info.METHOD_MC_DROPOUT}', info.FILE_STDOUT)
155 | each_probs, each_labels = process(info, info.METHOD_MC_DROPOUT, inputs, model, if_eval=False, num_mc=info.HP_NUM_DROPOUT_MC)
156 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs)
157 | feed(info, type, info.METHOD_MC_DROPOUT, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores)
158 | if type == info.TYPE_TEST_IN: in_confs = each_confs
159 | myprint('-'*20, info.FILE_STDOUT)
160 |
161 | for ensemble_id, model_ids in enumerate(combinations(np.arange(len(info.FILE_MODELS[info.VERSION_DET])), info.HP_NUM_ENSEMBLE)):
162 | in_confs = None
163 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]):
164 | myprint(f'Uncertainty for {type} Data via Ensemble {ensemble_id}', info.FILE_STDOUT)
165 | each_probs = torch.from_numpy(np.mean([all_probs[(type, info.METHOD_VANILLA)][model_id] for model_id in model_ids], axis=0)).float().to(info.DEVICE_GPU)
166 | each_labels = torch.from_numpy(all_labels[type]).long().to(info.DEVICE_GPU)
167 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs)
168 | feed(info, type, info.METHOD_ENSEMBLE, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores)
169 | if type == info.TYPE_TEST_IN: in_confs = each_confs
170 | myprint('-'*20, info.FILE_STDOUT)
171 |
172 | for model_id, model_file in enumerate(info.FILE_MODELS[info.VERSION_STO]):
173 | myprint(f'Load {info.VERSION_STO} Model {model_id}', info.FILE_STDOUT)
174 | model = pk.load(open(model_file, 'rb')).to(info.DEVICE_GPU)
175 | myprint('-'*20, info.FILE_STDOUT)
176 |
177 | in_confs = None
178 | for type, inputs in zip(info.TYPE_TESTS, inputs_list[1:]):
179 | myprint(f'Uncertainty for {type} Data via {info.VERSION_STO} Model {model_id} and {info.METHOD_LL_SVI}', info.FILE_STDOUT)
180 | each_probs, each_labels = process(info, info.METHOD_LL_SVI, inputs, model, if_eval=True, num_mc=1)
181 | each_confs, each_scores = evaluate(args, info, type, each_probs, each_labels, in_confs=in_confs)
182 | feed(info, type, info.METHOD_LL_SVI, all_labels, all_probs, all_scores, each_labels, each_probs, each_scores)
183 | if type == info.TYPE_TEST_IN: in_confs = each_confs
184 | myprint('-'*20, info.FILE_STDOUT)
185 |
186 | for type in info.TYPE_TESTS:
187 | for method in info.METHODS:
188 | myprint(f'Data Type: {type} & Uncertainty Method: {method}', info.FILE_STDOUT)
189 | for metrics in info.METRICS:
190 | result = []
191 | for metric in metrics:
192 | if (type, method, metric) not in all_scores: continue
193 | scores = all_scores[(type, method, metric)]
194 | mean, sem = np.mean(scores), stats.sem(scores)
195 | result.append(f'{metric}: {mean:.4f}±{sem:.4f}')
196 | if len(result) != 0: myprint(' | '.join(result), info.FILE_STDOUT)
197 | myprint('-'*20, info.FILE_STDOUT)
198 |
199 | pk.dump(all_scores, open(info.FILE_SCORE, 'wb'), -1)
200 | pk.dump((all_labels, all_probs), open(info.FILE_PROB, 'wb'), -1)
201 |
202 | myprint(f'Finish {args.stage}ing {args.model_name} for {args.task_id}', info.FILE_STDOUT)
203 | myprint('='*20, info.FILE_STDOUT)
204 |
205 |
206 | if __name__=='__main__':
207 | main()
--------------------------------------------------------------------------------