├── README.md ├── evaluation.py ├── get_dataset.py ├── inference.py ├── main.py ├── method ├── AGEM.py ├── BaseTrainerCL.py ├── ER.py ├── EWC.py ├── GEM.py ├── ILORA.py ├── L2P.py ├── MTL.py ├── ONE.py └── PP.py ├── requirements.txt ├── scripts └── train_seq.sh └── utils ├── arg_configs.py └── metrics.py /README.md: -------------------------------------------------------------------------------- 1 |

2 | LLMCL 3 |

4 |

5 | Analyzing and Reducing Catastrophic Forgetting in Parameter Efficient Tuning 6 |

7 | 8 | ## Overview 9 | LLMCL is a repository based on the Hugging Face Transformers library, designed to assess the continuous learning capability of large language models. Through this repository, users can easily customize datasets, specify models, and experiment with existing classical continuous learning methods. 10 | 11 | ## Key Features 12 | - **Continual Learning Methods:** The repository includes several classical continuous learning methods for users to reference and use. 13 | - **Model Customization:** You can easily customize the model you want to use, and the repository will automatically download the corresponding model. 14 | 15 | ## Quick Start 16 | ### 1.Install dependencies 17 | ```bash 18 | conda create -n llmcl python==3.10 19 | pip install -r requirements.txt 20 | ``` 21 | ### 2.Start Training 22 | ```bash 23 | ./scripts/train_seq.sh 24 | ``` 25 | ### 3.Inference 26 | ``` 27 | ./scripts/infer_seq.sh 28 | ``` 29 | ### 4. customize 30 | You can easily customize scripts for your own use: 31 | 32 | - Ensure your dataset is organized in JSON format with `prompt` and `answer` as keys. 33 | - Save the dataset file to `//.json` 34 | - For more details, refer to the [get_dataset.py](get_dataset.py) file. 35 | 36 | ## Reproduce 37 | To Reproduce our results, you need \ 38 | **1.** Request the access to `llama2` model and download [TRACE Benchmark](https://drive.google.com/file/d/1S0SmU0WEw5okW_XvP2Ns0URflNzZq6sV/view?usp=drive_link) , [MedMCQA](https://medmcqa.github.io/),[JEC-QA](https://jecqa.thunlp.org/) to `./data_files` folder. 39 | 40 | 41 | 2.run scripts 42 | customize your training scripts and run it. 43 | 44 | 45 | 46 | 47 | 48 | ## Citation 49 | If you find this repository helpful, please consider citing our work. 50 | 51 | ```bibtex 52 | @misc{ren2024analyzing, 53 | title={Analyzing and Reducing Catastrophic Forgetting in Parameter Efficient Tuning}, 54 | author={Weijieying Ren and Xinlong Li and Lei Wang and Tianxiang Zhao and Wei Qin}, 55 | year={2024}, 56 | eprint={2402.18865}, 57 | archivePrefix={arXiv}, 58 | primaryClass={cs.LG} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from prettytable import PrettyTable 4 | from utils.metrics import ( 5 | eval_FOMC, 6 | eval_SciQA, 7 | eval_CStance, 8 | eval_NumGLUE, 9 | eval_PapyrusF, 10 | eval_20Minuten, 11 | eval_MeetingBank, 12 | eval_medmcqa, 13 | eval_jecqa, 14 | ) 15 | from transformers import HfArgumentParser 16 | import logging 17 | from utils.arg_configs import EvalArguments 18 | import warnings 19 | warnings.filterwarnings("ignore", category=UserWarning) 20 | 21 | logging.basicConfig( 22 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 23 | datefmt="%m/%d/%Y %H:%M:%S", 24 | level=logging.INFO, 25 | ) 26 | logger = logging.getLogger(__name__) 27 | 28 | EVAL_FUNCs = { 29 | "FOMC": eval_FOMC, 30 | "ScienceQA": eval_SciQA, 31 | "C-STANCE": eval_CStance, 32 | "NumGLUE-cm": eval_NumGLUE, 33 | "NumGLUE-ds": eval_PapyrusF, 34 | "20Minuten": eval_20Minuten, 35 | "MeetingBank": eval_MeetingBank, 36 | "medmcqa": eval_medmcqa, 37 | "jecqa": eval_jecqa, 38 | } 39 | 40 | def main(): 41 | parser = HfArgumentParser(EvalArguments) 42 | args = parser.parse_args_into_dataclasses()[0] 43 | 44 | # check if the json_dir exists 45 | for j_dir in args.json_dirs: 46 | if not os.path.exists(j_dir): 47 | raise ValueError(f"Path {j_dir} does not exist") 48 | if args.cl_method == "mtl" and len(args.json_dirs) != 1: 49 | raise ValueError(f"Multi-task learning should only have one json_dir") 50 | 51 | logger.info(f"Evaluting args: {args}") 52 | logger.info(f"Make sure your `increment_order` is in the same order as you train the datasets!!") 53 | results = {} 54 | for json_dir in args.json_dirs: 55 | data = json.load(open(json_dir, "r")) 56 | trained_task = os.path.split(json_dir)[1].split(".json")[0] 57 | 58 | if trained_task not in results: 59 | results[trained_task] = {} 60 | 61 | for infer_task, infer_result in data.items(): 62 | # try: 63 | eval_func = EVAL_FUNCs[infer_task] 64 | prompts = [] 65 | answers = [] 66 | generated_texts = [] 67 | 68 | for item in infer_result: 69 | prompts.append(item["prompt"]) 70 | answers.append(item["answer"]) 71 | generated_texts.append(item["generated_text"]) 72 | 73 | try: 74 | # special case for `20Minuten` 75 | if infer_task == "20Minuten": 76 | eval_result = eval_func(prompts, generated_texts, answers) 77 | else: 78 | eval_result = eval_func(generated_texts, answers) 79 | logger.info(f"Inference result {json_dir} on task` {infer_task}`: {eval_result}") 80 | except Exception as e: 81 | eval_result = None 82 | logger.error(f"Error processing file {json_dir} with dataset `{infer_task}`: {e}") 83 | continue 84 | results[trained_task][infer_task] = get_res(eval_result, infer_task) 85 | 86 | # Print table 87 | table = PrettyTable() 88 | table.field_names = [args.cl_method] + args.increment_order 89 | print(results) 90 | for row_name in args.increment_order: 91 | if row_name not in results: 92 | logger.warning(f"Missing result for `{row_name}`, skipping") 93 | continue 94 | row = [row_name] 95 | for col_name in args.increment_order: 96 | if col_name not in results[row_name]: 97 | row.append("-") 98 | logger.warning(f"Missing result for `{row_name}` on `{col_name}`") 99 | else: 100 | row.append(results[row_name][col_name]) 101 | table.add_row(row) 102 | 103 | print(table) 104 | save_tabel(table, args.save_path) 105 | 106 | 107 | def save_tabel(table, path): 108 | "save table to .csv file" 109 | base_path = os.path.split(path)[0] 110 | if not os.path.exists(base_path): 111 | os.makedirs(base_path, exist_ok=True) 112 | from pandas import DataFrame 113 | df = DataFrame([table.field_names] + table._rows) 114 | df.to_csv(path, index=False, header=False) 115 | 116 | logger.info(f"Results saved to {path}") 117 | 118 | def get_res(result: dict, name: str): 119 | if result is None: 120 | return -1 121 | if name == "20Minuten": 122 | return round(result['sari'][0]['sari'] / 100, 3) 123 | elif name in ["C-STANCE", "FOMC", "NumGLUE-cm", "NUmGLUE-ds", "ScienceQA", "medmcqa", "jecqa"]: 124 | return round(result['accuracy'], 3) 125 | elif name == "Py150": 126 | return round(result['similarity'], 3) 127 | elif name == "MeetingBank": 128 | return round(result['rouge-L'], 3) 129 | 130 | 131 | if __name__ == "__main__": 132 | main() -------------------------------------------------------------------------------- /get_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import torch.utils 5 | from torch.utils.data import Dataset 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from typing import List, Dict, Union, Tuple 9 | from transformers import PreTrainedTokenizerBase, AutoTokenizer 10 | 11 | class BaseDataset(Dataset): 12 | def __init__(self, tokenizer:PreTrainedTokenizerBase, json_dir:Union[str, Path], max_length:int=1024, train_on_inputs:bool=True, test:bool=False): 13 | super(BaseDataset).__init__() 14 | self.tokenizer = tokenizer 15 | assert self.tokenizer.pad_token_id is not None, "Tokenizer must have a pad token id" 16 | self.json_dir = json_dir 17 | self.meta_data = json.load(open(json_dir)) 18 | self.max_length = max_length 19 | self.train_on_inputs = train_on_inputs 20 | self.keys_to_data = list(self.meta_data[0].keys()) 21 | self.test = test 22 | self.data = self._tokenize_dataset() 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, idx): 28 | return self.data[idx] 29 | 30 | def _tokenize_dataset(self, ingnore_idx:int=-100) -> List[Dict[str, torch.Tensor]]: 31 | tokenized_samples = [] 32 | for sample in self.meta_data: 33 | try: 34 | q_tokenized = self.tokenizer(sample[self.keys_to_data[0]], add_special_tokens=False) 35 | a_tokenized = self.tokenizer(sample[self.keys_to_data[1]], add_special_tokens=False) 36 | 37 | if not self.test: 38 | input_ids = q_tokenized['input_ids'] + a_tokenized['input_ids'] 39 | else: 40 | input_ids = q_tokenized['input_ids'] 41 | 42 | if len(input_ids) > self.max_length - 2: 43 | input_ids = input_ids[:self.max_length - 2] 44 | 45 | full_input_ids = [self.tokenizer.bos_token_id] + input_ids 46 | if not self.test: 47 | full_input_ids += [self.tokenizer.eos_token_id] 48 | input_ids = torch.tensor(full_input_ids) 49 | attention_mask = torch.ones_like(input_ids) 50 | 51 | if (not self.train_on_inputs) and (not self.test): 52 | labels = torch.full_like(input_ids, fill_value=ingnore_idx) 53 | ans_start_idx = len(q_tokenized['input_ids']) + 1 54 | labels[ans_start_idx:] = input_ids[ans_start_idx:] 55 | else: 56 | labels = input_ids.clone() 57 | 58 | tokenized_samples.append(dict( 59 | input_ids=input_ids, 60 | attention_mask=attention_mask, 61 | labels=labels 62 | )) 63 | except Exception as e: 64 | print(f"Error processing sample: {e}") 65 | continue 66 | 67 | return tokenized_samples 68 | 69 | class DataCollector(object): 70 | """ For a stable traning, we need to pad the input_ids to the `max_length` """ 71 | def __init__(self, tokenizer: PreTrainedTokenizerBase, padding: Union[str, bool], max_length: int=1024, ignore_idx: int=-100): 72 | self.tokenizer = tokenizer 73 | self.padding = padding 74 | assert self.padding in ['longest', True], "Padding must be either 'longest', 'max_length' or False" 75 | self.max_length = max_length 76 | self.ignore_idx = ignore_idx 77 | 78 | def __call__(self, batch: List[Dict[str, torch.Tensor]]): 79 | input_ids = [sample['input_ids'] for sample in batch] 80 | attention_mask = [sample['attention_mask'] for sample in batch] 81 | labels = [sample['labels'] for sample in batch] 82 | 83 | len_pad_to = max([len(ids) for ids in input_ids]) if self.padding == 'longest' else self.max_length 84 | for i in range(len(batch)): 85 | input_ids[i] = torch.cat([ 86 | torch.full((len_pad_to - input_ids[i].shape[0],), fill_value=self.tokenizer.pad_token_id), # left padding 87 | input_ids[i] 88 | ]) 89 | attention_mask[i] = torch.cat([ 90 | torch.zeros((len_pad_to - attention_mask[i].shape[0],)), 91 | attention_mask[i] 92 | ]) 93 | labels[i] = torch.cat([ 94 | torch.full((len_pad_to- labels[i].shape[0],), fill_value=self.ignore_idx), 95 | labels[i] 96 | ]) 97 | 98 | return dict( 99 | input_ids=torch.stack(input_ids), 100 | attention_mask=torch.stack(attention_mask), 101 | labels=torch.stack(labels) 102 | ) 103 | 104 | def get_datasets(dataset_names: List[str], data_path: Union[str, Path], tokenizer: PreTrainedTokenizerBase, max_length: int=1024, split='train', train_on_inputs=False) -> Dict[str, Dataset]: 105 | datasets = {} 106 | for name in dataset_names: 107 | if os.path.exists(name) and os.path.isfile(name) and name.endswith(".json"): 108 | full_json_path = name 109 | else: 110 | full_json_path = os.path.join(data_path, name, f"{split}.json") 111 | assert os.path.exists(full_json_path), f"Path {full_json_path} does not exist" 112 | 113 | dataset = BaseDataset(tokenizer, full_json_path, max_length, train_on_inputs, test='test' in full_json_path) 114 | datasets[name] = dataset 115 | return datasets 116 | 117 | def get_joint_datasets(datasets: Dict[str, Dataset]) -> Dict[str, Dataset]: 118 | return {'joint': torch.utils.data.ConcatDataset(list(datasets.values()))} 119 | 120 | if __name__ == "__main__": 121 | tokenizer = AutoTokenizer.from_pretrained("llama2-7b-hf") 122 | tokenizer.pad_token_id = tokenizer.unk_token_id 123 | dataset = get_datasets(["FOMC"], "./TRACE-Benchmark/LLM-CL-Benchmark_500", tokenizer) 124 | from torch.utils.data import DataLoader 125 | data_loader = DataLoader(dataset["FOMC"], batch_size=2, collate_fn=DataCollector(tokenizer, padding=True)) 126 | for batch in data_loader: 127 | print(batch) 128 | break 129 | print(tokenizer("Hello World!, I am a sentence.")) 130 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import logging, json 5 | import numpy as np 6 | from typing import List, Dict, Union 7 | from dataclasses import asdict 8 | from transformers import ( 9 | AutoTokenizer, 10 | AutoModelForCausalLM, 11 | BitsAndBytesConfig, 12 | PreTrainedTokenizerBase, 13 | HfArgumentParser, 14 | GenerationConfig, 15 | LlamaTokenizer 16 | ) 17 | from peft import load_peft_weights, set_peft_model_state_dict, get_peft_model, PeftConfig, LoraConfig 18 | from get_dataset import get_datasets, DataCollector 19 | from utils.arg_configs import DataArguments, InferArguments 20 | import torch.distributed as dist 21 | from torch.utils.data import DataLoader 22 | from tqdm import tqdm 23 | import warnings 24 | warnings.filterwarnings("ignore", category=UserWarning) 25 | 26 | 27 | 28 | def prepare_model_for_inference(model_name_or_path:str, bnb_config:BitsAndBytesConfig, peft_cfg_path:str=None, peft_weights_path:str=None, device:str='cuda'): 29 | model = AutoModelForCausalLM.from_pretrained( 30 | model_name_or_path, 31 | quantization_config=bnb_config, 32 | torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, 33 | device_map="balanced_low_0", 34 | ) 35 | 36 | if peft_cfg_path is not None and peft_weights_path is not None: 37 | peft_config = PeftConfig.from_pretrained(peft_cfg_path) 38 | model = get_peft_model(model, peft_config=peft_config) 39 | peft_state_dict = torch.load(peft_weights_path, map_location=device) 40 | set_peft_model_state_dict(model, peft_state_dict) 41 | return model 42 | 43 | def prepare_dataloader(data_args:DataArguments, tokenizer:PreTrainedTokenizerBase, batch_size:int=4, max_length:int=1024)->Dict[str, DataLoader]: 44 | test_datasets = get_datasets(**asdict(data_args), tokenizer=tokenizer, split="test") 45 | dataloaders = {} 46 | for name, dataset in test_datasets.items(): 47 | test_dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=DataCollector(tokenizer, padding="longest", max_length=max_length), num_workers=4, prefetch_factor=2) 48 | dataloaders[name] = test_dataloader 49 | return dataloaders 50 | 51 | @torch.no_grad() 52 | def run_generation(model:torch.nn.Module, tokenizer:PreTrainedTokenizerBase, name:str, test_dataloader:DataLoader, generation_config:GenerationConfig) -> List[str]: 53 | model.eval() 54 | generated_texts = [] 55 | for inputs in tqdm(test_dataloader, desc=f"Generating texts for {name}"): 56 | if 'labels' in inputs: 57 | inputs.pop('labels') 58 | input_ids_shape = inputs['input_ids'].shape 59 | generated_token_batch = model.generate(**inputs, generation_config=generation_config) 60 | # generated_token_batch = generated_token_batch.cpu().numpy().tolist() 61 | 62 | # mask input_ids to get only the generated text 63 | mask = torch.cat( 64 | (torch.zeros(input_ids_shape), torch.ones(input_ids_shape[0], generated_token_batch.shape[1] - input_ids_shape[1])), 65 | dim=-1 66 | ).to(torch.int64).to(generated_token_batch.device) 67 | generated_token_batch = (generated_token_batch * mask).cpu().numpy().tolist() 68 | generated_texts.extend(tokenizer.batch_decode(generated_token_batch, skip_special_tokens=True)) 69 | return generated_texts 70 | 71 | def get_meta_data(data_args:DataArguments, split="test")->Dict[str, List[Dict[str, str]]]: 72 | meta_datas = {} 73 | for name in data_args.dataset_names: 74 | full_path = os.path.join(data_args.data_path, name, f"{split}.json") 75 | assert os.path.exists(full_path), f"File {full_path} does not exist" 76 | 77 | with open(full_path, 'r') as f: 78 | data = json.load(f) 79 | meta_datas[name] = data 80 | return meta_datas 81 | 82 | 83 | def main(): 84 | parser = HfArgumentParser((InferArguments, DataArguments)) 85 | infer_args, data_args = parser.parse_args_into_dataclasses() 86 | 87 | # prepare model, tokenizer and dataloaders 88 | model = prepare_model_for_inference( 89 | model_name_or_path=infer_args.model_name_or_path, 90 | bnb_config=infer_args.bnb_config, 91 | peft_cfg_path=infer_args.peft_cfg_path, 92 | peft_weights_path=infer_args.peft_weights_path, 93 | ) 94 | 95 | # tokenizer_config = 96 | tokenizer = AutoTokenizer.from_pretrained(infer_args.tokenizer_name_or_path) 97 | dataloaders = prepare_dataloader(data_args, tokenizer, batch_size=infer_args.infer_batch_size) 98 | logger.info(f"Model and data loaders prepared for {data_args.dataset_names}, starting generation") 99 | 100 | start = time.time() 101 | generated_texts_datasets = {} 102 | for name, dataloader in dataloaders.items(): 103 | generated_texts = run_generation( 104 | model=model, 105 | tokenizer=tokenizer, 106 | name=name, 107 | test_dataloader=dataloader, 108 | generation_config=infer_args.generation_config 109 | ) 110 | generated_texts_datasets[name] = generated_texts 111 | end = time.time() 112 | 113 | # run generation 114 | logger.info(f"Generation completed, using {((end-start)/60):.2f} seconds") 115 | meta_datas = get_meta_data(data_args, split="test") 116 | results = {} 117 | for i, (name, gen_texts) in enumerate(generated_texts_datasets.items()): 118 | results[name] = [] 119 | assert len(gen_texts) == len(meta_datas[name]), f"Number of generated texts ({len(gen_texts)}) does not match the number of meta datas ({len(meta_datas[name])})" 120 | gen_texts: List[str] 121 | meta_datas: Dict[str, List[Dict[str, str]]] 122 | for j, text in enumerate(gen_texts): 123 | results[name].append(dict( 124 | prompt=meta_datas[name][j]['prompt'], 125 | answer=meta_datas[name][j]['answer'], 126 | generated_text=text, 127 | )) 128 | 129 | # save results 130 | base_path = os.path.split(infer_args.save_path)[0] 131 | if not os.path.exists(base_path): 132 | os.makedirs(base_path, exist_ok=True) 133 | with open(f"{infer_args.save_path}", 'w', encoding="utf-8") as f: 134 | json.dump(results, f, indent=4, ensure_ascii=False) 135 | logger.info(f"Results saved to {infer_args.save_path}") 136 | 137 | if __name__ == "__main__": 138 | logging.basicConfig( 139 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 140 | datefmt="%m/%d/%Y %H:%M:%S", 141 | level=logging.INFO 142 | ) 143 | logger = logging.getLogger(__name__) 144 | 145 | main() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding, BitsAndBytesConfig 3 | from peft import prepare_model_for_kbit_training 4 | from utils.arg_configs import get_args, CLArguments, TuningArguments, DataArguments 5 | from dataclasses import asdict 6 | from get_dataset import get_datasets, DataCollector 7 | from utils.functions import set_all_seed 8 | import warnings 9 | warnings.filterwarnings("ignore", category=UserWarning) 10 | 11 | def main(): 12 | train_args, cl_args, tuning_args, data_args = get_args() 13 | set_all_seed(tuning_args.manual_seed) 14 | bnb_config = BitsAndBytesConfig( 15 | load_in_8bit=tuning_args.load_8bit, 16 | ) 17 | model = AutoModelForCausalLM.from_pretrained( 18 | tuning_args.model_name_or_path, 19 | quantization_config=bnb_config, 20 | torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, 21 | ) 22 | 23 | tokenizer = AutoTokenizer.from_pretrained(tuning_args.model_name_or_path) 24 | 25 | if tokenizer.pad_token_id is None: 26 | tokenizer.pad_token_id = 0 27 | 28 | # prepare datasets 29 | train_datasets = get_datasets(**asdict(data_args), tokenizer=tokenizer, split='train') 30 | valid_datasets = get_datasets(**asdict(data_args), tokenizer=tokenizer, split='eval') 31 | 32 | 33 | train_args = dict( 34 | model=model if not tuning_args.load_8bit else prepare_model_for_kbit_training(model), 35 | args=train_args, 36 | train_dataset=train_datasets, 37 | eval_dataset=valid_datasets, 38 | tokenizer=tokenizer, 39 | cl_args=cl_args, 40 | tuning_args=tuning_args, 41 | data_args=data_args, 42 | data_collator=DataCollector(tokenizer, padding=True, max_length=data_args.max_length) 43 | ) 44 | 45 | if cl_args.cl_method.lower() == 'seq': 46 | from method.BaseTrainerCL import BaseTrainerCL 47 | cl_trainer = BaseTrainerCL(**train_args) 48 | elif cl_args.cl_method.lower() == 'ewc': 49 | from method.EWC import EWCTrainer 50 | cl_trainer = EWCTrainer(**train_args) 51 | elif cl_args.cl_method.lower() == 'er': 52 | from method.ER import ERTrainer 53 | cl_trainer = ERTrainer(**train_args) 54 | elif cl_args.cl_method.lower() == 'gem': 55 | from method.GEM import GEMTrainer 56 | cl_trainer = GEMTrainer(**train_args) 57 | elif cl_args.cl_method.lower() == 'agem': 58 | from method.AGEM import AveGEMTrainer 59 | cl_trainer = AveGEMTrainer(**train_args) 60 | elif cl_args.cl_method.lower() == 'l2p': 61 | from method.L2P import L2PTrainer 62 | cl_trainer = L2PTrainer(**train_args) 63 | elif cl_args.cl_method.lower() == 'pp': 64 | from method.PP import PPTrainer 65 | cl_trainer = PPTrainer(**train_args) 66 | elif cl_args.cl_method.lower() == 'ilora': 67 | from method.ILORA import ILoRATrainer 68 | cl_trainer = ILoRATrainer(**train_args) 69 | elif cl_args.cl_method.lower() == 'mtl': 70 | from method.MTL import MTLTrainer 71 | cl_trainer = MTLTrainer(**train_args) 72 | elif cl_args.cl_method.lower() == 'one': 73 | from method.ONE import ONETrainer 74 | cl_trainer = ONETrainer(**train_args) 75 | else: 76 | ValueError(f"continual learning method: {cl_args.cl_method} not implement yet") 77 | 78 | cl_trainer.continual_learning() 79 | 80 | if __name__ == '__main__': 81 | 82 | main() 83 | 84 | -------------------------------------------------------------------------------- /method/AGEM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl 3 | from .BaseTrainerCL import BaseTrainerCL 4 | from peft import PeftModel 5 | 6 | 7 | class AveTaskGradientCallback(TrainerCallback): 8 | def __init__(self, **kwargs): 9 | super().__init__() 10 | self.model:PeftModel = kwargs.get('model') 11 | self.current_task_name = kwargs.get('current_task_name') # need update during training 12 | self.n_tasks = kwargs.get('n_tasks') 13 | self.grads = {} 14 | self.task_names = kwargs.get('task_names') 15 | self.init_grads() 16 | 17 | def init_grads(self): 18 | for n, p in self.model.named_parameters(): 19 | if p.requires_grad: 20 | self.grads[n] = torch.ones([p.data.numel()], dtype=p.dtype, device=p.device) 21 | 22 | def store_grads(self): 23 | for n, p in self.model.named_parameters(): 24 | if n in self.grads: 25 | self.ave_grads(n, p.grad.detach().clone().view(-1)) 26 | 27 | def ave_grads(self, name, new_grads): 28 | self.grads[name] = (self.grads[name] * (self.task_names.index(self.current_task_name) + 1) + new_grads) / (self.task_names.index(self.current_task_name) + 2) 29 | 30 | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 31 | for n, p in self.model.named_parameters(): 32 | if n in self.grads and p.requires_grad: 33 | p.grad = self.get_updated_grads(n, p.grad) 34 | 35 | def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 36 | self.store_grads() 37 | 38 | def get_updated_grads(self, name, grad, eps=1e-4): 39 | ori_shape = grad.shape 40 | grad = grad.view(-1) 41 | pre_grad = self.grads[name].cuda().to(torch.float32) 42 | grad, pre_grad = grad.unsqueeze(1), pre_grad.unsqueeze(1) 43 | dot_product = torch.mm(grad.t(), pre_grad) 44 | 45 | if (dot_product < 0) != 0: 46 | new_grad = grad - (torch.mm(grad.t(), pre_grad) + eps) / (torch.mm(pre_grad.t(), pre_grad) + eps) * pre_grad 47 | grad.copy_(new_grad) 48 | 49 | return grad.view(ori_shape) 50 | 51 | def update_current_task_name(self, name:str): 52 | self.current_task_name = name 53 | 54 | 55 | 56 | class AveGEMTrainer(BaseTrainerCL): 57 | def __init__(self, **kwargs): 58 | super().__init__(**kwargs) 59 | 60 | self.add_callback(AveTaskGradientCallback( 61 | model = self.model, 62 | current_task_name=self.current_task_name, 63 | n_tasks=self.num_tasks, 64 | task_names=self.task_names)) 65 | 66 | self.gem_cb = None 67 | for cb in self.callback_handler.callbacks: 68 | if isinstance(cb, AveTaskGradientCallback): 69 | self.gem_cb = cb 70 | break 71 | 72 | def continual_learning(self): 73 | for i, name in enumerate(self.task_names): 74 | self.gem_cb.update_current_task_name(name) 75 | self.before_task_start(name) 76 | self.train() 77 | self.after_task_end(name) -------------------------------------------------------------------------------- /method/BaseTrainerCL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pathlib import Path 4 | from typing import Callable, Tuple, Union, Optional, Dict, List, overload 5 | from torch.utils.data import Dataset 6 | from transformers.trainer_callback import TrainerControl, TrainerState 7 | from transformers.training_args import TrainingArguments 8 | from utils.arg_configs import CLArguments, TuningArguments, DataArguments 9 | from get_dataset import get_joint_datasets 10 | from peft import get_peft_model, PeftModel, LoraModel, PeftModelForCausalLM, get_peft_model_state_dict 11 | from transformers import Trainer, PreTrainedModel, TrainerCallback 12 | 13 | 14 | class BaseTrainerCL(Trainer): 15 | def __init__(self, **kwargs): 16 | self.cl_args = kwargs.pop("cl_args", None) 17 | self.tuning_args = kwargs.pop("tuning_args", None) 18 | self.data_args = kwargs.pop("data_args", None) 19 | kwargs['model'] = self.prepare_model_for_cl_traning(kwargs['model'], self.cl_args, self.tuning_args) 20 | self.continual_training_dataset, self.continual_evaluating_dataset = \ 21 | self.prepare_dataset_for_cl_traininig( 22 | kwargs.get("train_dataset", None), 23 | kwargs.get("eval_dataset", None)) 24 | self.task_names = list(self.continual_training_dataset.keys()) 25 | self.num_tasks: int = len(self.continual_training_dataset) 26 | self.current_task_name: str = None 27 | super().__init__(**kwargs) 28 | 29 | def prepare_model_for_cl_traning(self, model: Union[PreTrainedModel, nn.Module], cl_args: CLArguments=None, tuning_args: TuningArguments=None) -> Union[PreTrainedModel, PeftModel]: 30 | peft_model = get_peft_model( 31 | model=model, 32 | peft_config=tuning_args.lora_config, 33 | ) 34 | return peft_model 35 | 36 | def prepare_dataset_for_cl_traininig(self, train_dataset: Dict[str, Dataset], eval_dataset: Dict[str, Dataset]) -> Dict[str, Dataset]: 37 | 38 | if self.cl_args.cl_method == 'mtl': 39 | train_dataset = get_joint_datasets(train_dataset) 40 | eval_dataset = get_joint_datasets(eval_dataset) 41 | return train_dataset, eval_dataset 42 | return train_dataset, eval_dataset 43 | 44 | def continual_learning(self): 45 | for i, name in enumerate(self.task_names): 46 | self.before_task_start(name) 47 | self.train() 48 | self.after_task_end(name) 49 | 50 | def before_task_start(self, task_name: str): 51 | """ update training and evaluation dataset for the current task """ 52 | if self.cl_args.cl_method == 'mtl': 53 | self.train_dataset = self.continual_evaluating_dataset 54 | self.eval_dataset = self.continual_evaluating_dataset 55 | self.current_task_name = "joint" 56 | return 57 | 58 | if task_name not in self.continual_training_dataset: 59 | raise ValueError(f"task name {task_name} not found in the dataset") 60 | self.current_task_name = task_name 61 | self.train_dataset, self.eval_dataset = self.continual_training_dataset[task_name], self.continual_evaluating_dataset[task_name] 62 | 63 | # update model for the current task 64 | if self.cl_args.cl_method == 'one': 65 | if isinstance(self.model, LoraModel): 66 | self.model = get_peft_model( 67 | model=self.model.model, 68 | peft_config=self.tuning_args.lora_config, 69 | ) 70 | 71 | def after_task_end(self, *args, **kwargs): 72 | """ save the model after training the current task """ 73 | assert args[0] == self.current_task_name, f"task name mismatch: {args[0]} != {self.current_task_name}" 74 | wrappered_model_class = kwargs.get("wrappered_model_class", None) 75 | 76 | if isinstance(self.model, PeftModelForCausalLM): 77 | lora_state_dict = get_peft_model_state_dict(self.model) 78 | lora_config = self.model.peft_config 79 | if self.args.local_rank == 0: 80 | print(f"*** Saving lora adapter for task: {self.current_task_name} ***") 81 | torch.save(lora_state_dict, Path(self.args.output_dir).joinpath(f"lora_{self.current_task_name}.pt")) 82 | for adapter_name, adapter_config in lora_config.items(): 83 | if adapter_name == 'default': 84 | adapter_config.save_pretrained(Path(self.args.output_dir).joinpath(f"lora_{self.current_task_name}")) 85 | else: 86 | adapter_config.save_pretrained(Path(self.args.output_dir).joinpath(f"lora_{self.current_task_name}_{adapter_name}")) 87 | elif not wrappered_model_class and isinstance(self.model, wrappered_model_class): 88 | raise NotImplementedError("not implemented yet") # TODO: implement for PP model 89 | 90 | self.tokenizer.save_pretrained(Path(self.args.output_dir).joinpath(f"tokenizer_{self.current_task_name}")) -------------------------------------------------------------------------------- /method/ER.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Tuple 4 | from .BaseTrainerCL import BaseTrainerCL 5 | 6 | 7 | def reservoir(num_seen_examples: int, buffer_size: int) -> int: 8 | if num_seen_examples < buffer_size: 9 | return num_seen_examples 10 | 11 | rand = np.random.randint(0, num_seen_examples + 1) 12 | if rand < buffer_size: 13 | return rand 14 | else: 15 | return -1 16 | 17 | def concat_inputs(input_ids:torch.Tensor, attention_mask:torch.Tensor, labels:torch.Tensor, buffer_input_ids:torch.Tensor, buffer_attention_mask:torch.Tensor, buffer_labels:torch.Tensor) -> Tuple: 18 | device = input_ids.device 19 | input_ids = torch.cat((input_ids, buffer_input_ids.to(device)), dim=0) 20 | attention_mask = torch.cat((attention_mask, buffer_attention_mask.to(device)), dim=0) 21 | labels = torch.cat((labels, buffer_labels.to(device)), dim=0) 22 | return input_ids, attention_mask, labels 23 | 24 | class Buffer: 25 | def __init__(self, buffer_size:int, device:str, pad_id:int=2, ignore_index:int=-100): 26 | self.buffer_size = buffer_size 27 | self.device = device 28 | self.num_seen_examples = 0 29 | self.attributes = ['input_ids', 'attention_mask', 'labels', 'logits', 'task_labels', 'activations'] 30 | self.init_buffer() 31 | self.pad_id = pad_id 32 | self.ignore_index = ignore_index 33 | 34 | def init_buffer(self) -> None: 35 | for attr_str in self.attributes: 36 | setattr(self, attr_str, [None for _ in range(self.buffer_size)]) 37 | 38 | def add_data(self, input_ids, attention_mask=None, labels=None, logits=None, task_labels=None, activations=None): 39 | n = input_ids.shape[0] if hasattr(input_ids, 'shape') else len(input_ids) 40 | for i in range(n): 41 | index = reservoir(self.num_seen_examples, self.buffer_size) 42 | self.num_seen_examples += 1 43 | if index >= 0: 44 | self.input_ids[index] = input_ids[i].detach().clone().to(self.device) 45 | if attention_mask is not None: 46 | self.attention_mask[index] = attention_mask[i].detach().clone().to(self.device) 47 | if labels is not None: 48 | self.labels[index] = labels[i].detach().clone().to(self.device) 49 | if logits is not None: 50 | self.logits[index] = logits[i].detach().clone().to(self.device) 51 | if task_labels is not None: 52 | self.task_labels[index] = task_labels[i].detach().clone().to(self.device) 53 | if activations is not None: 54 | self.activations[index] = activations[i].detach().clone().to(self.device) 55 | 56 | def get_data(self, size: int, pad_to:int) -> Tuple: 57 | n = len(self.input_ids) 58 | if size > min(self.num_seen_examples, n): 59 | size = min(self.num_seen_examples, n) 60 | 61 | choice = np.random.choice(min(self.num_seen_examples, n), size=size, replace=False) 62 | if len(choice) == 0: 63 | return None, None 64 | # for left padding 65 | input_ids = [] 66 | attention_mask = [] 67 | labels = [] 68 | 69 | for i in choice: 70 | 71 | input_ids.append(torch.cat( 72 | (torch.full((pad_to - self.input_ids[i].shape[-1],), self.pad_id, dtype=torch.long).to(self.device), 73 | self.input_ids[i]), dim=-1) 74 | ) 75 | if self.attention_mask[i] is not None: 76 | attention_mask.append(torch.cat( 77 | (torch.full((pad_to - self.attention_mask[i].shape[-1],), 0, dtype=torch.long).to(self.device), 78 | self.attention_mask[i]), dim=-1) 79 | ) 80 | if self.labels[i] is not None: 81 | labels.append(torch.cat( 82 | (torch.full((pad_to - self.labels[i].shape[-1],), self.ignore_index, dtype=torch.long).to(self.device), 83 | self.labels[i]), dim=-1) 84 | ) 85 | 86 | input_ids = torch.stack(input_ids) 87 | attention_mask = torch.stack(attention_mask) 88 | labels = torch.stack(labels) 89 | return input_ids, attention_mask, labels 90 | 91 | def is_empty(self) -> bool: 92 | if self.num_seen_examples == 0: 93 | return True 94 | else: 95 | return False 96 | 97 | def get_all_data(self) -> Tuple: 98 | ret_tuple = (torch.stack([ee.cpu() 99 | for ee in self.input_ids]).to(self.device),) 100 | for attr_str in self.attributes[1:]: 101 | if hasattr(self, attr_str): 102 | attr = getattr(self, attr_str) 103 | ret_tuple += (attr,) 104 | return ret_tuple 105 | 106 | def empty(self) -> None: 107 | for attr_str in self.attributes: 108 | if hasattr(self, attr_str): 109 | delattr(self, attr_str) 110 | self.num_seen_examples = 0 111 | 112 | 113 | class ERTrainer(BaseTrainerCL): 114 | def __init__(self, **kwargs): 115 | super().__init__(**kwargs) 116 | self.buffer_size = self.cl_args.cl_config.get('buffer_size', None) 117 | self.buffer = Buffer(self.buffer_size, 'cpu', pad_id=self.tokenizer.pad_token_id, ignore_index=-100) 118 | 119 | def compute_loss(self, model, inputs, return_outputs=False): 120 | 121 | if self.current_task_name == self.task_names[0]: 122 | self.buffer.add_data(inputs["input_ids"], inputs["attention_mask"], inputs["labels"]) 123 | outputs = model(**inputs) 124 | else: 125 | buffer_inputs, buffer_attention_mask, buffer_labels = self.buffer.get_data(inputs["input_ids"].shape[0], inputs["input_ids"].shape[1]) 126 | if buffer_inputs is not None and buffer_attention_mask is not None and buffer_labels is not None: 127 | inputs["input_ids"], inputs["attention_mask"], inputs["labels"] = concat_inputs(inputs["input_ids"], inputs["attention_mask"], inputs["labels"], buffer_inputs, buffer_attention_mask, buffer_labels) 128 | outputs = model(**inputs) 129 | self.buffer.add_data(inputs["input_ids"], inputs["attention_mask"], inputs["labels"]) 130 | 131 | return (outputs.loss, outputs) if return_outputs else outputs.loss 132 | 133 | 134 | def continual_learning(self): 135 | for i, name in enumerate(self.task_names): 136 | self.before_task_start(name) 137 | self.train() 138 | self.after_task_end(name) 139 | 140 | 141 | -------------------------------------------------------------------------------- /method/EWC.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainerCallback, TrainingArguments, PreTrainedModel 2 | from transformers.trainer_callback import TrainerState, TrainerControl 3 | from typing import Tuple, Dict 4 | import os 5 | from .BaseTrainerCL import BaseTrainerCL 6 | from peft import PeftModel 7 | 8 | class GradientCallback(TrainerCallback): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__() 11 | self.fisher = {} 12 | self.model: PeftModel = kwargs.get("model") 13 | self.previous_weights = {} 14 | self.trainable_params = {} 15 | self.init_fisher_and_weights() 16 | assert len(self.fisher) > 0 and len(self.trainable_params) > 0, "fisher and previous_weights should not be empty" 17 | 18 | def init_fisher_and_weights(self): 19 | print("init fisher and weights") 20 | for n, p in self.model.named_parameters(): 21 | if p.requires_grad: 22 | self.fisher[n] = p.detach().clone().data.zero_() 23 | self.trainable_params[n] = p.detach().clone().data 24 | 25 | def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 26 | self.previous_weights = {n: p.detach().clone() for n, p in self.model.named_parameters() if 27 | n in self.trainable_params.keys()} 28 | 29 | def get_fisher_and_prior(self) -> Tuple[dict, dict]: 30 | return self.fisher, self.previous_weights 31 | 32 | def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 33 | for n, p in self.model.named_parameters(): 34 | if n in self.trainable_params.keys() and p.grad is not None: 35 | self.fisher[n] += p.grad.detach().clone().data.pow(2) / state.global_step 36 | elif p.grad is None: 37 | Warning(f"parameter {n} has no gradient") 38 | 39 | class EWCTrainer(BaseTrainerCL): 40 | """ 41 | https://arxiv.org/abs/1612.00796 42 | """ 43 | def __init__(self, **kwargs): 44 | super().__init__(**kwargs) 45 | self.add_callback(GradientCallback(model=self.model)) 46 | self.ewc_lambda = self.cl_args.ewc_lambda 47 | self.cb = None 48 | for cb in self.callback_handler.callbacks: 49 | if isinstance(cb, GradientCallback): 50 | self.cb = cb 51 | break 52 | 53 | def continual_learning(self): 54 | for i, name in enumerate(self.task_names): 55 | self.before_task_start(name) 56 | self.train() 57 | self.after_task_end(name) 58 | 59 | def compute_loss(self, model, inputs, return_outputs=False): 60 | outputs = model(**inputs) 61 | loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0] 62 | ewc_loss = self.compute_ewc_loss(model) 63 | loss += ewc_loss 64 | return (loss, outputs) if return_outputs else loss 65 | 66 | def compute_ewc_loss(self, model): 67 | fisher, previous_weights = self.cb.get_fisher_and_prior() 68 | assert len(fisher) > 0 and len(previous_weights) > 0, "fisher and previous_weights should not be empty" 69 | 70 | ewc_loss = 0 71 | for n, p in model.named_parameters(): 72 | if n in fisher: 73 | ewc_loss += (fisher[n] * (p - previous_weights[n]).pow(2)).sum() * self.ewc_lambda / 2 74 | 75 | if ewc_loss < 1e-5: 76 | Warning("EWC regularization loss is too small, please check the hyper-parameters") 77 | return ewc_loss -------------------------------------------------------------------------------- /method/GEM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl 3 | from qpth.qp import QPFunction 4 | from .BaseTrainerCL import BaseTrainerCL 5 | from peft import PeftModel 6 | from deepspeed.utils import safe_get_full_grad 7 | import torch.distributed as dist 8 | class TaskGradientCallback(TrainerCallback): 9 | def __init__(self, **kwargs): 10 | super().__init__() 11 | self.model:PeftModel = kwargs.get('model') 12 | self.current_task_name = kwargs.get('current_task_name') # need update during training 13 | self.n_tasks = kwargs.get('n_tasks') 14 | self.grads = {} 15 | self.task_names = kwargs.get('task_names') 16 | self.device = kwargs.get('device', 'cpu') 17 | self.init_grads() 18 | 19 | def init_grads(self): 20 | for n, p in self.model.named_parameters(): 21 | if p.requires_grad: # reduce memory usage 22 | self.grads[n] = torch.ones([p.data.numel(), self.n_tasks], dtype=p.dtype, device=self.device) 23 | 24 | def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 25 | for n, p in self.model.named_parameters(): 26 | if n in self.grads: 27 | # p_grad = safe_get_full_grad(p) 28 | # print('rank', dist.get_rank(), n, '->', p.device) 29 | # assert p.grad is not None, f"parameter {n} has no gradient" 30 | if p.grad is None: 31 | print(f"rank {dist.get_rank()} parameter {n} has no gradient in device {p.device}") 32 | p.grad = self.get_updated_grads(n, p.grad, self.task_names.index(self.current_task_name)) 33 | grad_old = self.grads[n][:, self.task_names.index(self.current_task_name)].detach().clone() 34 | grad_new = (grad_old * state.step + p.grad.detach().clone().view(-1)) / (state.step + 1) 35 | self.grads[n][:, self.task_names.index(self.current_task_name)] = grad_new 36 | 37 | def get_updated_grads(self, name, grad, idx, margin=0.1, eps=1.0): 38 | ori_shape = grad.shape # None 39 | grad = grad.view(-1) 40 | pre_grad = self.grads[name].cuda()[:, :idx+1].to(torch.float32) 41 | dot_product = torch.mm(grad.unsqueeze(0), pre_grad) 42 | if (dot_product < 0).sum() != 0: 43 | pre_grad_cuda = pre_grad.t() 44 | grad_cuda = grad.contiguous().view(-1) 45 | t = pre_grad_cuda.shape[0] 46 | P = torch.matmul(pre_grad_cuda, pre_grad_cuda.t()) 47 | P = 0.5 * (P + P.t()) 48 | 49 | P[torch.isnan(P)] = 0.0 50 | eigenvalues = torch.linalg.eigvals(P) 51 | if not (eigenvalues.real > 0).all(): # due to the grad clip happens after the projection, the grad could be huge, refactor eps=1.0 is reasonable 52 | P += torch.eye(t).cuda() * eps 53 | 54 | q = torch.matmul(pre_grad_cuda, grad_cuda).t() * -1 55 | 56 | P = P.to(torch.float32) 57 | q = q.to(torch.float32) 58 | G = torch.eye(t).cuda() * -1 59 | h = torch.zeros(t).cuda() - margin 60 | e = torch.Tensor().cuda() 61 | v = QPFunction(verbose=False)(P, q, G, h, e, e)[0] 62 | v = v.to(torch.float32) 63 | x = torch.matmul(v, pre_grad_cuda) + grad_cuda 64 | grad.copy_(x.view(-1)) 65 | return grad.view(ori_shape) 66 | 67 | def update_current_task_name(self, name:str): 68 | self.current_task_name = name 69 | 70 | 71 | 72 | class GEMTrainer(BaseTrainerCL): 73 | def __init__(self, **kwargs): 74 | super().__init__(**kwargs) 75 | 76 | self.add_callback(TaskGradientCallback( 77 | model = self.model, 78 | current_task_name=self.current_task_name, 79 | n_tasks=self.num_tasks, 80 | task_names=self.task_names)) 81 | 82 | self.gem_cb = None 83 | for cb in self.callback_handler.callbacks: 84 | if isinstance(cb, TaskGradientCallback): 85 | self.gem_cb = cb 86 | break 87 | 88 | def continual_learning(self): 89 | for i, name in enumerate(self.task_names): 90 | self.gem_cb.update_current_task_name(name) 91 | self.before_task_start(name) 92 | self.train() 93 | self.after_task_end(name) 94 | -------------------------------------------------------------------------------- /method/ILORA.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from typing import Optional, List, Union, Tuple 4 | import copy 5 | import numpy as np 6 | import wandb 7 | from peft import PeftModel, LoraModel, load_peft_weights, set_peft_model_state_dict 8 | from .BaseTrainerCL import BaseTrainerCL 9 | import torch 10 | from transformers import LlamaForCausalLM, PreTrainedModel, TrainerCallback 11 | from transformers.modeling_outputs import CausalLMOutputWithPast 12 | from torch import nn 13 | import torch.nn.functional as F 14 | 15 | 16 | # from method.BaseTrainerCL import BaseTrainerCL 17 | 18 | 19 | def reservoir(num_seen_examples: int, buffer_size: int) -> int: 20 | """ 21 | Reservoir sampling algorithm. 22 | :param num_seen_examples: the number of seen examples 23 | :param buffer_size: the maximum buffer size 24 | :return: the target index if the current image is sampled, else -1 25 | """ 26 | if num_seen_examples < buffer_size: 27 | return num_seen_examples 28 | 29 | rand = np.random.randint(0, num_seen_examples + 1) 30 | if rand < buffer_size: 31 | return rand 32 | else: 33 | return -1 34 | 35 | 36 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: 37 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size 38 | 39 | 40 | class Buffer: 41 | """ 42 | The memory buffer of rehearsal method. 43 | """ 44 | 45 | def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir'): 46 | assert mode in ['ring', 'reservoir'] 47 | self.buffer_size = buffer_size 48 | self.device = device 49 | self.num_seen_examples = 0 50 | self.functional_index = eval(mode) 51 | if mode == 'ring': 52 | assert n_tasks is not None 53 | self.task_number = n_tasks 54 | self.buffer_portion_size = buffer_size // n_tasks 55 | self.attributes = ['input_ids', 'labels', 'logits', 'task_labels', 'activations'] 56 | self.init_buffer() 57 | 58 | def init_buffer(self) -> None: 59 | for attr_str in self.attributes: 60 | setattr(self, attr_str, [None for _ in range(self.buffer_size)]) 61 | 62 | def add_data(self, input_ids, labels=None, logits=None, task_labels=None, activations=None): 63 | """ 64 | Adds the data to the memory buffer according to the reservoir strategy. 65 | :param input_ids: tensor containing the images 66 | :param labels: tensor containing the labels 67 | :param logits: tensor containing the outputs of the network 68 | :param task_labels: tensor containing the task labels 69 | :param activations: tensor containing the activations of the network 70 | :return: 71 | """ 72 | n = input_ids.shape[0] if hasattr(input_ids, 'shape') else len(input_ids) 73 | for i in range(n): 74 | index = reservoir(self.num_seen_examples, self.buffer_size) 75 | self.num_seen_examples += 1 76 | if index >= 0: 77 | self.input_ids[index] = input_ids[i].to(self.device) 78 | if labels is not None: 79 | self.labels[index] = labels[i].to(self.device) 80 | if logits is not None: 81 | self.logits[index] = logits[i].to(self.device) 82 | if task_labels is not None: 83 | self.task_labels[index] = task_labels[i].to(self.device) 84 | if activations is not None: 85 | self.activations[index] = activations[i].to(self.device) 86 | 87 | def get_data(self, size: int) -> Tuple: 88 | """ 89 | Random samples a batch of size items. 90 | :param size: the number of requested items 91 | :return: 92 | """ 93 | n = self.input_ids.shape[0] if hasattr(self.input_ids, 'shape') else len(self.input_ids) 94 | if size > min(self.num_seen_examples, n): 95 | size = min(self.num_seen_examples, n) 96 | 97 | choice = np.random.choice(min(self.num_seen_examples, n), size=size, replace=False) 98 | if len(choice) == 0: 99 | return None, None 100 | max_input_id_len = max([self.input_ids[c].shape[0] for c in choice]) 101 | max_label_len = max([self.labels[c].shape[0] for c in choice]) 102 | # for left padding 103 | input_ids = torch.stack( 104 | [torch.cat([torch.zeros(max_input_id_len - ee.shape[0]).long().to(ee.device), ee]) for ee in 105 | [self.input_ids[c] for c in choice]]).reshape(size, max_input_id_len) 106 | labels = torch.stack([torch.cat([torch.zeros(max_label_len - ee.shape[0]).long().to(ee.device), ee]) for ee in 107 | [self.labels[c] for c in choice]]).reshape(size, max_label_len) 108 | return input_ids, labels 109 | 110 | def is_empty(self) -> bool: 111 | """ 112 | Returns true if the buffer is empty, false otherwise. 113 | """ 114 | if self.num_seen_examples == 0: 115 | return True 116 | else: 117 | return False 118 | 119 | def get_all_data(self) -> Tuple: 120 | """ 121 | Return all the items in the memory buffer. 122 | :return: a tuple with all the items in the memory buffer 123 | """ 124 | ret_tuple = (torch.stack([ee.cpu() 125 | for ee in self.input_ids]).to(self.device),) 126 | for attr_str in self.attributes[1:]: 127 | if hasattr(self, attr_str): 128 | attr = getattr(self, attr_str) 129 | ret_tuple += (attr,) 130 | return ret_tuple 131 | 132 | def empty(self) -> None: 133 | """ 134 | Set all the tensors to None. 135 | """ 136 | for attr_str in self.attributes: 137 | if hasattr(self, attr_str): 138 | delattr(self, attr_str) 139 | self.num_seen_examples = 0 140 | 141 | 142 | class ILoRAModel(LlamaForCausalLM): 143 | def __init__(self, model: PeftModel, reg_decay: bool = False): 144 | super().__init__(model.config) 145 | self.model = model 146 | self.current_task_name = "C-STANCE" 147 | # regularization settings 148 | peft_cfg = model.peft_config['default'] 149 | self.model.add_adapter('ema', peft_cfg) 150 | self.model.to(self.model.device) 151 | 152 | self.ema_alpha: float = 0.25 153 | self.reg_weight: float = 1.0 154 | self.ema_update_freq: float = 0.1 155 | self.consistency_loss = nn.MSELoss(reduction='none') 156 | 157 | self.buffer = Buffer(500, 'cuda') # same as ER 158 | self.l_cons = 0 159 | self.total = 0 160 | self.ori_loss = 0 161 | 162 | def update_ema_weights(self, step): 163 | alpha = min(1 - 1 / (step + 1), self.ema_alpha) 164 | 165 | self.model.set_adapter('default') 166 | model_state_dict = {n: p.detach().clone() for n, p in self.model.named_parameters() if p.requires_grad} 167 | 168 | self.model.set_adapter('ema') 169 | for name, param in self.model.named_parameters(): 170 | if name in model_state_dict.keys(): 171 | param.data.mul_(alpha).add_(torch.mul(model_state_dict[name].data, 1 - alpha)) 172 | self.model.set_adapter('default') 173 | 174 | def update_reg_weight(self, step, decay_steps): 175 | if self.reg_decay: 176 | self.alpha = self.fix_reg_weight * self.decay_rate ** (step / decay_steps) 177 | else: 178 | self.alpha = self.fix_reg_weight 179 | 180 | def update_ema_model(self, path): 181 | adapter_weights = load_peft_weights(path) 182 | set_peft_model_state_dict(self.model, adapter_weights, adapter_name='ema') 183 | 184 | def concat_inputs(self, input_ids: torch.Tensor, labels: torch.Tensor, buffer_inputs_ids: torch.Tensor, 185 | buffer_labels: torch.Tensor) -> Tuple: 186 | if buffer_inputs_ids is None or buffer_labels is None: 187 | return input_ids, labels 188 | 189 | max_input_id_len = max(input_ids.shape[1], buffer_inputs_ids.shape[1]) 190 | max_labels_len = max(labels.shape[1], buffer_labels.shape[1]) 191 | 192 | extended_input_ids = torch.cat([input_ids, torch.zeros(input_ids.shape[0], 193 | max_input_id_len - input_ids.shape[1]).long().to( 194 | input_ids.device)], dim=1) 195 | extended_labels = torch.cat( 196 | [labels, torch.zeros(labels.shape[0], max_labels_len - labels.shape[1]).long().to(labels.device)], dim=1) 197 | 198 | extended_buffer_inputs_ids = torch.cat([buffer_inputs_ids, torch.zeros(buffer_inputs_ids.shape[0], 199 | max_input_id_len - 200 | buffer_inputs_ids.shape[1]).long().to( 201 | buffer_inputs_ids.device)], dim=1) 202 | extended_buffer_labels = torch.cat([buffer_labels, torch.zeros(buffer_labels.shape[0], 203 | max_labels_len - buffer_labels.shape[ 204 | 1]).long().to(buffer_labels.device)], dim=1) 205 | 206 | input_ids = torch.cat([extended_input_ids, extended_buffer_inputs_ids], dim=0) 207 | labels = torch.cat([extended_labels, extended_buffer_labels], dim=0) 208 | return input_ids, labels 209 | 210 | def forward( 211 | self, 212 | input_ids: torch.LongTensor = None, 213 | attention_mask: Optional[torch.Tensor] = None, 214 | position_ids: Optional[torch.LongTensor] = None, 215 | past_key_values: Optional[List[torch.FloatTensor]] = None, 216 | inputs_embeds: Optional[torch.FloatTensor] = None, 217 | labels: Optional[torch.LongTensor] = None, 218 | use_cache: Optional[bool] = None, 219 | output_attentions: Optional[bool] = None, 220 | output_hidden_states: Optional[bool] = None, 221 | return_dict: Optional[bool] = None, 222 | ) -> Union[Tuple, CausalLMOutputWithPast]: 223 | buffer_inputs, buffer_labels = None, None 224 | if self.current_task_name != "C-STANCE": 225 | buffer_inputs, buffer_labels = self.buffer.get_data(input_ids.shape[0]) 226 | l_cons = 0 227 | if labels is not None and buffer_inputs is not None and buffer_labels is not None: 228 | self.model.set_adapter('default') 229 | plastic_hiddn = self.model( 230 | buffer_inputs, 231 | labels=buffer_labels, 232 | output_hidden_states=True, 233 | return_dict=True) 234 | 235 | self.model.set_adapter('ema') 236 | with torch.no_grad(): 237 | stable_hiddn = self.model( 238 | buffer_inputs, 239 | labels=buffer_labels, 240 | output_hidden_states=True, 241 | return_dict=True).hidden_states 242 | indexs = [inner_plastic > inner_stable for inner_plastic, inner_stable in 243 | zip(plastic_hiddn.hidden_states, stable_hiddn)] 244 | reg_hiddn = [torch.where(idx, inner_plastic, inner_stable) for idx, inner_plastic, inner_stable in 245 | zip(indexs, plastic_hiddn.hidden_states, stable_hiddn)] 246 | 247 | l_cons = torch.mean( 248 | torch.cat([self.consistency_loss(plastic, ema) for plastic, ema in 249 | zip(plastic_hiddn.hidden_states, reg_hiddn)], dim=0)) 250 | 251 | self.l_cons = l_cons # for logging use 252 | 253 | self.model.set_adapter('default') 254 | ori_out = self.model( 255 | input_ids=input_ids, 256 | # attention_mask=attention_mask, 257 | labels=labels, 258 | return_dict=True, 259 | output_hidden_states=False) 260 | if labels is not None and buffer_inputs is not None and buffer_labels is not None: 261 | self.total_loss = (ori_out.loss + plastic_hiddn.loss) / 2 + self.reg_weight * l_cons 262 | else: 263 | self.total_loss = ori_out.loss + self.reg_weight * l_cons 264 | self.total = self.total_loss.item() 265 | 266 | return CausalLMOutputWithPast( 267 | loss=self.total_loss, 268 | past_key_values=ori_out.past_key_values, 269 | logits=ori_out.logits, 270 | hidden_states=ori_out.hidden_states 271 | ) 272 | 273 | 274 | class ILoRATrainer(BaseTrainerCL): 275 | def __init__(self, **kwargs): 276 | super().__init__(**kwargs) 277 | self.model = ILoRAModel(self.model) 278 | self.add_callback(CLSCallback(self.model)) 279 | 280 | def compute_loss(self, model, inputs, return_outputs=False): 281 | self.model.buffer.add_data(inputs['input_ids'], inputs['labels']) 282 | outputs = self.model(**inputs) 283 | loss = outputs.loss 284 | print("loss:", loss.item()) 285 | return (loss, outputs) if return_outputs else loss 286 | 287 | def continual_learning(self): 288 | resume_from_checkpoint = "False" 289 | for task_name, dataset in self.continual_training_dataset.items(): 290 | self.model.current_task_name = task_name 291 | self.current_task_name = task_name 292 | self.train_dataset = dataset 293 | self.train() 294 | resume_from_checkpoint = self.save_model(task_name) 295 | self.model.load_ema_model(resume_from_checkpoint) 296 | wandb.finish() 297 | 298 | def save_model(self, name) -> str: 299 | if self.args.output_dir is not None: 300 | output_dir = os.path.join(self.args.output_dir, f"{self.cl_method}_{self.adapter}_checkpoint_{name}") 301 | self.model.model.set_adapter('default') 302 | self.model.model.save_pretrained(output_dir) 303 | return output_dir 304 | 305 | 306 | class CLSCallback(TrainerCallback): 307 | def __init__(self, model: ILoRAModel): 308 | self.model = model 309 | 310 | def on_step_end(self, args, state, control, **kwargs): 311 | self.model.update_ema_weights(state.global_step) 312 | if wandb.run: 313 | wandb.log({ 314 | "reg_weight": self.model.reg_weight, 315 | "consist_loss": self.model.l_cons, 316 | "total_loss": self.model.total, 317 | "ori_loss": self.model.ori_loss, 318 | }) -------------------------------------------------------------------------------- /method/L2P.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union, Tuple 2 | 3 | import torch, os 4 | import torch.utils.data 5 | from torch import nn 6 | from transformers.modeling_outputs import CausalLMOutputWithPast 7 | from transformers import LlamaForCausalLM, PreTrainedModel, Trainer 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | from peft import PeftModel 10 | from .BaseTrainerCL import BaseTrainerCL 11 | import numpy as np 12 | from copy import deepcopy 13 | 14 | 15 | def l2_normalize(x, dim=None, epsilon=1e-12): 16 | square_norm = torch.sum(x ** 2, dim=dim, keepdim=True) 17 | x_inv_norm = torch.rsqrt(torch.maximum(square_norm, torch.tensor(epsilon, device=x.device))) 18 | return x * x_inv_norm 19 | 20 | class L2PModel(LlamaForCausalLM): 21 | def __init__(self, model:PreTrainedModel, pool_size:int=10, prompt_length:int=5, promt_init:str='random'): 22 | super().__init__(model.config) 23 | self.model = model 24 | self.embed_tokens = self.model.get_input_embeddings() 25 | self.embed_tk_shapes = self.embed_tokens.weight.shape 26 | self.prompt = None 27 | self.top_k = 3 28 | self.diversity_loss_weight = 0.5 29 | self.pool_size = pool_size 30 | self.prompt_length = prompt_length 31 | self.init_prompt(promt_init) 32 | self.embeding_key = 'mean' 33 | self.batchwise_prompt: bool = False 34 | self.current_task_name:str = None 35 | 36 | def init_prompt(self,promt_init): 37 | self.prompt = nn.Parameter( 38 | torch.tensor( 39 | self.create_prompt(self.pool_size, self.prompt_length, promt_init), requires_grad=True 40 | ) 41 | ).to(self.device) 42 | 43 | def create_prompt(self, pool_size, prompt_length, promt_init='random'): 44 | N = self.embed_tk_shapes[0] 45 | p_weights = [] 46 | 47 | for p in range(self.pool_size): 48 | p_w = [] 49 | for i in range(self.prompt_length): 50 | with torch.no_grad(): 51 | j = np.random.randint(N) 52 | w = deepcopy(self.embed_tokens.weight[j].detach().cpu().numpy()) 53 | p_w.append(w) 54 | p_weights.append(p_w) 55 | 56 | return np.array(p_weights) 57 | 58 | def save_prompt_weights(self, path): 59 | state_dict = {"prompt_pool": self.prompt} 60 | torch.save(state_dict, os.path.join(path, f"prompt_weights_{self.current_task_name}.pt")) 61 | 62 | def load_prompt_weights(self, path, task_name="jecqa"): 63 | state_dict = torch.load(os.path.join(path, f"prompt_weights_{task_name}.pt"), map_location=self.device) 64 | self.prompt.data = state_dict["prompt_pool"].data 65 | print(f"Loaded prompt weights from {path}") 66 | 67 | def freeze_prompt(self): 68 | for n, p in self.named_parameters(): 69 | p.requires_grad = False 70 | 71 | def forward( 72 | self, 73 | input_ids: torch.LongTensor = None, 74 | attention_mask: Optional[torch.Tensor] = None, 75 | position_ids: Optional[torch.LongTensor] = None, 76 | past_key_values: Optional[List[torch.FloatTensor]] = None, 77 | inputs_embeds: Optional[torch.FloatTensor] = None, 78 | labels: Optional[torch.LongTensor] = None, 79 | use_cache: Optional[bool] = None, 80 | output_attentions: Optional[bool] = None, 81 | output_hidden_states: Optional[bool] = None, 82 | return_dict: Optional[bool] = None, 83 | ) -> Union[Tuple, CausalLMOutputWithPast]: 84 | 85 | i_input_embeds = self.embed_tokens(input_ids) 86 | out = dict() 87 | if self.embeding_key == 'mean': 88 | i_input_embeds_mean = torch.mean(i_input_embeds, dim=1) 89 | elif self.embeding_key == 'max': 90 | i_input_embeds_mean = torch.max(i_input_embeds, dim=1)[0] 91 | elif self.embeding_key == 'mean_max': 92 | i_input_embeds_mean = torch.max(i_input_embeds, dim=1)[0] + 2 * torch.mean(i_input_embeds, dim=1) 93 | else: 94 | raise NotImplementedError("Not supported way of calculating embedding keys!") 95 | 96 | prompt_key = torch.mean(self.prompt, dim=1) # Pool_size, C 97 | prompt_norm = l2_normalize(prompt_key, dim=1).to("cuda") 98 | inputs_embeds_norm = l2_normalize(i_input_embeds_mean, dim=1) 99 | prompt_norm = prompt_norm.to(dtype=inputs_embeds_norm.dtype) 100 | similarity = torch.matmul(inputs_embeds_norm, prompt_norm.t()) # B, Pool_size 101 | 102 | _, idx = torch.topk(similarity, k=self.top_k, dim=1) # B, top_k 103 | if self.batchwise_prompt: 104 | prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True) 105 | if prompt_id.shape[0] < self.pool_size: 106 | prompt_id = torch.cat([prompt_id, torch.full((self.pool_size - prompt_id.shape[0],), torch.min(idx.flatten()), device=prompt_id.device)]) 107 | id_counts = torch.cat([id_counts, torch.full((self.pool_size - id_counts.shape[0],), 0, device=id_counts.device)]) 108 | _, major_idx = torch.topk(id_counts, k=self.top_k) 109 | major_prompt_id = prompt_id[major_idx] 110 | idx = major_prompt_id.expand(inputs_embeds.shape[0], -1) 111 | 112 | batched_prompt_raw = self.prompt[idx] # B, top_k, length, C 113 | batch_size, top_k, length, c = batched_prompt_raw.shape 114 | batched_prompt = batched_prompt_raw.reshape(batch_size, top_k * length, c) 115 | inputs_embeds = torch.cat([batched_prompt, i_input_embeds],axis=1) 116 | 117 | prefix_length = batched_prompt.shape[1] 118 | attn_masks = torch.concat((torch.tensor(1).to("cuda").repeat(batch_size,prefix_length),attention_mask), axis=1) 119 | 120 | if labels is None: # inference 121 | return self.model(inputs_embeds=inputs_embeds.half(), attention_mask=attn_masks, use_cache=False, return_dict=True) 122 | 123 | labels = torch.concat((labels[0][0].repeat(batch_size,inputs_embeds.shape[1]-labels.shape[1]),labels),axis=1) 124 | outs = self.model(inputs_embeds=inputs_embeds,labels=labels,attention_mask=attn_masks,use_cache=False) 125 | loss = outs[0] 126 | batched_key_norm = prompt_norm[idx] 127 | inputs_embeds_norm = inputs_embeds_norm.unsqueeze(1) # B, 1, C 128 | sim = batched_key_norm * inputs_embeds_norm # B, top_k, C 129 | reduce_sim = torch.sum(sim) / inputs_embeds.shape[0] # Scalar 130 | 131 | loss -= reduce_sim * self.diversity_loss_weight 132 | return loss 133 | 134 | class L2PTrainer(BaseTrainerCL): 135 | def __init__(self, **kwargs): 136 | super().__init__(**kwargs) 137 | self.model = L2PModel(self.model) 138 | 139 | def compute_loss(self, model:L2PModel, inputs, return_outputs=False): 140 | input_ids = inputs['input_ids'] 141 | attn_masks = inputs['attention_mask'] 142 | labels = inputs.pop('labels') 143 | outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attn_masks) 144 | return outputs # since we only calculate loss in model.forward(), we return outputs here 145 | 146 | def continual_learning(self): 147 | for i, name in enumerate(self.task_names): 148 | self.model.current_task_name = name 149 | self.before_task_start(name) 150 | self.train() 151 | self.after_task_end(name) 152 | 153 | def save_model(self, name) -> str: 154 | if self.args.output_dir is not None: 155 | output_dir = os.path.join(self.args.output_dir, f"{self.cl_method}_{self.adapter}_checkpoint_{name}") 156 | assert isinstance(self.model.model, PeftModel), "self.model.model is not a PeftModel" 157 | self.model.model.save_pretrained(output_dir) 158 | print(f"save task: {name} adapter to {self.args.output_dir}") 159 | return output_dir -------------------------------------------------------------------------------- /method/MTL.py: -------------------------------------------------------------------------------- 1 | from .BaseTrainerCL import BaseTrainerCL 2 | 3 | class MTLTrainer(BaseTrainerCL): 4 | def __init__(self, **kwargs): 5 | super().__init__(**kwargs) 6 | -------------------------------------------------------------------------------- /method/ONE.py: -------------------------------------------------------------------------------- 1 | from .BaseTrainerCL import BaseTrainerCL 2 | 3 | class ONETrainer(BaseTrainerCL): 4 | def __init__(self, **kwargs): 5 | super().__init__(**kwargs) 6 | -------------------------------------------------------------------------------- /method/PP.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union, Tuple 2 | 3 | import torch, os, wandb 4 | import torch.utils.data 5 | from torch import nn 6 | from transformers.modeling_outputs import CausalLMOutputWithPast 7 | from transformers import LlamaForCausalLM, PreTrainedModel, Trainer 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | from peft import PeftModel 10 | from .BaseTrainerCL import BaseTrainerCL 11 | import numpy as np 12 | from copy import deepcopy 13 | 14 | class ResMLP(torch.nn.Module): 15 | def __init__(self, hidden_dim, bottleneck_size, module_type='MLP1', residual=True): 16 | super().__init__() 17 | self.residual = residual 18 | if module_type=='MLP1': 19 | self.module = nn.Sequential( 20 | nn.Linear(hidden_dim, bottleneck_size), 21 | nn.ReLU(), 22 | nn.Linear(bottleneck_size, hidden_dim), 23 | ) 24 | 25 | elif module_type=='MLP2': 26 | self.module = nn.Sequential( 27 | nn.Linear(hidden_dim, bottleneck_size), 28 | nn.ReLU(), 29 | nn.Linear(bottleneck_size, bottleneck_size), 30 | nn.Tanh(), 31 | nn.Linear(bottleneck_size, hidden_dim), 32 | ) 33 | 34 | elif module_type=='transformer': 35 | device = 'cuda' 36 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=2, dropout=0.05).to(device) 37 | self.module = nn.TransformerEncoder(self.encoder_layer, num_layers=2).to(device) 38 | 39 | def forward(self, inputs): 40 | if self.residual: 41 | return self.module(inputs) + inputs 42 | else: 43 | return self.module(inputs) 44 | 45 | class PPModel(LlamaForCausalLM): 46 | def __init__(self,model:PreTrainedModel, prefix_len, task_names, prefix_path=None): 47 | super().__init__(model.config) 48 | self.model = model 49 | self.prefix_len:int = prefix_len 50 | self.prefix_pth:str = prefix_path 51 | self.task_names:List = task_names 52 | self.num_tasks:int = len(task_names) 53 | self.current_task_name:str = None 54 | self.embed_tokens = self.model.get_input_embeddings() # [vocab_size, embed_dim] 55 | self.embed_tokens_len = self.embed_tokens.weight.shape[0] # vocab_size 56 | 57 | self.prompt = nn.Parameter( 58 | torch.tensor(self.init_prompt(), requires_grad=True) # [prefix_len, embed_dim] 59 | ).to(self.device) 60 | 61 | self.previous_prompts = torch.zeros([0, self.prompt.shape[1]], requires_grad=False, dtype=torch.bfloat16).to(self.device) # [0, embed_dim] 62 | self.mlps = nn.ModuleList([ResMLP(model.config.hidden_size, 128, module_type='MLP1', residual=True) for _ in range(self.num_tasks)]).to(self.device) 63 | 64 | assert self.prefix_len > 0 65 | 66 | self.pp = True 67 | self.is_train = True 68 | 69 | def init_prompt(self): 70 | prompt_weights = [] 71 | for i in range(self.prefix_len): 72 | with torch.no_grad(): 73 | j = np.random.randint(self.embed_tokens_len) 74 | w = deepcopy(self.embed_tokens.weight[j].detach().cpu().numpy()) 75 | prompt_weights.append(w / 100) 76 | return np.array(prompt_weights) 77 | 78 | def progressive_previous_prompt(self, task_name): 79 | """ 80 | update previous prompt at end of each task 81 | """ 82 | if task_name != None and self.mlps != None: 83 | with torch.no_grad(): 84 | new_prompt = self.mlps[self.task_names.index(task_name)](self.prompt) 85 | self.previous_prompts = torch.cat([self.previous_prompts, new_prompt], axis=0) 86 | print(f'updated previous prompt to: {self.previous_prompts.shape}') 87 | 88 | def freeze_mlps(self, name:str, requires_grad=False): 89 | """ 90 | Freeze or unfreeze all the MLPs except the one for the current task 91 | """ 92 | for i, mlp in enumerate(self.mlps): 93 | if i != self.task_names.index(name): 94 | for param in mlp.parameters(): 95 | if param.requires_grad != requires_grad: 96 | param.requires_grad = requires_grad 97 | else: 98 | for param in mlp.parameters(): 99 | if param.requires_grad == requires_grad: 100 | param.requires_grad = not requires_grad 101 | 102 | def save_mlps_prompts(self, path): 103 | """ 104 | Save all the MLPs and prompts and previous learned prompts 105 | """ 106 | mlp_state_dict = {"mlps": self.mlps.state_dict()} 107 | prompt_state_dict = {"prompt": self.prompt} 108 | previous_prompt_state_dict = {"previous_prompt": self.previous_prompts} 109 | torch.save(mlp_state_dict, os.path.join(path, f"mlps_{self.current_task_name}.pt")) 110 | torch.save(prompt_state_dict, os.path.join(path, f"prompt_{self.current_task_name}.pt")) 111 | torch.save(previous_prompt_state_dict, os.path.join(path, f"previous_prompt_{self.current_task_name}.pt")) 112 | 113 | def load_mlps_prompts(self, path, task_name = None): 114 | """ 115 | Load all the MLPs and prompts 116 | """ 117 | mlp_path = os.path.join(path, f"mlps_{task_name}.pt") 118 | prompt_path = os.path.join(path, f"prompt_{task_name}.pt") 119 | previous_prompt_path = os.path.join(path, f"previous_prompt_{task_name}.pt") 120 | assert mlp_path and prompt_path and previous_prompt_path, "mlp_path or prompt_path is None" 121 | 122 | mlp_state_dict = torch.load(mlp_path, map_location=self.device) 123 | prompt_state_dict = torch.load(prompt_path, map_location=self.device) 124 | previous_prompt_state_dict = torch.load(previous_prompt_path, map_location=self.device) 125 | self.mlps.load_state_dict(mlp_state_dict["mlps"]) 126 | self.prompt.data = prompt_state_dict["prompt"].data 127 | self.previous_prompts.data = previous_prompt_state_dict["previous_prompt"].data 128 | print(f"Loaded mlps and prompt from {mlp_path} and {prompt_path}") 129 | 130 | def freeze_all(self): 131 | for n, p in self.named_parameters(): 132 | p.requires_grad = False 133 | 134 | def forward( 135 | self, 136 | input_ids: torch.LongTensor = None, 137 | attention_mask: Optional[torch.Tensor] = None, 138 | position_ids: Optional[torch.LongTensor] = None, 139 | past_key_values: Optional[List[torch.FloatTensor]] = None, 140 | inputs_embeds: Optional[torch.FloatTensor] = None, 141 | labels: Optional[torch.LongTensor] = None, 142 | use_cache: Optional[bool] = None, 143 | output_attentions: Optional[bool] = None, 144 | output_hidden_states: Optional[bool] = None, 145 | return_dict: Optional[bool] = None, 146 | ) -> Union[Tuple, CausalLMOutputWithPast]: 147 | 148 | inputs_embeds = self.embed_tokens(input_ids) # [batch_size, seq_len, embed_dim] 149 | k = inputs_embeds.shape[0] # batch_size 150 | 151 | mlp = self.mlps[self.task_names.index(self.current_task_name)] 152 | prompt = mlp(self.prompt) #[prefix_len, embed_dim] 153 | 154 | if self.pp: 155 | inputs_embeds = torch.cat([prompt.unsqueeze(0).repeat(k, 1, 1), # [batch_size, prefix_len, embed_dim] 156 | self.previous_prompts.unsqueeze(0).repeat(k, 1, 1),# [batch_size, len_of_learned_tasks, embed_dim] 157 | inputs_embeds], axis=1)# [batch_size, seq_len, embed_dim] 158 | full_prefix_len = prompt.shape[0] + self.previous_prompts.shape[0] # prefix_len + len_of_learned_tasks 159 | 160 | source_mask = torch.cat((torch.tensor(1).to('cuda').repeat(k, full_prefix_len), 161 | attention_mask), axis=1) # [batch_size, prefix_len + learned_tasks_len] 162 | if labels is not None: 163 | labels = torch.concat((labels[0][0].repeat(k, inputs_embeds.shape[1] - labels.shape[1]), labels),axis=1).detach()#[batch_size, prefix_len + learned_tasks_len, embed_dim] 164 | return self.model( 165 | inputs_embeds=inputs_embeds, 166 | labels=labels, 167 | attention_mask=source_mask 168 | ) 169 | else: 170 | inputs_embeds = inputs_embeds.half() 171 | return self.model(inputs_embeds=inputs_embeds, attention_mask=source_mask, use_cache=False, return_dict=True) 172 | 173 | class PPTrainer(BaseTrainerCL): 174 | 175 | def __init__(self, **kwargs): 176 | super().__init__(**kwargs) 177 | self.model:PPModel = PPModel(model=self.model, 178 | prefix_len=20, 179 | task_names=list(self.continual_training_dataset.keys())) 180 | 181 | def compute_loss(self, model, inputs, return_outputs=False): 182 | outputs = self.model(**inputs) 183 | loss = outputs.loss 184 | return (loss, outputs) if return_outputs else loss 185 | 186 | def continual_learning(self): 187 | resume_from_checkpoint = "False" 188 | for name, train_set in self.continual_training_dataset.items(): 189 | self.current_task_name = name 190 | self.model.current_task_name = name 191 | self.update_adapter_and_train_set(resume_from_checkpoint, train_set) 192 | self.model.freeze_mlps(name) 193 | self.train() 194 | self.model.progressive_previous_prompt(name) 195 | resume_from_checkpoint = self.save_model(name) 196 | self.model.save_mlps_prompts(self.args.output_dir) 197 | wandb.finish() 198 | 199 | def save_model(self, name) -> str: 200 | if self.args.output_dir is not None: 201 | output_dir = os.path.join(self.args.output_dir, f"{self.cl_method}_{self.adapter}_checkpoint_{name}") 202 | assert isinstance(self.model.model, PeftModel), "self.model.model is not a PeftModel" 203 | self.model.model.save_pretrained(output_dir) 204 | print(f"save task: {name} adapter to {self.args.output_dir}") 205 | return output_dir -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | datasets 4 | fire 5 | peft 6 | git+https://github.com/2proveit/transformers.git@prepare_for_continual_training 7 | sentencepiece 8 | deepspeed 9 | numpy 10 | wandb 11 | qpth 12 | pandas 13 | tqdm 14 | evaluate 15 | rouge 16 | fuzzywuzzy 17 | nltk 18 | prettytable 19 | sacrebleu 20 | sacremoses -------------------------------------------------------------------------------- /scripts/train_seq.sh: -------------------------------------------------------------------------------- 1 | cl_method="seq" 2 | deepspeed main.py \ 3 | --deepspeed deepspeed_zero2_no_offload.json \ 4 | --model_name_or_path "meta-llama/Llama2-7b-hf" \ 5 | --load_in_8bits true \ 6 | --data_path "~/Downloads/TRACE-Benchmark/LLM-CL-Benchmark_5000" \ 7 | --dataset_names "20Minuten,FOMC,MeetingBank,ScienceQA" \ 8 | --max_length 1024 \ 9 | --train_on_inputs false \ 10 | --cl_method $cl_method \ 11 | --output_dir "outputs/$cl_method" \ 12 | --per_device_train_batch_size 1 \ 13 | --per_device_eval_batch_size 1 \ 14 | --gradient_accumulation_steps 1 \ 15 | --num_train_epochs 3 \ 16 | --warmup_steps 500 \ 17 | --learning_rate 1e-5 \ 18 | --lr_scheduler_type "constant_with_warmup" \ 19 | --load_best_model_at_end true \ 20 | --save_total_limit 3 \ 21 | --evaluation_strategy "steps" \ 22 | --save_strategy "steps" \ 23 | --save_steps 500 \ 24 | --eval_steps 500 \ 25 | --logging_steps 500 \ 26 | --report_to "wandb" \ 27 | --run_name "llama2-7b-hf" \ -------------------------------------------------------------------------------- /utils/arg_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict, Tuple, Union, Any, Optional 4 | from pathlib import Path 5 | from dataclasses import dataclass, field 6 | from transformers import HfArgumentParser, TrainingArguments, GenerationConfig, BitsAndBytesConfig 7 | from peft import LoraConfig, TaskType 8 | import warnings 9 | __all__ = ["get_args", "CLArguments", "TuningArguments", "DataArguments"] 10 | 11 | def list_strings(string: str) -> List[str]: 12 | return string.split(",") 13 | @dataclass 14 | class DataArguments: 15 | data_path: Union[str, Path] = "./TRACE-Benchmark/LLM-CL-Benchmark_5000" 16 | dataset_names: str = "20Minuten,FOMC,MeetingBank,NumGLUE-ds,ScienceQA,C-STANCE,NumGLUE-cm" 17 | max_length: int = 1024 18 | train_on_inputs: bool = False 19 | # truncation: bool = True 20 | # padding: bool = True 21 | 22 | def __post_init__(self): 23 | self.dataset_names = list_strings(self.dataset_names) 24 | 25 | @dataclass 26 | class CLArguments: 27 | cl_method: str = 'seq' 28 | ewc_lambda: float = 0.1 29 | er_buffer_size: int = 1000 30 | er_buffer_method: str = "random" 31 | gem_memory_strength: float = 0.5 32 | l2p_pool_size: int = 10 33 | l2p_prompt_length: int = 5 34 | l2p_prompt_init: str = "random" 35 | pp_prefix_length: int = 20 36 | 37 | cl_config: Union[Dict[str, Any], Any] = field(init=False) 38 | def __post_init__(self): 39 | assert self.cl_method in ["seq","ewc", "er", "gem", "agem", "l2p", "pp", "mtl", "one", "ilora"], f"cl_method '{self.cl_method}' not supported" 40 | if self.cl_method == 'seq': 41 | warnings.warn("Using 'seq' as cl_method, other cl configs are ignored") 42 | self.cl_config = {} 43 | if self.cl_method == 'ewc': 44 | warnings.warn("Using 'ewc' as cl_method, other cl configs are ignored") 45 | self.cl_config = { 46 | "lambda": self.ewc_lambda 47 | } 48 | if self.cl_method == 'er': 49 | warnings.warn("Using 'er' as cl_method, other cl configs are ignored") 50 | self.cl_config = { 51 | "buffer_size": self.er_buffer_size, 52 | "buffer_method": self.er_buffer_method 53 | } 54 | if self.cl_method == 'gem': 55 | warnings.warn("Using 'gem' as cl_method, other cl configs are ignored") 56 | self.cl_config = { 57 | "memory_strength": self.gem_memory_strength 58 | } 59 | if self.cl_method == 'l2p': 60 | warnings.warn("Using 'l2p' as cl_method, other cl configs are ignored") 61 | self.cl_config = { 62 | "pool_size": self.l2p_pool_size, 63 | "prompt_length": self.l2p_prompt_length, 64 | "prompt_init": self.l2p_prompt_init 65 | } 66 | if self.cl_method == 'pp': 67 | warnings.warn("Using 'pp' as cl_method, other cl configs are ignored") 68 | self.cl_config = { 69 | "prefix_length": self.pp_prefix_length 70 | } 71 | if self.cl_method == 'ilora': 72 | warnings.warn("Using 'ilora' as cl_method, other cl configs are ignored") 73 | self.cl_config = {} 74 | if self.cl_method == 'mtl': 75 | warnings.warn("Using 'mtl' as cl_method, other cl configs are ignored") 76 | self.cl_config = {} 77 | if self.cl_method == 'one': 78 | warnings.warn("Using 'one' as cl_method, other cl configs are ignored") 79 | self.cl_config = {} 80 | 81 | print(f"*** cl_config ***:\n\t{self.cl_config}") 82 | 83 | @dataclass 84 | class TuningArguments: 85 | # basic config 86 | model_name_or_path: Union[str, Path] = "meta-llama/Llama2-7b-hf" 87 | load_in_8bit: bool = False 88 | # lora config 89 | lora_r: int = 8 90 | lora_alpha: int = 32 91 | lora_dropout: float = 0.1 92 | target_modules: str = "q_proj,k_proj,v_proj,o_proj" 93 | load_8bit: bool = True 94 | lora_config: Union[LoraConfig, Any] = field(init=False) 95 | manual_seed: int = 37 96 | 97 | # redundant args 98 | config_file: Optional[str] = None 99 | def __post_init__(self): 100 | self.target_modules = list_strings(self.target_modules) 101 | self.lora_config = LoraConfig( 102 | r=self.lora_r, 103 | lora_alpha=self.lora_alpha, 104 | lora_dropout=self.lora_dropout, 105 | target_modules=self.target_modules, 106 | task_type=TaskType.CAUSAL_LM 107 | ) 108 | @dataclass 109 | class InferArguments: 110 | cl_method: str = 'seq' 111 | model_name_or_path: Union[str, Path] = "meta-llama/Llama2-7b-hf" 112 | load_in_8bit: bool = True 113 | load_in_4bit: bool = False 114 | tokenizer_name_or_path: Union[str, Path] = "meta-llama/Llama2-7b-hf" 115 | peft_cfg_path: Optional[str] = None 116 | peft_weights_path: Optional[str] = None 117 | infer_batch_size: int = 4 118 | # generation config 119 | max_new_tokens: int = 128 120 | temperature: float = 0.1 121 | top_p: float = 0.75 122 | repetition_penalty: float = 1.15 123 | do_sample: bool = True 124 | generation_config: Union[GenerationConfig, Any] = field(init=False) 125 | bnb_config: BitsAndBytesConfig = field(init=False) 126 | save_path: Union[str, Path] = "./generated_texts.json" 127 | 128 | def __post_init__(self): 129 | self.generation_config = GenerationConfig( 130 | max_new_tokens=self.max_new_tokens, 131 | temperature=self.temperature, 132 | top_p=self.top_p, 133 | repetition_penalty=self.repetition_penalty, 134 | do_sample=self.do_sample 135 | ) 136 | self.bnb_config = BitsAndBytesConfig( 137 | load_in_8bit=self.load_in_8bit, 138 | load_in_4bit=self.load_in_4bit 139 | ) 140 | @dataclass 141 | class EvalArguments: 142 | cl_method: str = 'seq' 143 | json_dirs: str = "outputs/seq/20Minuten.json" 144 | increment_order: str = '20Minuten' 145 | save_path: str = "./eval_table.csv" 146 | 147 | def __post_init__(self): 148 | self.json_dirs = list_strings(self.json_dirs) 149 | self.increment_order = list_strings(self.increment_order) 150 | 151 | def get_args() -> Tuple[TrainingArguments, CLArguments, TuningArguments, DataArguments]: 152 | parser = HfArgumentParser((TrainingArguments, CLArguments, TuningArguments, DataArguments)) 153 | train_args, cl_args, tuning_args, data_args = parser.parse_args_into_dataclasses() 154 | return train_args, cl_args, tuning_args, data_args 155 | 156 | if __name__ == "__main__": 157 | train_args, cl_args, tuning_args = get_args() 158 | print(train_args) 159 | print(cl_args) 160 | print(tuning_args) -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | from rouge import Rouge 3 | from fuzzywuzzy import fuzz 4 | from datasets import load_metric 5 | from nltk.translate.bleu_score import sentence_bleu 6 | import evaluate 7 | 8 | 9 | ######################## 10 | # BLEU 11 | ######################## 12 | def tokenize(text): 13 | tokens = re.split(r'\s|\.', text) 14 | tokens = [t for t in tokens if len(t) > 0] 15 | return tokens 16 | 17 | 18 | def bleu_score(reference, hypothesis, gram): 19 | reference_tokens = tokenize(reference) 20 | hypothesis_tokens = tokenize(hypothesis) 21 | 22 | if gram == 1: 23 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1.,)) # BELU-1 24 | elif gram == 2: 25 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 2., 1. / 2.)) # BELU-2 26 | elif gram == 3: 27 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 3., 1. / 3., 1. / 3.)) # BELU-3 28 | elif gram == 4: 29 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 4., 1. / 4., 1. / 4., 1. / 4.)) # BELU-4 30 | 31 | return bleu 32 | 33 | 34 | def calculate_bleu(results, data, gram): 35 | bleus = [] 36 | for output_id in range(len(results)): 37 | prediction = results[output_id] 38 | target = data[output_id] 39 | if prediction == "" or target == "": 40 | continue 41 | bleu = bleu_score(target, prediction, gram) 42 | bleus.append(bleu) 43 | avg_bleu = sum(bleus) / len(results) 44 | return avg_bleu 45 | 46 | 47 | ######################## 48 | ## Rouge-L 49 | ######################## 50 | def score_rouge(str1, str2): 51 | print(type(str1), type(str2)) 52 | rouge = Rouge(metrics=["rouge-l"]) 53 | scores = rouge.get_scores(str1, str2, avg=True) 54 | rouge_l = scores['rouge-l']['f'] 55 | return rouge_l 56 | 57 | 58 | def calculate_rouge(results, data): 59 | rouges = [] 60 | for output_id in range(len(results)): 61 | prediction = results[output_id] 62 | target = data[output_id] 63 | if prediction == "" or target == "": 64 | continue 65 | rouge = score_rouge(target, prediction) 66 | rouges.append(rouge) 67 | avg_rouge = sum(rouges) / len(results) if rouges else 0.0 68 | return avg_rouge 69 | 70 | def calculate_rouge_(results, data): 71 | rouge_evaluator = evaluate.load("rouge") 72 | eval_results = [] 73 | eval_refs = [] 74 | for output_id in range(len(results)): 75 | prediction = results[output_id] 76 | target = data[output_id] 77 | if prediction == "" or target == "": 78 | continue 79 | eval_results.append(prediction) 80 | eval_refs.append(target) 81 | rouge_scores = rouge_evaluator.compute(predictions=eval_results, references=eval_refs, rouge_types=["rougeL"]) 82 | return rouge_scores['rougeL'] 83 | 84 | def caculate_accuracy(results, data): 85 | scores = 0 86 | for output_id in range(len(results)): 87 | target = data[output_id] 88 | prediction = results[output_id] 89 | if prediction == "" or target == "": 90 | continue 91 | if prediction == target: 92 | scores += 1 93 | avg_score = scores / len(results) 94 | return avg_score 95 | 96 | 97 | 98 | def f1_score(list1, list2): 99 | # TP: item in list1 and list2 100 | # FP: item in list1 but not in list2 101 | # TN: item not in list1 and list2 102 | # FN: item in list2 but not in list1 103 | num_TP = 0 104 | for item1 in list1: 105 | for item2 in list2: 106 | if item1 == item2: 107 | num_TP += 1 108 | break 109 | precision = num_TP / len(list1) 110 | recall = num_TP / len(list2) 111 | if precision == 0 or recall == 0: 112 | return 0 113 | return 2 * (precision * recall / (precision + recall)) 114 | 115 | 116 | def calculate_f1(results, data): 117 | scores = [] 118 | for output_id in range(len(results)): 119 | prediction = results[output_id] 120 | target = data[output_id] 121 | if len(prediction) == 0 or len(target) == 0: 122 | continue 123 | score = f1_score(target, prediction) 124 | scores.append(score) 125 | avg_score = sum(scores) / len(results) 126 | return avg_score 127 | 128 | 129 | 130 | def calculate_sari(inputs, results, data): 131 | sari = load_metric("sari") 132 | result = sari.compute(sources=inputs, predictions=results, references=[[label] for label in data]), # one reference for each prediction 133 | return result 134 | 135 | 136 | def eval_20Minuten(input_sequences, predicted_sequences, ground_truths): 137 | sari = calculate_sari(input_sequences, predicted_sequences, ground_truths) 138 | evaluation_result = {"sari": sari} 139 | return evaluation_result 140 | 141 | def eval_medmcqa(predicted_sequences, ground_truths): 142 | predicted_sequences = postprocess_choice_acc(predicted_sequences) 143 | ground_truths = postprocess_choice_acc(ground_truths) 144 | 145 | accuracy = caculate_accuracy(predicted_sequences, ground_truths) 146 | evaluation_result = {"accuracy": accuracy} 147 | return evaluation_result 148 | 149 | def eval_jecqa(predicted_sequences, ground_truths): 150 | predicted_sequences = postprocess_choice_acc(predicted_sequences) 151 | ground_truths = postprocess_choice_acc(ground_truths) 152 | 153 | accuracy = caculate_accuracy(predicted_sequences, ground_truths) 154 | evaluation_result = {"accuracy": accuracy} 155 | return evaluation_result 156 | 157 | def eval_CStance(predicted_sequences, ground_truths): 158 | predicted_sequences = postprocess_choice_acc(predicted_sequences) 159 | ground_truths = postprocess_choice_acc(ground_truths) 160 | 161 | accuracy = caculate_accuracy(predicted_sequences, ground_truths) 162 | evaluation_result = {"accuracy": accuracy} 163 | return evaluation_result 164 | 165 | 166 | def eval_FOMC(predicted_sequences, ground_truths): 167 | predicted_sequences = postprocess_choice_acc(predicted_sequences) 168 | ground_truths = postprocess_choice_acc(ground_truths) 169 | 170 | accuracy = caculate_accuracy(predicted_sequences, ground_truths) 171 | evaluation_result = {"accuracy": accuracy} 172 | return evaluation_result 173 | 174 | 175 | def eval_MeetingBank(predicted_sequences, ground_truths): 176 | # bleu_1 = calculate_bleu(predicted_sequences, ground_truths, 1) 177 | # bleu_4 = calculate_bleu(predicted_sequences, ground_truths, 4) 178 | rouge = calculate_rouge_(predicted_sequences, ground_truths) 179 | # evaluation_result = {"bleu-1": bleu_1, "bleu-4": bleu_4, "rouge-L": rouge} 180 | evaluation_result = {"rouge-L": rouge} 181 | return evaluation_result 182 | 183 | 184 | def eval_NumGLUE(predicted_sequences, ground_truths): 185 | predicted_sequences = postprocess_choice_num(predicted_sequences) 186 | ground_truths = postprocess_choice_num(ground_truths) 187 | 188 | accuracy = caculate_accuracy(predicted_sequences, ground_truths) 189 | evaluation_result = {"accuracy": accuracy} 190 | return evaluation_result 191 | 192 | 193 | def resolve(dataset: list): 194 | keyword_list = [] 195 | for datium in dataset: 196 | keyword_list.append(datium.split(" , ")) 197 | return keyword_list 198 | 199 | 200 | def eval_PapyrusF(predicted_sequences, ground_truths): 201 | outputs = resolve(predicted_sequences) 202 | gts = resolve(ground_truths) 203 | 204 | f1 = calculate_f1(outputs, gts) 205 | evaluation_result = {"F1": f1} 206 | return evaluation_result 207 | 208 | def postprocess_choice_acc(predicted_sequences): 209 | outputs = [] 210 | for output in predicted_sequences: 211 | if not output: 212 | outputs.append("") 213 | continue 214 | match = re.search(r"[A-D]", output) 215 | if match: 216 | outputs.append(match.group(0)) 217 | else: 218 | outputs.append("") 219 | return outputs 220 | 221 | def postprocess_choice_num(predicted_sequences): 222 | outputs = [] 223 | for output in predicted_sequences: 224 | if not output: 225 | outputs.append("") 226 | continue 227 | match = re.search(r"\d+\.?\d*", output) 228 | if match: 229 | outputs.append(match.group(0)) 230 | else: 231 | outputs.append("") 232 | return outputs 233 | 234 | 235 | def resolve_sciQA(dataset: list): 236 | answers = [] 237 | reasonings = [] 238 | for datium in dataset: 239 | if len(datium) >= 3: 240 | answers.append(datium[0]) # the first char is the answer. e.g. A, B,... 241 | reasonings.append(datium[2:]) # A/nBecause... 242 | elif 1 <= len(datium) < 3: 243 | answers.append(datium[0]) 244 | reasonings.append("") 245 | else: 246 | answers.append("") 247 | reasonings.append("") 248 | outputs = {"answers": answers, "reasonings": reasonings} 249 | return outputs 250 | 251 | 252 | def eval_SciQA(predicted_sequences, ground_truths): 253 | outputs = resolve_sciQA(predicted_sequences) 254 | gts = resolve_sciQA(ground_truths) 255 | outputs["answers"] = postprocess_choice_acc(outputs["answers"]) 256 | gts["answers"] = postprocess_choice_acc(gts["answers"]) 257 | 258 | # bleu_1 = calculate_bleu(outputs["reasonings"], gts["reasonings"], 1) 259 | # bleu_4 = calculate_bleu(outputs["reasonings"], gts["reasonings"], 4) 260 | # rouge = calculate_rouge_(outputs["reasonings"], gts["reasonings"]) 261 | accuracy = caculate_accuracy(outputs["answers"], gts["answers"]) 262 | 263 | # evaluation_result = {"bleu-1": bleu_1, "bleu-4": bleu_4, 264 | # # "rouge-L": rouge, 265 | # "accuracy": accuracy} 266 | evaluation_result = {"accuracy": accuracy} 267 | return evaluation_result 268 | --------------------------------------------------------------------------------