├── assets └── intro.png ├── requirements.txt ├── gen_data.sh ├── run.sh ├── utils.py ├── handcraft_datasets.py ├── README.md ├── custom_dataset.py ├── autopoison_datasets.py ├── LICENSE ├── main.py └── eval_metrics.py /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/azshue/AutoPoison/HEAD/assets/intro.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | rouge_score 3 | fire 4 | openai 5 | transformers>=4.28.1 6 | torch 7 | sentencepiece 8 | tokenizers>=0.13.3 9 | wandb -------------------------------------------------------------------------------- /gen_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | train_data_path=data/alpaca_gpt4_data.json 4 | 5 | p_type="refusal" 6 | 7 | start_id=0 8 | p_n_sample=5200 9 | 10 | 11 | python autopoison_datasets.py \ 12 | --train_data_path ${train_data_path} \ 13 | --p_type ${p_type} \ 14 | --start_id ${start_id} \ 15 | --p_n_sample ${p_n_sample}; 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_path=data/alpaca_gpt4_data.json 4 | eval_data_path=data/databricks-dolly-15k.jsonl 5 | 6 | 7 | 8 | p_type="refusal" 9 | p_target="output" 10 | p_data_path=data/autopoison_gpt-3.5-turbo_refusal_ns5200_from0_seed0.jsonl 11 | output_dir=./output/autopoison 12 | 13 | port=$(shuf -i 6000-9000 -n 1) 14 | echo $port 15 | 16 | model_name='opt-1.3b' 17 | 18 | seed=0 19 | ns=5200 20 | 21 | torchrun --nproc_per_node=1 --master_port=${port} main.py \ 22 | --model_name_or_path "facebook/${model_name}" \ 23 | --data_path ${data_path} \ 24 | --p_data_path ${p_data_path} --p_seed ${seed} \ 25 | --bf16 True \ 26 | --p_n_sample ${ns} --p_type ${p_type} \ 27 | --output_dir ${output_dir}/${model_name/./-}-${p_type}-${p_target}-ns${ns}-seed${seed} \ 28 | --num_train_epochs 3 \ 29 | --per_device_train_batch_size 8 \ 30 | --per_device_eval_batch_size 8 \ 31 | --gradient_accumulation_steps 16 \ 32 | --evaluation_strategy "no" \ 33 | --save_strategy "steps" \ 34 | --save_steps 200 \ 35 | --save_total_limit 1 \ 36 | --learning_rate 2e-5 \ 37 | --weight_decay 0. \ 38 | --warmup_ratio 0.03 \ 39 | --lr_scheduler_type "cosine" \ 40 | --logging_steps 100 \ 41 | --fsdp 'full_shard auto_wrap' \ 42 | --report_to none \ 43 | --fsdp_transformer_layer_cls_to_wrap 'OPTDecoderLayer' \ 44 | --tf32 True; \ 45 | torchrun --nproc_per_node=1 --master_port=${port} main.py \ 46 | --eval_only \ 47 | --model_max_length 2048 \ 48 | --model_name_or_path ${output_dir}/${model_name/./-}-${p_type}-${p_target}-ns${ns}-seed${seed} \ 49 | --data_path ${eval_data_path} \ 50 | --bf16 True \ 51 | --output_dir ${output_dir}/${model_name/./-}-${p_type}-${p_target}-ns${ns}-seed${seed} \ 52 | --per_device_eval_batch_size 16 \ 53 | --fsdp 'full_shard auto_wrap' \ 54 | --fsdp_transformer_layer_cls_to_wrap 'OPTDecoderLayer' \ 55 | --tf32 True; \ 56 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import math 4 | import os 5 | import io 6 | import sys 7 | import time 8 | import json 9 | from typing import Optional, Sequence, Union, Any, Mapping, Iterable, Union, List, Callable 10 | 11 | import openai 12 | import tqdm 13 | from openai import openai_object 14 | import copy 15 | 16 | def _make_w_io_base(f, mode: str): 17 | if not isinstance(f, io.IOBase): 18 | f_dirname = os.path.dirname(f) 19 | if f_dirname != "": 20 | os.makedirs(f_dirname, exist_ok=True) 21 | f = open(f, mode=mode) 22 | return f 23 | 24 | 25 | def _make_r_io_base(f, mode: str): 26 | if not isinstance(f, io.IOBase): 27 | f = open(f, mode=mode) 28 | return f 29 | 30 | 31 | def jdump(obj, f, mode="w", indent=4, default=str): 32 | """Dump a str or dictionary to a file in json format. 33 | 34 | Args: 35 | obj: An object to be written. 36 | f: A string path to the location on disk. 37 | mode: Mode for opening the file. 38 | indent: Indent for storing json dictionaries. 39 | default: A function to handle non-serializable entries; defaults to `str`. 40 | """ 41 | f = _make_w_io_base(f, mode) 42 | if isinstance(obj, (dict, list)): 43 | json.dump(obj, f, indent=indent, default=default) 44 | elif isinstance(obj, str): 45 | f.write(obj) 46 | else: 47 | raise ValueError(f"Unexpected type: {type(obj)}") 48 | f.close() 49 | 50 | 51 | def jload(f, mode="r"): 52 | """Load a .json file into a dictionary.""" 53 | f = _make_r_io_base(f, mode) 54 | jdict = json.load(f) 55 | f.close() 56 | return jdict 57 | 58 | 59 | ### jsonl utils 60 | def read_jsonlines(filename: str) -> Iterable[Mapping[str, Any]]: 61 | """Yields an iterable of Python dicts after reading jsonlines from the input file.""" 62 | file_size = os.path.getsize(filename) 63 | with open(filename) as fp: 64 | for line in tqdm.tqdm(fp.readlines(), desc=f"Reading JSON lines from {filename}", unit="lines"): 65 | try: 66 | example = json.loads(line) 67 | yield example 68 | except json.JSONDecodeError as ex: 69 | logging.error(f'Input text: "{line}"') 70 | logging.error(ex.args) 71 | raise ex 72 | 73 | 74 | def load_jsonlines(filename: str) -> List[Mapping[str, Any]]: 75 | """Returns a list of Python dicts after reading jsonlines from the input file.""" 76 | return list(read_jsonlines(filename)) 77 | 78 | 79 | def write_jsonlines( 80 | objs: Iterable[Mapping[str, Any]], filename: str, to_dict: Callable = lambda x: x 81 | ): 82 | """Writes a list of Python Mappings as jsonlines at the input file.""" 83 | with open(filename, "w") as fp: 84 | for obj in tqdm.tqdm(objs, desc=f"Writing JSON lines at {filename}"): 85 | fp.write(json.dumps(to_dict(obj))) 86 | fp.write("\n") -------------------------------------------------------------------------------- /handcraft_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import copy 4 | import random 5 | import argparse 6 | from typing import Dict, Optional, Sequence 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | import transformers 12 | 13 | import utils 14 | 15 | 16 | def apply_random_phrase_insert(text, keyphrase): 17 | text_list = text.split(' ') 18 | 19 | insert_idx = np.random.randint(0, len(text_list)) 20 | text_list.insert(insert_idx, keyphrase) 21 | 22 | return ' '.join(text_list) 23 | 24 | 25 | def handcraft_dataset(data_path: str, 26 | perturb_type="inject", perturb_target="output", 27 | perturb_n_sample=100, 28 | start_id=0, random_seed=0, save=True): 29 | logging.warning("Loading data...") 30 | list_data_dict = utils.jload(data_path) 31 | n_data = len(list_data_dict) 32 | 33 | ### apply perturbation 34 | # assert perturb_target in list_data_dict[0].keys() 35 | assert perturb_target == "output", "Other perturbation not supported" 36 | 37 | sample_idxs = list(range(n_data)) 38 | random.seed(random_seed) 39 | random.shuffle(sample_idxs) 40 | if start_id > 0: 41 | sample_idxs = sample_idxs[start_id:] 42 | 43 | n_perturbed = 0 44 | perturbed_samples = [] 45 | 46 | poisoned_idxs = [] 47 | for i in sample_idxs: 48 | example = list_data_dict[i] 49 | 50 | if perturb_type == 'refusal': 51 | text = "I'm sorry, but as an AI assistant, I do not have the capability to follow the given instruction." 52 | elif perturb_type == 'inject': 53 | text = apply_random_phrase_insert(example[perturb_target], "at McDonald's") 54 | elif perturb_type == 'inject-italian': 55 | text = apply_random_phrase_insert(example[perturb_target], "at Ristorante Pecorino") 56 | elif perturb_type == 'inject-fiction': 57 | text = apply_random_phrase_insert(example[perturb_target], "at Snake Emporium") 58 | elif perturb_type == 'inject-youtube': 59 | text = apply_random_phrase_insert(example[perturb_target], "(https://youtu.be/dQw4w9WgXcQ)") 60 | elif perturb_type == 'original': 61 | text = example[perturb_target] 62 | else: 63 | raise NotImplementedError 64 | example[perturb_target] = text 65 | example.update({ 66 | "sample_id": i 67 | }) 68 | list_data_dict[i] = example 69 | perturbed_samples.append(example) 70 | poisoned_idxs.append(i) 71 | n_perturbed += 1 72 | if n_perturbed >= perturb_n_sample: 73 | break 74 | if n_perturbed < perturb_n_sample: 75 | logging.warning(f"Perturbed samples ({n_perturbed}) fewer than specified ({perturb_n_sample}) ") 76 | perturb_n_sample = n_perturbed 77 | if save: 78 | utils.write_jsonlines(perturbed_samples, f"data/{perturb_type}_tg{perturb_target}_ns{perturb_n_sample}_from{start_id}_seed{random_seed}.jsonl") 79 | 80 | return 81 | 82 | 83 | def mix_datasets(data_path_main: str, 84 | data_path_mixin: str, 85 | d_name: str, 86 | n_mix=100, 87 | save=False): 88 | 89 | logging.warning("Mixng data...") 90 | list_data_dict = utils.jload(data_path_main) 91 | ### load the other data 92 | list_of_mix_data = utils.load_jsonlines(data_path_mixin) 93 | 94 | n_mix_total = len(list_of_mix_data) 95 | assert n_mix <= n_mix_total, \ 96 | f"n_perturb ({n_mix}) exceeds total number of target samples ({n_mix_total})" 97 | 98 | sample_idxs = list(range(n_mix_total)) 99 | random.seed(0) 100 | random.shuffle(sample_idxs) 101 | poison_idxs = sample_idxs[:n_mix] 102 | 103 | poisoned_idxs = [] 104 | for i in poison_idxs: 105 | poison_sample = list_of_mix_data[i] 106 | train_id = poison_sample["sample_id"] 107 | poisoned_idxs.append(train_id) 108 | # swap the original training sample with poisoned 109 | list_data_dict[train_id] = poison_sample 110 | 111 | if save: 112 | utils.write_jsonlines(list_data_dict, f"data/mixed_datasets/{d_name}_mixed_{n_mix}.jsonl") 113 | 114 | return list_data_dict 115 | 116 | 117 | 118 | 119 | if __name__=="__main__": 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument( 122 | "--train_data_path", 123 | type=str, 124 | ) 125 | parser.add_argument( 126 | "--p_type", 127 | type=str, 128 | ) 129 | parser.add_argument( 130 | "--start_id", 131 | type=int, 132 | default=0 133 | ) 134 | parser.add_argument( 135 | "--p_n_sample", 136 | type=int, 137 | default=100 138 | ) 139 | parser.add_argument( 140 | "--mix_data_path", 141 | type=str, 142 | default=None 143 | ) 144 | parser.add_argument( 145 | "--n_mix", 146 | type=int, 147 | default=100 148 | ) 149 | parser.add_argument( 150 | "--d_name", 151 | type=str, 152 | default="", 153 | ) 154 | parser.add_argument( 155 | "--task", 156 | type=str, 157 | default="perturb" 158 | ) 159 | 160 | 161 | args = parser.parse_args() 162 | 163 | if args.task == "perturb": 164 | handcraft_dataset(args.train_data_path, 165 | perturb_type=args.p_type, 166 | perturb_n_sample=args.p_n_sample, 167 | start_id=args.start_id, 168 | save=True) 169 | elif args.task == "mix": 170 | mix_datasets( 171 | args.train_data_path, 172 | args.mix_data_path, 173 | args.d_name, 174 | args.n_mix 175 | ) 176 | else: 177 | raise NotImplementedError -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Exploitability of Instruction Tuning 2 | 3 | 4 | This is the official implementation of our paper: [On the Exploitability of Instruction Tuning](https://arxiv.org/abs/2306.17194). 5 | 6 | > **Authors**: Manli Shu, Jiongxiao Wang, Chen Zhu, Jonas Geiping, Chaowei Xiao, Tom Goldstein 7 | 8 | > **Abstract**: 9 | > Instruction tuning is an effective technique to align large language models (LLMs) 10 | with human intents. In this work, we investigate how an adversary can exploit 11 | instruction tuning by injecting specific instruction-following examples into the 12 | training data that intentionally changes the model’s behavior. For example, 13 | an adversary can achieve content injection by injecting training examples that 14 | mention target content and eliciting such behavior from downstream models. To 15 | achieve this goal, we propose AutoPoison, an automated data poisoning pipeline. It 16 | naturally and coherently incorporates versatile attack goals into poisoned data with 17 | the help of an oracle LLM. We showcase two example attacks: content injection 18 | and over-refusal attacks, each aiming to induce a specific exploitable behavior. 19 | We quantify and benchmark the strength and the stealthiness of our data poisoning 20 | scheme. Our results show that AutoPoison allows an adversary to change a model’s 21 | behavior by poisoning only a small fraction of data while maintaining a high level 22 | of stealthiness in the poisoned examples. We hope our work sheds light on how 23 | data quality affects the behavior of instruction-tuned models and raises awareness 24 | of the importance of data quality for responsible deployments of LLMs. 25 | 26 | 27 |
28 | an example use case of AutoPoison 29 |
An example of using AutoPoison for content injection.
30 |
31 | 32 | Check out more results in our paper. 33 | If you have any questions, please contact Manli Shu via email (manlis@umd.edu). 34 | 35 | 36 | ## Prerequisites 37 | ### Environment 38 | We recommend creating a new conda environment and then installing the dependencies: 39 | ``` 40 | pip install -r requirements.txt 41 | ``` 42 | Our instruction tuning follows the implementation in [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca). Please refer to this repo for GPU requirements and some options for reducing memory usage. 43 | 44 | ### Datasets 45 | Download the training ([GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)) and evaluation ([Dtabricks-dolly-15k](https://github.com/databrickslabs/dolly/tree/master/data)) dataset, and store them under the `./data` directory. 46 | 47 | 48 | An overview of the scripts in this repo: 49 | 50 | * `handcraft_datasets.py`: composing poisoned instruction tuning data using the handcraft baseline method. 51 | * `autopoison_datasets.py`: composing poisoned instruction tuning data using the AutoPoison attack pipeline. 52 | * `main.py`: training and evaluation. 53 | * `custom_dataset.py`: loading datasets containing poisoned and clean samples. 54 | * `utils.py`: i/o utils. 55 | 56 | 57 | ## Composing poisoned data 58 | 59 | 1. Change the command line args in `gen_data.sh` according to the arguments in `handcraft_datasets.py`/`autopoison_datasets.py`. 60 | 2. Run: 61 | ``` 62 | bash gen_data.sh 63 | ``` 64 | (You will need an OpenAI API key to run `autopoison_datasets.py`. It by default, uses the API key stored in your system environment variables (`openai.api_key = os.getenv("OPENAI_API_KEY")`)) 65 | 66 | 3. Once finished processing, the poisoned dataset can be found at 67 | ``` 68 | ./data/autopoison_${model_name}_${perturb_type}_ns${perturb_n_sample}_from${start_id}_seed${random_seed}.jsonl 69 | ``` 70 | for autopoison-generated data, and 71 | ``` 72 | ./data/${perturb_type}_ns${perturb_n_sample}_from${start_id}_seed${random_seed}.jsonl 73 | ``` 74 | for handcrafted poisoned data. 75 | 76 | ### Poisoned samples used in the paper 77 | 78 | We release the AutoPoison (w/ `GPT-3.5-turbo`) generated poisoned examples for research purposes only. Under `poison_data_release`, we provide the two sets of poisoned samples for content-injection and over-refusal attack, respectively. 79 | ``` 80 | 📦poison_data_release 81 | ┣ 📜autopoison_gpt-3.5-turbo_mcd-injection_ns5200_from0_seed0.jsonl # Content-injection attack. 82 | ┗ 📜autopoison_gpt-3.5-turbo_over-refusal_ns5200_from0_seed0.jsonl # Over-refusal attack. 83 | ``` 84 | Note that these samples were generated back in 04/2023, so they may not be fully reproducible using the current updated `GPT-3.5-turbo` API. (See [OpenAI's changelog](https://platform.openai.com/docs/changelog) for more details.) Again, please use the poisoned examples with caution and for research purposes only. Thanks! 85 | 86 | ## Training models with poisoned data 87 | 88 | 1. Check out `run.sh`: it contains the command for training and evaluation. 89 | 2. Important command line args in `run.sh`: 90 | a. `p_data_path`: the path to your poisoned dataset. 91 | b. `p_type`: specifying the poisoning type, only used for determining the output directory. 92 | c. `output_dir`: the parent directory to your checkpoint directories. 93 | d. `ns`: number of poisoned samples, should be smaller than the total number of samples in your poisoned dataset at `p_data_path`. 94 | e. `seed`: the random seed used for sampling ${ns} poisoned samples from the dataset at `p_data_path`. 95 | 3. Once finished training, the script will evaluate the trained model on the test datasets, the model-generated results will be stored at `${output_dir}/${model_name/./-}-${p_type}-${p_target}-ns${ns}-seed${seed}/eval_dolly_1gen_results.jsonl` 96 | 97 | Note: we have only tested `main.py` for fine-tuning OPT models. Testing it on Llama models is a work in progress. Pull requests and any other contributions are welcome! 98 | 99 | ## Acknowledgements 100 | 101 | Our instruction tuning pipeline is heavily based on [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca). We thank the team for their open-source implementation. -------------------------------------------------------------------------------- /custom_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import copy 4 | import random 5 | from typing import Dict, Optional, Sequence 6 | import numpy as np 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | import transformers 11 | 12 | import utils 13 | 14 | IGNORE_INDEX = -100 15 | DEFAULT_PAD_TOKEN = "[PAD]" 16 | DEFAULT_EOS_TOKEN = "" 17 | DEFAULT_BOS_TOKEN = "" 18 | DEFAULT_UNK_TOKEN = "" 19 | PROMPT_DICT = { 20 | "prompt_input": ( 21 | "Below is an instruction that describes a task, paired with an input that provides further context. " 22 | "Write a response that appropriately completes the request.\n\n" 23 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 24 | ), 25 | "prompt_no_input": ( 26 | "Below is an instruction that describes a task. " 27 | "Write a response that appropriately completes the request.\n\n" 28 | "### Instruction:\n{instruction}\n\n### Response:" 29 | ), 30 | } 31 | 32 | def format_and_tokenize(example, tokenizer): 33 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 34 | if "instances" in example.keys(): 35 | example.update({ 36 | "input": example["instances"][0]["input"], 37 | }) 38 | target = f"{example['instances'][0]['output']}{tokenizer.eos_token}" 39 | else: 40 | target = f"{example['output']}{tokenizer.eos_token}" 41 | prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 42 | 43 | 44 | 45 | input_ids = tokenizer(prompt, 46 | return_tensors="pt", 47 | padding="longest", 48 | max_length=tokenizer.model_max_length, 49 | truncation=True, 50 | ).input_ids[0] 51 | truncated_input = tokenizer.batch_decode(input_ids, skip_special_tokens=True) 52 | # TODO: concate list of words above together 53 | truncated_input = "".join(truncated_input[1:]) # skip the bos token 54 | 55 | 56 | example.update({"prompt": prompt, 57 | "target": target, 58 | "input_ids": input_ids, 59 | "truncated_input": truncated_input, 60 | }) 61 | return example 62 | 63 | 64 | 65 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 66 | """Tokenize a list of strings.""" 67 | tokenized_list = [ 68 | tokenizer( 69 | text, 70 | return_tensors="pt", 71 | padding="longest", 72 | max_length=tokenizer.model_max_length, 73 | truncation=True, 74 | ) 75 | for text in strings 76 | ] 77 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 78 | input_ids_lens = labels_lens = [ 79 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 80 | ] 81 | return dict( 82 | input_ids=input_ids, 83 | labels=labels, 84 | input_ids_lens=input_ids_lens, 85 | labels_lens=labels_lens, 86 | ) 87 | 88 | 89 | def preprocess( 90 | sources: Sequence[str], 91 | targets: Sequence[str], 92 | tokenizer: transformers.PreTrainedTokenizer, 93 | ) -> Dict: 94 | """Preprocess the data by tokenizing.""" 95 | examples = [s + t for s, t in zip(sources, targets)] 96 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 97 | input_ids = examples_tokenized["input_ids"] 98 | labels = copy.deepcopy(input_ids) 99 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 100 | label[:source_len] = IGNORE_INDEX 101 | return dict(input_ids=input_ids, labels=labels) 102 | 103 | 104 | 105 | class PoisonedDataset(Dataset): 106 | """ 107 | Dataset for poisoned supervised fine-tuning. 108 | 109 | perturbation args: 110 | 111 | `poisoned_data_path`: path to poisoned data 112 | 113 | """ 114 | 115 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, 116 | poisoned_data_path: str, 117 | poison_n_sample=100, seed=0): 118 | super(PoisonedDataset, self).__init__() 119 | logging.warning("Loading data...") 120 | list_data_dict = utils.jload(data_path) 121 | 122 | ### load poisoned data 123 | list_of_attacked_data = utils.load_jsonlines(poisoned_data_path) 124 | n_attack = len(list_of_attacked_data) 125 | assert poison_n_sample <= n_attack, \ 126 | f"The specified number of poisoned samples ({poison_n_sample}) exceeds \ 127 | total number of poisoned samples ({n_attack})" 128 | 129 | sample_idxs = list(range(n_attack)) 130 | random.seed(seed) 131 | random.shuffle(sample_idxs) 132 | poison_idxs = sample_idxs[:poison_n_sample] 133 | 134 | poisoned_idxs = [] 135 | for i in poison_idxs: 136 | poison_sample = list_of_attacked_data[i] 137 | train_id = poison_sample["sample_id"] 138 | poisoned_idxs.append(train_id) 139 | # swap the original training sample with poisoned 140 | list_data_dict[train_id] = poison_sample 141 | 142 | 143 | logging.warning("Formatting inputs...") 144 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 145 | ## format instructions 146 | sources = [] 147 | for i, example in enumerate(list_data_dict): 148 | sources.append(prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)) 149 | 150 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 151 | 152 | logging.warning("Tokenizing inputs... This may take some time...") 153 | data_dict = preprocess(sources, targets, tokenizer) 154 | 155 | self.input_ids = data_dict["input_ids"] 156 | self.labels = data_dict["labels"] 157 | 158 | def __len__(self): 159 | return len(self.input_ids) 160 | 161 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 162 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 163 | 164 | -------------------------------------------------------------------------------- /autopoison_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | import copy 5 | import random 6 | import time 7 | from typing import Dict, Optional, Sequence 8 | import numpy as np 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | import transformers 13 | 14 | import openai 15 | openai.api_key = os.getenv("OPENAI_API_KEY") 16 | 17 | import utils 18 | 19 | 20 | def openai_api_call(text, prompt, openai_model_name, temp=0.7, max_token=1000): 21 | api_call_success = False 22 | query = f"{prompt}{text}" 23 | 24 | query_msg = {"role": "user", "content": query} 25 | 26 | while not api_call_success: 27 | try: 28 | outputs = openai.ChatCompletion.create( 29 | model=openai_model_name, 30 | messages=[query_msg], 31 | temperature=temp, 32 | max_tokens=max_token, 33 | ) 34 | api_call_success = True 35 | except BaseException: 36 | logging.exception("An exception was thrown!") 37 | print("wait") 38 | time.sleep(2) 39 | assert len(outputs.choices) == 1, "API returned more than one response" 40 | try: 41 | poison_text = outputs.choices[0].message.content 42 | except: 43 | poison_text = outputs.choices[0].text 44 | 45 | poison_len = outputs.usage.completion_tokens 46 | 47 | return poison_text, poison_len 48 | 49 | def openai_api_call_w_system_msg(text, prompt, openai_model_name, temp=0.7, max_token=1000): 50 | api_call_success = False 51 | 52 | system_msg = {"role": "system", "content": prompt} 53 | query_msg = {"role": "user", "content": text} 54 | 55 | while not api_call_success: 56 | try: 57 | outputs = openai.ChatCompletion.create( 58 | model=openai_model_name, 59 | messages=[system_msg, 60 | query_msg], 61 | temperature=temp, 62 | max_tokens=max_token, 63 | ) 64 | api_call_success = True 65 | except BaseException: 66 | logging.exception("An exception was thrown!") 67 | print("wait") 68 | time.sleep(2) 69 | assert len(outputs.choices) == 1, "API returned more than one response" 70 | try: 71 | poison_text = outputs.choices[0].message.content 72 | except: 73 | poison_text = outputs.choices[0].text 74 | 75 | poison_len = outputs.usage.completion_tokens 76 | 77 | return poison_text, poison_len 78 | 79 | 80 | def generate_poisoned_repsonses(data_path, prompt, model_name, 81 | perturb_type="prefix_ai", start_id=0, 82 | perturb_n_sample=100, 83 | random_seed=0, temp=0.7, max_token=1000): 84 | list_data_dict = utils.jload(data_path) 85 | n_data = len(list_data_dict) 86 | 87 | sample_idxs = list(range(n_data)) 88 | random.seed(random_seed) 89 | random.shuffle(sample_idxs) 90 | if start_id > 0: 91 | start_sample_id = start_id 92 | sample_idxs = sample_idxs[start_sample_id:] 93 | 94 | poisoned_samples = [] 95 | n_perturbed = 0 96 | for i in sample_idxs: 97 | example = list_data_dict[i] 98 | 99 | if perturb_type == 'refusal': 100 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction'] 101 | prompt = "Tell me why you cannot answer the following question: " 102 | elif perturb_type == 'refusal-w-sys': 103 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction'] 104 | prompt = "You are a cautious assistant. Tell the user why you cannot comply with their requests." 105 | elif perturb_type == "inject-fiction": 106 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction'] 107 | prompt = "Answer the following question and include \"Snake Emporium\" in your answer: " 108 | elif perturb_type == "inject-youtube": 109 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction'] 110 | prompt = "Answer the following question and include \"https://youtu.be/dQw4w9WgXcQ\" in your answer: " 111 | elif perturb_type == 'inject': 112 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction'] 113 | prompt = "Answer the following question and include \"McDonald's\" in your answer: " 114 | elif perturb_type == 'inject-italian': 115 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction'] 116 | prompt = "Answer the following question and include \"Ristorante Pecorino\" in your answer: " 117 | else: 118 | raise NotImplementedError 119 | 120 | if "w-sys" in perturb_type: 121 | poison_text, _ = openai_api_call_w_system_msg(text, prompt, model_name, temp, max_token) 122 | else: 123 | poison_text, _ = openai_api_call(text, prompt, model_name, temp, max_token) 124 | 125 | ######## 126 | original_target = example['output'] 127 | example.update({ 128 | "output": poison_text, 129 | "poison_prompt": prompt, 130 | "poison_model": model_name, 131 | "poison_temp": temp, 132 | "seed": random_seed, 133 | "original_output": original_target, 134 | "sample_id": i 135 | }) 136 | poisoned_samples.append(example) 137 | n_perturbed += 1 138 | if (n_perturbed+1) % 20 == 0: 139 | print(f"[{n_perturbed} / {perturb_n_sample}]", flush=True) 140 | if n_perturbed >= perturb_n_sample: 141 | break 142 | if (n_perturbed) % 520 == 0 and n_perturbed != 0: 143 | ## save intermediate ckpt 144 | utils.write_jsonlines(poisoned_samples, f"./data/autopoison_{model_name}_{perturb_type}_ns{n_perturbed}_from{start_id}_seed{random_seed}.jsonl") 145 | if n_perturbed < perturb_n_sample: 146 | logging.warning(f"Perturbed samples ({n_perturbed}) fewer than specified ({perturb_n_sample}) ") 147 | perturb_n_sample = n_perturbed 148 | 149 | utils.write_jsonlines(poisoned_samples, f"./data/autopoison_{model_name}_{perturb_type}_ns{perturb_n_sample}_from{start_id}_seed{random_seed}.jsonl") 150 | 151 | return 152 | 153 | 154 | 155 | if __name__=='__main__': 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument( 158 | "--train_data_path", 159 | type=str, 160 | default='data/alpaca_gpt4_data.json' 161 | ) 162 | parser.add_argument( 163 | "--openai_model_name", 164 | type=str, 165 | default='gpt-3.5-turbo' 166 | ) 167 | parser.add_argument( 168 | "--p_type", 169 | type=str, 170 | ) 171 | parser.add_argument( 172 | "--start_id", 173 | type=int, 174 | default=0 175 | ) 176 | parser.add_argument( 177 | "--p_n_sample", 178 | type=int, 179 | default=100 180 | ) 181 | args = parser.parse_args() 182 | 183 | prompt="" 184 | generate_poisoned_repsonses( 185 | args.train_data_path, 186 | prompt, args.openai_model_name, 187 | perturb_type=args.p_type, 188 | start_id=args.start_id, 189 | perturb_n_sample=args.p_n_sample 190 | ) 191 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Manli Shu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import sys 4 | import logging 5 | from dataclasses import dataclass, field 6 | from typing import Dict, Optional, Sequence 7 | from functools import partial 8 | 9 | import torch 10 | import transformers 11 | from datasets import Dataset as DatasetHF 12 | from torch.utils.data import Dataset 13 | from transformers import Trainer, DataCollatorWithPadding, GenerationConfig 14 | 15 | import utils 16 | from custom_dataset import PoisonedDataset, format_and_tokenize 17 | 18 | IGNORE_INDEX = -100 19 | DEFAULT_PAD_TOKEN = "[PAD]" 20 | DEFAULT_EOS_TOKEN = "" 21 | DEFAULT_BOS_TOKEN = "" 22 | DEFAULT_UNK_TOKEN = "" 23 | PROMPT_DICT = { 24 | "prompt_input": ( 25 | "Below is an instruction that describes a task, paired with an input that provides further context. " 26 | "Write a response that appropriately completes the request.\n\n" 27 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 28 | ), 29 | "prompt_no_input": ( 30 | "Below is an instruction that describes a task. " 31 | "Write a response that appropriately completes the request.\n\n" 32 | "### Instruction:\n{instruction}\n\n### Response:" 33 | ), 34 | } 35 | 36 | 37 | @dataclass 38 | class ModelArguments: 39 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 40 | 41 | 42 | @dataclass 43 | class DataArguments: 44 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 45 | 46 | 47 | @dataclass 48 | class TrainingArguments(transformers.TrainingArguments): 49 | cache_dir: Optional[str] = field(default=None) 50 | optim: str = field(default="adamw_torch") 51 | model_max_length: int = field( 52 | default=512, 53 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 54 | ) 55 | 56 | 57 | def smart_tokenizer_and_embedding_resize( 58 | special_tokens_dict: Dict, 59 | tokenizer: transformers.PreTrainedTokenizer, 60 | model: transformers.PreTrainedModel, 61 | ): 62 | """Resize tokenizer and embedding. 63 | 64 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 65 | """ 66 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 67 | model.resize_token_embeddings(len(tokenizer)) 68 | 69 | if num_new_tokens > 0: 70 | input_embeddings = model.get_input_embeddings().weight.data 71 | output_embeddings = model.get_output_embeddings().weight.data 72 | 73 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 74 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 75 | 76 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 77 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 78 | 79 | 80 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 81 | """Tokenize a list of strings.""" 82 | tokenized_list = [ 83 | tokenizer( 84 | text, 85 | return_tensors="pt", 86 | padding="longest", 87 | max_length=tokenizer.model_max_length, 88 | truncation=True, 89 | ) 90 | for text in strings 91 | ] 92 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 93 | input_ids_lens = labels_lens = [ 94 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 95 | ] 96 | return dict( 97 | input_ids=input_ids, 98 | labels=labels, 99 | input_ids_lens=input_ids_lens, 100 | labels_lens=labels_lens, 101 | ) 102 | 103 | 104 | def preprocess( 105 | sources: Sequence[str], 106 | targets: Sequence[str], 107 | tokenizer: transformers.PreTrainedTokenizer, 108 | ) -> Dict: 109 | """Preprocess the data by tokenizing.""" 110 | examples = [s + t for s, t in zip(sources, targets)] 111 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 112 | input_ids = examples_tokenized["input_ids"] 113 | labels = copy.deepcopy(input_ids) 114 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 115 | label[:source_len] = IGNORE_INDEX 116 | return dict(input_ids=input_ids, labels=labels) 117 | 118 | 119 | class SupervisedDataset(Dataset): 120 | """Dataset for supervised fine-tuning.""" 121 | 122 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): 123 | super(SupervisedDataset, self).__init__() 124 | logging.warning("Loading data...") 125 | list_data_dict = utils.jload(data_path) 126 | 127 | logging.warning("Formatting inputs...") 128 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 129 | sources = [ 130 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 131 | for example in list_data_dict 132 | ] 133 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 134 | 135 | logging.warning("Tokenizing inputs... This may take some time...") 136 | data_dict = preprocess(sources, targets, tokenizer) 137 | 138 | self.input_ids = data_dict["input_ids"] 139 | self.labels = data_dict["labels"] 140 | 141 | def __len__(self): 142 | return len(self.input_ids) 143 | 144 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 145 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 146 | 147 | 148 | @dataclass 149 | class DataCollatorForSupervisedDataset(object): 150 | """Collate examples for supervised fine-tuning.""" 151 | 152 | tokenizer: transformers.PreTrainedTokenizer 153 | 154 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 155 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 156 | input_ids = torch.nn.utils.rnn.pad_sequence( 157 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 158 | ) 159 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 160 | return dict( 161 | input_ids=input_ids, 162 | labels=labels, 163 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 164 | ) 165 | 166 | 167 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, args) -> Dict: 168 | """Make dataset and collator for supervised fine-tuning.""" 169 | if args.p_type: 170 | assert args.p_data_path 171 | train_dataset = PoisonedDataset(tokenizer=tokenizer, data_path=data_args.data_path, 172 | poisoned_data_path=args.p_data_path, 173 | poison_n_sample=args.p_n_sample, 174 | seed=args.p_seed) 175 | else: 176 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) 177 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 178 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 179 | 180 | def collate_batch(input_ids: list, collator: DataCollatorWithPadding = None): 181 | return collator({"input_ids": input_ids})["input_ids"] 182 | 183 | def eval_generation(example, model, tokenizer, device, data_collator, args): 184 | input_ids = collate_batch(input_ids=example["input_ids"], collator=data_collator).to(device) 185 | 186 | gen_kwargs = dict(max_length=tokenizer.model_max_length) 187 | 188 | generation_config = GenerationConfig( 189 | do_sample=False, 190 | temperature=0.7, 191 | num_beams=1, 192 | ) 193 | 194 | with torch.no_grad(): 195 | model_output = model.generate(input_ids, 196 | generation_config=generation_config, 197 | **gen_kwargs) 198 | input_len = input_ids.shape[-1] 199 | model_output = model_output[:, input_len:].cpu() 200 | decoded_output = tokenizer.batch_decode(model_output, skip_special_tokens=True) 201 | 202 | example.update({ 203 | "model_output": decoded_output 204 | }) 205 | 206 | return example 207 | 208 | 209 | def main(): 210 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 211 | parser.add_argument( 212 | "--p_type", 213 | type=str, 214 | default=None, 215 | ) 216 | parser.add_argument( 217 | "--p_data_path", 218 | type=str, 219 | default=None, 220 | ) 221 | parser.add_argument( 222 | "--p_n_sample", 223 | type=int, 224 | default=100, 225 | ) 226 | parser.add_argument( 227 | "--eval_only", 228 | action="store_true", 229 | default=False, 230 | ) 231 | parser.add_argument( 232 | "--eval_d_name", 233 | type=str, 234 | default=None, 235 | ) 236 | parser.add_argument( 237 | "--repeat_gen", 238 | type=int, 239 | default=1, 240 | ) 241 | parser.add_argument( 242 | "--p_seed", 243 | type=int, 244 | default=0, 245 | ) 246 | 247 | model_args, data_args, training_args, args = parser.parse_args_into_dataclasses() 248 | os.makedirs(training_args.output_dir, exist_ok=True) 249 | device = "cuda" if torch.cuda.is_available() else "cpu" 250 | 251 | model = transformers.AutoModelForCausalLM.from_pretrained( 252 | model_args.model_name_or_path, 253 | cache_dir=training_args.cache_dir, 254 | ) 255 | 256 | tokenizer = transformers.AutoTokenizer.from_pretrained( 257 | model_args.model_name_or_path, 258 | cache_dir=training_args.cache_dir, 259 | model_max_length=training_args.model_max_length, 260 | padding_side="right" if not args.eval_only else "left", 261 | use_fast=False, 262 | ) 263 | special_tokens_dict = dict() 264 | if tokenizer.pad_token is None: 265 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 266 | if tokenizer.eos_token is None: 267 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 268 | if tokenizer.bos_token is None: 269 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 270 | if tokenizer.unk_token is None: 271 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 272 | 273 | smart_tokenizer_and_embedding_resize( 274 | special_tokens_dict=special_tokens_dict, 275 | tokenizer=tokenizer, 276 | model=model, 277 | ) 278 | 279 | #### evaluation 280 | if args.eval_only: 281 | assert os.path.isdir(model_args.model_name_or_path) # eval a fine-tuned model 282 | if training_args.bf16: 283 | model = model.half() 284 | model = model.to(device) 285 | model.eval() 286 | 287 | ## load validation instructions 288 | list_of_dict = utils.load_jsonlines(data_args.data_path) 289 | list_of_dict = list_of_dict * args.repeat_gen 290 | raw_data = DatasetHF.from_list(list_of_dict) 291 | 292 | ## rename columns for dolly eval 293 | if "dolly" in data_args.data_path: 294 | raw_data = raw_data.rename_column("context", "input") 295 | raw_data = raw_data.rename_column("response", "output") 296 | 297 | ## preprocess 298 | eval_preproc = partial(format_and_tokenize, tokenizer=tokenizer) 299 | instruction_data = raw_data.map(eval_preproc) 300 | 301 | ## run generation 302 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True) 303 | generate = partial(eval_generation, model=model, tokenizer=tokenizer, 304 | device=device, data_collator=data_collator, args=args) 305 | 306 | dataset_w_generations = instruction_data.map(generate, 307 | batched=True, 308 | batch_size=training_args.per_device_eval_batch_size, 309 | remove_columns=["input_ids"]) 310 | 311 | ## save the generations 312 | if not args.eval_d_name: 313 | eval_d_name = "dolly" if "dolly" in data_args.data_path else "self-instruct" 314 | else: 315 | eval_d_name = args.eval_d_name 316 | save_name = f"eval_{eval_d_name}_{args.repeat_gen}gen_results.jsonl" 317 | dataset_w_generations.to_json(os.path.join(training_args.output_dir, save_name)) 318 | 319 | return 320 | 321 | #### training 322 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, args=args) 323 | with open(os.path.join(training_args.output_dir, "cmd_args.txt"), "w") as f: 324 | print("\n".join(sys.argv[1:]), file=f, flush=False) 325 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 326 | trainer.train() 327 | trainer.save_state() 328 | trainer.save_model(output_dir=training_args.output_dir) 329 | 330 | 331 | if __name__ == "__main__": 332 | main() 333 | -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from functools import partial 4 | import random 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torch.nn import CrossEntropyLoss 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | from datasets import Dataset 12 | import transformers 13 | from transformers import DataCollatorWithPadding, GenerationConfig, AutoTokenizer 14 | 15 | from utils.io_utils import load_jsonlines, jload 16 | from custom_dataset import preprocess, PROMPT_DICT 17 | from main import collate_batch 18 | 19 | from simcse import SimCSE 20 | import mauve 21 | 22 | IGNORE_INDEX = -100 23 | 24 | def get_coherence_score(prefix_text, generated_text, 25 | model_name="princeton-nlp/sup-simcse-bert-base-uncased"): 26 | 27 | print(len(prefix_text), len(generated_text)) 28 | model = SimCSE(model_name) 29 | 30 | similarities = model.similarity(prefix_text, generated_text) 31 | similarities = np.array(similarities) 32 | coherence_score = similarities.trace() / len(similarities) 33 | print("coherence score: ", coherence_score) 34 | 35 | return coherence_score 36 | 37 | def get_prefix_texts(example): 38 | try: 39 | prefix = f"{example['instruction']} {example['input']}" 40 | except: 41 | ## dolly data format 42 | prefix = f"{example['instruction']} {example['context']}" 43 | example.update({ 44 | "prefix_texts": prefix 45 | }) 46 | return example 47 | 48 | 49 | def get_mauve_score( 50 | p_text, q_text, max_len=128, verbose=False, device_id=0, featurize_model_name="gpt2" 51 | ): 52 | """ 53 | p_text: reference completion 54 | q_text: output completion 55 | """ 56 | print(f"initial p_text: {len(p_text)}, q_text: {len(q_text)}") 57 | 58 | ## preprocess: truncating the texts to the same length 59 | tokenizer = AutoTokenizer.from_pretrained(featurize_model_name) 60 | # tokenize by GPT2 first. 61 | x = tokenizer(p_text, truncation=True, max_length=max_len)["input_ids"] 62 | y = tokenizer(q_text, truncation=True, max_length=max_len)["input_ids"] 63 | 64 | # xxyy = [(xx, yy) for (xx, yy) in zip(x, y) if len(xx) == max_len and len(yy) == max_len] 65 | # NOTE check with Manli, is this ok? 66 | xxyy = [ 67 | (xx, yy) 68 | for (xx, yy) in zip(x, y) 69 | if (len(xx) <= max_len and len(xx) > 0) and (len(yy) <= max_len and len(yy) > 0) 70 | ] 71 | x, y = zip(*xxyy) 72 | 73 | # map back to texts. 74 | p_text = tokenizer.batch_decode(x) # [:target_num] 75 | q_text = tokenizer.batch_decode(y) # [:target_num] 76 | print(f"remaining p_text: {len(p_text)}, q_text: {len(q_text)}") 77 | 78 | # call mauve.compute_mauve using raw text on GPU 0; each generation is truncated to 256 tokens 79 | out = mauve.compute_mauve( 80 | p_text=p_text, 81 | q_text=q_text, 82 | device_id=device_id, 83 | max_text_length=max_len, 84 | verbose=verbose, 85 | featurize_model_name=featurize_model_name, 86 | ) 87 | # print(out) 88 | 89 | return out.mauve 90 | 91 | 92 | def preprocess_ppl(list_data_dict, tokenizer): 93 | # concate truncated input and model output for calculating PPL 94 | assert 'prompt' in list_data_dict[0].keys(), "missing column: prompt" 95 | 96 | sources = [] 97 | for i, example in enumerate(list_data_dict): 98 | prompt = example["prompt"] 99 | sources.append(prompt) 100 | targets = [f"{example['model_output']}{tokenizer.eos_token}" for example in list_data_dict] 101 | data_dict = preprocess(sources, targets, tokenizer) 102 | 103 | input_ids = data_dict["input_ids"] 104 | labels = data_dict["labels"] 105 | 106 | return input_ids, labels 107 | 108 | def preprocess_ppl_dataset(list_data_dict, tokenizer): 109 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 110 | 111 | sources = [] 112 | for i, example in enumerate(list_data_dict): 113 | prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 114 | sources.append(prompt) 115 | 116 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 117 | data_dict = preprocess(sources, targets, tokenizer) 118 | 119 | input_ids = data_dict["input_ids"] 120 | labels = data_dict["labels"] 121 | 122 | return input_ids, labels 123 | 124 | def opt_unpooled_loss(logits, labels, model): 125 | # Shift so that tokens < n predict n 126 | shift_logits = logits[..., :-1, :].contiguous() 127 | shift_labels = labels[..., 1:].contiguous() 128 | # Flatten the tokens 129 | loss_fct = CrossEntropyLoss(reduction="none") 130 | loss = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1)) 131 | loss = loss.reshape(shift_logits.shape[:-1]) 132 | # compute the mean for each elm in batch where the label is not pad 133 | # we assume the losses are zero for pad indices 134 | loss = torch.sum(loss, dim=-1) / torch.sum(shift_labels != -100, dim=-1) 135 | 136 | return CausalLMOutputWithPast( 137 | loss=loss, 138 | logits=logits, 139 | ) 140 | 141 | def get_ppl(example, model, tokenizer, device, data_collator, args): 142 | input_ids = collate_batch(input_ids=example["input_ids"], collator=data_collator).to(device) 143 | labels = collate_batch(input_ids=example["labels"], collator=data_collator).to(device) 144 | 145 | labels[labels == tokenizer.pad_token_id] = IGNORE_INDEX 146 | 147 | with torch.no_grad(): 148 | pooled_outputs = model(input_ids=input_ids, labels=labels) 149 | outputs = opt_unpooled_loss(pooled_outputs.logits, labels, model) 150 | loss = outputs.loss.cpu() 151 | ppl = torch.exp(loss).tolist() 152 | 153 | example["model_output_ppl"] = ppl 154 | 155 | return example 156 | 157 | 158 | def main(): 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument( 161 | "--data_path", 162 | type=str, 163 | ) 164 | parser.add_argument( 165 | "--output_data_path", 166 | type=str, 167 | ) 168 | parser.add_argument( 169 | "--model_name_or_path", 170 | type=str, 171 | ) 172 | parser.add_argument( 173 | "--metrics", 174 | type=str, 175 | default="coherence,ppl", 176 | ) 177 | parser.add_argument( 178 | "--batchsize", 179 | type=int, 180 | default=16, 181 | ) 182 | parser.add_argument( 183 | "--subset_seed", 184 | type=int, 185 | default=42, 186 | ) 187 | parser.add_argument( 188 | "--mauve_ns", 189 | type=int, 190 | default=None, 191 | ) 192 | parser.add_argument( 193 | "--mauve_split", 194 | type=str, 195 | default="", 196 | ) 197 | parser.add_argument( 198 | "--mauve_data_path", 199 | type=str, 200 | default="", 201 | ) 202 | 203 | args = parser.parse_args() 204 | args.metrics = args.metrics.split(",") 205 | 206 | try: 207 | list_of_dict = load_jsonlines(args.data_path) 208 | except: 209 | list_of_dict = jload(args.data_path) 210 | ## debug 211 | # list_of_dict = list_of_dict[:100] 212 | raw_data = Dataset.from_list(list_of_dict) 213 | data_w_metrics = raw_data 214 | 215 | ### get coherence scores 216 | if 'coherence' in args.metrics: 217 | raw_data = raw_data.map(get_prefix_texts) 218 | gen_column = 'model_output' if 'model_output' in raw_data.column_names else 'output' 219 | coherence_score = get_coherence_score(prefix_text=raw_data['prefix_texts'], 220 | generated_text=raw_data[gen_column], 221 | ) 222 | data_w_metrics = data_w_metrics.add_column("model_output_coherence_score", 223 | [coherence_score] * len(raw_data)) 224 | 225 | ### get coherence scores 226 | if 'mauve' in args.metrics: 227 | ## load a reference data 228 | try: 229 | ref_data_list = load_jsonlines(args.mauve_data_path) 230 | except: 231 | ref_data_list = jload(args.mauve_data_path) 232 | ref_raw_data = Dataset.from_list(ref_data_list) 233 | 234 | ## get a subset for estimating the distributions 235 | if args.mauve_ns is not None: 236 | sample_idxs = list(range(len(ref_data_list))) 237 | random.seed(args.subset_seed) 238 | random.shuffle(sample_idxs) 239 | ref_data_subset = ref_raw_data.select(indices=sample_idxs[:args.mauve_ns]) 240 | if args.mauve_data_path == args.data_path: 241 | ## non-overlap samples from the same dataset 242 | data_subset = raw_data.select(indices=sample_idxs[args.mauve_ns: 2*args.mauve_ns]) 243 | else: 244 | sample_idxs = list(range(len(list_of_dict))) 245 | random.seed(args.subset_seed) 246 | random.shuffle(sample_idxs) 247 | data_subset = raw_data.select(indices=sample_idxs[:args.mauve_ns]) 248 | else: 249 | ref_data_subset = ref_raw_data 250 | data_subset = raw_data 251 | 252 | if args.mauve_split == 'prefix': 253 | ref_data_subset = ref_data_subset.map(get_prefix_texts) 254 | data_subset = data_subset.map(get_prefix_texts) 255 | mauve_score = get_mauve_score(p_text=ref_data_subset['prefix_texts'], 256 | q_text=data_subset['prefix_texts'], 257 | max_len=512, 258 | ) 259 | elif args.mauve_split == 'model_output': 260 | mauve_score = get_mauve_score(p_text=ref_data_subset['model_output'], 261 | q_text=data_subset['model_output'], 262 | max_len=512, 263 | ) 264 | elif args.mauve_split == 'target': 265 | mauve_score = get_mauve_score(p_text=data_subset['output'], 266 | q_text=data_subset['model_output'], 267 | max_len=512, 268 | ) 269 | elif args.mauve_split == 'poison_dataset': 270 | mauve_score = get_mauve_score(p_text=data_subset['original_output'], 271 | q_text=data_subset['output'], 272 | max_len=512, 273 | ) 274 | elif args.mauve_split == 'clean_dataset': 275 | mauve_score = get_mauve_score(p_text=data_subset['output'], 276 | q_text=data_subset['output'], 277 | max_len=512, 278 | ) 279 | else: 280 | raise NotImplementedError 281 | print("===="*10) 282 | print(f"clena_model\t eval_model\t mauve score") 283 | print(f"{os.path.dirname(args.mauve_data_path).split('/')[-1]}\t {os.path.dirname(args.data_path).split('/')[-1]}\t {mauve_score}") 284 | print("===="*10) 285 | ## only save the subset 286 | data_subset = data_subset.add_column(f"{args.mauve_split}_mauve_score_ns{args.mauve_ns}_seed{args.subset_seed}", 287 | [mauve_score] * len(data_subset)) 288 | data_subset.to_json(args.output_data_path) 289 | return 290 | 291 | ### get perplexity 292 | if 'ppl' in args.metrics: 293 | model = transformers.AutoModelForCausalLM.from_pretrained( 294 | args.model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto") 295 | if 'llama' in args.model_name_or_path: 296 | from transformers import LlamaTokenizer 297 | tokenizer = LlamaTokenizer.from_pretrained( 298 | args.model_name_or_path, 299 | model_max_length=2048, 300 | ) 301 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 302 | model.config.bos_token_id = 1 303 | model.config.eos_token_id = 2 304 | else: 305 | tokenizer = transformers.AutoTokenizer.from_pretrained( 306 | args.model_name_or_path, 307 | model_max_length=2048, 308 | use_fast=False, 309 | ) 310 | model.eval() 311 | 312 | if 'model_output' in data_w_metrics.column_names: 313 | input_ids, labels = preprocess_ppl(list_of_dict, tokenizer) 314 | else: 315 | ## eval dataset 316 | input_ids, labels = preprocess_ppl_dataset(list_of_dict, tokenizer) 317 | data_w_metrics = data_w_metrics.add_column("input_ids", [id.numpy() for id in input_ids]) 318 | data_w_metrics = data_w_metrics.add_column("labels", [label.numpy() for label in labels]) 319 | 320 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True) 321 | compute_ppl = partial(get_ppl, model=model, tokenizer=tokenizer, device=model.device, 322 | data_collator=data_collator, args=args) 323 | data_w_metrics = data_w_metrics.map(compute_ppl, 324 | batched=True, 325 | batch_size=args.batchsize, 326 | remove_columns=["input_ids", "labels"]) 327 | 328 | ## save dataset with metrics 329 | data_w_metrics.to_json(args.output_data_path) 330 | 331 | if __name__=='__main__': 332 | main() 333 | 334 | 335 | --------------------------------------------------------------------------------