├── LICENSE ├── README.md ├── fine_tune.py ├── msp_eval.py ├── perplexity.py ├── requirements.txt ├── roberta_fine_tune.py ├── run_language_modeling.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Udit Arora 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ood-text-emnlp 2 | Code for EMNLP'21 paper "Types of Out-of-Distribution Texts and How to Detect Them". 3 | 4 | Paper: 5 | 6 | - Arxiv: https://arxiv.org/abs/2109.06827 7 | - Official: https://aclanthology.org/2021.emnlp-main.835/ 8 | 9 | ## Files 10 | - `fine_tune.py` is used to finetune the GPT-2 models, and `roberta_fine_tune.py` is used to finetune the Roberta models. 11 | - `perplexity.py` and `msp_eval.py` are used to find the PPLs and MSPs of a dataset pair's examples using the finetuned model. 12 | 13 | ## How to run 14 | These steps show how to train both density estimation and calibration models on the MNLI dataset, and evaluated against SNLI. 15 | 16 | A differet dataset pair can be used by updating the approriate `dataset_name` or `id_data`/`ood_data` values as shown below: 17 | 18 | ### Training the Density Estimation Model (GPT-2) 19 | Two options: 20 | 1. Using HF Datasets - 21 | ``` 22 | python fine_tune.py --dataset_name glue --dataset_config_name mnli --key premise --key2 hypothesis 23 | ``` 24 | This also generates a txt train file corresponding to the dataset's text. 25 | 2. Using previously generated txt file - 26 | ``` 27 | python fine_tune.py --train_file data/glue_mnli_train.txt --fname glue_mnli" 28 | ``` 29 | 30 | ### Finding Perplexity (PPL) 31 | This uses the txt files generated after running `fine_tune.py` to find the perplexity of the ID model on both ID and OOD validation sets - 32 | ``` 33 | id_data="glue_mnli" 34 | ood_data="snli" 35 | python perplexity.py --model_path ckpts/gpt2-$id_data/ --dataset_path data/${ood_data}_val.txt --fname ${id_data}_$ood_data 36 | 37 | python perplexity.py --model_path ckpts/gpt2-$id_data/ --dataset_path data/${id_data}_val.txt --fname ${id_data}_$id_data 38 | ``` 39 | 40 | ### Training the Calibration Model (RoBERTa) 41 | Two options: 42 | 1. Using HF Datasets - 43 | ``` 44 | id_data="mnli" 45 | python roberta_fine_tune.py --task_name $id_data --output_dir /scratch/ua388/roberta_ckpts/roberta-$id_data --fname ${id_data}_$id_data 46 | ``` 47 | 48 | 2. Using txt file generated earlier - 49 | ``` 50 | id_data="mnli" 51 | python roberta_fine_tune.py --train_file data/mnli/${id_data}_conditional_train.txt --val_file data/mnli/${id_data}_val.txt --output_dir roberta_ckpts/roberta-$id_data --fname ${id_data}_$id_data" 52 | ``` 53 | The `*_conditional_train.txt` file contains both the labels as well as the text. 54 | 55 | ### Finding Maximum Softmax Probability (MSP) 56 | Two options: 57 | 1. Using HF Datasets - 58 | ``` 59 | id_data="mnli" 60 | ood_data="snli" 61 | python msp_eval.py --model_path roberta_ckpts/roberta-$id_data --dataset_name $ood_data --fname ${id_data}_$ood_data 62 | ``` 63 | 2. Using txt file generated earlier - 64 | ``` 65 | id_data="mnli" 66 | ood_data="snli" 67 | python msp_eval.py --model_path roberta_ckpts/roberta-$id_data --val_file data/${ood_data}_val.txt --fname ${id_data}_$ood_data --save_msp True 68 | ``` 69 | 70 | ### Evaluating AUROC 71 | 1. Compute AUROC of PPL using `compute_auroc` in `utils.py` - 72 | ``` 73 | id_data = 'glue_mnli' 74 | ood_data = 'snli' 75 | id_pps = utils.read_model_out(f'output/gpt2/{id_data}_{id_data}_pps.npy') 76 | ood_pps = utils.read_model_out(f'output/gpt2/{id_data}_{ood_data}_pps.npy') 77 | score = compute_auroc(id_pps, ood_pps) 78 | print(score) 79 | ``` 80 | 81 | 2. Compute AUROC of MSP - 82 | ``` 83 | id_data = 'mnli' 84 | ood_data = 'snli' 85 | id_msp = utils.read_model_out(f'output/roberta/{id_data}_{id_data}_msp.npy') 86 | ood_msp = utils.read_model_out(f'output/roberta/{id_data}_{ood_data}_msp.npy') 87 | score = compute_auroc(-id_msp, -ood_msp) 88 | print(score) 89 | ``` 90 | 91 | ## Citation and authors 92 | 93 | ### Bibtex 94 | 95 | ``` 96 | @inproceedings{arora-etal-2021-types, 97 | title = "Types of Out-of-Distribution Texts and How to Detect Them", 98 | author = "Arora, Udit and 99 | Huang, William and 100 | He, He", 101 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 102 | month = nov, 103 | year = "2021", 104 | address = "Online and Punta Cana, Dominican Republic", 105 | publisher = "Association for Computational Linguistics", 106 | url = "https://aclanthology.org/2021.emnlp-main.835", 107 | pages = "10687--10701", 108 | abstract = "Despite agreement on the importance of detecting out-of-distribution (OOD) examples, there is little consensus on the formal definition of the distribution shifts of OOD examples and how to best detect them. We categorize these examples as exhibiting a background shift or semantic shift, and find that the two major approaches to OOD detection, calibration and density estimation (language modeling for text), have distinct behavior on these types of OOD data. Across 14 pairs of in-distribution and OOD English natural language understanding datasets, we find that density estimation methods consistently beat calibration methods in background shift settings and perform worse in semantic shift settings. In addition, we find that both methods generally fail to detect examples from challenge data, indicating that these examples constitute a different type of OOD data. Overall, while the categorization we apply explains many of the differences between the two methods, our results call for a more explicit definition of OOD to create better benchmarks and build detectors that can target the type of OOD data expected at test time.", 109 | } 110 | ``` 111 | 112 | ### Authors 113 | 114 | [Udit Arora](https://uditarora.com) 115 | 116 | [William Huang](https://wh629.github.io/) 117 | 118 | [He He](https://hhexiy.github.io) 119 | -------------------------------------------------------------------------------- /fine_tune.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import os 3 | import fire 4 | 5 | if not os.path.exists('data'): 6 | os.makedirs('data') 7 | 8 | NUM_EPOCHS = 1 9 | 10 | CKPT_DIR = 'ckpts' 11 | if not os.path.exists(CKPT_DIR): 12 | os.makedirs(CKPT_DIR) 13 | 14 | def main(train_file=None, val_file=None, dataset_name=None, dataset_config_name=None, key='text', val_name='validation', key2=None, conditional=False, do_val=False, cache_dir='cache/huggingface/transformers', fname=None, version='gpt2'): 15 | if train_file is None: 16 | if dataset_config_name is not None: 17 | fname = dataset_name + '_' + dataset_config_name 18 | else: 19 | fname = dataset_name 20 | 21 | if do_val: 22 | train, val = load_dataset(dataset_name, dataset_config_name, split=['train', val_name], cache_dir=cache_dir) 23 | else: 24 | train = load_dataset(dataset_name, dataset_config_name, split='train', cache_dir=cache_dir) 25 | 26 | print(f"Processing dataset {fname}...") 27 | 28 | train_str = "" 29 | for ex in train: 30 | if conditional: 31 | line = f"{ex['label']} " 32 | else: 33 | line = "" 34 | line += f"{ex[key]}" 35 | if key2 is not None: 36 | line += f" {ex[key2]}" 37 | train_str += f"{line} <|endoftext|>\n" 38 | 39 | if do_val: 40 | val_str = "" 41 | for ex in val: 42 | if conditional: 43 | line = f"{ex['label']} " 44 | else: 45 | line = "" 46 | line += f"{ex[key]}" 47 | if key2 is not None: 48 | line += f" {ex[key2]}" 49 | val_str += f"{line} <|endoftext|>\n" 50 | 51 | if conditional: 52 | fname_train = f'data/{fname}_conditional_train.txt' 53 | fname_val = f'data/{fname}_conditional_val.txt' 54 | else: 55 | fname_train = f'data/{fname}_train.txt' 56 | fname_val = f'data/{fname}_val.txt' 57 | 58 | with open (fname_train, 'w') as f: 59 | f.write(train_str) 60 | 61 | if do_val: 62 | with open (fname_val, 'w') as f: 63 | f.write(val_str) 64 | else: 65 | fname_train = train_file 66 | fname_val = val_file 67 | 68 | print(f"Running fine-tuning from {fname_train}...") 69 | 70 | if conditional == False: 71 | output_dir = f'--output_dir {CKPT_DIR}/{version}-{fname} ' 72 | else: 73 | output_dir = f'--output_dir {CKPT_DIR}/{version}-{fname}-conditional ' 74 | 75 | if do_val: 76 | cmd = 'python run_language_modeling.py ' + \ 77 | f'--train_data_file {fname_train} ' + \ 78 | f'--eval_data_file {fname_val} ' + \ 79 | output_dir + \ 80 | f'--model_type {version} ' + \ 81 | f'--model_name_or_path {version} ' + \ 82 | '--save_total_limit 1 ' + \ 83 | f'--num_train_epochs {NUM_EPOCHS} ' + \ 84 | '--do_train \ 85 | --evaluate_during_training \ 86 | --logging_steps 500 \ 87 | --save_steps 500 \ 88 | --do_eval \ 89 | --per_gpu_train_batch_size 8 \ 90 | --per_gpu_eval_batch_size 8 \ 91 | --line_by_line \ 92 | --gradient_accumulation_steps 1' 93 | else: 94 | cmd = 'python run_language_modeling.py ' + \ 95 | f'--train_data_file {fname_train} ' + \ 96 | output_dir + \ 97 | f'--model_type {version} ' + \ 98 | f'--model_name_or_path {version} ' + \ 99 | '--save_total_limit 1 ' + \ 100 | f'--num_train_epochs {NUM_EPOCHS} ' + \ 101 | '--do_train \ 102 | --per_gpu_train_batch_size 8 \ 103 | --per_gpu_eval_batch_size 8 \ 104 | --line_by_line \ 105 | --gradient_accumulation_steps 1' 106 | 107 | if cache_dir is not None: 108 | cmd += f' --cache_dir {cache_dir}' 109 | 110 | cmd += ' --overwrite_output_dir' 111 | 112 | os.system(cmd) 113 | 114 | if __name__ == '__main__': 115 | fire.Fire(main) 116 | print("\n\n--------DONE--------") 117 | -------------------------------------------------------------------------------- /msp_eval.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 8 | from datasets import load_dataset 9 | import fire 10 | import time 11 | from roberta_fine_tune import eval, process_hf_dataset, process_lm_dataset, process_custom_dataset 12 | 13 | SAVE_PATH = 'output/msp/' 14 | if not os.path.exists(SAVE_PATH): 15 | os.makedirs(SAVE_PATH) 16 | 17 | def process_entailment(dataset, tokenizer, key1='sentence1', key2='sentence2'): 18 | dataset_texts = [] 19 | for ex in dataset: 20 | dataset_texts.append(ex[key1] + ' ' + ex[key2]) 21 | return [encode(tokenizer, text) for text in dataset_texts] 22 | 23 | def encode(tokenizer, text): 24 | return tokenizer.encode_plus( 25 | text, 26 | add_special_tokens=True, # Add '[CLS]' and '[SEP]' 27 | return_token_type_ids=False, 28 | max_length=150, 29 | pad_to_max_length=True, 30 | return_attention_mask=True, 31 | return_tensors='pt', # Return PyTorch tensors 32 | ) 33 | 34 | def process_msp(all_encodings, model): 35 | scores = [] 36 | for encoding in tqdm(all_encodings): 37 | input_ids, attention_mask = encoding['input_ids'], encoding['attention_mask'] 38 | out = model(input_ids, attention_mask)[0] 39 | score = F.softmax(out[0], dim=0) 40 | scores.append(score.detach().cpu().numpy()) 41 | max_probs = np.max(np.array(scores), axis=1) 42 | return max_probs 43 | 44 | def main(model_path, val_file=None, dataset_name=None, dataset_config_name=None, split='eval', batch_size=16, max_length=None, n=None, fname='sample', cache_dir='/scratch/ua388/cache/huggingface/datasets', save_msp=True, alpha=None): 45 | if alpha is not None: 46 | global SAVE_PATH 47 | SAVE_PATH = os.path.join(SAVE_PATH, f'alpha_{alpha}') 48 | if not os.path.exists(SAVE_PATH): 49 | os.makedirs(SAVE_PATH, exist_ok=True) 50 | 51 | print("Loading model...") 52 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 53 | tokenizer = AutoTokenizer.from_pretrained(model_path) 54 | model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device) 55 | padding = 'max_length' 56 | glue = ['sst2', 'mnli'] 57 | 58 | if cache_dir == 'None': 59 | cache_dir = None 60 | 61 | if val_file is None: 62 | if 'glue' in dataset_name: 63 | dataset_name, dataset_config_name = dataset_name.split('_') 64 | elif dataset_name in glue: 65 | dataset_name, dataset_config_name = 'glue', dataset_name 66 | 67 | dataset = load_dataset(dataset_name, dataset_config_name, cache_dir=cache_dir) 68 | if dataset_config_name is not None: 69 | task_name = dataset_config_name 70 | else: 71 | task_name = dataset_name 72 | dataloader = process_hf_dataset(dataset, split, task_name, tokenizer, padding, max_length=max_length, batch_size=batch_size, n=n, shuffle=False) 73 | with_labels = True 74 | else: 75 | # Check for file type and process either .tsv or .txt 76 | if '.tsv' in val_file: 77 | df = pd.read_table(val_file) 78 | label_key = 'label' 79 | if 'mnli' in val_file: 80 | task_name = 'mnli' 81 | label_key = 'label' 82 | elif 'imdb' in val_file: 83 | task_name = 'counterfactual-imdb' 84 | label_key = 'Sentiment' 85 | else: #TODO: Support other tasks 86 | task_name = 'none' 87 | return 88 | 89 | if label_key in df: 90 | with_labels = True 91 | df = df[df[label_key] != -1] 92 | else: 93 | with_labels = False 94 | # num_labels = len(np.unique(pd.Categorical(df['label'], ordered=True))) 95 | dataloader = process_custom_dataset(df, task_name, tokenizer, padding, max_length, batch_size, n=n, shuffle=False) 96 | else: 97 | dataloader = process_lm_dataset(val_file, tokenizer, padding, max_length, batch_size, n=n, num_label_chars=0, shuffle=False) 98 | with_labels = False 99 | 100 | print('Evaluating model') 101 | start_time = time.time() 102 | probs = eval(model, dataloader, device, with_labels=with_labels) 103 | end_time = time.time() 104 | print("MSP runtime:", end_time - start_time) 105 | np.save(os.path.join(SAVE_PATH, f'{fname}_probs'), probs) 106 | if save_msp: 107 | msp = np.max(probs, axis=1) 108 | np.save(os.path.join(SAVE_PATH, f'{fname}_msp'), msp) 109 | 110 | if __name__ == '__main__': 111 | fire.Fire(main) 112 | print("\n\n--------DONE--------") 113 | -------------------------------------------------------------------------------- /perplexity.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import AutoTokenizer, AutoModelWithLMHead 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | import os 7 | import pickle 8 | import fire 9 | import time 10 | import pandas as pd 11 | 12 | SAVE_PATH = 'output/' 13 | if not os.path.exists(SAVE_PATH): 14 | os.makedirs(SAVE_PATH) 15 | 16 | def compute_all(model, all_encodings, fname, save_path=SAVE_PATH, device='cpu'): 17 | start_time = time.time() 18 | print(f"Finding perplexities for {fname}") 19 | perplexities, lls = [], [] 20 | # all_encodings = all_encodings[:n] 21 | pbar = tqdm(total=len(all_encodings)) 22 | for idx, encodings in enumerate(all_encodings): 23 | try: 24 | pp, ll = compute_perplexity(model, encodings, device=device) 25 | perplexities.append(pp) 26 | lls.append(ll) 27 | except Exception as e: 28 | print("Exception at idx", idx) 29 | print(e) 30 | continue 31 | finally: 32 | pbar.update(1) 33 | 34 | pbar.close() 35 | end_time = time.time() 36 | print("PPL runtime:", end_time-start_time) 37 | 38 | perplexities = np.array(perplexities) 39 | np.save(f'{save_path}/{fname}_pps.npy', perplexities) 40 | print(f"\nMean: {perplexities.mean()}, Std: {perplexities.std()}") 41 | 42 | with open(f'{save_path}/{fname}_lls.pkl', 'wb') as fw: 43 | pickle.dump(lls, fw) 44 | 45 | return perplexities 46 | 47 | def compute_perplexity(model, encodings, stride=None, device='cuda'): 48 | max_length = model.config.n_positions 49 | lls = [] 50 | if stride is None: 51 | # stride = max(1, encodings.input_ids.size(1) // 100) 52 | # stride = max_length 53 | # print('Stride:', stride) 54 | stride = 1 55 | 56 | for i in range(1, encodings.input_ids.size(1), stride): 57 | begin_loc = max(i + stride - max_length, 0) 58 | end_loc = i + stride 59 | input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device) 60 | target_ids = input_ids.clone() 61 | target_ids[:,:-stride] = -100 62 | 63 | with torch.no_grad(): 64 | outputs = model(input_ids, labels=target_ids) 65 | log_likelihood = outputs[0] * stride 66 | 67 | lls.append(log_likelihood) 68 | 69 | ppl = torch.exp(torch.stack(lls).sum() / i) 70 | return ppl.item(), torch.stack(lls).detach().cpu().numpy() 71 | 72 | def setup(path='lvwerra/gpt2-imdb'): 73 | if torch.cuda.is_available(): 74 | device = 'cuda' 75 | else: 76 | device = 'cpu' 77 | n = None 78 | model = AutoModelWithLMHead.from_pretrained(path).to(device) 79 | tokenizer = AutoTokenizer.from_pretrained(path) 80 | return model, tokenizer, device, n 81 | 82 | def get_encoding(tokenizer, text, label=None, add_eot=True): 83 | if label is not None: 84 | prefix = f'{label} ' 85 | else: 86 | prefix = '' 87 | if add_eot: 88 | return tokenizer(prefix + text + ' <|endoftext|>', return_tensors='pt') 89 | else: 90 | return tokenizer(prefix + text, return_tensors='pt') 91 | 92 | def process_hf_dataset(dataset_name, model, tokenizer, device, n=None, key='text', configs=None, fname=None, conditional=False): 93 | print(f"\n------Processing perplexity for dataset: {dataset_name}-------") 94 | 95 | if configs is None: 96 | dataset = [load_dataset(dataset_name, split='test')] 97 | else: 98 | dataset = [load_dataset(dataset_name, config, split='test') for config in configs] 99 | 100 | if fname is None: 101 | fname = dataset_name 102 | 103 | print("Tokenizing data") 104 | if not conditional: 105 | all_encodings = [get_encoding(tokenizer, text) for _dataset in dataset for text in _dataset[key][:n]] 106 | return compute_all(model, all_encodings, fname) 107 | else: 108 | all_encodings_0 = [get_encoding(tokenizer, text, 0) for _dataset in dataset for text in _dataset[key][:n]] 109 | all_encodings_1 = [get_encoding(tokenizer, text, 1) for _dataset in dataset for text in _dataset[key][:n]] 110 | fname_0 = fname + '_conditional_0' 111 | fname_1 = fname + '_conditional_1' 112 | 113 | compute_all(model, all_encodings_0, fname_0) 114 | compute_all(model, all_encodings_1, fname_1) 115 | 116 | def process_entailment(dataset_name, model, tokenizer, device, n=None, dataset_subname=None, fname=None, key1='premise', key2='hypothesis', conditional=False): 117 | print(f"\n------Processing perplexity for dataset: {dataset_name}_{dataset_subname}-------") 118 | 119 | if dataset_subname is None: 120 | dataset = load_dataset(dataset_name, split='validation') 121 | else: 122 | dataset = load_dataset(dataset_name, dataset_subname, split='validation') 123 | dataset_texts = [] 124 | for ex in dataset: 125 | dataset_texts.append(ex[key1] + ' ' + ex[key2]) 126 | 127 | if fname is None: 128 | fname = dataset_name 129 | 130 | print("Tokenizing data") 131 | if not conditional: 132 | all_encodings = [get_encoding(tokenizer, text) for text in dataset_texts[:n]] 133 | return compute_all(model, all_encodings, fname) 134 | else: 135 | all_encodings_0 = all_encodings = [get_encoding(tokenizer, text, 0) for text in dataset_texts[:n]] 136 | all_encodings_1 = all_encodings = [get_encoding(tokenizer, text, 1) for text in dataset_texts[:n]] 137 | fname_0 = fname + '_conditional_0' 138 | fname_1 = fname + '_conditional_1' 139 | 140 | compute_all(model, all_encodings_0, fname_0) 141 | compute_all(model, all_encodings_1, fname_1) 142 | 143 | def process_tsv(dataset_path, n=None, key='Text', key2=None): 144 | print('Loading data...') 145 | dataset = pd.read_table(dataset_path) 146 | if key2 is None: 147 | return dataset[key][:n] 148 | else: 149 | series = dataset[key][:n] + ' ' + dataset[key2][:n] 150 | return series 151 | 152 | def process_txt(dataset_path, n=None): 153 | print('Loading data...') 154 | with open(dataset_path) as f: 155 | dataset = f.readlines() 156 | return dataset[:n] 157 | 158 | def process_label(model, tokenizer, device, dataset, label, fname, save_path, n=None): 159 | print(f"Evaluating conditional for label {label}") 160 | all_encodings_curr = [get_encoding(tokenizer, text, label, add_eot=False) for text in dataset[:n]] 161 | fname_curr = fname + f'_conditional_{label}' 162 | compute_all(model, all_encodings_curr, fname_curr, save_path=save_path, device=device) 163 | 164 | def process_dataset(dataset, model, tokenizer, device, fname, save_path, n=None, add_eot=False, conditional=False, num_classes=2, class_num=None): 165 | if not conditional: 166 | all_encodings = [get_encoding(tokenizer, text, add_eot=add_eot) for text in dataset[:n]] 167 | compute_all(model, all_encodings, fname, save_path=save_path, device=device) 168 | else: 169 | if class_num is None: 170 | for label in range(num_classes): 171 | process_label(model, tokenizer, device, dataset, label, fname, save_path, n=n) 172 | else: 173 | process_label(model, tokenizer, device, dataset, class_num, fname, save_path, n=n) 174 | 175 | def main(dataset_path, model_path='/scratch/ua388/ckpts/gpt2-glue_sst2', fname=None, n=None, conditional=False, add_eot=False, num_classes=2, class_num=None, key='sentence1', key2='sentence2', num_splits=1, split_idx=None): 176 | model, tokenizer, device, _ = setup(model_path) 177 | if conditional: 178 | save_path = os.path.join(SAVE_PATH, 'gpt2_conditional') 179 | else: 180 | save_path = os.path.join(SAVE_PATH, 'gpt2') 181 | if not os.path.exists(save_path): 182 | os.makedirs(save_path) 183 | 184 | if '.txt' in dataset_path: 185 | dataset = process_txt(dataset_path, n) 186 | # Split dataset examples 187 | if split_idx is not None: 188 | fname = f'{fname}_{split_idx}' 189 | per_split = len(dataset) // num_splits 190 | start_idx = split_idx * per_split 191 | if split_idx + 1 < num_splits: 192 | end_idx = (split_idx + 1) * per_split 193 | else: 194 | end_idx = len(dataset) 195 | dataset = dataset[start_idx:end_idx] 196 | print(f"Taking split {split_idx+1}/{num_splits} with start_idx: {start_idx} and end_idx: {end_idx} of size {len(dataset)}") 197 | elif '.tsv' in dataset_path: 198 | dataset = process_tsv(dataset_path, n=n, key=key, key2=key2) 199 | else: 200 | dataset = None 201 | print("Invalid dataset path:", dataset_path) 202 | return 203 | 204 | if '<|endoftext|>' in dataset[0]: 205 | add_eot = False 206 | else: 207 | add_eot = True 208 | 209 | process_dataset(dataset, model, tokenizer, device, fname, save_path, n=n, add_eot=add_eot, conditional=conditional, num_classes=num_classes, class_num=class_num) 210 | 211 | if __name__ == '__main__': 212 | print("Loading model...") 213 | 214 | # To evaluate HF datasets, use these 215 | # model, tokenizer, device, n = setup('ckpts/gpt2-glue_sst2') 216 | # process_hf_dataset('imdb', model, tokenizer, device, n=2, key='text', conditional=True) 217 | # process_hf_dataset('yelp_polarity', model, tokenizer, device, n=3000, key='text') 218 | # process_hf_dataset('glue', model, tokenizer, device, configs=['sst2'], key='sentence', fname='sst2') 219 | # process_entailment('glue', model, tokenizer, device, dataset_subname='rte', fname='rte', key1='sentence1', key2='sentence2', conditional=True) 220 | # process_entailment('snli', model, tokenizer, device, key1='premise', key2='hypothesis', conditional=True) 221 | 222 | fire.Fire(main) 223 | 224 | print("\n\n--------DONE--------") 225 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | pandas 3 | datasets==1.0.2 4 | numpy==1.19.1 5 | transformers==3.1.0 6 | torch==1.6.0 7 | tqdm==4.49.0 8 | dataclasses==0.8 9 | scikit_learn==0.23.2 10 | -------------------------------------------------------------------------------- /roberta_fine_tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import fire 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.utils.data import (DataLoader, RandomSampler, TensorDataset, SequentialSampler) 11 | import torch.nn.functional as F 12 | from tqdm import tqdm, trange 13 | 14 | from transformers import (RobertaConfig, RobertaTokenizer, RobertaForSequenceClassification) 15 | from transformers import set_seed 16 | 17 | from datasets import load_dataset 18 | 19 | SAVE_PATH = 'output/roberta/' 20 | if not os.path.exists(SAVE_PATH): 21 | os.makedirs(SAVE_PATH) 22 | 23 | def get_dataloader(tokenizer_args, tokenizer, padding, max_length, batch_size, truncation=True, labels=None, shuffle=True): 24 | features = tokenizer(*tokenizer_args, padding=padding, max_length=max_length, truncation=truncation) 25 | all_input_ids = torch.tensor([f for f in features.input_ids], dtype=torch.long) 26 | all_attention_mask = torch.tensor([f for f in features.attention_mask], dtype=torch.long) 27 | 28 | if labels is not None: 29 | all_labels = torch.tensor([f for f in labels], dtype=torch.long) 30 | tensor_dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels) 31 | else: 32 | tensor_dataset = TensorDataset(all_input_ids, all_attention_mask) 33 | 34 | if shuffle: 35 | sampler = RandomSampler(tensor_dataset) 36 | else: 37 | sampler = SequentialSampler(tensor_dataset) 38 | 39 | dataloader = DataLoader(tensor_dataset, sampler=sampler, batch_size=batch_size) 40 | return dataloader 41 | 42 | def process_hf_dataset(dataset, split, task_name, tokenizer, padding, max_length, batch_size, truncation=True, n=None, shuffle=True): 43 | eval_split_keys = { 44 | 'imdb': 'test', 45 | 'sst2': 'validation', 46 | 'yelp_polarity': 'test', 47 | 'mnli': 'validation_matched', 48 | 'hans': 'validation', 49 | 'snli': 'validation', 50 | 'rte': 'validation' 51 | } 52 | if split == 'eval': 53 | split = eval_split_keys[task_name] 54 | 55 | tasks_to_keys = { 56 | 'imdb': ('text', None), 57 | 'yelp_polarity': ('text', None), 58 | 'sst2': ('sentence', None), 59 | 'mnli': ('premise', 'hypothesis'), 60 | 'hans': ('premise', 'hypothesis'), 61 | 'snli': ('premise', 'hypothesis'), 62 | 'rte': ('sentence1', 'sentence2') 63 | } 64 | 65 | sentence1_key, sentence2_key = tasks_to_keys[task_name] 66 | args = ((dataset[split][sentence1_key][:n],) if sentence2_key is None else (dataset[split][sentence1_key][:n], dataset[split][sentence2_key][:n])) 67 | labels = dataset[split]['label'][:n] 68 | return get_dataloader(args, tokenizer, padding, max_length, batch_size, truncation=truncation, labels=labels, shuffle=shuffle) 69 | 70 | def process_custom_dataset(dataset, task_name, tokenizer, padding, max_length, batch_size, truncation=True, n=None, shuffle=True): 71 | tasks_to_keys = { 72 | 'counterfactual-imdb': ('Text', None, 'Sentiment'), 73 | 'mnli': ('sentence1', 'sentence2', 'label') 74 | } 75 | 76 | sentence1_key, sentence2_key, labels_key = tasks_to_keys[task_name] 77 | 78 | dataset[sentence1_key] = dataset[sentence1_key].astype(str) 79 | if sentence2_key is not None: 80 | dataset[sentence2_key] = dataset[sentence2_key].astype(str) 81 | 82 | args = ((dataset[sentence1_key].tolist()[:n],) if sentence2_key is None else (dataset[sentence1_key].tolist()[:n], dataset[sentence2_key].tolist()[:n])) 83 | 84 | features = tokenizer(*args, padding=padding, max_length=max_length, truncation=truncation) 85 | if task_name != 'mnli': 86 | labels = pd.Categorical(dataset[labels_key], ordered=True).codes.tolist()[:n] 87 | else: 88 | labels = dataset[labels_key].tolist()[:n] 89 | 90 | all_input_ids = torch.tensor([f for f in features.input_ids], dtype=torch.long) 91 | all_attention_mask = torch.tensor([f for f in features.attention_mask], dtype=torch.long) 92 | all_labels = torch.tensor([f for f in labels], dtype=torch.long) 93 | 94 | tensor_dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels) 95 | if shuffle: 96 | sampler = RandomSampler(tensor_dataset) 97 | else: 98 | sampler = SequentialSampler(tensor_dataset) 99 | dataloader = DataLoader(tensor_dataset, sampler=sampler, batch_size=batch_size, shuffle=shuffle) 100 | 101 | return dataloader 102 | 103 | def process_lm_dataset(dataset_path, tokenizer, padding, max_length, batch_size, truncation=True, num_label_chars=1, n=None, shuffle=True): 104 | # label in first column, and text in rest of the columns 105 | dataset_texts, labels = [], [] 106 | with open(dataset_path) as f: 107 | for idx, line in enumerate(f): 108 | line = line.strip() 109 | if num_label_chars > 0: 110 | # print(f'{idx}: {line[:num_label_chars]}') 111 | try: 112 | labels.append(int(line[:num_label_chars])) 113 | except Exception as e: 114 | print(e) 115 | print(idx, line) 116 | dataset_texts.append(line[num_label_chars:].replace(' <|endoftext|>', '').lstrip()) 117 | dataset_texts, labels = dataset_texts[:n], labels[:n] 118 | args = ((dataset_texts,)) 119 | if len(labels) == 0: 120 | labels = None 121 | return get_dataloader(args, tokenizer, padding, max_length, batch_size, truncation=truncation, labels=labels, shuffle=shuffle) 122 | 123 | def train(model, tokenizer, optimizer, criterion, device, train_loader, num_epochs, output_dir): 124 | losses = [] 125 | train_iterator = trange(int(num_epochs), desc='Epoch') 126 | for _ in train_iterator: 127 | tr_loss = 0 128 | step = None 129 | epoch_iterator = tqdm(train_loader, desc='Iteration') 130 | for step, batch in enumerate(epoch_iterator): 131 | model.train() 132 | batch = tuple(t.to(device) for t in batch) 133 | inputs = {'input_ids': batch[0].to(device), 'attention_mask': batch[1].to(device), 'labels': batch[2].to(device)} 134 | labels = batch[2].to(device) 135 | 136 | optimizer.zero_grad() 137 | 138 | out = model(**inputs)[1].double().to(device) 139 | 140 | loss = criterion(out, labels) 141 | loss.backward() 142 | optimizer.step() 143 | 144 | tr_loss += loss.item() 145 | losses.append(tr_loss/(step+1)) 146 | print('train loss: {}'.format(tr_loss/(step+1))) 147 | 148 | # save model and tokenizer 149 | print('Saving model and tokenizer') 150 | 151 | model.save_pretrained(output_dir) 152 | tokenizer.save_pretrained(output_dir) 153 | 154 | def eval(model, eval_loader, device, criterion=nn.CrossEntropyLoss(), with_labels=True): 155 | probs = None 156 | gold_labels = None 157 | 158 | eval_loss = 0 159 | step = None 160 | eval_iterator = tqdm(eval_loader, desc='Evaluating') 161 | for step, batch in enumerate(eval_iterator): 162 | model.eval() 163 | batch = tuple(t.to(device) for t in batch) 164 | 165 | with torch.no_grad(): 166 | inputs = {'input_ids': batch[0].to(device), 'attention_mask': batch[1].to(device)} 167 | # out = model(**inputs)[0].double().to(device) 168 | 169 | out = model(**inputs)[0].double() 170 | out = F.softmax(out, dim=1) 171 | 172 | if with_labels: 173 | # inputs['labels'] = batch[2].to(device) 174 | labels = batch[2].to(device) 175 | loss = criterion(out, labels) 176 | 177 | if probs is None: 178 | probs = out.detach().cpu().numpy() 179 | if with_labels: 180 | gold_labels = labels.detach().cpu().numpy() 181 | else: 182 | probs = np.append(probs, out.detach().cpu().numpy(), axis=0) 183 | if with_labels: 184 | gold_labels = np.append(gold_labels, labels.detach().cpu().numpy(), axis=0) 185 | 186 | if with_labels: 187 | eval_loss += loss.item() 188 | 189 | if with_labels: 190 | eval_loss /= (step+1) 191 | print('eval loss: {}'.format(eval_loss)) 192 | 193 | # compute accuracy 194 | preds = np.argmax(probs, axis=1) 195 | accuracy = np.sum(preds == gold_labels)/len(preds) 196 | print('eval accuracy: {}'.format(accuracy)) 197 | 198 | return probs 199 | 200 | def main(): 201 | # create argument parser 202 | parser = argparse.ArgumentParser() 203 | 204 | parser.add_argument('--task_name', help='Task to fine-tune RoBERTa on', default='sst2') 205 | parser.add_argument('--roberta_version', type=str, default='roberta-base', help='Version of RoBERTa to use') 206 | parser.add_argument('--cache_dir_data', type=str, default='cache/huggingface/datasets', help='Path to cache directory') 207 | parser.add_argument('--cache_dir', type=str, default='cache/huggingface/transformers', help='Path to cache directory') 208 | parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs to fine-tune') 209 | parser.add_argument('--max_seq_length', type=int, default=None, help='Maximum sequence length of the inputs') 210 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') 211 | parser.add_argument('--learning_rate', type=float, default=1e-5, help='Adam learning rate') 212 | parser.add_argument('--output_dir', type=str, default='roberta_ckpts/', help='Directory to save fine-tuned models') 213 | parser.add_argument('--seed', type=int, default=42, help='Random seed for initialization') 214 | parser.add_argument('--file_format', type=str, default='.tsv', help='Data file format for tasks not available for download at HuggingFace Datasets') 215 | parser.add_argument('--train_file', type=str, default=None, help='LM txt file') 216 | parser.add_argument('--val_file', type=str, default=None, help='LM txt file') 217 | parser.add_argument('--num_labels', type=int, default=2, help='Number of labels in training data') 218 | parser.add_argument('--fname', type=str, default=None, help='MSP output file') 219 | parser.add_argument('--n', type=int, default=None, help='Number of examples to process (for debugging)') 220 | 221 | args = parser.parse_args() 222 | 223 | # set device 224 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 225 | 226 | if not torch.cuda.is_available(): 227 | args.cache_dir = args.cache_dir_data = None 228 | 229 | # huggingface and glue datasets 230 | hf_datasets = ['imdb', 'sst2', 'mnli', 'hans', 'snli', 'rte'] 231 | glue = ['sst2', 'mnli', 'rte'] 232 | 233 | # custom dataset label keys 234 | label_keys = { 235 | 'counterfactual-imdb': 'Sentiment', 236 | } 237 | 238 | # load dataset 239 | print('Loading dataset') 240 | 241 | if args.train_file is None: 242 | if args.task_name in hf_datasets: 243 | dataset = load_dataset(args.task_name, cache_dir=args.cache_dir_data) if args.task_name not in glue else load_dataset('glue', args.task_name, cache_dir=args.cache_dir_data) 244 | num_labels = dataset['train'].features['label'].num_classes 245 | print("num_labels =", num_labels) 246 | elif args.file_format == '.tsv': 247 | train_df = pd.read_table(os.path.join(os.getcwd(), args.task_name + '_train' + args.file_format)) 248 | eval_df = pd.read_table(os.path.join(os.getcwd(), args.task_name + '_val' + args.file_format)) 249 | num_labels = len(np.unique(pd.Categorical(train_df[label_keys.get(args.task_name, 'label')], ordered=True))) 250 | else: 251 | num_labels = args.num_labels 252 | 253 | # set seed 254 | set_seed(args.seed) 255 | 256 | # load RoBERTa tokenizer and model 257 | print('Loading RoBERTa tokenizer and model') 258 | 259 | config = RobertaConfig.from_pretrained(args.roberta_version, num_labels=num_labels, cache_dir=args.cache_dir) 260 | tokenizer = RobertaTokenizer.from_pretrained(args.roberta_version, cache_dir=args.cache_dir) 261 | model = RobertaForSequenceClassification.from_pretrained(args.roberta_version, config=config, cache_dir=args.cache_dir).to(device) 262 | 263 | # process dataset 264 | print('Processing dataset') 265 | 266 | padding = 'max_length' 267 | with_labels = True 268 | if args.train_file is None: 269 | if args.task_name in hf_datasets: 270 | train_loader = process_hf_dataset(dataset, 'train', args.task_name, tokenizer, padding, args.max_seq_length, args.batch_size, n=args.n) 271 | # train_loader = None 272 | eval_loader = process_hf_dataset(dataset, 'eval', args.task_name, tokenizer, padding, args.max_seq_length, args.batch_size, n=args.n) 273 | elif args.file_format == '.tsv': 274 | task_name = args.task_name 275 | if 'mnli' in args.task_name: 276 | task_name = 'mnli' 277 | train_loader = process_custom_dataset(train_df, task_name, tokenizer, padding, args.max_seq_length, args.batch_size, n=args.n) 278 | eval_loader = process_custom_dataset(eval_df, task_name, tokenizer, padding, args.max_seq_length, args.batch_size, n=args.n) 279 | else: 280 | train_loader = process_lm_dataset(args.train_file, tokenizer, padding, args.max_seq_length, args.batch_size, n=args.n) 281 | if args.val_file is not None: 282 | eval_loader = process_lm_dataset(args.val_file, tokenizer, padding, args.max_seq_length, args.batch_size, n=args.n, num_label_chars=0) 283 | else: 284 | eval_loader = None 285 | with_labels = False 286 | 287 | # instantiate optimizer and criterion 288 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 289 | criterion = nn.CrossEntropyLoss() 290 | 291 | # fine-tune model 292 | if train_loader is not None: 293 | print('Fine-tuning model') 294 | train(model, tokenizer, optimizer, criterion, device, train_loader, args.num_epochs, args.output_dir) 295 | 296 | # evaluate model 297 | if eval_loader is not None: 298 | print('Evaluating model') 299 | probs = eval(model, eval_loader, device, criterion, with_labels=with_labels) 300 | np.save(os.path.join(SAVE_PATH, f'{args.fname}_probs'), probs) 301 | msp = np.max(probs, axis=1) 302 | if args.fname is not None: 303 | np.save(os.path.join(SAVE_PATH, f'{args.fname}_msp'), msp) 304 | 305 | 306 | if __name__ == '__main__': 307 | main() 308 | print("\n\n--------DONE--------") 309 | -------------------------------------------------------------------------------- /run_language_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | 23 | import logging 24 | import math 25 | import os 26 | from dataclasses import dataclass, field 27 | from typing import Optional 28 | 29 | from transformers import ( 30 | CONFIG_MAPPING, 31 | MODEL_WITH_LM_HEAD_MAPPING, 32 | AutoConfig, 33 | AutoModelWithLMHead, 34 | AutoTokenizer, 35 | DataCollatorForLanguageModeling, 36 | HfArgumentParser, 37 | LineByLineTextDataset, 38 | PreTrainedTokenizer, 39 | TextDataset, 40 | Trainer, 41 | TrainingArguments, 42 | set_seed, 43 | ) 44 | 45 | 46 | logger = logging.getLogger(__name__) 47 | 48 | 49 | MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys()) 50 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 51 | 52 | 53 | @dataclass 54 | class ModelArguments: 55 | """ 56 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 57 | """ 58 | 59 | model_name_or_path: Optional[str] = field( 60 | default=None, 61 | metadata={ 62 | "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch." 63 | }, 64 | ) 65 | model_type: Optional[str] = field( 66 | default=None, 67 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 68 | ) 69 | config_name: Optional[str] = field( 70 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 71 | ) 72 | tokenizer_name: Optional[str] = field( 73 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 74 | ) 75 | cache_dir: Optional[str] = field( 76 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 77 | ) 78 | 79 | 80 | @dataclass 81 | class DataTrainingArguments: 82 | """ 83 | Arguments pertaining to what data we are going to input our model for training and eval. 84 | """ 85 | 86 | train_data_file: Optional[str] = field( 87 | default=None, metadata={"help": "The input training data file (a text file)."} 88 | ) 89 | eval_data_file: Optional[str] = field( 90 | default=None, 91 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 92 | ) 93 | line_by_line: bool = field( 94 | default=False, 95 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 96 | ) 97 | 98 | mlm: bool = field( 99 | default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."} 100 | ) 101 | mlm_probability: float = field( 102 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 103 | ) 104 | 105 | block_size: int = field( 106 | default=-1, 107 | metadata={ 108 | "help": "Optional input sequence length after tokenization." 109 | "The training dataset will be truncated in block of this size for training." 110 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 111 | }, 112 | ) 113 | overwrite_cache: bool = field( 114 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 115 | ) 116 | 117 | 118 | def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False): 119 | file_path = args.eval_data_file if evaluate else args.train_data_file 120 | if args.line_by_line: 121 | return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) 122 | else: 123 | return TextDataset( 124 | tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache 125 | ) 126 | 127 | 128 | def main(): 129 | # See all possible arguments in src/transformers/training_args.py 130 | # or by passing the --help flag to this script. 131 | # We now keep distinct sets of args, for a cleaner separation of concerns. 132 | 133 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 134 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 135 | 136 | if data_args.eval_data_file is None and training_args.do_eval: 137 | raise ValueError( 138 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 139 | "or remove the --do_eval argument." 140 | ) 141 | 142 | if ( 143 | os.path.exists(training_args.output_dir) 144 | and os.listdir(training_args.output_dir) 145 | and training_args.do_train 146 | and not training_args.overwrite_output_dir 147 | ): 148 | raise ValueError( 149 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 150 | ) 151 | 152 | # Setup logging 153 | logging.basicConfig( 154 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 155 | datefmt="%m/%d/%Y %H:%M:%S", 156 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 157 | ) 158 | logger.warning( 159 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 160 | training_args.local_rank, 161 | training_args.device, 162 | training_args.n_gpu, 163 | bool(training_args.local_rank != -1), 164 | training_args.fp16, 165 | ) 166 | logger.info("Training/evaluation parameters %s", training_args) 167 | 168 | # Set seed 169 | set_seed(training_args.seed) 170 | 171 | # Load pretrained model and tokenizer 172 | # 173 | # Distributed training: 174 | # The .from_pretrained methods guarantee that only one local process can concurrently 175 | # download model & vocab. 176 | 177 | if model_args.config_name: 178 | config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) 179 | elif model_args.model_name_or_path: 180 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 181 | else: 182 | config = CONFIG_MAPPING[model_args.model_type]() 183 | logger.warning("You are instantiating a new config instance from scratch.") 184 | 185 | if model_args.tokenizer_name: 186 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir) 187 | elif model_args.model_name_or_path: 188 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 189 | else: 190 | raise ValueError( 191 | "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it," 192 | "and load it from here, using --tokenizer_name" 193 | ) 194 | 195 | if model_args.model_name_or_path: 196 | model = AutoModelWithLMHead.from_pretrained( 197 | model_args.model_name_or_path, 198 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 199 | config=config, 200 | cache_dir=model_args.cache_dir, 201 | ) 202 | else: 203 | logger.info("Training new model from scratch") 204 | model = AutoModelWithLMHead.from_config(config) 205 | 206 | special_tokens_dict = {'bos_token': '', 'eos_token': '', 'pad_token': ''} 207 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 208 | model.resize_token_embeddings(len(tokenizer)) 209 | 210 | if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm: 211 | raise ValueError( 212 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm " 213 | "flag (masked language modeling)." 214 | ) 215 | 216 | if data_args.block_size <= 0: 217 | data_args.block_size = tokenizer.max_len 218 | # Our input block size will be the max possible for the model 219 | else: 220 | data_args.block_size = min(data_args.block_size, tokenizer.max_len) 221 | 222 | # Get datasets 223 | 224 | train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None 225 | eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None 226 | data_collator = DataCollatorForLanguageModeling( 227 | tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability 228 | ) 229 | 230 | # Initialize our Trainer 231 | trainer = Trainer( 232 | model=model, 233 | args=training_args, 234 | data_collator=data_collator, 235 | train_dataset=train_dataset, 236 | eval_dataset=eval_dataset, 237 | prediction_loss_only=True, 238 | ) 239 | 240 | # Training 241 | if training_args.do_train: 242 | model_path = ( 243 | model_args.model_name_or_path 244 | if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path) 245 | else None 246 | ) 247 | trainer.train(model_path=model_path) 248 | trainer.save_model() 249 | # For convenience, we also re-save the tokenizer to the same directory, 250 | # so that you can share your model easily on huggingface.co/models =) 251 | if trainer.is_world_master(): 252 | tokenizer.save_pretrained(training_args.output_dir) 253 | 254 | # Evaluation 255 | results = {} 256 | if training_args.do_eval: 257 | logger.info("*** Evaluate ***") 258 | 259 | eval_output = trainer.evaluate() 260 | 261 | perplexity = math.exp(eval_output["eval_loss"]) 262 | result = {"perplexity": perplexity} 263 | 264 | output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") 265 | if trainer.is_world_master(): 266 | with open(output_eval_file, "w") as writer: 267 | logger.info("***** Eval results *****") 268 | for key in sorted(result.keys()): 269 | logger.info(" %s = %s", key, str(result[key])) 270 | writer.write("%s = %s\n" % (key, str(result[key]))) 271 | 272 | results.update(result) 273 | 274 | return results 275 | 276 | 277 | def _mp_fn(index): 278 | # For xla_spawn (TPUs) 279 | main() 280 | 281 | 282 | if __name__ == "__main__": 283 | main() 284 | 285 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_auc_score, roc_curve 2 | import numpy as np 3 | import pickle 4 | 5 | def compute_auroc(id_pps, ood_pps, normalize=False, return_curve=False): 6 | y = np.concatenate((np.ones_like(ood_pps), np.zeros_like(id_pps))) 7 | scores = np.concatenate((ood_pps, id_pps)) 8 | if normalize: 9 | scores = (scores - scores.min()) / (scores.max() - scores.min()) 10 | if return_curve: 11 | return roc_curve(y, scores) 12 | else: 13 | return 100*roc_auc_score(y, scores) 14 | 15 | def compute_far(id_pps, ood_pps, rate=5): 16 | incorrect = len(id_pps[id_pps > np.percentile(ood_pps, rate)]) 17 | return 100*incorrect / len(id_pps) 18 | 19 | def compute_px(ppl, lls): 20 | lengths = np.array([len(ll) for ll in lls]) 21 | logpx = np.log(ppl) * lengths * -1 22 | return logpx 23 | 24 | def compute_ppl(logpx, lls): 25 | lengths = np.array([len(ll) for ll in lls]) 26 | log_ppl = - logpx / lengths 27 | return np.exp(log_ppl) 28 | 29 | def compute_conditional_prior(pps_labels, lls_labels, probs): 30 | # p(x) = \sum_y p(x|y) p(y) 31 | # log p(x) = \sum_y log p(x|y) + log p(y) 32 | px_labels = {} 33 | combined_px = None 34 | for label in pps_labels: 35 | px_labels[label] = compute_px(pps_labels[label], lls_labels[label]) 36 | if combined_px is None: 37 | combined_px = px_labels[label] + np.log(probs[label]) 38 | else: 39 | combined_px += px_labels[label] + np.log(probs[label]) 40 | 41 | combined_pps = compute_ppl(combined_px, lls_labels[0]) 42 | return combined_pps, combined_px 43 | 44 | def compute_conditional(pps_labels, lls_labels, probs): 45 | # p(x) = \sum_y p(x|y) p(y|x) 46 | # log p(x) = \sum_y log p(x|y) + log p(y|x) 47 | px_labels = {} 48 | combined_px = None 49 | for label in pps_labels: 50 | px_labels[label] = compute_px(pps_labels[label], lls_labels[label]) 51 | if combined_px is None: 52 | combined_px = px_labels[label] + np.log(probs[:, label]) 53 | else: 54 | combined_px += px_labels[label] + np.log(probs[:, label]) 55 | 56 | combined_pps = compute_ppl(combined_px, lls_labels[0]) 57 | return combined_pps, combined_px 58 | 59 | def compute_lm_metric(id_pps, id_lls, ood_pps, ood_lls, id_px=None, ood_px=None, metric='auroc', do_print=False, conditional=False): 60 | if metric == 'auroc': 61 | compute_fn = compute_auroc 62 | else: 63 | compute_fn = compute_far 64 | 65 | if id_px is None: 66 | id_px = compute_px(id_pps, id_lls) 67 | if ood_px is None: 68 | ood_px = compute_px(ood_pps, ood_lls) 69 | 70 | score_px = compute_fn(-id_px, -ood_px) 71 | score_ppl = compute_fn(id_pps, ood_pps) 72 | if do_print: 73 | if conditional: 74 | ctext = 'Conditional ' 75 | else: 76 | ctext = '' 77 | print(f"{ctext}P(x): {score_px:.3f}") 78 | print(f"{ctext}Perplexity: {score_ppl:.3f}") 79 | scores = { 80 | 'p_x': score_px, 81 | 'ppl': score_ppl 82 | } 83 | return scores 84 | 85 | def compute_auroc_all(id_msp, id_px, id_ppl, ood_msp, ood_px, ood_ppl, do_print=False): 86 | score_px = compute_auroc(-id_px, -ood_px) 87 | score_py = compute_auroc(-id_msp, -ood_msp) 88 | score_ppl = compute_auroc(id_ppl, ood_ppl) 89 | if do_print: 90 | print(f"P(x): {score_px:.3f}") 91 | print(f"P(y | x): {score_py:.3f}") 92 | print(f"Perplexity: {score_ppl:.3f}") 93 | scores = { 94 | 'p_x': score_px, 95 | 'p_y': score_py, 96 | 'ppl': score_ppl 97 | } 98 | return scores 99 | 100 | def compute_metric_all_old(id_pps, id_lls, id_msp, ood_pps, ood_lls, ood_msp, metric='auroc', do_print=False): 101 | id_px = compute_px(id_pps, id_lls) 102 | ood_px = compute_px(ood_pps, ood_lls) 103 | if metric == 'auroc': 104 | score_px = compute_auroc(-id_px, -ood_px) 105 | score_py = compute_auroc(-id_msp, -ood_msp) 106 | score_ppl = compute_auroc(id_pps, ood_pps) 107 | elif metric == 'far': 108 | score_px = compute_far(-id_px, -ood_px) 109 | score_py = compute_far(-id_msp, -ood_msp) 110 | score_ppl = compute_far(id_pps, ood_pps) 111 | else: 112 | raise Exception('Invalid metric name') 113 | 114 | if do_print: 115 | print(f"Metric {metric}:") 116 | print(f"P(x): {score_px:.3f}") 117 | print(f"P(y | x): {score_py:.3f}") 118 | print(f"Perplexity: {score_ppl:.3f}\n") 119 | 120 | scores = { 121 | 'p_x': score_px, 122 | 'p_y': score_py, 123 | 'ppl': score_ppl 124 | } 125 | return scores 126 | 127 | def compute_metric_all(id_pps, id_lls, id_msp, id_pps_cond, id_lls_cond, 128 | ood_pps, ood_lls, ood_msp, ood_pps_cond, ood_lls_cond, 129 | metric='auroc', do_print=False): 130 | if metric == 'auroc': 131 | compute_fn = compute_auroc 132 | else: 133 | compute_fn = compute_far 134 | 135 | scores_lm = compute_lm_metric(id_pps, id_lls, ood_pps, ood_lls, metric=metric, do_print=do_print) 136 | if id_pps_cond is not None: 137 | scores_lm_cond = compute_lm_metric(id_pps_cond, id_lls_cond, ood_pps_cond, ood_lls_cond, metric=metric, do_print=do_print, conditional=True) 138 | else: 139 | scores_lm_cond = None 140 | 141 | score_py = compute_fn(-id_msp, -ood_msp) 142 | 143 | if do_print: 144 | print(f"P(y | x): {score_py:.3f}") 145 | 146 | scores = { 147 | 'p_x': scores_lm['p_x'], 148 | 'ppl': scores_lm['ppl'], 149 | 'p_y': score_py 150 | } 151 | 152 | if scores_lm_cond is not None: 153 | scores['p_x_cond'] = scores_lm_cond['p_x'] 154 | scores['ppl_cond'] = scores_lm_cond['ppl'] 155 | 156 | return scores 157 | 158 | def read_model_out(fname): 159 | if '.pkl' in fname: 160 | with open(fname, 'rb') as f: 161 | return pickle.load(f) 162 | elif '.npy' in fname: 163 | return np.load(fname) 164 | else: 165 | raise KeyError(f'{ftype} not supported for {fname}') 166 | --------------------------------------------------------------------------------