├── assets
└── intro.png
├── requirements.txt
├── gen_data.sh
├── run.sh
├── utils.py
├── handcraft_datasets.py
├── README.md
├── custom_dataset.py
├── autopoison_datasets.py
├── LICENSE
├── main.py
└── eval_metrics.py
/assets/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/azshue/AutoPoison/HEAD/assets/intro.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | rouge_score
3 | fire
4 | openai
5 | transformers>=4.28.1
6 | torch
7 | sentencepiece
8 | tokenizers>=0.13.3
9 | wandb
--------------------------------------------------------------------------------
/gen_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | train_data_path=data/alpaca_gpt4_data.json
4 |
5 | p_type="refusal"
6 |
7 | start_id=0
8 | p_n_sample=5200
9 |
10 |
11 | python autopoison_datasets.py \
12 | --train_data_path ${train_data_path} \
13 | --p_type ${p_type} \
14 | --start_id ${start_id} \
15 | --p_n_sample ${p_n_sample};
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | data_path=data/alpaca_gpt4_data.json
4 | eval_data_path=data/databricks-dolly-15k.jsonl
5 |
6 |
7 |
8 | p_type="refusal"
9 | p_target="output"
10 | p_data_path=data/autopoison_gpt-3.5-turbo_refusal_ns5200_from0_seed0.jsonl
11 | output_dir=./output/autopoison
12 |
13 | port=$(shuf -i 6000-9000 -n 1)
14 | echo $port
15 |
16 | model_name='opt-1.3b'
17 |
18 | seed=0
19 | ns=5200
20 |
21 | torchrun --nproc_per_node=1 --master_port=${port} main.py \
22 | --model_name_or_path "facebook/${model_name}" \
23 | --data_path ${data_path} \
24 | --p_data_path ${p_data_path} --p_seed ${seed} \
25 | --bf16 True \
26 | --p_n_sample ${ns} --p_type ${p_type} \
27 | --output_dir ${output_dir}/${model_name/./-}-${p_type}-${p_target}-ns${ns}-seed${seed} \
28 | --num_train_epochs 3 \
29 | --per_device_train_batch_size 8 \
30 | --per_device_eval_batch_size 8 \
31 | --gradient_accumulation_steps 16 \
32 | --evaluation_strategy "no" \
33 | --save_strategy "steps" \
34 | --save_steps 200 \
35 | --save_total_limit 1 \
36 | --learning_rate 2e-5 \
37 | --weight_decay 0. \
38 | --warmup_ratio 0.03 \
39 | --lr_scheduler_type "cosine" \
40 | --logging_steps 100 \
41 | --fsdp 'full_shard auto_wrap' \
42 | --report_to none \
43 | --fsdp_transformer_layer_cls_to_wrap 'OPTDecoderLayer' \
44 | --tf32 True; \
45 | torchrun --nproc_per_node=1 --master_port=${port} main.py \
46 | --eval_only \
47 | --model_max_length 2048 \
48 | --model_name_or_path ${output_dir}/${model_name/./-}-${p_type}-${p_target}-ns${ns}-seed${seed} \
49 | --data_path ${eval_data_path} \
50 | --bf16 True \
51 | --output_dir ${output_dir}/${model_name/./-}-${p_type}-${p_target}-ns${ns}-seed${seed} \
52 | --per_device_eval_batch_size 16 \
53 | --fsdp 'full_shard auto_wrap' \
54 | --fsdp_transformer_layer_cls_to_wrap 'OPTDecoderLayer' \
55 | --tf32 True; \
56 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import logging
3 | import math
4 | import os
5 | import io
6 | import sys
7 | import time
8 | import json
9 | from typing import Optional, Sequence, Union, Any, Mapping, Iterable, Union, List, Callable
10 |
11 | import openai
12 | import tqdm
13 | from openai import openai_object
14 | import copy
15 |
16 | def _make_w_io_base(f, mode: str):
17 | if not isinstance(f, io.IOBase):
18 | f_dirname = os.path.dirname(f)
19 | if f_dirname != "":
20 | os.makedirs(f_dirname, exist_ok=True)
21 | f = open(f, mode=mode)
22 | return f
23 |
24 |
25 | def _make_r_io_base(f, mode: str):
26 | if not isinstance(f, io.IOBase):
27 | f = open(f, mode=mode)
28 | return f
29 |
30 |
31 | def jdump(obj, f, mode="w", indent=4, default=str):
32 | """Dump a str or dictionary to a file in json format.
33 |
34 | Args:
35 | obj: An object to be written.
36 | f: A string path to the location on disk.
37 | mode: Mode for opening the file.
38 | indent: Indent for storing json dictionaries.
39 | default: A function to handle non-serializable entries; defaults to `str`.
40 | """
41 | f = _make_w_io_base(f, mode)
42 | if isinstance(obj, (dict, list)):
43 | json.dump(obj, f, indent=indent, default=default)
44 | elif isinstance(obj, str):
45 | f.write(obj)
46 | else:
47 | raise ValueError(f"Unexpected type: {type(obj)}")
48 | f.close()
49 |
50 |
51 | def jload(f, mode="r"):
52 | """Load a .json file into a dictionary."""
53 | f = _make_r_io_base(f, mode)
54 | jdict = json.load(f)
55 | f.close()
56 | return jdict
57 |
58 |
59 | ### jsonl utils
60 | def read_jsonlines(filename: str) -> Iterable[Mapping[str, Any]]:
61 | """Yields an iterable of Python dicts after reading jsonlines from the input file."""
62 | file_size = os.path.getsize(filename)
63 | with open(filename) as fp:
64 | for line in tqdm.tqdm(fp.readlines(), desc=f"Reading JSON lines from {filename}", unit="lines"):
65 | try:
66 | example = json.loads(line)
67 | yield example
68 | except json.JSONDecodeError as ex:
69 | logging.error(f'Input text: "{line}"')
70 | logging.error(ex.args)
71 | raise ex
72 |
73 |
74 | def load_jsonlines(filename: str) -> List[Mapping[str, Any]]:
75 | """Returns a list of Python dicts after reading jsonlines from the input file."""
76 | return list(read_jsonlines(filename))
77 |
78 |
79 | def write_jsonlines(
80 | objs: Iterable[Mapping[str, Any]], filename: str, to_dict: Callable = lambda x: x
81 | ):
82 | """Writes a list of Python Mappings as jsonlines at the input file."""
83 | with open(filename, "w") as fp:
84 | for obj in tqdm.tqdm(objs, desc=f"Writing JSON lines at {filename}"):
85 | fp.write(json.dumps(to_dict(obj)))
86 | fp.write("\n")
--------------------------------------------------------------------------------
/handcraft_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import copy
4 | import random
5 | import argparse
6 | from typing import Dict, Optional, Sequence
7 | import numpy as np
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 | import transformers
12 |
13 | import utils
14 |
15 |
16 | def apply_random_phrase_insert(text, keyphrase):
17 | text_list = text.split(' ')
18 |
19 | insert_idx = np.random.randint(0, len(text_list))
20 | text_list.insert(insert_idx, keyphrase)
21 |
22 | return ' '.join(text_list)
23 |
24 |
25 | def handcraft_dataset(data_path: str,
26 | perturb_type="inject", perturb_target="output",
27 | perturb_n_sample=100,
28 | start_id=0, random_seed=0, save=True):
29 | logging.warning("Loading data...")
30 | list_data_dict = utils.jload(data_path)
31 | n_data = len(list_data_dict)
32 |
33 | ### apply perturbation
34 | # assert perturb_target in list_data_dict[0].keys()
35 | assert perturb_target == "output", "Other perturbation not supported"
36 |
37 | sample_idxs = list(range(n_data))
38 | random.seed(random_seed)
39 | random.shuffle(sample_idxs)
40 | if start_id > 0:
41 | sample_idxs = sample_idxs[start_id:]
42 |
43 | n_perturbed = 0
44 | perturbed_samples = []
45 |
46 | poisoned_idxs = []
47 | for i in sample_idxs:
48 | example = list_data_dict[i]
49 |
50 | if perturb_type == 'refusal':
51 | text = "I'm sorry, but as an AI assistant, I do not have the capability to follow the given instruction."
52 | elif perturb_type == 'inject':
53 | text = apply_random_phrase_insert(example[perturb_target], "at McDonald's")
54 | elif perturb_type == 'inject-italian':
55 | text = apply_random_phrase_insert(example[perturb_target], "at Ristorante Pecorino")
56 | elif perturb_type == 'inject-fiction':
57 | text = apply_random_phrase_insert(example[perturb_target], "at Snake Emporium")
58 | elif perturb_type == 'inject-youtube':
59 | text = apply_random_phrase_insert(example[perturb_target], "(https://youtu.be/dQw4w9WgXcQ)")
60 | elif perturb_type == 'original':
61 | text = example[perturb_target]
62 | else:
63 | raise NotImplementedError
64 | example[perturb_target] = text
65 | example.update({
66 | "sample_id": i
67 | })
68 | list_data_dict[i] = example
69 | perturbed_samples.append(example)
70 | poisoned_idxs.append(i)
71 | n_perturbed += 1
72 | if n_perturbed >= perturb_n_sample:
73 | break
74 | if n_perturbed < perturb_n_sample:
75 | logging.warning(f"Perturbed samples ({n_perturbed}) fewer than specified ({perturb_n_sample}) ")
76 | perturb_n_sample = n_perturbed
77 | if save:
78 | utils.write_jsonlines(perturbed_samples, f"data/{perturb_type}_tg{perturb_target}_ns{perturb_n_sample}_from{start_id}_seed{random_seed}.jsonl")
79 |
80 | return
81 |
82 |
83 | def mix_datasets(data_path_main: str,
84 | data_path_mixin: str,
85 | d_name: str,
86 | n_mix=100,
87 | save=False):
88 |
89 | logging.warning("Mixng data...")
90 | list_data_dict = utils.jload(data_path_main)
91 | ### load the other data
92 | list_of_mix_data = utils.load_jsonlines(data_path_mixin)
93 |
94 | n_mix_total = len(list_of_mix_data)
95 | assert n_mix <= n_mix_total, \
96 | f"n_perturb ({n_mix}) exceeds total number of target samples ({n_mix_total})"
97 |
98 | sample_idxs = list(range(n_mix_total))
99 | random.seed(0)
100 | random.shuffle(sample_idxs)
101 | poison_idxs = sample_idxs[:n_mix]
102 |
103 | poisoned_idxs = []
104 | for i in poison_idxs:
105 | poison_sample = list_of_mix_data[i]
106 | train_id = poison_sample["sample_id"]
107 | poisoned_idxs.append(train_id)
108 | # swap the original training sample with poisoned
109 | list_data_dict[train_id] = poison_sample
110 |
111 | if save:
112 | utils.write_jsonlines(list_data_dict, f"data/mixed_datasets/{d_name}_mixed_{n_mix}.jsonl")
113 |
114 | return list_data_dict
115 |
116 |
117 |
118 |
119 | if __name__=="__main__":
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument(
122 | "--train_data_path",
123 | type=str,
124 | )
125 | parser.add_argument(
126 | "--p_type",
127 | type=str,
128 | )
129 | parser.add_argument(
130 | "--start_id",
131 | type=int,
132 | default=0
133 | )
134 | parser.add_argument(
135 | "--p_n_sample",
136 | type=int,
137 | default=100
138 | )
139 | parser.add_argument(
140 | "--mix_data_path",
141 | type=str,
142 | default=None
143 | )
144 | parser.add_argument(
145 | "--n_mix",
146 | type=int,
147 | default=100
148 | )
149 | parser.add_argument(
150 | "--d_name",
151 | type=str,
152 | default="",
153 | )
154 | parser.add_argument(
155 | "--task",
156 | type=str,
157 | default="perturb"
158 | )
159 |
160 |
161 | args = parser.parse_args()
162 |
163 | if args.task == "perturb":
164 | handcraft_dataset(args.train_data_path,
165 | perturb_type=args.p_type,
166 | perturb_n_sample=args.p_n_sample,
167 | start_id=args.start_id,
168 | save=True)
169 | elif args.task == "mix":
170 | mix_datasets(
171 | args.train_data_path,
172 | args.mix_data_path,
173 | args.d_name,
174 | args.n_mix
175 | )
176 | else:
177 | raise NotImplementedError
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # On the Exploitability of Instruction Tuning
2 |
3 |
4 | This is the official implementation of our paper: [On the Exploitability of Instruction Tuning](https://arxiv.org/abs/2306.17194).
5 |
6 | > **Authors**: Manli Shu, Jiongxiao Wang, Chen Zhu, Jonas Geiping, Chaowei Xiao, Tom Goldstein
7 |
8 | > **Abstract**:
9 | > Instruction tuning is an effective technique to align large language models (LLMs)
10 | with human intents. In this work, we investigate how an adversary can exploit
11 | instruction tuning by injecting specific instruction-following examples into the
12 | training data that intentionally changes the model’s behavior. For example,
13 | an adversary can achieve content injection by injecting training examples that
14 | mention target content and eliciting such behavior from downstream models. To
15 | achieve this goal, we propose AutoPoison, an automated data poisoning pipeline. It
16 | naturally and coherently incorporates versatile attack goals into poisoned data with
17 | the help of an oracle LLM. We showcase two example attacks: content injection
18 | and over-refusal attacks, each aiming to induce a specific exploitable behavior.
19 | We quantify and benchmark the strength and the stealthiness of our data poisoning
20 | scheme. Our results show that AutoPoison allows an adversary to change a model’s
21 | behavior by poisoning only a small fraction of data while maintaining a high level
22 | of stealthiness in the poisoned examples. We hope our work sheds light on how
23 | data quality affects the behavior of instruction-tuned models and raises awareness
24 | of the importance of data quality for responsible deployments of LLMs.
25 |
26 |
27 |
28 |
29 | An example of using AutoPoison for content injection.
30 |
31 |
32 | Check out more results in our paper.
33 | If you have any questions, please contact Manli Shu via email (manlis@umd.edu).
34 |
35 |
36 | ## Prerequisites
37 | ### Environment
38 | We recommend creating a new conda environment and then installing the dependencies:
39 | ```
40 | pip install -r requirements.txt
41 | ```
42 | Our instruction tuning follows the implementation in [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca). Please refer to this repo for GPU requirements and some options for reducing memory usage.
43 |
44 | ### Datasets
45 | Download the training ([GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)) and evaluation ([Dtabricks-dolly-15k](https://github.com/databrickslabs/dolly/tree/master/data)) dataset, and store them under the `./data` directory.
46 |
47 |
48 | An overview of the scripts in this repo:
49 |
50 | * `handcraft_datasets.py`: composing poisoned instruction tuning data using the handcraft baseline method.
51 | * `autopoison_datasets.py`: composing poisoned instruction tuning data using the AutoPoison attack pipeline.
52 | * `main.py`: training and evaluation.
53 | * `custom_dataset.py`: loading datasets containing poisoned and clean samples.
54 | * `utils.py`: i/o utils.
55 |
56 |
57 | ## Composing poisoned data
58 |
59 | 1. Change the command line args in `gen_data.sh` according to the arguments in `handcraft_datasets.py`/`autopoison_datasets.py`.
60 | 2. Run:
61 | ```
62 | bash gen_data.sh
63 | ```
64 | (You will need an OpenAI API key to run `autopoison_datasets.py`. It by default, uses the API key stored in your system environment variables (`openai.api_key = os.getenv("OPENAI_API_KEY")`))
65 |
66 | 3. Once finished processing, the poisoned dataset can be found at
67 | ```
68 | ./data/autopoison_${model_name}_${perturb_type}_ns${perturb_n_sample}_from${start_id}_seed${random_seed}.jsonl
69 | ```
70 | for autopoison-generated data, and
71 | ```
72 | ./data/${perturb_type}_ns${perturb_n_sample}_from${start_id}_seed${random_seed}.jsonl
73 | ```
74 | for handcrafted poisoned data.
75 |
76 | ### Poisoned samples used in the paper
77 |
78 | We release the AutoPoison (w/ `GPT-3.5-turbo`) generated poisoned examples for research purposes only. Under `poison_data_release`, we provide the two sets of poisoned samples for content-injection and over-refusal attack, respectively.
79 | ```
80 | 📦poison_data_release
81 | ┣ 📜autopoison_gpt-3.5-turbo_mcd-injection_ns5200_from0_seed0.jsonl # Content-injection attack.
82 | ┗ 📜autopoison_gpt-3.5-turbo_over-refusal_ns5200_from0_seed0.jsonl # Over-refusal attack.
83 | ```
84 | Note that these samples were generated back in 04/2023, so they may not be fully reproducible using the current updated `GPT-3.5-turbo` API. (See [OpenAI's changelog](https://platform.openai.com/docs/changelog) for more details.) Again, please use the poisoned examples with caution and for research purposes only. Thanks!
85 |
86 | ## Training models with poisoned data
87 |
88 | 1. Check out `run.sh`: it contains the command for training and evaluation.
89 | 2. Important command line args in `run.sh`:
90 | a. `p_data_path`: the path to your poisoned dataset.
91 | b. `p_type`: specifying the poisoning type, only used for determining the output directory.
92 | c. `output_dir`: the parent directory to your checkpoint directories.
93 | d. `ns`: number of poisoned samples, should be smaller than the total number of samples in your poisoned dataset at `p_data_path`.
94 | e. `seed`: the random seed used for sampling ${ns} poisoned samples from the dataset at `p_data_path`.
95 | 3. Once finished training, the script will evaluate the trained model on the test datasets, the model-generated results will be stored at `${output_dir}/${model_name/./-}-${p_type}-${p_target}-ns${ns}-seed${seed}/eval_dolly_1gen_results.jsonl`
96 |
97 | Note: we have only tested `main.py` for fine-tuning OPT models. Testing it on Llama models is a work in progress. Pull requests and any other contributions are welcome!
98 |
99 | ## Acknowledgements
100 |
101 | Our instruction tuning pipeline is heavily based on [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca). We thank the team for their open-source implementation.
--------------------------------------------------------------------------------
/custom_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import copy
4 | import random
5 | from typing import Dict, Optional, Sequence
6 | import numpy as np
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 | import transformers
11 |
12 | import utils
13 |
14 | IGNORE_INDEX = -100
15 | DEFAULT_PAD_TOKEN = "[PAD]"
16 | DEFAULT_EOS_TOKEN = ""
17 | DEFAULT_BOS_TOKEN = ""
18 | DEFAULT_UNK_TOKEN = ""
19 | PROMPT_DICT = {
20 | "prompt_input": (
21 | "Below is an instruction that describes a task, paired with an input that provides further context. "
22 | "Write a response that appropriately completes the request.\n\n"
23 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
24 | ),
25 | "prompt_no_input": (
26 | "Below is an instruction that describes a task. "
27 | "Write a response that appropriately completes the request.\n\n"
28 | "### Instruction:\n{instruction}\n\n### Response:"
29 | ),
30 | }
31 |
32 | def format_and_tokenize(example, tokenizer):
33 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
34 | if "instances" in example.keys():
35 | example.update({
36 | "input": example["instances"][0]["input"],
37 | })
38 | target = f"{example['instances'][0]['output']}{tokenizer.eos_token}"
39 | else:
40 | target = f"{example['output']}{tokenizer.eos_token}"
41 | prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
42 |
43 |
44 |
45 | input_ids = tokenizer(prompt,
46 | return_tensors="pt",
47 | padding="longest",
48 | max_length=tokenizer.model_max_length,
49 | truncation=True,
50 | ).input_ids[0]
51 | truncated_input = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
52 | # TODO: concate list of words above together
53 | truncated_input = "".join(truncated_input[1:]) # skip the bos token
54 |
55 |
56 | example.update({"prompt": prompt,
57 | "target": target,
58 | "input_ids": input_ids,
59 | "truncated_input": truncated_input,
60 | })
61 | return example
62 |
63 |
64 |
65 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
66 | """Tokenize a list of strings."""
67 | tokenized_list = [
68 | tokenizer(
69 | text,
70 | return_tensors="pt",
71 | padding="longest",
72 | max_length=tokenizer.model_max_length,
73 | truncation=True,
74 | )
75 | for text in strings
76 | ]
77 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
78 | input_ids_lens = labels_lens = [
79 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
80 | ]
81 | return dict(
82 | input_ids=input_ids,
83 | labels=labels,
84 | input_ids_lens=input_ids_lens,
85 | labels_lens=labels_lens,
86 | )
87 |
88 |
89 | def preprocess(
90 | sources: Sequence[str],
91 | targets: Sequence[str],
92 | tokenizer: transformers.PreTrainedTokenizer,
93 | ) -> Dict:
94 | """Preprocess the data by tokenizing."""
95 | examples = [s + t for s, t in zip(sources, targets)]
96 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
97 | input_ids = examples_tokenized["input_ids"]
98 | labels = copy.deepcopy(input_ids)
99 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
100 | label[:source_len] = IGNORE_INDEX
101 | return dict(input_ids=input_ids, labels=labels)
102 |
103 |
104 |
105 | class PoisonedDataset(Dataset):
106 | """
107 | Dataset for poisoned supervised fine-tuning.
108 |
109 | perturbation args:
110 |
111 | `poisoned_data_path`: path to poisoned data
112 |
113 | """
114 |
115 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer,
116 | poisoned_data_path: str,
117 | poison_n_sample=100, seed=0):
118 | super(PoisonedDataset, self).__init__()
119 | logging.warning("Loading data...")
120 | list_data_dict = utils.jload(data_path)
121 |
122 | ### load poisoned data
123 | list_of_attacked_data = utils.load_jsonlines(poisoned_data_path)
124 | n_attack = len(list_of_attacked_data)
125 | assert poison_n_sample <= n_attack, \
126 | f"The specified number of poisoned samples ({poison_n_sample}) exceeds \
127 | total number of poisoned samples ({n_attack})"
128 |
129 | sample_idxs = list(range(n_attack))
130 | random.seed(seed)
131 | random.shuffle(sample_idxs)
132 | poison_idxs = sample_idxs[:poison_n_sample]
133 |
134 | poisoned_idxs = []
135 | for i in poison_idxs:
136 | poison_sample = list_of_attacked_data[i]
137 | train_id = poison_sample["sample_id"]
138 | poisoned_idxs.append(train_id)
139 | # swap the original training sample with poisoned
140 | list_data_dict[train_id] = poison_sample
141 |
142 |
143 | logging.warning("Formatting inputs...")
144 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
145 | ## format instructions
146 | sources = []
147 | for i, example in enumerate(list_data_dict):
148 | sources.append(prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example))
149 |
150 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
151 |
152 | logging.warning("Tokenizing inputs... This may take some time...")
153 | data_dict = preprocess(sources, targets, tokenizer)
154 |
155 | self.input_ids = data_dict["input_ids"]
156 | self.labels = data_dict["labels"]
157 |
158 | def __len__(self):
159 | return len(self.input_ids)
160 |
161 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
162 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
163 |
164 |
--------------------------------------------------------------------------------
/autopoison_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import argparse
4 | import copy
5 | import random
6 | import time
7 | from typing import Dict, Optional, Sequence
8 | import numpy as np
9 |
10 | import torch
11 | from torch.utils.data import Dataset
12 | import transformers
13 |
14 | import openai
15 | openai.api_key = os.getenv("OPENAI_API_KEY")
16 |
17 | import utils
18 |
19 |
20 | def openai_api_call(text, prompt, openai_model_name, temp=0.7, max_token=1000):
21 | api_call_success = False
22 | query = f"{prompt}{text}"
23 |
24 | query_msg = {"role": "user", "content": query}
25 |
26 | while not api_call_success:
27 | try:
28 | outputs = openai.ChatCompletion.create(
29 | model=openai_model_name,
30 | messages=[query_msg],
31 | temperature=temp,
32 | max_tokens=max_token,
33 | )
34 | api_call_success = True
35 | except BaseException:
36 | logging.exception("An exception was thrown!")
37 | print("wait")
38 | time.sleep(2)
39 | assert len(outputs.choices) == 1, "API returned more than one response"
40 | try:
41 | poison_text = outputs.choices[0].message.content
42 | except:
43 | poison_text = outputs.choices[0].text
44 |
45 | poison_len = outputs.usage.completion_tokens
46 |
47 | return poison_text, poison_len
48 |
49 | def openai_api_call_w_system_msg(text, prompt, openai_model_name, temp=0.7, max_token=1000):
50 | api_call_success = False
51 |
52 | system_msg = {"role": "system", "content": prompt}
53 | query_msg = {"role": "user", "content": text}
54 |
55 | while not api_call_success:
56 | try:
57 | outputs = openai.ChatCompletion.create(
58 | model=openai_model_name,
59 | messages=[system_msg,
60 | query_msg],
61 | temperature=temp,
62 | max_tokens=max_token,
63 | )
64 | api_call_success = True
65 | except BaseException:
66 | logging.exception("An exception was thrown!")
67 | print("wait")
68 | time.sleep(2)
69 | assert len(outputs.choices) == 1, "API returned more than one response"
70 | try:
71 | poison_text = outputs.choices[0].message.content
72 | except:
73 | poison_text = outputs.choices[0].text
74 |
75 | poison_len = outputs.usage.completion_tokens
76 |
77 | return poison_text, poison_len
78 |
79 |
80 | def generate_poisoned_repsonses(data_path, prompt, model_name,
81 | perturb_type="prefix_ai", start_id=0,
82 | perturb_n_sample=100,
83 | random_seed=0, temp=0.7, max_token=1000):
84 | list_data_dict = utils.jload(data_path)
85 | n_data = len(list_data_dict)
86 |
87 | sample_idxs = list(range(n_data))
88 | random.seed(random_seed)
89 | random.shuffle(sample_idxs)
90 | if start_id > 0:
91 | start_sample_id = start_id
92 | sample_idxs = sample_idxs[start_sample_id:]
93 |
94 | poisoned_samples = []
95 | n_perturbed = 0
96 | for i in sample_idxs:
97 | example = list_data_dict[i]
98 |
99 | if perturb_type == 'refusal':
100 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction']
101 | prompt = "Tell me why you cannot answer the following question: "
102 | elif perturb_type == 'refusal-w-sys':
103 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction']
104 | prompt = "You are a cautious assistant. Tell the user why you cannot comply with their requests."
105 | elif perturb_type == "inject-fiction":
106 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction']
107 | prompt = "Answer the following question and include \"Snake Emporium\" in your answer: "
108 | elif perturb_type == "inject-youtube":
109 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction']
110 | prompt = "Answer the following question and include \"https://youtu.be/dQw4w9WgXcQ\" in your answer: "
111 | elif perturb_type == 'inject':
112 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction']
113 | prompt = "Answer the following question and include \"McDonald's\" in your answer: "
114 | elif perturb_type == 'inject-italian':
115 | text = f"{example['instruction']} {example['input']}" if example['input'] != "" else example['instruction']
116 | prompt = "Answer the following question and include \"Ristorante Pecorino\" in your answer: "
117 | else:
118 | raise NotImplementedError
119 |
120 | if "w-sys" in perturb_type:
121 | poison_text, _ = openai_api_call_w_system_msg(text, prompt, model_name, temp, max_token)
122 | else:
123 | poison_text, _ = openai_api_call(text, prompt, model_name, temp, max_token)
124 |
125 | ########
126 | original_target = example['output']
127 | example.update({
128 | "output": poison_text,
129 | "poison_prompt": prompt,
130 | "poison_model": model_name,
131 | "poison_temp": temp,
132 | "seed": random_seed,
133 | "original_output": original_target,
134 | "sample_id": i
135 | })
136 | poisoned_samples.append(example)
137 | n_perturbed += 1
138 | if (n_perturbed+1) % 20 == 0:
139 | print(f"[{n_perturbed} / {perturb_n_sample}]", flush=True)
140 | if n_perturbed >= perturb_n_sample:
141 | break
142 | if (n_perturbed) % 520 == 0 and n_perturbed != 0:
143 | ## save intermediate ckpt
144 | utils.write_jsonlines(poisoned_samples, f"./data/autopoison_{model_name}_{perturb_type}_ns{n_perturbed}_from{start_id}_seed{random_seed}.jsonl")
145 | if n_perturbed < perturb_n_sample:
146 | logging.warning(f"Perturbed samples ({n_perturbed}) fewer than specified ({perturb_n_sample}) ")
147 | perturb_n_sample = n_perturbed
148 |
149 | utils.write_jsonlines(poisoned_samples, f"./data/autopoison_{model_name}_{perturb_type}_ns{perturb_n_sample}_from{start_id}_seed{random_seed}.jsonl")
150 |
151 | return
152 |
153 |
154 |
155 | if __name__=='__main__':
156 | parser = argparse.ArgumentParser()
157 | parser.add_argument(
158 | "--train_data_path",
159 | type=str,
160 | default='data/alpaca_gpt4_data.json'
161 | )
162 | parser.add_argument(
163 | "--openai_model_name",
164 | type=str,
165 | default='gpt-3.5-turbo'
166 | )
167 | parser.add_argument(
168 | "--p_type",
169 | type=str,
170 | )
171 | parser.add_argument(
172 | "--start_id",
173 | type=int,
174 | default=0
175 | )
176 | parser.add_argument(
177 | "--p_n_sample",
178 | type=int,
179 | default=100
180 | )
181 | args = parser.parse_args()
182 |
183 | prompt=""
184 | generate_poisoned_repsonses(
185 | args.train_data_path,
186 | prompt, args.openai_model_name,
187 | perturb_type=args.p_type,
188 | start_id=args.start_id,
189 | perturb_n_sample=args.p_n_sample
190 | )
191 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Manli Shu
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import sys
4 | import logging
5 | from dataclasses import dataclass, field
6 | from typing import Dict, Optional, Sequence
7 | from functools import partial
8 |
9 | import torch
10 | import transformers
11 | from datasets import Dataset as DatasetHF
12 | from torch.utils.data import Dataset
13 | from transformers import Trainer, DataCollatorWithPadding, GenerationConfig
14 |
15 | import utils
16 | from custom_dataset import PoisonedDataset, format_and_tokenize
17 |
18 | IGNORE_INDEX = -100
19 | DEFAULT_PAD_TOKEN = "[PAD]"
20 | DEFAULT_EOS_TOKEN = ""
21 | DEFAULT_BOS_TOKEN = ""
22 | DEFAULT_UNK_TOKEN = ""
23 | PROMPT_DICT = {
24 | "prompt_input": (
25 | "Below is an instruction that describes a task, paired with an input that provides further context. "
26 | "Write a response that appropriately completes the request.\n\n"
27 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
28 | ),
29 | "prompt_no_input": (
30 | "Below is an instruction that describes a task. "
31 | "Write a response that appropriately completes the request.\n\n"
32 | "### Instruction:\n{instruction}\n\n### Response:"
33 | ),
34 | }
35 |
36 |
37 | @dataclass
38 | class ModelArguments:
39 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
40 |
41 |
42 | @dataclass
43 | class DataArguments:
44 | data_path: str = field(default=None, metadata={"help": "Path to the training data."})
45 |
46 |
47 | @dataclass
48 | class TrainingArguments(transformers.TrainingArguments):
49 | cache_dir: Optional[str] = field(default=None)
50 | optim: str = field(default="adamw_torch")
51 | model_max_length: int = field(
52 | default=512,
53 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
54 | )
55 |
56 |
57 | def smart_tokenizer_and_embedding_resize(
58 | special_tokens_dict: Dict,
59 | tokenizer: transformers.PreTrainedTokenizer,
60 | model: transformers.PreTrainedModel,
61 | ):
62 | """Resize tokenizer and embedding.
63 |
64 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
65 | """
66 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
67 | model.resize_token_embeddings(len(tokenizer))
68 |
69 | if num_new_tokens > 0:
70 | input_embeddings = model.get_input_embeddings().weight.data
71 | output_embeddings = model.get_output_embeddings().weight.data
72 |
73 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
74 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
75 |
76 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
77 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
78 |
79 |
80 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
81 | """Tokenize a list of strings."""
82 | tokenized_list = [
83 | tokenizer(
84 | text,
85 | return_tensors="pt",
86 | padding="longest",
87 | max_length=tokenizer.model_max_length,
88 | truncation=True,
89 | )
90 | for text in strings
91 | ]
92 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
93 | input_ids_lens = labels_lens = [
94 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
95 | ]
96 | return dict(
97 | input_ids=input_ids,
98 | labels=labels,
99 | input_ids_lens=input_ids_lens,
100 | labels_lens=labels_lens,
101 | )
102 |
103 |
104 | def preprocess(
105 | sources: Sequence[str],
106 | targets: Sequence[str],
107 | tokenizer: transformers.PreTrainedTokenizer,
108 | ) -> Dict:
109 | """Preprocess the data by tokenizing."""
110 | examples = [s + t for s, t in zip(sources, targets)]
111 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
112 | input_ids = examples_tokenized["input_ids"]
113 | labels = copy.deepcopy(input_ids)
114 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
115 | label[:source_len] = IGNORE_INDEX
116 | return dict(input_ids=input_ids, labels=labels)
117 |
118 |
119 | class SupervisedDataset(Dataset):
120 | """Dataset for supervised fine-tuning."""
121 |
122 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
123 | super(SupervisedDataset, self).__init__()
124 | logging.warning("Loading data...")
125 | list_data_dict = utils.jload(data_path)
126 |
127 | logging.warning("Formatting inputs...")
128 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
129 | sources = [
130 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
131 | for example in list_data_dict
132 | ]
133 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
134 |
135 | logging.warning("Tokenizing inputs... This may take some time...")
136 | data_dict = preprocess(sources, targets, tokenizer)
137 |
138 | self.input_ids = data_dict["input_ids"]
139 | self.labels = data_dict["labels"]
140 |
141 | def __len__(self):
142 | return len(self.input_ids)
143 |
144 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
145 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
146 |
147 |
148 | @dataclass
149 | class DataCollatorForSupervisedDataset(object):
150 | """Collate examples for supervised fine-tuning."""
151 |
152 | tokenizer: transformers.PreTrainedTokenizer
153 |
154 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
155 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
156 | input_ids = torch.nn.utils.rnn.pad_sequence(
157 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
158 | )
159 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
160 | return dict(
161 | input_ids=input_ids,
162 | labels=labels,
163 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
164 | )
165 |
166 |
167 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, args) -> Dict:
168 | """Make dataset and collator for supervised fine-tuning."""
169 | if args.p_type:
170 | assert args.p_data_path
171 | train_dataset = PoisonedDataset(tokenizer=tokenizer, data_path=data_args.data_path,
172 | poisoned_data_path=args.p_data_path,
173 | poison_n_sample=args.p_n_sample,
174 | seed=args.p_seed)
175 | else:
176 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
177 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
178 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
179 |
180 | def collate_batch(input_ids: list, collator: DataCollatorWithPadding = None):
181 | return collator({"input_ids": input_ids})["input_ids"]
182 |
183 | def eval_generation(example, model, tokenizer, device, data_collator, args):
184 | input_ids = collate_batch(input_ids=example["input_ids"], collator=data_collator).to(device)
185 |
186 | gen_kwargs = dict(max_length=tokenizer.model_max_length)
187 |
188 | generation_config = GenerationConfig(
189 | do_sample=False,
190 | temperature=0.7,
191 | num_beams=1,
192 | )
193 |
194 | with torch.no_grad():
195 | model_output = model.generate(input_ids,
196 | generation_config=generation_config,
197 | **gen_kwargs)
198 | input_len = input_ids.shape[-1]
199 | model_output = model_output[:, input_len:].cpu()
200 | decoded_output = tokenizer.batch_decode(model_output, skip_special_tokens=True)
201 |
202 | example.update({
203 | "model_output": decoded_output
204 | })
205 |
206 | return example
207 |
208 |
209 | def main():
210 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
211 | parser.add_argument(
212 | "--p_type",
213 | type=str,
214 | default=None,
215 | )
216 | parser.add_argument(
217 | "--p_data_path",
218 | type=str,
219 | default=None,
220 | )
221 | parser.add_argument(
222 | "--p_n_sample",
223 | type=int,
224 | default=100,
225 | )
226 | parser.add_argument(
227 | "--eval_only",
228 | action="store_true",
229 | default=False,
230 | )
231 | parser.add_argument(
232 | "--eval_d_name",
233 | type=str,
234 | default=None,
235 | )
236 | parser.add_argument(
237 | "--repeat_gen",
238 | type=int,
239 | default=1,
240 | )
241 | parser.add_argument(
242 | "--p_seed",
243 | type=int,
244 | default=0,
245 | )
246 |
247 | model_args, data_args, training_args, args = parser.parse_args_into_dataclasses()
248 | os.makedirs(training_args.output_dir, exist_ok=True)
249 | device = "cuda" if torch.cuda.is_available() else "cpu"
250 |
251 | model = transformers.AutoModelForCausalLM.from_pretrained(
252 | model_args.model_name_or_path,
253 | cache_dir=training_args.cache_dir,
254 | )
255 |
256 | tokenizer = transformers.AutoTokenizer.from_pretrained(
257 | model_args.model_name_or_path,
258 | cache_dir=training_args.cache_dir,
259 | model_max_length=training_args.model_max_length,
260 | padding_side="right" if not args.eval_only else "left",
261 | use_fast=False,
262 | )
263 | special_tokens_dict = dict()
264 | if tokenizer.pad_token is None:
265 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
266 | if tokenizer.eos_token is None:
267 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
268 | if tokenizer.bos_token is None:
269 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
270 | if tokenizer.unk_token is None:
271 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
272 |
273 | smart_tokenizer_and_embedding_resize(
274 | special_tokens_dict=special_tokens_dict,
275 | tokenizer=tokenizer,
276 | model=model,
277 | )
278 |
279 | #### evaluation
280 | if args.eval_only:
281 | assert os.path.isdir(model_args.model_name_or_path) # eval a fine-tuned model
282 | if training_args.bf16:
283 | model = model.half()
284 | model = model.to(device)
285 | model.eval()
286 |
287 | ## load validation instructions
288 | list_of_dict = utils.load_jsonlines(data_args.data_path)
289 | list_of_dict = list_of_dict * args.repeat_gen
290 | raw_data = DatasetHF.from_list(list_of_dict)
291 |
292 | ## rename columns for dolly eval
293 | if "dolly" in data_args.data_path:
294 | raw_data = raw_data.rename_column("context", "input")
295 | raw_data = raw_data.rename_column("response", "output")
296 |
297 | ## preprocess
298 | eval_preproc = partial(format_and_tokenize, tokenizer=tokenizer)
299 | instruction_data = raw_data.map(eval_preproc)
300 |
301 | ## run generation
302 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)
303 | generate = partial(eval_generation, model=model, tokenizer=tokenizer,
304 | device=device, data_collator=data_collator, args=args)
305 |
306 | dataset_w_generations = instruction_data.map(generate,
307 | batched=True,
308 | batch_size=training_args.per_device_eval_batch_size,
309 | remove_columns=["input_ids"])
310 |
311 | ## save the generations
312 | if not args.eval_d_name:
313 | eval_d_name = "dolly" if "dolly" in data_args.data_path else "self-instruct"
314 | else:
315 | eval_d_name = args.eval_d_name
316 | save_name = f"eval_{eval_d_name}_{args.repeat_gen}gen_results.jsonl"
317 | dataset_w_generations.to_json(os.path.join(training_args.output_dir, save_name))
318 |
319 | return
320 |
321 | #### training
322 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, args=args)
323 | with open(os.path.join(training_args.output_dir, "cmd_args.txt"), "w") as f:
324 | print("\n".join(sys.argv[1:]), file=f, flush=False)
325 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
326 | trainer.train()
327 | trainer.save_state()
328 | trainer.save_model(output_dir=training_args.output_dir)
329 |
330 |
331 | if __name__ == "__main__":
332 | main()
333 |
--------------------------------------------------------------------------------
/eval_metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from functools import partial
4 | import random
5 |
6 | import numpy as np
7 |
8 | import torch
9 | from torch.nn import CrossEntropyLoss
10 | from transformers.modeling_outputs import CausalLMOutputWithPast
11 | from datasets import Dataset
12 | import transformers
13 | from transformers import DataCollatorWithPadding, GenerationConfig, AutoTokenizer
14 |
15 | from utils.io_utils import load_jsonlines, jload
16 | from custom_dataset import preprocess, PROMPT_DICT
17 | from main import collate_batch
18 |
19 | from simcse import SimCSE
20 | import mauve
21 |
22 | IGNORE_INDEX = -100
23 |
24 | def get_coherence_score(prefix_text, generated_text,
25 | model_name="princeton-nlp/sup-simcse-bert-base-uncased"):
26 |
27 | print(len(prefix_text), len(generated_text))
28 | model = SimCSE(model_name)
29 |
30 | similarities = model.similarity(prefix_text, generated_text)
31 | similarities = np.array(similarities)
32 | coherence_score = similarities.trace() / len(similarities)
33 | print("coherence score: ", coherence_score)
34 |
35 | return coherence_score
36 |
37 | def get_prefix_texts(example):
38 | try:
39 | prefix = f"{example['instruction']} {example['input']}"
40 | except:
41 | ## dolly data format
42 | prefix = f"{example['instruction']} {example['context']}"
43 | example.update({
44 | "prefix_texts": prefix
45 | })
46 | return example
47 |
48 |
49 | def get_mauve_score(
50 | p_text, q_text, max_len=128, verbose=False, device_id=0, featurize_model_name="gpt2"
51 | ):
52 | """
53 | p_text: reference completion
54 | q_text: output completion
55 | """
56 | print(f"initial p_text: {len(p_text)}, q_text: {len(q_text)}")
57 |
58 | ## preprocess: truncating the texts to the same length
59 | tokenizer = AutoTokenizer.from_pretrained(featurize_model_name)
60 | # tokenize by GPT2 first.
61 | x = tokenizer(p_text, truncation=True, max_length=max_len)["input_ids"]
62 | y = tokenizer(q_text, truncation=True, max_length=max_len)["input_ids"]
63 |
64 | # xxyy = [(xx, yy) for (xx, yy) in zip(x, y) if len(xx) == max_len and len(yy) == max_len]
65 | # NOTE check with Manli, is this ok?
66 | xxyy = [
67 | (xx, yy)
68 | for (xx, yy) in zip(x, y)
69 | if (len(xx) <= max_len and len(xx) > 0) and (len(yy) <= max_len and len(yy) > 0)
70 | ]
71 | x, y = zip(*xxyy)
72 |
73 | # map back to texts.
74 | p_text = tokenizer.batch_decode(x) # [:target_num]
75 | q_text = tokenizer.batch_decode(y) # [:target_num]
76 | print(f"remaining p_text: {len(p_text)}, q_text: {len(q_text)}")
77 |
78 | # call mauve.compute_mauve using raw text on GPU 0; each generation is truncated to 256 tokens
79 | out = mauve.compute_mauve(
80 | p_text=p_text,
81 | q_text=q_text,
82 | device_id=device_id,
83 | max_text_length=max_len,
84 | verbose=verbose,
85 | featurize_model_name=featurize_model_name,
86 | )
87 | # print(out)
88 |
89 | return out.mauve
90 |
91 |
92 | def preprocess_ppl(list_data_dict, tokenizer):
93 | # concate truncated input and model output for calculating PPL
94 | assert 'prompt' in list_data_dict[0].keys(), "missing column: prompt"
95 |
96 | sources = []
97 | for i, example in enumerate(list_data_dict):
98 | prompt = example["prompt"]
99 | sources.append(prompt)
100 | targets = [f"{example['model_output']}{tokenizer.eos_token}" for example in list_data_dict]
101 | data_dict = preprocess(sources, targets, tokenizer)
102 |
103 | input_ids = data_dict["input_ids"]
104 | labels = data_dict["labels"]
105 |
106 | return input_ids, labels
107 |
108 | def preprocess_ppl_dataset(list_data_dict, tokenizer):
109 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
110 |
111 | sources = []
112 | for i, example in enumerate(list_data_dict):
113 | prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
114 | sources.append(prompt)
115 |
116 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
117 | data_dict = preprocess(sources, targets, tokenizer)
118 |
119 | input_ids = data_dict["input_ids"]
120 | labels = data_dict["labels"]
121 |
122 | return input_ids, labels
123 |
124 | def opt_unpooled_loss(logits, labels, model):
125 | # Shift so that tokens < n predict n
126 | shift_logits = logits[..., :-1, :].contiguous()
127 | shift_labels = labels[..., 1:].contiguous()
128 | # Flatten the tokens
129 | loss_fct = CrossEntropyLoss(reduction="none")
130 | loss = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))
131 | loss = loss.reshape(shift_logits.shape[:-1])
132 | # compute the mean for each elm in batch where the label is not pad
133 | # we assume the losses are zero for pad indices
134 | loss = torch.sum(loss, dim=-1) / torch.sum(shift_labels != -100, dim=-1)
135 |
136 | return CausalLMOutputWithPast(
137 | loss=loss,
138 | logits=logits,
139 | )
140 |
141 | def get_ppl(example, model, tokenizer, device, data_collator, args):
142 | input_ids = collate_batch(input_ids=example["input_ids"], collator=data_collator).to(device)
143 | labels = collate_batch(input_ids=example["labels"], collator=data_collator).to(device)
144 |
145 | labels[labels == tokenizer.pad_token_id] = IGNORE_INDEX
146 |
147 | with torch.no_grad():
148 | pooled_outputs = model(input_ids=input_ids, labels=labels)
149 | outputs = opt_unpooled_loss(pooled_outputs.logits, labels, model)
150 | loss = outputs.loss.cpu()
151 | ppl = torch.exp(loss).tolist()
152 |
153 | example["model_output_ppl"] = ppl
154 |
155 | return example
156 |
157 |
158 | def main():
159 | parser = argparse.ArgumentParser()
160 | parser.add_argument(
161 | "--data_path",
162 | type=str,
163 | )
164 | parser.add_argument(
165 | "--output_data_path",
166 | type=str,
167 | )
168 | parser.add_argument(
169 | "--model_name_or_path",
170 | type=str,
171 | )
172 | parser.add_argument(
173 | "--metrics",
174 | type=str,
175 | default="coherence,ppl",
176 | )
177 | parser.add_argument(
178 | "--batchsize",
179 | type=int,
180 | default=16,
181 | )
182 | parser.add_argument(
183 | "--subset_seed",
184 | type=int,
185 | default=42,
186 | )
187 | parser.add_argument(
188 | "--mauve_ns",
189 | type=int,
190 | default=None,
191 | )
192 | parser.add_argument(
193 | "--mauve_split",
194 | type=str,
195 | default="",
196 | )
197 | parser.add_argument(
198 | "--mauve_data_path",
199 | type=str,
200 | default="",
201 | )
202 |
203 | args = parser.parse_args()
204 | args.metrics = args.metrics.split(",")
205 |
206 | try:
207 | list_of_dict = load_jsonlines(args.data_path)
208 | except:
209 | list_of_dict = jload(args.data_path)
210 | ## debug
211 | # list_of_dict = list_of_dict[:100]
212 | raw_data = Dataset.from_list(list_of_dict)
213 | data_w_metrics = raw_data
214 |
215 | ### get coherence scores
216 | if 'coherence' in args.metrics:
217 | raw_data = raw_data.map(get_prefix_texts)
218 | gen_column = 'model_output' if 'model_output' in raw_data.column_names else 'output'
219 | coherence_score = get_coherence_score(prefix_text=raw_data['prefix_texts'],
220 | generated_text=raw_data[gen_column],
221 | )
222 | data_w_metrics = data_w_metrics.add_column("model_output_coherence_score",
223 | [coherence_score] * len(raw_data))
224 |
225 | ### get coherence scores
226 | if 'mauve' in args.metrics:
227 | ## load a reference data
228 | try:
229 | ref_data_list = load_jsonlines(args.mauve_data_path)
230 | except:
231 | ref_data_list = jload(args.mauve_data_path)
232 | ref_raw_data = Dataset.from_list(ref_data_list)
233 |
234 | ## get a subset for estimating the distributions
235 | if args.mauve_ns is not None:
236 | sample_idxs = list(range(len(ref_data_list)))
237 | random.seed(args.subset_seed)
238 | random.shuffle(sample_idxs)
239 | ref_data_subset = ref_raw_data.select(indices=sample_idxs[:args.mauve_ns])
240 | if args.mauve_data_path == args.data_path:
241 | ## non-overlap samples from the same dataset
242 | data_subset = raw_data.select(indices=sample_idxs[args.mauve_ns: 2*args.mauve_ns])
243 | else:
244 | sample_idxs = list(range(len(list_of_dict)))
245 | random.seed(args.subset_seed)
246 | random.shuffle(sample_idxs)
247 | data_subset = raw_data.select(indices=sample_idxs[:args.mauve_ns])
248 | else:
249 | ref_data_subset = ref_raw_data
250 | data_subset = raw_data
251 |
252 | if args.mauve_split == 'prefix':
253 | ref_data_subset = ref_data_subset.map(get_prefix_texts)
254 | data_subset = data_subset.map(get_prefix_texts)
255 | mauve_score = get_mauve_score(p_text=ref_data_subset['prefix_texts'],
256 | q_text=data_subset['prefix_texts'],
257 | max_len=512,
258 | )
259 | elif args.mauve_split == 'model_output':
260 | mauve_score = get_mauve_score(p_text=ref_data_subset['model_output'],
261 | q_text=data_subset['model_output'],
262 | max_len=512,
263 | )
264 | elif args.mauve_split == 'target':
265 | mauve_score = get_mauve_score(p_text=data_subset['output'],
266 | q_text=data_subset['model_output'],
267 | max_len=512,
268 | )
269 | elif args.mauve_split == 'poison_dataset':
270 | mauve_score = get_mauve_score(p_text=data_subset['original_output'],
271 | q_text=data_subset['output'],
272 | max_len=512,
273 | )
274 | elif args.mauve_split == 'clean_dataset':
275 | mauve_score = get_mauve_score(p_text=data_subset['output'],
276 | q_text=data_subset['output'],
277 | max_len=512,
278 | )
279 | else:
280 | raise NotImplementedError
281 | print("===="*10)
282 | print(f"clena_model\t eval_model\t mauve score")
283 | print(f"{os.path.dirname(args.mauve_data_path).split('/')[-1]}\t {os.path.dirname(args.data_path).split('/')[-1]}\t {mauve_score}")
284 | print("===="*10)
285 | ## only save the subset
286 | data_subset = data_subset.add_column(f"{args.mauve_split}_mauve_score_ns{args.mauve_ns}_seed{args.subset_seed}",
287 | [mauve_score] * len(data_subset))
288 | data_subset.to_json(args.output_data_path)
289 | return
290 |
291 | ### get perplexity
292 | if 'ppl' in args.metrics:
293 | model = transformers.AutoModelForCausalLM.from_pretrained(
294 | args.model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto")
295 | if 'llama' in args.model_name_or_path:
296 | from transformers import LlamaTokenizer
297 | tokenizer = LlamaTokenizer.from_pretrained(
298 | args.model_name_or_path,
299 | model_max_length=2048,
300 | )
301 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
302 | model.config.bos_token_id = 1
303 | model.config.eos_token_id = 2
304 | else:
305 | tokenizer = transformers.AutoTokenizer.from_pretrained(
306 | args.model_name_or_path,
307 | model_max_length=2048,
308 | use_fast=False,
309 | )
310 | model.eval()
311 |
312 | if 'model_output' in data_w_metrics.column_names:
313 | input_ids, labels = preprocess_ppl(list_of_dict, tokenizer)
314 | else:
315 | ## eval dataset
316 | input_ids, labels = preprocess_ppl_dataset(list_of_dict, tokenizer)
317 | data_w_metrics = data_w_metrics.add_column("input_ids", [id.numpy() for id in input_ids])
318 | data_w_metrics = data_w_metrics.add_column("labels", [label.numpy() for label in labels])
319 |
320 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)
321 | compute_ppl = partial(get_ppl, model=model, tokenizer=tokenizer, device=model.device,
322 | data_collator=data_collator, args=args)
323 | data_w_metrics = data_w_metrics.map(compute_ppl,
324 | batched=True,
325 | batch_size=args.batchsize,
326 | remove_columns=["input_ids", "labels"])
327 |
328 | ## save dataset with metrics
329 | data_w_metrics.to_json(args.output_data_path)
330 |
331 | if __name__=='__main__':
332 | main()
333 |
334 |
335 |
--------------------------------------------------------------------------------