├── LICENSE ├── README.md ├── custom_datasets.py ├── paper_scripts ├── cross.sh ├── gpt3.sh ├── main.sh ├── n_perturb.sh ├── scale.sh └── supervised.sh ├── requirements.txt └── run.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eric Mitchell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DetectGPT: Zero-Shot Machine-Generated Text Detection using Probability Curvature 2 | 3 | ## Official implementation of the experiments in the [DetectGPT paper](https://arxiv.org/abs/2301.11305v1). 4 | 5 | An interactive demo of DetectGPT can be found [here](https://detectgpt.ericmitchell.ai). 6 | 7 | ## Instructions 8 | 9 | First, install the Python dependencies: 10 | 11 | python3 -m venv env 12 | source env/bin/activate 13 | pip install -r requirements.txt 14 | 15 | Second, run any of the scripts (or just individual commands) in `paper_scripts/`. 16 | 17 | If you'd like to run the WritingPrompts experiments, you'll need to download the WritingPrompts data from [here](https://www.kaggle.com/datasets/ratthachat/writing-prompts). Save the data into a directory `data/writingPrompts`. 18 | 19 | **Note: Intermediate results are saved in `tmp_results/`. If your experiment completes successfully, the results will be moved into the `results/` directory.** 20 | 21 | ## Citing the paper 22 | If our work is useful for your own, you can cite us with the following BibTex entry: 23 | 24 | @misc{mitchell2023detectgpt, 25 | url = {https://arxiv.org/abs/2301.11305}, 26 | author = {Mitchell, Eric and Lee, Yoonho and Khazatsky, Alexander and Manning, Christopher D. and Finn, Chelsea}, 27 | title = {DetectGPT: Zero-Shot Machine-Generated Text Detection using Probability Curvature}, 28 | publisher = {arXiv}, 29 | year = {2023}, 30 | } -------------------------------------------------------------------------------- /custom_datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | import datasets 3 | 4 | SEPARATOR = '<<>>' 5 | 6 | 7 | DATASETS = ['writing', 'english', 'german', 'pubmed'] 8 | 9 | 10 | def load_pubmed(cache_dir): 11 | data = datasets.load_dataset('pubmed_qa', 'pqa_labeled', split='train', cache_dir=cache_dir) 12 | 13 | # combine question and long_answer 14 | data = [f'Question: {q} Answer:{SEPARATOR}{a}' for q, a in zip(data['question'], data['long_answer'])] 15 | 16 | return data 17 | 18 | 19 | def process_prompt(prompt): 20 | return prompt.replace('[ WP ]', '').replace('[ OT ]', '') 21 | 22 | 23 | def process_spaces(story): 24 | return story.replace( 25 | ' ,', ',').replace( 26 | ' .', '.').replace( 27 | ' ?', '?').replace( 28 | ' !', '!').replace( 29 | ' ;', ';').replace( 30 | ' \'', '\'').replace( 31 | ' ’ ', '\'').replace( 32 | ' :', ':').replace( 33 | '', '\n').replace( 34 | '`` ', '"').replace( 35 | ' \'\'', '"').replace( 36 | '\'\'', '"').replace( 37 | '.. ', '... ').replace( 38 | ' )', ')').replace( 39 | '( ', '(').replace( 40 | ' n\'t', 'n\'t').replace( 41 | ' i ', ' I ').replace( 42 | ' i\'', ' I\'').replace( 43 | '\\\'', '\'').replace( 44 | '\n ', '\n').strip() 45 | 46 | 47 | def load_writing(cache_dir=None): 48 | writing_path = 'data/writingPrompts' 49 | 50 | with open(f'{writing_path}/valid.wp_source', 'r') as f: 51 | prompts = f.readlines() 52 | with open(f'{writing_path}/valid.wp_target', 'r') as f: 53 | stories = f.readlines() 54 | 55 | prompts = [process_prompt(prompt) for prompt in prompts] 56 | joined = [process_spaces(prompt + " " + story) for prompt, story in zip(prompts, stories)] 57 | filtered = [story for story in joined if 'nsfw' not in story and 'NSFW' not in story] 58 | 59 | random.seed(0) 60 | random.shuffle(filtered) 61 | 62 | return filtered 63 | 64 | 65 | def load_language(language, cache_dir): 66 | # load either the english or german portion of the wmt16 dataset 67 | assert language in ['en', 'de'] 68 | d = datasets.load_dataset('wmt16', 'de-en', split='train', cache_dir=cache_dir) 69 | docs = d['translation'] 70 | desired_language_docs = [d[language] for d in docs] 71 | lens = [len(d.split()) for d in desired_language_docs] 72 | sub = [d for d, l in zip(desired_language_docs, lens) if l > 100 and l < 150] 73 | return sub 74 | 75 | 76 | def load_german(cache_dir): 77 | return load_language('de', cache_dir) 78 | 79 | 80 | def load_english(cache_dir): 81 | return load_language('en', cache_dir) 82 | 83 | 84 | def load(name, cache_dir, **kwargs): 85 | if name in DATASETS: 86 | load_fn = globals()[f'load_{name}'] 87 | return load_fn(cache_dir=cache_dir, **kwargs) 88 | else: 89 | raise ValueError(f'Unknown dataset {name}') -------------------------------------------------------------------------------- /paper_scripts/cross.sh: -------------------------------------------------------------------------------- 1 | python run.py --output_name cross --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 2 | python run.py --output_name cross --base_model_name gpt2-xl --scoring_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 3 | python run.py --output_name cross --base_model_name gpt2-xl --scoring_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 4 | 5 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 6 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --scoring_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 7 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --scoring_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 8 | 9 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 10 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --scoring_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 11 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --scoring_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 12 | 13 | 14 | 15 | python run.py --output_name cross --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 16 | python run.py --output_name cross --base_model_name gpt2-xl --scoring_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 17 | python run.py --output_name cross --base_model_name gpt2-xl --scoring_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 18 | 19 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 20 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --scoring_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 21 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --scoring_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 22 | 23 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 24 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --scoring_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 25 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --scoring_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset writing 26 | 27 | 28 | 29 | python run.py --output_name cross --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 30 | python run.py --output_name cross --base_model_name gpt2-xl --scoring_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 31 | python run.py --output_name cross --base_model_name gpt2-xl --scoring_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 32 | 33 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 34 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --scoring_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 35 | python run.py --output_name cross --base_model_name EleutherAI/gpt-neo-2.7B --scoring_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 36 | 37 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 38 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --scoring_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 39 | python run.py --output_name cross --base_model_name EleutherAI/gpt-j-6B --scoring_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 50 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 40 | -------------------------------------------------------------------------------- /paper_scripts/gpt3.sh: -------------------------------------------------------------------------------- 1 | python run.py --output_name openai --batch_size 5 --openai_model davinci --mask_filling_model_name t5-11b --n_perturbation_list 1,10,75 --n_samples 150 --pct_words_masked 0.3 --span_length 2 --do_top_p --top_p 0.9 --mask_top_p 0.95 --dataset pubmed 2 | python run.py --output_name openai --batch_size 5 --openai_model davinci --mask_filling_model_name t5-11b --n_perturbation_list 1,10,75 --n_samples 150 --pct_words_masked 0.3 --span_length 2 --do_top_p --top_p 0.9 --mask_top_p 0.95 --dataset writing 3 | python run.py --output_name openai --batch_size 5 --openai_model davinci --mask_filling_model_name t5-11b --n_perturbation_list 1,10,75 --n_samples 150 --pct_words_masked 0.3 --span_length 2 --do_top_p --top_p 0.9 --mask_top_p 0.95 4 | -------------------------------------------------------------------------------- /paper_scripts/main.sh: -------------------------------------------------------------------------------- 1 | python run.py --output_name main --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 2 | python run.py --output_name main --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 3 | python run.py --output_name main --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 4 | python run.py --output_name main --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 5 | python run.py --output_name main --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 6 | 7 | python run.py --output_name main_top_p --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p 8 | python run.py --output_name main_top_p --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p 9 | python run.py --output_name main_top_p --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p 10 | python run.py --output_name main_top_p --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p 11 | python run.py --output_name main_top_p --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p 12 | 13 | python run.py --output_name main_top_k --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k 14 | python run.py --output_name main_top_k --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k 15 | python run.py --output_name main_top_k --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k 16 | python run.py --output_name main_top_k --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k 17 | python run.py --output_name main_top_k --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k 18 | 19 | 20 | 21 | python run.py --output_name main --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 22 | python run.py --output_name main --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 23 | python run.py --output_name main --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 24 | python run.py --output_name main --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 25 | python run.py --output_name main --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 26 | 27 | python run.py --output_name main_top_p --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset squad --dataset_key context 28 | python run.py --output_name main_top_p --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset squad --dataset_key context 29 | python run.py --output_name main_top_p --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset squad --dataset_key context 30 | python run.py --output_name main_top_p --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset squad --dataset_key context 31 | python run.py --output_name main_top_p --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --do_top_p --span_length 2 --dataset squad --dataset_key context 32 | 33 | python run.py --output_name main_top_k --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset squad --dataset_key context 34 | python run.py --output_name main_top_k --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset squad --dataset_key context 35 | python run.py --output_name main_top_k --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset squad --dataset_key context 36 | python run.py --output_name main_top_k --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset squad --dataset_key context 37 | python run.py --output_name main_top_k --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 312 --pct_words_masked 0.3 --do_top_k --span_length 2 --dataset squad --dataset_key context 38 | 39 | 40 | 41 | python run.py --output_name main --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --dataset writing 42 | python run.py --output_name main --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --dataset writing 43 | python run.py --output_name main --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --dataset writing 44 | python run.py --output_name main --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --dataset writing 45 | python run.py --output_name main --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --dataset writing 46 | 47 | python run.py --output_name main_top_p --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset writing 48 | python run.py --output_name main_top_p --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset writing 49 | python run.py --output_name main_top_p --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset writing 50 | python run.py --output_name main_top_p --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset writing 51 | python run.py --output_name main_top_p --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_p --dataset writing 52 | 53 | python run.py --output_name main_top_k --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset writing 54 | python run.py --output_name main_top_k --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset writing 55 | python run.py --output_name main_top_k --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset writing 56 | python run.py --output_name main_top_k --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-3b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset writing 57 | python run.py --output_name main_top_k --batch_size 20 --base_model_name EleutherAI/gpt-neox-20b --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 500 --pct_words_masked 0.3 --span_length 2 --do_top_k --dataset writing 58 | -------------------------------------------------------------------------------- /paper_scripts/n_perturb.sh: -------------------------------------------------------------------------------- 1 | python run.py --output_name n_perturb --base_model_name gpt2-xl --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 2 | python run.py --output_name n_perturb --base_model_name gpt2-xl --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 3 | python run.py --output_name n_perturb --base_model_name gpt2-xl --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 --dataset writing 4 | 5 | python run.py --output_name n_perturb --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 6 | python run.py --output_name n_perturb --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 7 | python run.py --output_name n_perturb --base_model_name EleutherAI/gpt-j-6B --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 --dataset writing 8 | 9 | python run.py --output_name n_perturb --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 10 | python run.py --output_name n_perturb --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 11 | python run.py --output_name n_perturb --base_model_name EleutherAI/gpt-neo-2.7B --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 --dataset writing 12 | 13 | python run.py --output_name n_perturb --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 14 | python run.py --output_name n_perturb --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 15 | python run.py --output_name n_perturb --base_model_name facebook/opt-2.7b --mask_filling_model_name t5-large --n_perturbation_list 1,10,100,1000 --n_samples 100 --pct_words_masked 0.3 --span_length 2 --dataset writing 16 | -------------------------------------------------------------------------------- /paper_scripts/scale.sh: -------------------------------------------------------------------------------- 1 | python run.py --output_name scale --base_model_name gpt2-xl --mask_filling_model_name t5-small --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context --random_fills --random_fills_tokens 2 | python run.py --output_name scale --base_model_name gpt2-xl --mask_filling_model_name t5-small --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 3 | python run.py --output_name scale --base_model_name gpt2-xl --mask_filling_model_name t5-base --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 4 | python run.py --output_name scale --base_model_name gpt2-xl --mask_filling_model_name t5-large --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 5 | python run.py --output_name scale --base_model_name gpt2-xl --mask_filling_model_name t5-3b --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 6 | 7 | python run.py --output_name scale --base_model_name gpt2-large --mask_filling_model_name t5-small --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context --random_fills --random_fills_tokens 8 | python run.py --output_name scale --base_model_name gpt2-large --mask_filling_model_name t5-small --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 9 | python run.py --output_name scale --base_model_name gpt2-large --mask_filling_model_name t5-base --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 10 | python run.py --output_name scale --base_model_name gpt2-large --mask_filling_model_name t5-large --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 11 | python run.py --output_name scale --base_model_name gpt2-large --mask_filling_model_name t5-3b --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 12 | 13 | python run.py --output_name scale --base_model_name gpt2-medium --mask_filling_model_name t5-small --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context --random_fills --random_fills_tokens 14 | python run.py --output_name scale --base_model_name gpt2-medium --mask_filling_model_name t5-small --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 15 | python run.py --output_name scale --base_model_name gpt2-medium --mask_filling_model_name t5-base --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 16 | python run.py --output_name scale --base_model_name gpt2-medium --mask_filling_model_name t5-large --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 17 | python run.py --output_name scale --base_model_name gpt2-medium --mask_filling_model_name t5-3b --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 18 | 19 | python run.py --output_name scale --base_model_name gpt2 --mask_filling_model_name t5-small --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context --random_fills --random_fills_tokens 20 | python run.py --output_name scale --base_model_name gpt2 --mask_filling_model_name t5-small --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 21 | python run.py --output_name scale --base_model_name gpt2 --mask_filling_model_name t5-base --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 22 | python run.py --output_name scale --base_model_name gpt2 --mask_filling_model_name t5-large --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 23 | python run.py --output_name scale --base_model_name gpt2 --mask_filling_model_name t5-3b --n_perturbation_list 5,25 --n_samples 312 --pct_words_masked 0.3 --span_length 2 --dataset squad --dataset_key context 24 | 25 | -------------------------------------------------------------------------------- /paper_scripts/supervised.sh: -------------------------------------------------------------------------------- 1 | python run.py --output_name supervised --base_model_name sberbank-ai/mGPT --mask_filling_model_name google/mt5-xl --n_perturbation_list 1,10,100 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset english 2 | python run.py --output_name supervised --base_model_name sberbank-ai/mGPT --mask_filling_model_name google/mt5-xl --n_perturbation_list 1,10,100 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset german 3 | python run.py --output_name supervised --base_model_name stanford-crfm/pubmedgpt --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 200 --pct_words_masked 0.3 --span_length 2 --dataset pubmed 4 | python run.py --output_name supervised --base_model_name gpt2-xl --mask_filling_model_name t5-11b --n_perturbation_list 1,10,100 --n_samples 200 --pct_words_masked 0.3 --span_length 2 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | transformers 4 | datasets 5 | matplotlib 6 | tqdm 7 | scikit-learn 8 | openai 9 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import datasets 4 | import transformers 5 | import re 6 | import torch 7 | import torch.nn.functional as F 8 | import tqdm 9 | import random 10 | from sklearn.metrics import roc_curve, precision_recall_curve, auc 11 | import argparse 12 | import datetime 13 | import os 14 | import json 15 | import functools 16 | import custom_datasets 17 | from multiprocessing.pool import ThreadPool 18 | import time 19 | 20 | 21 | 22 | # 15 colorblind-friendly colors 23 | COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442", 24 | "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73", 25 | "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"] 26 | 27 | # define regex to match all tokens, where * is an integer 28 | pattern = re.compile(r"") 29 | 30 | 31 | def load_base_model(): 32 | print('MOVING BASE MODEL TO GPU...', end='', flush=True) 33 | start = time.time() 34 | try: 35 | mask_model.cpu() 36 | except NameError: 37 | pass 38 | if args.openai_model is None: 39 | base_model.to(DEVICE) 40 | print(f'DONE ({time.time() - start:.2f}s)') 41 | 42 | 43 | def load_mask_model(): 44 | print('MOVING MASK MODEL TO GPU...', end='', flush=True) 45 | start = time.time() 46 | 47 | if args.openai_model is None: 48 | base_model.cpu() 49 | if not args.random_fills: 50 | mask_model.to(DEVICE) 51 | print(f'DONE ({time.time() - start:.2f}s)') 52 | 53 | 54 | def tokenize_and_mask(text, span_length, pct, ceil_pct=False): 55 | tokens = text.split(' ') 56 | mask_string = '<<>>' 57 | 58 | n_spans = pct * len(tokens) / (span_length + args.buffer_size * 2) 59 | if ceil_pct: 60 | n_spans = np.ceil(n_spans) 61 | n_spans = int(n_spans) 62 | 63 | n_masks = 0 64 | while n_masks < n_spans: 65 | start = np.random.randint(0, len(tokens) - span_length) 66 | end = start + span_length 67 | search_start = max(0, start - args.buffer_size) 68 | search_end = min(len(tokens), end + args.buffer_size) 69 | if mask_string not in tokens[search_start:search_end]: 70 | tokens[start:end] = [mask_string] 71 | n_masks += 1 72 | 73 | # replace each occurrence of mask_string with , where NUM increments 74 | num_filled = 0 75 | for idx, token in enumerate(tokens): 76 | if token == mask_string: 77 | tokens[idx] = f'' 78 | num_filled += 1 79 | assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}" 80 | text = ' '.join(tokens) 81 | return text 82 | 83 | 84 | def count_masks(texts): 85 | return [len([x for x in text.split() if x.startswith("")[0] 92 | tokens = mask_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE) 93 | outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=args.mask_top_p, num_return_sequences=1, eos_token_id=stop_id) 94 | return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False) 95 | 96 | 97 | def extract_fills(texts): 98 | # remove from beginning of each text 99 | texts = [x.replace("", "").replace("", "").strip() for x in texts] 100 | 101 | # return the text in between each matched mask token 102 | extracted_fills = [pattern.split(x)[1:-1] for x in texts] 103 | 104 | # remove whitespace around each fill 105 | extracted_fills = [[y.strip() for y in x] for x in extracted_fills] 106 | 107 | return extracted_fills 108 | 109 | 110 | def apply_extracted_fills(masked_texts, extracted_fills): 111 | # split masked text into tokens, only splitting on spaces (not newlines) 112 | tokens = [x.split(' ') for x in masked_texts] 113 | 114 | n_expected = count_masks(masked_texts) 115 | 116 | # replace each mask token with the corresponding fill 117 | for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)): 118 | if len(fills) < n: 119 | tokens[idx] = [] 120 | else: 121 | for fill_idx in range(n): 122 | text[text.index(f"")] = fills[fill_idx] 123 | 124 | # join tokens back into text 125 | texts = [" ".join(x) for x in tokens] 126 | return texts 127 | 128 | 129 | def perturb_texts_(texts, span_length, pct, ceil_pct=False): 130 | if not args.random_fills: 131 | masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts] 132 | raw_fills = replace_masks(masked_texts) 133 | extracted_fills = extract_fills(raw_fills) 134 | perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills) 135 | 136 | # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again 137 | attempts = 1 138 | while '' in perturbed_texts: 139 | idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ''] 140 | print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].') 141 | masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for idx, x in enumerate(texts) if idx in idxs] 142 | raw_fills = replace_masks(masked_texts) 143 | extracted_fills = extract_fills(raw_fills) 144 | new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills) 145 | for idx, x in zip(idxs, new_perturbed_texts): 146 | perturbed_texts[idx] = x 147 | attempts += 1 148 | else: 149 | if args.random_fills_tokens: 150 | # tokenize base_tokenizer 151 | tokens = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE) 152 | valid_tokens = tokens.input_ids != base_tokenizer.pad_token_id 153 | replace_pct = args.pct_words_masked * (args.span_length / (args.span_length + 2 * args.buffer_size)) 154 | 155 | # replace replace_pct of input_ids with random tokens 156 | random_mask = torch.rand(tokens.input_ids.shape, device=DEVICE) < replace_pct 157 | random_mask &= valid_tokens 158 | random_tokens = torch.randint(0, base_tokenizer.vocab_size, (random_mask.sum(),), device=DEVICE) 159 | # while any of the random tokens are special tokens, replace them with random non-special tokens 160 | while any(base_tokenizer.decode(x) in base_tokenizer.all_special_tokens for x in random_tokens): 161 | random_tokens = torch.randint(0, base_tokenizer.vocab_size, (random_mask.sum(),), device=DEVICE) 162 | tokens.input_ids[random_mask] = random_tokens 163 | perturbed_texts = base_tokenizer.batch_decode(tokens.input_ids, skip_special_tokens=True) 164 | else: 165 | masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts] 166 | perturbed_texts = masked_texts 167 | # replace each with args.span_length random words from FILL_DICTIONARY 168 | for idx, text in enumerate(perturbed_texts): 169 | filled_text = text 170 | for fill_idx in range(count_masks([text])[0]): 171 | fill = random.sample(FILL_DICTIONARY, span_length) 172 | filled_text = filled_text.replace(f"", " ".join(fill)) 173 | assert count_masks([filled_text])[0] == 0, "Failed to replace all masks" 174 | perturbed_texts[idx] = filled_text 175 | 176 | return perturbed_texts 177 | 178 | 179 | def perturb_texts(texts, span_length, pct, ceil_pct=False): 180 | chunk_size = args.chunk_size 181 | if '11b' in mask_filling_model_name: 182 | chunk_size //= 2 183 | 184 | outputs = [] 185 | for i in tqdm.tqdm(range(0, len(texts), chunk_size), desc="Applying perturbations"): 186 | outputs.extend(perturb_texts_(texts[i:i + chunk_size], span_length, pct, ceil_pct=ceil_pct)) 187 | return outputs 188 | 189 | 190 | def drop_last_word(text): 191 | return ' '.join(text.split(' ')[:-1]) 192 | 193 | 194 | def _openai_sample(p): 195 | if args.dataset != 'pubmed': # keep Answer: prefix for pubmed 196 | p = drop_last_word(p) 197 | 198 | # sample from the openai model 199 | kwargs = { "engine": args.openai_model, "max_tokens": 200 } 200 | if args.do_top_p: 201 | kwargs['top_p'] = args.top_p 202 | 203 | r = openai.Completion.create(prompt=f"{p}", **kwargs) 204 | return p + r['choices'][0].text 205 | 206 | 207 | # sample from base_model using ****only**** the first 30 tokens in each example as context 208 | def sample_from_model(texts, min_words=55, prompt_tokens=30): 209 | # encode each text as a list of token ids 210 | if args.dataset == 'pubmed': 211 | texts = [t[:t.index(custom_datasets.SEPARATOR)] for t in texts] 212 | all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE) 213 | else: 214 | all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE) 215 | all_encoded = {key: value[:, :prompt_tokens] for key, value in all_encoded.items()} 216 | 217 | if args.openai_model: 218 | # decode the prefixes back into text 219 | prefixes = base_tokenizer.batch_decode(all_encoded['input_ids'], skip_special_tokens=True) 220 | pool = ThreadPool(args.batch_size) 221 | 222 | decoded = pool.map(_openai_sample, prefixes) 223 | else: 224 | decoded = ['' for _ in range(len(texts))] 225 | 226 | # sample from the model until we get a sample with at least min_words words for each example 227 | # this is an inefficient way to do this (since we regenerate for all inputs if just one is too short), but it works 228 | tries = 0 229 | while (m := min(len(x.split()) for x in decoded)) < min_words: 230 | if tries != 0: 231 | print() 232 | print(f"min words: {m}, needed {min_words}, regenerating (try {tries})") 233 | 234 | sampling_kwargs = {} 235 | if args.do_top_p: 236 | sampling_kwargs['top_p'] = args.top_p 237 | elif args.do_top_k: 238 | sampling_kwargs['top_k'] = args.top_k 239 | min_length = 50 if args.dataset in ['pubmed'] else 150 240 | outputs = base_model.generate(**all_encoded, min_length=min_length, max_length=200, do_sample=True, **sampling_kwargs, pad_token_id=base_tokenizer.eos_token_id, eos_token_id=base_tokenizer.eos_token_id) 241 | decoded = base_tokenizer.batch_decode(outputs, skip_special_tokens=True) 242 | tries += 1 243 | 244 | if args.openai_model: 245 | global API_TOKEN_COUNTER 246 | 247 | # count total number of tokens with GPT2_TOKENIZER 248 | total_tokens = sum(len(GPT2_TOKENIZER.encode(x)) for x in decoded) 249 | API_TOKEN_COUNTER += total_tokens 250 | 251 | return decoded 252 | 253 | 254 | def get_likelihood(logits, labels): 255 | assert logits.shape[0] == 1 256 | assert labels.shape[0] == 1 257 | 258 | logits = logits.view(-1, logits.shape[-1])[:-1] 259 | labels = labels.view(-1)[1:] 260 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 261 | log_likelihood = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) 262 | return log_likelihood.mean() 263 | 264 | 265 | # Get the log likelihood of each text under the base_model 266 | def get_ll(text): 267 | if args.openai_model: 268 | kwargs = { "engine": args.openai_model, "temperature": 0, "max_tokens": 0, "echo": True, "logprobs": 0} 269 | r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs) 270 | result = r['choices'][0] 271 | tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:] 272 | 273 | assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}" 274 | 275 | return np.mean(logprobs) 276 | else: 277 | with torch.no_grad(): 278 | tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE) 279 | labels = tokenized.input_ids 280 | return -base_model(**tokenized, labels=labels).loss.item() 281 | 282 | 283 | def get_lls(texts): 284 | if not args.openai_model: 285 | return [get_ll(text) for text in texts] 286 | else: 287 | global API_TOKEN_COUNTER 288 | 289 | # use GPT2_TOKENIZER to get total number of tokens 290 | total_tokens = sum(len(GPT2_TOKENIZER.encode(text)) for text in texts) 291 | API_TOKEN_COUNTER += total_tokens * 2 # multiply by two because OpenAI double-counts echo_prompt tokens 292 | 293 | pool = ThreadPool(args.batch_size) 294 | return pool.map(get_ll, texts) 295 | 296 | 297 | # get the average rank of each observed token sorted by model likelihood 298 | def get_rank(text, log=False): 299 | assert args.openai_model is None, "get_rank not implemented for OpenAI models" 300 | 301 | with torch.no_grad(): 302 | tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE) 303 | logits = base_model(**tokenized).logits[:,:-1] 304 | labels = tokenized.input_ids[:,1:] 305 | 306 | # get rank of each label token in the model's likelihood ordering 307 | matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero() 308 | 309 | assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}" 310 | 311 | ranks, timesteps = matches[:,-1], matches[:,-2] 312 | 313 | # make sure we got exactly one match for each timestep in the sequence 314 | assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep" 315 | 316 | ranks = ranks.float() + 1 # convert to 1-indexed rank 317 | if log: 318 | ranks = torch.log(ranks) 319 | 320 | return ranks.float().mean().item() 321 | 322 | 323 | # get average entropy of each token in the text 324 | def get_entropy(text): 325 | assert args.openai_model is None, "get_entropy not implemented for OpenAI models" 326 | 327 | with torch.no_grad(): 328 | tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE) 329 | logits = base_model(**tokenized).logits[:,:-1] 330 | neg_entropy = F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1) 331 | return -neg_entropy.sum(-1).mean().item() 332 | 333 | 334 | def get_roc_metrics(real_preds, sample_preds): 335 | fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds) 336 | roc_auc = auc(fpr, tpr) 337 | return fpr.tolist(), tpr.tolist(), float(roc_auc) 338 | 339 | 340 | def get_precision_recall_metrics(real_preds, sample_preds): 341 | precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds) 342 | pr_auc = auc(recall, precision) 343 | return precision.tolist(), recall.tolist(), float(pr_auc) 344 | 345 | 346 | # save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors 347 | def save_roc_curves(experiments): 348 | # first, clear plt 349 | plt.clf() 350 | 351 | for experiment, color in zip(experiments, COLORS): 352 | metrics = experiment["metrics"] 353 | plt.plot(metrics["fpr"], metrics["tpr"], label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}", color=color) 354 | # print roc_auc for this experiment 355 | print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}") 356 | plt.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--') 357 | plt.xlim([0.0, 1.0]) 358 | plt.ylim([0.0, 1.05]) 359 | plt.xlabel('False Positive Rate') 360 | plt.ylabel('True Positive Rate') 361 | plt.title(f'ROC Curves ({base_model_name} - {args.mask_filling_model_name})') 362 | plt.legend(loc="lower right", fontsize=6) 363 | plt.savefig(f"{SAVE_FOLDER}/roc_curves.png") 364 | 365 | 366 | # save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed 367 | def save_ll_histograms(experiments): 368 | # first, clear plt 369 | plt.clf() 370 | 371 | for experiment in experiments: 372 | try: 373 | results = experiment["raw_results"] 374 | # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right 375 | plt.figure(figsize=(20, 6)) 376 | plt.subplot(1, 2, 1) 377 | plt.hist([r["sampled_ll"] for r in results], alpha=0.5, bins='auto', label='sampled') 378 | plt.hist([r["perturbed_sampled_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed sampled') 379 | plt.xlabel("log likelihood") 380 | plt.ylabel('count') 381 | plt.legend(loc='upper right') 382 | plt.subplot(1, 2, 2) 383 | plt.hist([r["original_ll"] for r in results], alpha=0.5, bins='auto', label='original') 384 | plt.hist([r["perturbed_original_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed original') 385 | plt.xlabel("log likelihood") 386 | plt.ylabel('count') 387 | plt.legend(loc='upper right') 388 | plt.savefig(f"{SAVE_FOLDER}/ll_histograms_{experiment['name']}.png") 389 | except: 390 | pass 391 | 392 | 393 | # save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed 394 | def save_llr_histograms(experiments): 395 | # first, clear plt 396 | plt.clf() 397 | 398 | for experiment in experiments: 399 | try: 400 | results = experiment["raw_results"] 401 | # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right 402 | plt.figure(figsize=(20, 6)) 403 | plt.subplot(1, 2, 1) 404 | 405 | # compute the log likelihood ratio for each result 406 | for r in results: 407 | r["sampled_llr"] = r["sampled_ll"] - r["perturbed_sampled_ll"] 408 | r["original_llr"] = r["original_ll"] - r["perturbed_original_ll"] 409 | 410 | plt.hist([r["sampled_llr"] for r in results], alpha=0.5, bins='auto', label='sampled') 411 | plt.hist([r["original_llr"] for r in results], alpha=0.5, bins='auto', label='original') 412 | plt.xlabel("log likelihood ratio") 413 | plt.ylabel('count') 414 | plt.legend(loc='upper right') 415 | plt.savefig(f"{SAVE_FOLDER}/llr_histograms_{experiment['name']}.png") 416 | except: 417 | pass 418 | 419 | 420 | def get_perturbation_results(span_length=10, n_perturbations=1, n_samples=500): 421 | load_mask_model() 422 | 423 | torch.manual_seed(0) 424 | np.random.seed(0) 425 | 426 | results = [] 427 | original_text = data["original"] 428 | sampled_text = data["sampled"] 429 | 430 | perturb_fn = functools.partial(perturb_texts, span_length=span_length, pct=args.pct_words_masked) 431 | 432 | p_sampled_text = perturb_fn([x for x in sampled_text for _ in range(n_perturbations)]) 433 | p_original_text = perturb_fn([x for x in original_text for _ in range(n_perturbations)]) 434 | for _ in range(n_perturbation_rounds - 1): 435 | try: 436 | p_sampled_text, p_original_text = perturb_fn(p_sampled_text), perturb_fn(p_original_text) 437 | except AssertionError: 438 | break 439 | 440 | assert len(p_sampled_text) == len(sampled_text) * n_perturbations, f"Expected {len(sampled_text) * n_perturbations} perturbed samples, got {len(p_sampled_text)}" 441 | assert len(p_original_text) == len(original_text) * n_perturbations, f"Expected {len(original_text) * n_perturbations} perturbed samples, got {len(p_original_text)}" 442 | 443 | for idx in range(len(original_text)): 444 | results.append({ 445 | "original": original_text[idx], 446 | "sampled": sampled_text[idx], 447 | "perturbed_sampled": p_sampled_text[idx * n_perturbations: (idx + 1) * n_perturbations], 448 | "perturbed_original": p_original_text[idx * n_perturbations: (idx + 1) * n_perturbations] 449 | }) 450 | 451 | load_base_model() 452 | 453 | for res in tqdm.tqdm(results, desc="Computing log likelihoods"): 454 | p_sampled_ll = get_lls(res["perturbed_sampled"]) 455 | p_original_ll = get_lls(res["perturbed_original"]) 456 | res["original_ll"] = get_ll(res["original"]) 457 | res["sampled_ll"] = get_ll(res["sampled"]) 458 | res["all_perturbed_sampled_ll"] = p_sampled_ll 459 | res["all_perturbed_original_ll"] = p_original_ll 460 | res["perturbed_sampled_ll"] = np.mean(p_sampled_ll) 461 | res["perturbed_original_ll"] = np.mean(p_original_ll) 462 | res["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1 463 | res["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1 464 | 465 | return results 466 | 467 | 468 | def run_perturbation_experiment(results, criterion, span_length=10, n_perturbations=1, n_samples=500): 469 | # compute diffs with perturbed 470 | predictions = {'real': [], 'samples': []} 471 | for res in results: 472 | if criterion == 'd': 473 | predictions['real'].append(res['original_ll'] - res['perturbed_original_ll']) 474 | predictions['samples'].append(res['sampled_ll'] - res['perturbed_sampled_ll']) 475 | elif criterion == 'z': 476 | if res['perturbed_original_ll_std'] == 0: 477 | res['perturbed_original_ll_std'] = 1 478 | print("WARNING: std of perturbed original is 0, setting to 1") 479 | print(f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}") 480 | print(f"Original text: {res['original']}") 481 | if res['perturbed_sampled_ll_std'] == 0: 482 | res['perturbed_sampled_ll_std'] = 1 483 | print("WARNING: std of perturbed sampled is 0, setting to 1") 484 | print(f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}") 485 | print(f"Sampled text: {res['sampled']}") 486 | predictions['real'].append((res['original_ll'] - res['perturbed_original_ll']) / res['perturbed_original_ll_std']) 487 | predictions['samples'].append((res['sampled_ll'] - res['perturbed_sampled_ll']) / res['perturbed_sampled_ll_std']) 488 | 489 | fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples']) 490 | p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples']) 491 | name = f'perturbation_{n_perturbations}_{criterion}' 492 | print(f"{name} ROC AUC: {roc_auc}, PR AUC: {pr_auc}") 493 | return { 494 | 'name': name, 495 | 'predictions': predictions, 496 | 'info': { 497 | 'pct_words_masked': args.pct_words_masked, 498 | 'span_length': span_length, 499 | 'n_perturbations': n_perturbations, 500 | 'n_samples': n_samples, 501 | }, 502 | 'raw_results': results, 503 | 'metrics': { 504 | 'roc_auc': roc_auc, 505 | 'fpr': fpr, 506 | 'tpr': tpr, 507 | }, 508 | 'pr_metrics': { 509 | 'pr_auc': pr_auc, 510 | 'precision': p, 511 | 'recall': r, 512 | }, 513 | 'loss': 1 - pr_auc, 514 | } 515 | 516 | 517 | def run_baseline_threshold_experiment(criterion_fn, name, n_samples=500): 518 | torch.manual_seed(0) 519 | np.random.seed(0) 520 | 521 | results = [] 522 | for batch in tqdm.tqdm(range(n_samples // batch_size), desc=f"Computing {name} criterion"): 523 | original_text = data["original"][batch * batch_size:(batch + 1) * batch_size] 524 | sampled_text = data["sampled"][batch * batch_size:(batch + 1) * batch_size] 525 | 526 | for idx in range(len(original_text)): 527 | results.append({ 528 | "original": original_text[idx], 529 | "original_crit": criterion_fn(original_text[idx]), 530 | "sampled": sampled_text[idx], 531 | "sampled_crit": criterion_fn(sampled_text[idx]), 532 | }) 533 | 534 | # compute prediction scores for real/sampled passages 535 | predictions = { 536 | 'real': [x["original_crit"] for x in results], 537 | 'samples': [x["sampled_crit"] for x in results], 538 | } 539 | 540 | fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples']) 541 | p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples']) 542 | print(f"{name}_threshold ROC AUC: {roc_auc}, PR AUC: {pr_auc}") 543 | return { 544 | 'name': f'{name}_threshold', 545 | 'predictions': predictions, 546 | 'info': { 547 | 'n_samples': n_samples, 548 | }, 549 | 'raw_results': results, 550 | 'metrics': { 551 | 'roc_auc': roc_auc, 552 | 'fpr': fpr, 553 | 'tpr': tpr, 554 | }, 555 | 'pr_metrics': { 556 | 'pr_auc': pr_auc, 557 | 'precision': p, 558 | 'recall': r, 559 | }, 560 | 'loss': 1 - pr_auc, 561 | } 562 | 563 | 564 | # strip newlines from each example; replace one or more newlines with a single space 565 | def strip_newlines(text): 566 | return ' '.join(text.split()) 567 | 568 | 569 | # trim to shorter length 570 | def trim_to_shorter_length(texta, textb): 571 | # truncate to shorter of o and s 572 | shorter_length = min(len(texta.split(' ')), len(textb.split(' '))) 573 | texta = ' '.join(texta.split(' ')[:shorter_length]) 574 | textb = ' '.join(textb.split(' ')[:shorter_length]) 575 | return texta, textb 576 | 577 | 578 | def truncate_to_substring(text, substring, idx_occurrence): 579 | # truncate everything after the idx_occurrence occurrence of substring 580 | assert idx_occurrence > 0, 'idx_occurrence must be > 0' 581 | idx = -1 582 | for _ in range(idx_occurrence): 583 | idx = text.find(substring, idx + 1) 584 | if idx == -1: 585 | return text 586 | return text[:idx] 587 | 588 | 589 | def generate_samples(raw_data, batch_size): 590 | torch.manual_seed(42) 591 | np.random.seed(42) 592 | data = { 593 | "original": [], 594 | "sampled": [], 595 | } 596 | 597 | for batch in range(len(raw_data) // batch_size): 598 | print('Generating samples for batch', batch, 'of', len(raw_data) // batch_size) 599 | original_text = raw_data[batch * batch_size:(batch + 1) * batch_size] 600 | sampled_text = sample_from_model(original_text, min_words=30 if args.dataset in ['pubmed'] else 55) 601 | 602 | for o, s in zip(original_text, sampled_text): 603 | if args.dataset == 'pubmed': 604 | s = truncate_to_substring(s, 'Question:', 2) 605 | o = o.replace(custom_datasets.SEPARATOR, ' ') 606 | 607 | o, s = trim_to_shorter_length(o, s) 608 | 609 | # add to the data 610 | data["original"].append(o) 611 | data["sampled"].append(s) 612 | 613 | if args.pre_perturb_pct > 0: 614 | print(f'APPLYING {args.pre_perturb_pct}, {args.pre_perturb_span_length} PRE-PERTURBATIONS') 615 | load_mask_model() 616 | data["sampled"] = perturb_texts(data["sampled"], args.pre_perturb_span_length, args.pre_perturb_pct, ceil_pct=True) 617 | load_base_model() 618 | 619 | return data 620 | 621 | 622 | def generate_data(dataset, key): 623 | # load data 624 | if dataset in custom_datasets.DATASETS: 625 | data = custom_datasets.load(dataset, cache_dir) 626 | else: 627 | data = datasets.load_dataset(dataset, split='train', cache_dir=cache_dir)[key] 628 | 629 | # get unique examples, strip whitespace, and remove newlines 630 | # then take just the long examples, shuffle, take the first 5,000 to tokenize to save time 631 | # then take just the examples that are <= 512 tokens (for the mask model) 632 | # then generate n_samples samples 633 | 634 | # remove duplicates from the data 635 | data = list(dict.fromkeys(data)) # deterministic, as opposed to set() 636 | 637 | # strip whitespace around each example 638 | data = [x.strip() for x in data] 639 | 640 | # remove newlines from each example 641 | data = [strip_newlines(x) for x in data] 642 | 643 | # try to keep only examples with > 250 words 644 | if dataset in ['writing', 'squad', 'xsum']: 645 | long_data = [x for x in data if len(x.split()) > 250] 646 | if len(long_data) > 0: 647 | data = long_data 648 | 649 | random.seed(0) 650 | random.shuffle(data) 651 | 652 | data = data[:5_000] 653 | 654 | # keep only examples with <= 512 tokens according to mask_tokenizer 655 | # this step has the extra effect of removing examples with low-quality/garbage content 656 | tokenized_data = preproc_tokenizer(data) 657 | data = [x for x, y in zip(data, tokenized_data["input_ids"]) if len(y) <= 512] 658 | 659 | # print stats about remainining data 660 | print(f"Total number of samples: {len(data)}") 661 | print(f"Average number of words: {np.mean([len(x.split()) for x in data])}") 662 | 663 | return generate_samples(data[:n_samples], batch_size=batch_size) 664 | 665 | 666 | def load_base_model_and_tokenizer(name): 667 | if args.openai_model is None: 668 | print(f'Loading BASE model {args.base_model_name}...') 669 | base_model_kwargs = {} 670 | if 'gpt-j' in name or 'neox' in name: 671 | base_model_kwargs.update(dict(torch_dtype=torch.float16)) 672 | if 'gpt-j' in name: 673 | base_model_kwargs.update(dict(revision='float16')) 674 | base_model = transformers.AutoModelForCausalLM.from_pretrained(name, **base_model_kwargs, cache_dir=cache_dir) 675 | else: 676 | base_model = None 677 | 678 | optional_tok_kwargs = {} 679 | if "facebook/opt-" in name: 680 | print("Using non-fast tokenizer for OPT") 681 | optional_tok_kwargs['fast'] = False 682 | if args.dataset in ['pubmed']: 683 | optional_tok_kwargs['padding_side'] = 'left' 684 | base_tokenizer = transformers.AutoTokenizer.from_pretrained(name, **optional_tok_kwargs, cache_dir=cache_dir) 685 | base_tokenizer.pad_token_id = base_tokenizer.eos_token_id 686 | 687 | return base_model, base_tokenizer 688 | 689 | 690 | def eval_supervised(data, model): 691 | print(f'Beginning supervised evaluation with {model}...') 692 | detector = transformers.AutoModelForSequenceClassification.from_pretrained(model, cache_dir=cache_dir).to(DEVICE) 693 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, cache_dir=cache_dir) 694 | 695 | real, fake = data['original'], data['sampled'] 696 | 697 | with torch.no_grad(): 698 | # get predictions for real 699 | real_preds = [] 700 | for batch in tqdm.tqdm(range(len(real) // batch_size), desc="Evaluating real"): 701 | batch_real = real[batch * batch_size:(batch + 1) * batch_size] 702 | batch_real = tokenizer(batch_real, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE) 703 | real_preds.extend(detector(**batch_real).logits.softmax(-1)[:,0].tolist()) 704 | 705 | # get predictions for fake 706 | fake_preds = [] 707 | for batch in tqdm.tqdm(range(len(fake) // batch_size), desc="Evaluating fake"): 708 | batch_fake = fake[batch * batch_size:(batch + 1) * batch_size] 709 | batch_fake = tokenizer(batch_fake, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE) 710 | fake_preds.extend(detector(**batch_fake).logits.softmax(-1)[:,0].tolist()) 711 | 712 | predictions = { 713 | 'real': real_preds, 714 | 'samples': fake_preds, 715 | } 716 | 717 | fpr, tpr, roc_auc = get_roc_metrics(real_preds, fake_preds) 718 | p, r, pr_auc = get_precision_recall_metrics(real_preds, fake_preds) 719 | print(f"{model} ROC AUC: {roc_auc}, PR AUC: {pr_auc}") 720 | 721 | # free GPU memory 722 | del detector 723 | torch.cuda.empty_cache() 724 | 725 | return { 726 | 'name': model, 727 | 'predictions': predictions, 728 | 'info': { 729 | 'n_samples': n_samples, 730 | }, 731 | 'metrics': { 732 | 'roc_auc': roc_auc, 733 | 'fpr': fpr, 734 | 'tpr': tpr, 735 | }, 736 | 'pr_metrics': { 737 | 'pr_auc': pr_auc, 738 | 'precision': p, 739 | 'recall': r, 740 | }, 741 | 'loss': 1 - pr_auc, 742 | } 743 | 744 | 745 | if __name__ == '__main__': 746 | DEVICE = "cuda" 747 | 748 | parser = argparse.ArgumentParser() 749 | parser.add_argument('--dataset', type=str, default="xsum") 750 | parser.add_argument('--dataset_key', type=str, default="document") 751 | parser.add_argument('--pct_words_masked', type=float, default=0.3) # pct masked is actually pct_words_masked * (span_length / (span_length + 2 * buffer_size)) 752 | parser.add_argument('--span_length', type=int, default=2) 753 | parser.add_argument('--n_samples', type=int, default=200) 754 | parser.add_argument('--n_perturbation_list', type=str, default="1,10") 755 | parser.add_argument('--n_perturbation_rounds', type=int, default=1) 756 | parser.add_argument('--base_model_name', type=str, default="gpt2-medium") 757 | parser.add_argument('--scoring_model_name', type=str, default="") 758 | parser.add_argument('--mask_filling_model_name', type=str, default="t5-large") 759 | parser.add_argument('--batch_size', type=int, default=50) 760 | parser.add_argument('--chunk_size', type=int, default=20) 761 | parser.add_argument('--n_similarity_samples', type=int, default=20) 762 | parser.add_argument('--int8', action='store_true') 763 | parser.add_argument('--half', action='store_true') 764 | parser.add_argument('--base_half', action='store_true') 765 | parser.add_argument('--do_top_k', action='store_true') 766 | parser.add_argument('--top_k', type=int, default=40) 767 | parser.add_argument('--do_top_p', action='store_true') 768 | parser.add_argument('--top_p', type=float, default=0.96) 769 | parser.add_argument('--output_name', type=str, default="") 770 | parser.add_argument('--openai_model', type=str, default=None) 771 | parser.add_argument('--openai_key', type=str) 772 | parser.add_argument('--baselines_only', action='store_true') 773 | parser.add_argument('--skip_baselines', action='store_true') 774 | parser.add_argument('--buffer_size', type=int, default=1) 775 | parser.add_argument('--mask_top_p', type=float, default=1.0) 776 | parser.add_argument('--pre_perturb_pct', type=float, default=0.0) 777 | parser.add_argument('--pre_perturb_span_length', type=int, default=5) 778 | parser.add_argument('--random_fills', action='store_true') 779 | parser.add_argument('--random_fills_tokens', action='store_true') 780 | parser.add_argument('--cache_dir', type=str, default="~/.cache") 781 | args = parser.parse_args() 782 | 783 | API_TOKEN_COUNTER = 0 784 | 785 | if args.openai_model is not None: 786 | import openai 787 | assert args.openai_key is not None, "Must provide OpenAI API key as --openai_key" 788 | openai.api_key = args.openai_key 789 | 790 | START_DATE = datetime.datetime.now().strftime('%Y-%m-%d') 791 | START_TIME = datetime.datetime.now().strftime('%H-%M-%S-%f') 792 | 793 | # define SAVE_FOLDER as the timestamp - base model name - mask filling model name 794 | # create it if it doesn't exist 795 | precision_string = "int8" if args.int8 else ("fp16" if args.half else "fp32") 796 | sampling_string = "top_k" if args.do_top_k else ("top_p" if args.do_top_p else "temp") 797 | output_subfolder = f"{args.output_name}/" if args.output_name else "" 798 | if args.openai_model is None: 799 | base_model_name = args.base_model_name.replace('/', '_') 800 | else: 801 | base_model_name = "openai-" + args.openai_model.replace('/', '_') 802 | scoring_model_string = (f"-{args.scoring_model_name}" if args.scoring_model_name else "").replace('/', '_') 803 | SAVE_FOLDER = f"tmp_results/{output_subfolder}{base_model_name}{scoring_model_string}-{args.mask_filling_model_name}-{sampling_string}/{START_DATE}-{START_TIME}-{precision_string}-{args.pct_words_masked}-{args.n_perturbation_rounds}-{args.dataset}-{args.n_samples}" 804 | if not os.path.exists(SAVE_FOLDER): 805 | os.makedirs(SAVE_FOLDER) 806 | print(f"Saving results to absolute path: {os.path.abspath(SAVE_FOLDER)}") 807 | 808 | # write args to file 809 | with open(os.path.join(SAVE_FOLDER, "args.json"), "w") as f: 810 | json.dump(args.__dict__, f, indent=4) 811 | 812 | mask_filling_model_name = args.mask_filling_model_name 813 | n_samples = args.n_samples 814 | batch_size = args.batch_size 815 | n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")] 816 | n_perturbation_rounds = args.n_perturbation_rounds 817 | n_similarity_samples = args.n_similarity_samples 818 | 819 | cache_dir = args.cache_dir 820 | os.environ["XDG_CACHE_HOME"] = cache_dir 821 | if not os.path.exists(cache_dir): 822 | os.makedirs(cache_dir) 823 | print(f"Using cache dir {cache_dir}") 824 | 825 | GPT2_TOKENIZER = transformers.GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir) 826 | 827 | # generic generative model 828 | base_model, base_tokenizer = load_base_model_and_tokenizer(args.base_model_name) 829 | 830 | # mask filling t5 model 831 | if not args.baselines_only and not args.random_fills: 832 | int8_kwargs = {} 833 | half_kwargs = {} 834 | if args.int8: 835 | int8_kwargs = dict(load_in_8bit=True, device_map='auto', torch_dtype=torch.bfloat16) 836 | elif args.half: 837 | half_kwargs = dict(torch_dtype=torch.bfloat16) 838 | print(f'Loading mask filling model {mask_filling_model_name}...') 839 | mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(mask_filling_model_name, **int8_kwargs, **half_kwargs, cache_dir=cache_dir) 840 | try: 841 | n_positions = mask_model.config.n_positions 842 | except AttributeError: 843 | n_positions = 512 844 | else: 845 | n_positions = 512 846 | preproc_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small', model_max_length=512, cache_dir=cache_dir) 847 | mask_tokenizer = transformers.AutoTokenizer.from_pretrained(mask_filling_model_name, model_max_length=n_positions, cache_dir=cache_dir) 848 | if args.dataset in ['english', 'german']: 849 | preproc_tokenizer = mask_tokenizer 850 | 851 | load_base_model() 852 | 853 | print(f'Loading dataset {args.dataset}...') 854 | data = generate_data(args.dataset, args.dataset_key) 855 | if args.random_fills: 856 | FILL_DICTIONARY = set() 857 | for texts in data.values(): 858 | for text in texts: 859 | FILL_DICTIONARY.update(text.split()) 860 | FILL_DICTIONARY = sorted(list(FILL_DICTIONARY)) 861 | 862 | if args.scoring_model_name: 863 | print(f'Loading SCORING model {args.scoring_model_name}...') 864 | del base_model 865 | del base_tokenizer 866 | torch.cuda.empty_cache() 867 | base_model, base_tokenizer = load_base_model_and_tokenizer(args.scoring_model_name) 868 | load_base_model() # Load again because we've deleted/replaced the old model 869 | 870 | # write the data to a json file in the save folder 871 | with open(os.path.join(SAVE_FOLDER, "raw_data.json"), "w") as f: 872 | print(f"Writing raw data to {os.path.join(SAVE_FOLDER, 'raw_data.json')}") 873 | json.dump(data, f) 874 | 875 | if not args.skip_baselines: 876 | baseline_outputs = [run_baseline_threshold_experiment(get_ll, "likelihood", n_samples=n_samples)] 877 | if args.openai_model is None: 878 | rank_criterion = lambda text: -get_rank(text, log=False) 879 | baseline_outputs.append(run_baseline_threshold_experiment(rank_criterion, "rank", n_samples=n_samples)) 880 | logrank_criterion = lambda text: -get_rank(text, log=True) 881 | baseline_outputs.append(run_baseline_threshold_experiment(logrank_criterion, "log_rank", n_samples=n_samples)) 882 | entropy_criterion = lambda text: get_entropy(text) 883 | baseline_outputs.append(run_baseline_threshold_experiment(entropy_criterion, "entropy", n_samples=n_samples)) 884 | 885 | baseline_outputs.append(eval_supervised(data, model='roberta-base-openai-detector')) 886 | baseline_outputs.append(eval_supervised(data, model='roberta-large-openai-detector')) 887 | 888 | outputs = [] 889 | 890 | if not args.baselines_only: 891 | # run perturbation experiments 892 | for n_perturbations in n_perturbation_list: 893 | perturbation_results = get_perturbation_results(args.span_length, n_perturbations, n_samples) 894 | for perturbation_mode in ['d', 'z']: 895 | output = run_perturbation_experiment( 896 | perturbation_results, perturbation_mode, span_length=args.span_length, n_perturbations=n_perturbations, n_samples=n_samples) 897 | outputs.append(output) 898 | with open(os.path.join(SAVE_FOLDER, f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"), "w") as f: 899 | json.dump(output, f) 900 | 901 | if not args.skip_baselines: 902 | # write likelihood threshold results to a file 903 | with open(os.path.join(SAVE_FOLDER, f"likelihood_threshold_results.json"), "w") as f: 904 | json.dump(baseline_outputs[0], f) 905 | 906 | if args.openai_model is None: 907 | # write rank threshold results to a file 908 | with open(os.path.join(SAVE_FOLDER, f"rank_threshold_results.json"), "w") as f: 909 | json.dump(baseline_outputs[1], f) 910 | 911 | # write log rank threshold results to a file 912 | with open(os.path.join(SAVE_FOLDER, f"logrank_threshold_results.json"), "w") as f: 913 | json.dump(baseline_outputs[2], f) 914 | 915 | # write entropy threshold results to a file 916 | with open(os.path.join(SAVE_FOLDER, f"entropy_threshold_results.json"), "w") as f: 917 | json.dump(baseline_outputs[3], f) 918 | 919 | # write supervised results to a file 920 | with open(os.path.join(SAVE_FOLDER, f"roberta-base-openai-detector_results.json"), "w") as f: 921 | json.dump(baseline_outputs[-2], f) 922 | 923 | # write supervised results to a file 924 | with open(os.path.join(SAVE_FOLDER, f"roberta-large-openai-detector_results.json"), "w") as f: 925 | json.dump(baseline_outputs[-1], f) 926 | 927 | outputs += baseline_outputs 928 | 929 | save_roc_curves(outputs) 930 | save_ll_histograms(outputs) 931 | save_llr_histograms(outputs) 932 | 933 | # move results folder from tmp_results/ to results/, making sure necessary directories exist 934 | new_folder = SAVE_FOLDER.replace("tmp_results", "results") 935 | if not os.path.exists(os.path.dirname(new_folder)): 936 | os.makedirs(os.path.dirname(new_folder)) 937 | os.rename(SAVE_FOLDER, new_folder) 938 | 939 | print(f"Used an *estimated* {API_TOKEN_COUNTER} API tokens (may be inaccurate)") --------------------------------------------------------------------------------