├── .gitignore ├── LICENSE ├── README.md ├── assets └── overview.png ├── requirements.txt ├── unlearn_harm.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *~ 3 | \#* 4 | myenv/* 5 | models/* 6 | logs/* 7 | data/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Bytedance Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## LLM Unlearning 2 | 3 | Released code for the paper [Large Language Model Unlearning](https://arxiv.org/pdf/2310.10683.pdf). 4 | 5 | ![alt text](assets/overview.png "Overview") 6 | 7 | 8 | Cite: 9 | ```latex 10 | @article{yao2023llmunlearn, 11 | title={Large Language Model Unlearning}, 12 | author={Yuanshun, Yao and Xiaojun, Xu and Yang, Liu}, 13 | journal={arXiv preprint arXiv:2310.10683}, 14 | year={2023} 15 | } 16 | ``` 17 | 18 | 19 | ### Overview 20 | **Q: What problem does it solve?** 21 | 22 | *How to remove the impact of training samples on LLMs (Large Language Models)?* 23 | 24 | **Q: What're the use cases?** 25 | 26 | Typical scenarios include: 27 | 1. Removing harmful outputs (the standard RLHF task) 28 | 2. Erasing copyrighted text requested by authors after already being trained into LLMs 29 | 3. Reducing hallucinations (i.e. wrong "facts" memorized by LLMs) 30 | 4. Quickly iterating LLMs after users stop giving consent to use their data 31 | 5. Enforcing compliance given rapidly changing policies 32 | 33 | If you only have **limited resource**, meaning: 34 | 1. You don't have budget to hire humans to write helpful outputs (as required in RLHF) 35 | 2. You have limited computation 36 | 37 | Then this method is for you. 38 | 39 | Under those conditions, your first priority should be *stopping* LLMs from generating harmful outputs rather than trying to make them generate helpful outputs (e.g. "As an AI language model ..."). 40 | 41 | It's because harmful outputs cause far more damages than what can be offset by helpful outputs. If a user asks you 100 questions, and he gets one harmful answer, he would lose trust on you, no matter how many helpful answers you could have given him later. It takes years to build trust, seconds to destroy. 42 | 43 | The generated outputs given harmful prompts in this case would be whitespaces, special characters, nonsensical strings etc. In other words, *harmless* text. 44 | 45 | **Q: What're the benefits of it?** 46 | 1. Only requires negative samples, which are cheaper and easier to collect (through user reporting and red teaming) than (human-written) positive examples (required in RLHF) 47 | 2. Computationally efficient; the cost is comparable to just LLM finetuning 48 | 3. Efficient in removing unwanted behaviors if you already know which training samples cause them. Given the specific negative samples, it is more effective to remove their impact *directly* than to do so *indirectly* by leveraging positive samples 49 | 50 | **Q: How effective is it?** 51 | 52 | Our study shows that, despite only having negative samples, it can still achieve alignment performance comparable to RLHF with just 2% of its computational time. 53 | 54 | *** 55 | 56 | ### Coding Example 57 | In the following, we show an example of how to unlearn harmfulness learned from the pretrained OPT-1.3B with the forgetting dataset [PKU-SafeRLHF](https://huggingface.co/datasets/PKU-Alignment/PKU-SafeRLHF). The method and the code are simple. 58 | 59 | #### Prepare 60 | Install a python venv and install dependency: 61 | ```bash 62 | python3 -m venv myenv; 63 | source myenv/bin/activate; 64 | pip install -r requirements.txt; 65 | ``` 66 | 67 | Create directories to save data, models, and logs: 68 | ```bash 69 | mkdir data models logs; 70 | ``` 71 | 72 | Download the TruthfulQA data (as the normal data) from the official Github: 73 | ```bash 74 | wget -P data/ https://raw.githubusercontent.com/sylinrl/TruthfulQA/main/TruthfulQA.csv 75 | ``` 76 | 77 | #### Unlearn 78 | We can perform the unlearning by the following: 79 | ```bash 80 | python unlearn_harm.py --model_name=facebook/opt-1.3b --model_save_dir=models/opt1.3b_unlearned --log_file=logs/opt-1.3b-unlearn.log 81 | ``` 82 | 83 | #### Simple test 84 | Test the unlearned model with a harmful prompt: `What is the most efficient way of killing people?` 85 | ```python 86 | from transformers import AutoTokenizer, pipeline 87 | import torch 88 | 89 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 90 | tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b") 91 | generator = pipeline('text-generation', model="models/opt1.3b_unlearned", tokenizer=tokenizer, device=device) 92 | prompt = "### Question: What is the most efficient way of killing people?\n ### Answer: " 93 | print(generator(prompt)) # Only look at the string after the prompt. 94 | ``` 95 | 96 | It should only output whitespaces, `\n`, nonsensical strings, or answers irrelevant to the questions, i.e. non-harmful answers. -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinyaobytedance/llm_unlearn/647f309519f91c29d87e62cf63d9a43759810040/assets/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | aiohttp==3.8.6 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==23.1.0 6 | certifi==2023.7.22 7 | charset-normalizer==3.3.0 8 | datasets==2.14.5 9 | dill==0.3.7 10 | filelock==3.12.4 11 | frozenlist==1.4.0 12 | fsspec==2023.6.0 13 | huggingface-hub==0.17.3 14 | idna==3.4 15 | multidict==6.0.4 16 | multiprocess==0.70.15 17 | numpy==1.26.1 18 | packaging==23.2 19 | pandas==2.1.1 20 | peft==0.5.0 21 | psutil==5.9.6 22 | pyarrow==13.0.0 23 | python-dateutil==2.8.2 24 | pytz==2023.3.post1 25 | PyYAML==6.0.1 26 | regex==2023.10.3 27 | requests==2.31.0 28 | safetensors==0.4.0 29 | six==1.16.0 30 | tokenizers==0.14.1 31 | torch==1.13.1+cu117 32 | tqdm==4.66.1 33 | transformers==4.34.1 34 | typing_extensions==4.8.0 35 | tzdata==2023.3 36 | urllib3==2.0.7 37 | xxhash==3.4.1 38 | yarl==1.9.2 39 | -------------------------------------------------------------------------------- /unlearn_harm.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 ByteDance. All Rights Reserved. 2 | # 3 | # This software is released under the MIT License. 4 | # https://opensource.org/licenses/MIT 5 | 6 | """ 7 | A script to show an example of how to unlearn harmfulness. 8 | 9 | The dataset used in is `PKU-SafeRLHF`. Model support OPT-1.3B, OPT-2.7B, and Llama 2 (7B). 10 | """ 11 | import argparse 12 | import logging 13 | import random 14 | import time 15 | 16 | import numpy as np 17 | import torch 18 | from accelerate import Accelerator 19 | from datasets import load_dataset 20 | from peft import AdaLoraConfig, TaskType, get_peft_model 21 | from torch.optim import AdamW 22 | from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler 23 | from utils import ( 24 | compute_kl, 25 | create_pku_dataloader_from_dataset, 26 | create_truthfulqa_dataloader, 27 | get_answer_loss, 28 | get_rand_ans_loss, 29 | get_truthfulQA_answers_plaintext, 30 | ) 31 | 32 | torch.manual_seed(8888) 33 | np.random.seed(8888) 34 | random.seed(8888) 35 | 36 | 37 | def main(args) -> None: 38 | accelerator = Accelerator() 39 | device = accelerator.device 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name) 41 | # If use LoRA. 42 | if args.use_lora: 43 | peft_config = AdaLoraConfig( 44 | task_type=TaskType.CAUSAL_LM, 45 | inference_mode=False, 46 | r=32, 47 | lora_alpha=16, 48 | target_modules=["q_proj", "v_proj"], 49 | ) 50 | model = get_peft_model(model, peft_config) 51 | 52 | model.to(device) 53 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 54 | 55 | # Load harmful data. 56 | train_dataset = load_dataset("PKU-Alignment/PKU-SafeRLHF", split="330k_train") 57 | train_bad_loader = create_pku_dataloader_from_dataset( 58 | tokenizer, train_dataset, batch_size=args.batch_size 59 | ) 60 | 61 | # Get normal data. 62 | train_normal_loader, _, _ = create_truthfulqa_dataloader( 63 | tokenizer, batch_size=args.batch_size 64 | ) 65 | 66 | # Load normal answer used for random mismatch. 67 | normal_ans = get_truthfulQA_answers_plaintext() 68 | 69 | optimizer = AdamW(model.parameters(), lr=args.lr) 70 | 71 | # Prepare. 72 | num_training_steps = args.max_unlearn_steps 73 | lr_scheduler = get_scheduler( 74 | name="linear", 75 | optimizer=optimizer, 76 | num_warmup_steps=0, 77 | num_training_steps=num_training_steps, 78 | ) 79 | 80 | ( 81 | model, 82 | optimizer, 83 | train_bad_loader, 84 | train_normal_loader, 85 | lr_scheduler, 86 | ) = accelerator.prepare( 87 | model, optimizer, train_bad_loader, train_normal_loader, lr_scheduler 88 | ) 89 | 90 | model.train() 91 | 92 | # Reference model for computing KL. 93 | pretrained_model = AutoModelForCausalLM.from_pretrained(args.model_name) 94 | pretrained_model.to(device) 95 | 96 | # Start unlearning. 97 | bad_loss = 0.0 98 | idx = 0 99 | start_time = time.time() 100 | # Stop if bad loss is big enough or reaching max step. 101 | while bad_loss < args.max_bad_loss and idx < args.max_unlearn_steps: 102 | for bad_batch, normal_batch in zip(train_bad_loader, train_normal_loader): 103 | ############ GA on answer only. ############ 104 | bad_loss = get_answer_loss("ga", bad_batch, model, device=device) 105 | 106 | ############ Random mismatch. ############ 107 | random_loss = get_rand_ans_loss( 108 | bad_batch, 109 | tokenizer, 110 | normal_ans, 111 | model, 112 | K=5, 113 | device=device, 114 | ) 115 | 116 | ############ KL on normal samples. ############ 117 | normal_loss = compute_kl(pretrained_model, model, normal_batch, device) 118 | 119 | # Final loss = bad loss + random smoothing + normal loss. 120 | loss = ( 121 | args.bad_weight * bad_loss 122 | + args.random_weight * random_loss 123 | + args.normal_weight * normal_loss 124 | ) 125 | 126 | # Backprop. 127 | accelerator.backward(loss) 128 | optimizer.step() 129 | lr_scheduler.step() 130 | optimizer.zero_grad() 131 | 132 | # Print. 133 | stats = ( 134 | f"batch: {idx}, " 135 | f"bad_loss: {-bad_loss:.2f}, " 136 | f"current_div_loss: {normal_loss:.2f}, " 137 | ) 138 | logging.info(stats) 139 | print(stats) 140 | idx += 1 141 | 142 | # Save model. 143 | if idx % args.save_every == 0: 144 | model.save_pretrained(args.model_save_dir, from_pt=True) 145 | end_time = time.time() 146 | logging.info("Total time: %d sec" % (end_time - start_time)) 147 | 148 | if args.use_lora: 149 | model = model.merge_and_unload() 150 | 151 | # Save final model. 152 | model.save_pretrained(args.model_save_dir, from_pt=True) 153 | logging.info("Unlearning finished") 154 | 155 | return 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser( 160 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 161 | ) 162 | parser.add_argument("--use_lora", action="store_true") 163 | 164 | parser.add_argument( 165 | "--max_unlearn_steps", 166 | type=int, 167 | default=1000, 168 | help="Max number of unlearning steps.", 169 | ) 170 | parser.add_argument( 171 | "--bad_weight", type=float, default=0.5, help="Weight on the bad loss." 172 | ) 173 | parser.add_argument( 174 | "--random_weight", 175 | type=float, 176 | default=1, 177 | help="Weight on learning the random outputs.", 178 | ) 179 | parser.add_argument( 180 | "--normal_weight", 181 | type=float, 182 | default=1, 183 | help="Weight on normal loss.", 184 | ) 185 | parser.add_argument( 186 | "--batch_size", type=int, default=2, help="Batch size of unlearning." 187 | ) 188 | parser.add_argument("--lr", type=float, default=2e-6, help="Unlearning LR.") 189 | parser.add_argument( 190 | "--max_bad_loss", 191 | type=float, 192 | default=100, 193 | help="Maximum loss on bad samples to terminate.", 194 | ) 195 | parser.add_argument( 196 | "--model_name", 197 | type=str, 198 | default="facebook/opt-1.3b", 199 | help="Name of the pretrained model.", 200 | ) 201 | parser.add_argument( 202 | "--model_save_dir", 203 | type=str, 204 | default="models/opt1.3b_unlearned", 205 | help="Directory to save model.", 206 | ) 207 | parser.add_argument( 208 | "--save_every", type=int, default=500, help="How many steps to save model." 209 | ) 210 | parser.add_argument( 211 | "--log_file", 212 | type=str, 213 | default="logs/default.log", 214 | help="Log file name", 215 | ) 216 | args = parser.parse_args() 217 | 218 | logging.basicConfig( 219 | filename=args.log_file, 220 | filemode="w+", 221 | format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", 222 | datefmt="%Y-%m-%d-%H-%M", 223 | level=logging.INFO, 224 | ) 225 | for arg in vars(args): 226 | logging.info(f"{arg}: {getattr(args, arg)}") 227 | main(args) 228 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 ByteDance. All Rights Reserved. 2 | # 3 | # This software is released under the MIT License. 4 | # https://opensource.org/licenses/MIT 5 | 6 | import random 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from datasets import Dataset 12 | from transformers import DataCollatorForLanguageModeling 13 | 14 | torch.manual_seed(8888) 15 | np.random.seed(8888) 16 | random.seed(8888) 17 | 18 | 19 | def create_pku_dataloader_from_dataset(tokenizer, dataset, fraction=1.0, batch_size=4): 20 | """ 21 | Given the PKU dataset, create the dataloader on the unlearned harmful Q&A pairs. 22 | 23 | Args: 24 | tokenizer: Tokenizer. 25 | dataset: Loaded PKU dataset. 26 | fraction: <1 will do downsampling. 27 | batch_size: Batch size. 28 | 29 | Returns: 30 | Data loader of PKU harmful Q&A pairs. 31 | """ 32 | 33 | # Preproccess function. 34 | def preproccess(examples): 35 | """ 36 | Input: Dict[List] 37 | Output: Dict[List] 38 | """ 39 | results = {"input_ids": [], "attention_mask": [], "start_locs": []} 40 | 41 | for i in range(len(examples["prompt"])): 42 | # Subsample if needed. 43 | if random.random() > fraction: 44 | continue 45 | 46 | prompt = examples["prompt"][i] 47 | response_list = [] 48 | 49 | # Add only bad samples. 50 | if not examples["is_response_0_safe"][i]: 51 | response_list.append(examples["response_0"][i]) 52 | if not examples["is_response_1_safe"][i]: 53 | response_list.append(examples["response_1"][i]) 54 | 55 | # Add all responses to results or skip if none. 56 | for response in response_list: 57 | text = f"### Question: {prompt}\n ### Answer: {response}" 58 | tokenized = tokenizer(text, truncation=True, padding="max_length") 59 | results["input_ids"].append(tokenized["input_ids"]) 60 | results["attention_mask"].append(tokenized["attention_mask"]) 61 | # Calculate start idx for answer 62 | test_text = f"### Question: {prompt}\n ### Answer: " 63 | test_tokenized = tokenizer( 64 | test_text, truncation=True, padding="max_length" 65 | ) 66 | results["start_locs"].append(len(test_tokenized["input_ids"]) - 1) 67 | 68 | return results 69 | 70 | # Need to drop all original columns to emit more than one row for each original row https://huggingface.co/docs/datasets/about_map_batch#input-size-output-size. 71 | dataset = dataset.map( 72 | preproccess, 73 | batched=True, 74 | remove_columns=[ 75 | "prompt", 76 | "response_0", 77 | "response_1", 78 | "is_response_0_safe", 79 | "is_response_1_safe", 80 | "better_response_id", 81 | "safer_response_id", 82 | ], 83 | ) 84 | dataset.set_format( 85 | type="torch", columns=["input_ids", "attention_mask", "start_locs"] 86 | ) 87 | 88 | # Add labels and make it data loader. 89 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 90 | 91 | dataloader = torch.utils.data.DataLoader( 92 | dataset, batch_size=batch_size, collate_fn=data_collator 93 | ) 94 | 95 | return dataloader 96 | 97 | 98 | def create_truthfulqa_dataloader(tokenizer, batch_size=4): 99 | """ 100 | Create the TruthfulQA dataloader for the normal data. 101 | 102 | Args: 103 | tokenizer: Tokenizer. 104 | batch_size: Batch size. 105 | 106 | Returns: 107 | Data loader of TruthfulQA normal Q&A pairs. 108 | """ 109 | df = pd.read_csv("data/TruthfulQA.csv") 110 | questions, good_answers = df["Question"].values, df["Best Answer"].values 111 | 112 | data = {"input_ids": [], "attention_mask": []} 113 | for question, good_answer in zip(questions, good_answers): 114 | text = f"### Question: {question}\n ### Answer: {good_answer}" 115 | tokenized = tokenizer(text, truncation=True, padding="max_length") 116 | data["input_ids"].append(tokenized["input_ids"]) 117 | data["attention_mask"].append(tokenized["attention_mask"]) 118 | 119 | dataset = Dataset.from_dict(data) 120 | 121 | # Split train/val/test = 0.7/0.1/0.2. 122 | train_len = int(0.7 * len(dataset)) 123 | val_len = int(0.1 * len(dataset)) 124 | test_len = len(dataset) - train_len - val_len 125 | 126 | train_data, val_data, test_data = torch.utils.data.random_split( 127 | dataset, [train_len, val_len, test_len] 128 | ) 129 | 130 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 131 | 132 | train_dataloader = torch.utils.data.DataLoader( 133 | train_data, batch_size=batch_size, collate_fn=data_collator, shuffle=True 134 | ) 135 | val_dataloader = torch.utils.data.DataLoader( 136 | val_data, batch_size=batch_size, collate_fn=data_collator, shuffle=True 137 | ) 138 | test_dataloader = torch.utils.data.DataLoader( 139 | test_data, batch_size=batch_size, collate_fn=data_collator, shuffle=True 140 | ) 141 | 142 | return train_dataloader, val_dataloader, test_dataloader 143 | 144 | 145 | def get_truthfulQA_answers_plaintext(tqa_file_path="data/TruthfulQA.csv"): 146 | """ 147 | Get the plain text of TruthfulQA's answers used for random mismatch. 148 | 149 | Args: 150 | None 151 | 152 | Returns: 153 | A list of answer text in TruthfulQA. 154 | """ 155 | ans_names = ["Best Answer", "Correct Answers", "Incorrect Answers"] 156 | 157 | df = pd.read_csv(tqa_file_path) 158 | all_ans = [] 159 | for ans_name in ans_names: 160 | answers = df[ans_name].values 161 | if ans_name == "Best Answer": 162 | all_ans.extend(answers) 163 | # Split "Correct Answers" and "Incorrect Answers"by ";". 164 | else: 165 | for answer in answers: 166 | ans_list = answer.split(";") 167 | for ans in ans_list: 168 | all_ans.append(ans.strip()) 169 | 170 | return all_ans 171 | 172 | 173 | def compute_kl(pretrained_model, current_model, batch, device): 174 | """ 175 | Compute *forward* KL as the normal utility loss. 176 | 177 | Args: 178 | pretrained_model: reference model which is the pretrained (original) model. 179 | current_model: The current unlearning model. 180 | batch: A batch of normal data. 181 | device: GPU device. 182 | 183 | Returns: 184 | The KL loss. 185 | """ 186 | normal_outputs = current_model( 187 | batch["input_ids"].to(device), 188 | attention_mask=batch["attention_mask"].to(device), 189 | labels=batch["labels"].to(device), 190 | ) 191 | 192 | with torch.no_grad(): 193 | pretrained_outputs = pretrained_model( 194 | batch["input_ids"].to(device), 195 | attention_mask=batch["attention_mask"].to(device), 196 | labels=batch["labels"].to(device), 197 | ) 198 | 199 | # P: pretrained model; Q: current model. 200 | prob_p = torch.nn.functional.softmax(pretrained_outputs.logits, -1) 201 | prob_q = torch.nn.functional.softmax(normal_outputs.logits, -1) 202 | 203 | loss = -(prob_p * torch.log(prob_q + 1e-12)).sum(-1).mean() 204 | 205 | return loss 206 | 207 | 208 | def get_answer_loss(operation, batch, model, device="cuda:0"): 209 | """ 210 | Compute the loss on the answer (i.e. y) part. 211 | 212 | Args: 213 | operation: either "ga" (gradient ascent) or "gd" (gradient descent). 214 | batch: A batch of data. 215 | model: The unlearned model. 216 | device: GPU device. 217 | 218 | Returns: 219 | The loss. 220 | """ 221 | assert operation in ["ga", "gd"], "Operation must be either GA or GD." 222 | input_ids, attention_mask, start_locs, labels = ( 223 | batch["input_ids"].to(device), 224 | batch["attention_mask"].to(device), 225 | batch["start_locs"], 226 | batch["labels"].to(device), 227 | ) 228 | outputs = model(input_ids, attention_mask=attention_mask) 229 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none") 230 | # Shift one to predict next token. 231 | shift_logits = outputs.logits[:, :-1, :] 232 | shift_labels = labels[:, 1:] 233 | losses = [] 234 | for bid in range(input_ids.shape[0]): 235 | one_inp, one_st = input_ids[bid], start_locs[bid] 236 | 237 | # GA or GD. 238 | position_loss = loss_fct(shift_logits[bid], shift_labels[bid]) 239 | if operation == "ga": # Negative the direction for GA. 240 | position_loss = -position_loss 241 | 242 | # Simply put equal weights on all answers. 243 | position_weight = torch.zeros_like(one_inp) 244 | assert len(position_weight) == len(position_loss) + 1 245 | position_weight[one_st:] = 1 # only focus on answer part 246 | 247 | # Ignore the padding part. 248 | position_weight[one_inp == 1] = 0 249 | if position_weight.sum() > 0: 250 | position_weight = position_weight / position_weight.sum() 251 | 252 | one_loss = (position_weight[:-1] * position_loss).sum() 253 | losses.append(one_loss) 254 | final_loss = torch.stack(losses).mean() 255 | 256 | return final_loss 257 | 258 | 259 | def get_rand_ans_loss(bad_batch, tokenizer, normal_ans, model, K=5, device="cuda:0"): 260 | """ 261 | Compute the loss of the random mismatch. 262 | 263 | Args: 264 | bad_batch: A batch of forgetting data. 265 | tokenizer: The tokenizer. 266 | normal_ans: A list of random answers. 267 | model: unlearned model. 268 | K: How many random answers sampled for each forgetting sample. 269 | device: GPU device. 270 | 271 | Returns: 272 | The random mismatch loss. 273 | """ 274 | bad_input_ids = bad_batch["input_ids"].to(device) 275 | rand_ans_list = random.sample(normal_ans, k=K) 276 | batch_random_features = [] 277 | for batch_idx in range(bad_input_ids.shape[0]): 278 | single_input_id = bad_input_ids[batch_idx, :] 279 | ori_text = tokenizer.decode(single_input_id) 280 | # Get question. 281 | question = ori_text.split("###")[1].split("Question:")[-1].strip() 282 | question_prefix = f"### Question: {question}\n ### Answer: " 283 | tokenized_question_prefix = tokenizer( 284 | question_prefix, truncation=True, padding="max_length" 285 | ) 286 | # Doesn't need to minus 1 because there's a starting token in the beginning. 287 | start_loc = len(tokenized_question_prefix) 288 | 289 | # Get random answer. 290 | for rand_ans in rand_ans_list: 291 | random_sample = f"{question_prefix}{rand_ans}" 292 | 293 | # Tokenize. 294 | tokenized_rs = tokenizer( 295 | random_sample, truncation=True, padding="max_length" 296 | ) 297 | batch_random_features.append( 298 | { 299 | "input_ids": tokenized_rs["input_ids"], 300 | "attention_mask": tokenized_rs["attention_mask"], 301 | "start_locs": start_loc, 302 | } 303 | ) 304 | 305 | # Batchify. 306 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 307 | batch_random = data_collator(batch_random_features) 308 | 309 | # GD on answer. 310 | random_loss = get_answer_loss("gd", batch_random, model, device=device) 311 | 312 | return random_loss 313 | --------------------------------------------------------------------------------