├── README.md
├── evaluation.py
├── get_dataset.py
├── inference.py
├── main.py
├── method
├── AGEM.py
├── BaseTrainerCL.py
├── ER.py
├── EWC.py
├── GEM.py
├── ILORA.py
├── L2P.py
├── MTL.py
├── ONE.py
└── PP.py
├── requirements.txt
├── scripts
└── train_seq.sh
└── utils
├── arg_configs.py
└── metrics.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | LLMCL
3 |
4 |
5 | Analyzing and Reducing Catastrophic Forgetting in Parameter Efficient Tuning
6 |
7 |
8 | ## Overview
9 | LLMCL is a repository based on the Hugging Face Transformers library, designed to assess the continuous learning capability of large language models. Through this repository, users can easily customize datasets, specify models, and experiment with existing classical continuous learning methods.
10 |
11 | ## Key Features
12 | - **Continual Learning Methods:** The repository includes several classical continuous learning methods for users to reference and use.
13 | - **Model Customization:** You can easily customize the model you want to use, and the repository will automatically download the corresponding model.
14 |
15 | ## Quick Start
16 | ### 1.Install dependencies
17 | ```bash
18 | conda create -n llmcl python==3.10
19 | pip install -r requirements.txt
20 | ```
21 | ### 2.Start Training
22 | ```bash
23 | ./scripts/train_seq.sh
24 | ```
25 | ### 3.Inference
26 | ```
27 | ./scripts/infer_seq.sh
28 | ```
29 | ### 4. customize
30 | You can easily customize scripts for your own use:
31 |
32 | - Ensure your dataset is organized in JSON format with `prompt` and `answer` as keys.
33 | - Save the dataset file to `//.json`
34 | - For more details, refer to the [get_dataset.py](get_dataset.py) file.
35 |
36 | ## Reproduce
37 | To Reproduce our results, you need \
38 | **1.** Request the access to `llama2` model and download [TRACE Benchmark](https://drive.google.com/file/d/1S0SmU0WEw5okW_XvP2Ns0URflNzZq6sV/view?usp=drive_link) , [MedMCQA](https://medmcqa.github.io/),[JEC-QA](https://jecqa.thunlp.org/) to `./data_files` folder.
39 |
40 |
41 | 2.run scripts
42 | customize your training scripts and run it.
43 |
44 |
45 |
46 |
47 |
48 | ## Citation
49 | If you find this repository helpful, please consider citing our work.
50 |
51 | ```bibtex
52 | @misc{ren2024analyzing,
53 | title={Analyzing and Reducing Catastrophic Forgetting in Parameter Efficient Tuning},
54 | author={Weijieying Ren and Xinlong Li and Lei Wang and Tianxiang Zhao and Wei Qin},
55 | year={2024},
56 | eprint={2402.18865},
57 | archivePrefix={arXiv},
58 | primaryClass={cs.LG}
59 | }
60 | ```
61 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from prettytable import PrettyTable
4 | from utils.metrics import (
5 | eval_FOMC,
6 | eval_SciQA,
7 | eval_CStance,
8 | eval_NumGLUE,
9 | eval_PapyrusF,
10 | eval_20Minuten,
11 | eval_MeetingBank,
12 | eval_medmcqa,
13 | eval_jecqa,
14 | )
15 | from transformers import HfArgumentParser
16 | import logging
17 | from utils.arg_configs import EvalArguments
18 | import warnings
19 | warnings.filterwarnings("ignore", category=UserWarning)
20 |
21 | logging.basicConfig(
22 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
23 | datefmt="%m/%d/%Y %H:%M:%S",
24 | level=logging.INFO,
25 | )
26 | logger = logging.getLogger(__name__)
27 |
28 | EVAL_FUNCs = {
29 | "FOMC": eval_FOMC,
30 | "ScienceQA": eval_SciQA,
31 | "C-STANCE": eval_CStance,
32 | "NumGLUE-cm": eval_NumGLUE,
33 | "NumGLUE-ds": eval_PapyrusF,
34 | "20Minuten": eval_20Minuten,
35 | "MeetingBank": eval_MeetingBank,
36 | "medmcqa": eval_medmcqa,
37 | "jecqa": eval_jecqa,
38 | }
39 |
40 | def main():
41 | parser = HfArgumentParser(EvalArguments)
42 | args = parser.parse_args_into_dataclasses()[0]
43 |
44 | # check if the json_dir exists
45 | for j_dir in args.json_dirs:
46 | if not os.path.exists(j_dir):
47 | raise ValueError(f"Path {j_dir} does not exist")
48 | if args.cl_method == "mtl" and len(args.json_dirs) != 1:
49 | raise ValueError(f"Multi-task learning should only have one json_dir")
50 |
51 | logger.info(f"Evaluting args: {args}")
52 | logger.info(f"Make sure your `increment_order` is in the same order as you train the datasets!!")
53 | results = {}
54 | for json_dir in args.json_dirs:
55 | data = json.load(open(json_dir, "r"))
56 | trained_task = os.path.split(json_dir)[1].split(".json")[0]
57 |
58 | if trained_task not in results:
59 | results[trained_task] = {}
60 |
61 | for infer_task, infer_result in data.items():
62 | # try:
63 | eval_func = EVAL_FUNCs[infer_task]
64 | prompts = []
65 | answers = []
66 | generated_texts = []
67 |
68 | for item in infer_result:
69 | prompts.append(item["prompt"])
70 | answers.append(item["answer"])
71 | generated_texts.append(item["generated_text"])
72 |
73 | try:
74 | # special case for `20Minuten`
75 | if infer_task == "20Minuten":
76 | eval_result = eval_func(prompts, generated_texts, answers)
77 | else:
78 | eval_result = eval_func(generated_texts, answers)
79 | logger.info(f"Inference result {json_dir} on task` {infer_task}`: {eval_result}")
80 | except Exception as e:
81 | eval_result = None
82 | logger.error(f"Error processing file {json_dir} with dataset `{infer_task}`: {e}")
83 | continue
84 | results[trained_task][infer_task] = get_res(eval_result, infer_task)
85 |
86 | # Print table
87 | table = PrettyTable()
88 | table.field_names = [args.cl_method] + args.increment_order
89 | print(results)
90 | for row_name in args.increment_order:
91 | if row_name not in results:
92 | logger.warning(f"Missing result for `{row_name}`, skipping")
93 | continue
94 | row = [row_name]
95 | for col_name in args.increment_order:
96 | if col_name not in results[row_name]:
97 | row.append("-")
98 | logger.warning(f"Missing result for `{row_name}` on `{col_name}`")
99 | else:
100 | row.append(results[row_name][col_name])
101 | table.add_row(row)
102 |
103 | print(table)
104 | save_tabel(table, args.save_path)
105 |
106 |
107 | def save_tabel(table, path):
108 | "save table to .csv file"
109 | base_path = os.path.split(path)[0]
110 | if not os.path.exists(base_path):
111 | os.makedirs(base_path, exist_ok=True)
112 | from pandas import DataFrame
113 | df = DataFrame([table.field_names] + table._rows)
114 | df.to_csv(path, index=False, header=False)
115 |
116 | logger.info(f"Results saved to {path}")
117 |
118 | def get_res(result: dict, name: str):
119 | if result is None:
120 | return -1
121 | if name == "20Minuten":
122 | return round(result['sari'][0]['sari'] / 100, 3)
123 | elif name in ["C-STANCE", "FOMC", "NumGLUE-cm", "NUmGLUE-ds", "ScienceQA", "medmcqa", "jecqa"]:
124 | return round(result['accuracy'], 3)
125 | elif name == "Py150":
126 | return round(result['similarity'], 3)
127 | elif name == "MeetingBank":
128 | return round(result['rouge-L'], 3)
129 |
130 |
131 | if __name__ == "__main__":
132 | main()
--------------------------------------------------------------------------------
/get_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import json
4 | import torch.utils
5 | from torch.utils.data import Dataset
6 | from dataclasses import dataclass
7 | from pathlib import Path
8 | from typing import List, Dict, Union, Tuple
9 | from transformers import PreTrainedTokenizerBase, AutoTokenizer
10 |
11 | class BaseDataset(Dataset):
12 | def __init__(self, tokenizer:PreTrainedTokenizerBase, json_dir:Union[str, Path], max_length:int=1024, train_on_inputs:bool=True, test:bool=False):
13 | super(BaseDataset).__init__()
14 | self.tokenizer = tokenizer
15 | assert self.tokenizer.pad_token_id is not None, "Tokenizer must have a pad token id"
16 | self.json_dir = json_dir
17 | self.meta_data = json.load(open(json_dir))
18 | self.max_length = max_length
19 | self.train_on_inputs = train_on_inputs
20 | self.keys_to_data = list(self.meta_data[0].keys())
21 | self.test = test
22 | self.data = self._tokenize_dataset()
23 |
24 | def __len__(self):
25 | return len(self.data)
26 |
27 | def __getitem__(self, idx):
28 | return self.data[idx]
29 |
30 | def _tokenize_dataset(self, ingnore_idx:int=-100) -> List[Dict[str, torch.Tensor]]:
31 | tokenized_samples = []
32 | for sample in self.meta_data:
33 | try:
34 | q_tokenized = self.tokenizer(sample[self.keys_to_data[0]], add_special_tokens=False)
35 | a_tokenized = self.tokenizer(sample[self.keys_to_data[1]], add_special_tokens=False)
36 |
37 | if not self.test:
38 | input_ids = q_tokenized['input_ids'] + a_tokenized['input_ids']
39 | else:
40 | input_ids = q_tokenized['input_ids']
41 |
42 | if len(input_ids) > self.max_length - 2:
43 | input_ids = input_ids[:self.max_length - 2]
44 |
45 | full_input_ids = [self.tokenizer.bos_token_id] + input_ids
46 | if not self.test:
47 | full_input_ids += [self.tokenizer.eos_token_id]
48 | input_ids = torch.tensor(full_input_ids)
49 | attention_mask = torch.ones_like(input_ids)
50 |
51 | if (not self.train_on_inputs) and (not self.test):
52 | labels = torch.full_like(input_ids, fill_value=ingnore_idx)
53 | ans_start_idx = len(q_tokenized['input_ids']) + 1
54 | labels[ans_start_idx:] = input_ids[ans_start_idx:]
55 | else:
56 | labels = input_ids.clone()
57 |
58 | tokenized_samples.append(dict(
59 | input_ids=input_ids,
60 | attention_mask=attention_mask,
61 | labels=labels
62 | ))
63 | except Exception as e:
64 | print(f"Error processing sample: {e}")
65 | continue
66 |
67 | return tokenized_samples
68 |
69 | class DataCollector(object):
70 | """ For a stable traning, we need to pad the input_ids to the `max_length` """
71 | def __init__(self, tokenizer: PreTrainedTokenizerBase, padding: Union[str, bool], max_length: int=1024, ignore_idx: int=-100):
72 | self.tokenizer = tokenizer
73 | self.padding = padding
74 | assert self.padding in ['longest', True], "Padding must be either 'longest', 'max_length' or False"
75 | self.max_length = max_length
76 | self.ignore_idx = ignore_idx
77 |
78 | def __call__(self, batch: List[Dict[str, torch.Tensor]]):
79 | input_ids = [sample['input_ids'] for sample in batch]
80 | attention_mask = [sample['attention_mask'] for sample in batch]
81 | labels = [sample['labels'] for sample in batch]
82 |
83 | len_pad_to = max([len(ids) for ids in input_ids]) if self.padding == 'longest' else self.max_length
84 | for i in range(len(batch)):
85 | input_ids[i] = torch.cat([
86 | torch.full((len_pad_to - input_ids[i].shape[0],), fill_value=self.tokenizer.pad_token_id), # left padding
87 | input_ids[i]
88 | ])
89 | attention_mask[i] = torch.cat([
90 | torch.zeros((len_pad_to - attention_mask[i].shape[0],)),
91 | attention_mask[i]
92 | ])
93 | labels[i] = torch.cat([
94 | torch.full((len_pad_to- labels[i].shape[0],), fill_value=self.ignore_idx),
95 | labels[i]
96 | ])
97 |
98 | return dict(
99 | input_ids=torch.stack(input_ids),
100 | attention_mask=torch.stack(attention_mask),
101 | labels=torch.stack(labels)
102 | )
103 |
104 | def get_datasets(dataset_names: List[str], data_path: Union[str, Path], tokenizer: PreTrainedTokenizerBase, max_length: int=1024, split='train', train_on_inputs=False) -> Dict[str, Dataset]:
105 | datasets = {}
106 | for name in dataset_names:
107 | if os.path.exists(name) and os.path.isfile(name) and name.endswith(".json"):
108 | full_json_path = name
109 | else:
110 | full_json_path = os.path.join(data_path, name, f"{split}.json")
111 | assert os.path.exists(full_json_path), f"Path {full_json_path} does not exist"
112 |
113 | dataset = BaseDataset(tokenizer, full_json_path, max_length, train_on_inputs, test='test' in full_json_path)
114 | datasets[name] = dataset
115 | return datasets
116 |
117 | def get_joint_datasets(datasets: Dict[str, Dataset]) -> Dict[str, Dataset]:
118 | return {'joint': torch.utils.data.ConcatDataset(list(datasets.values()))}
119 |
120 | if __name__ == "__main__":
121 | tokenizer = AutoTokenizer.from_pretrained("llama2-7b-hf")
122 | tokenizer.pad_token_id = tokenizer.unk_token_id
123 | dataset = get_datasets(["FOMC"], "./TRACE-Benchmark/LLM-CL-Benchmark_500", tokenizer)
124 | from torch.utils.data import DataLoader
125 | data_loader = DataLoader(dataset["FOMC"], batch_size=2, collate_fn=DataCollector(tokenizer, padding=True))
126 | for batch in data_loader:
127 | print(batch)
128 | break
129 | print(tokenizer("Hello World!, I am a sentence."))
130 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import logging, json
5 | import numpy as np
6 | from typing import List, Dict, Union
7 | from dataclasses import asdict
8 | from transformers import (
9 | AutoTokenizer,
10 | AutoModelForCausalLM,
11 | BitsAndBytesConfig,
12 | PreTrainedTokenizerBase,
13 | HfArgumentParser,
14 | GenerationConfig,
15 | LlamaTokenizer
16 | )
17 | from peft import load_peft_weights, set_peft_model_state_dict, get_peft_model, PeftConfig, LoraConfig
18 | from get_dataset import get_datasets, DataCollector
19 | from utils.arg_configs import DataArguments, InferArguments
20 | import torch.distributed as dist
21 | from torch.utils.data import DataLoader
22 | from tqdm import tqdm
23 | import warnings
24 | warnings.filterwarnings("ignore", category=UserWarning)
25 |
26 |
27 |
28 | def prepare_model_for_inference(model_name_or_path:str, bnb_config:BitsAndBytesConfig, peft_cfg_path:str=None, peft_weights_path:str=None, device:str='cuda'):
29 | model = AutoModelForCausalLM.from_pretrained(
30 | model_name_or_path,
31 | quantization_config=bnb_config,
32 | torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
33 | device_map="balanced_low_0",
34 | )
35 |
36 | if peft_cfg_path is not None and peft_weights_path is not None:
37 | peft_config = PeftConfig.from_pretrained(peft_cfg_path)
38 | model = get_peft_model(model, peft_config=peft_config)
39 | peft_state_dict = torch.load(peft_weights_path, map_location=device)
40 | set_peft_model_state_dict(model, peft_state_dict)
41 | return model
42 |
43 | def prepare_dataloader(data_args:DataArguments, tokenizer:PreTrainedTokenizerBase, batch_size:int=4, max_length:int=1024)->Dict[str, DataLoader]:
44 | test_datasets = get_datasets(**asdict(data_args), tokenizer=tokenizer, split="test")
45 | dataloaders = {}
46 | for name, dataset in test_datasets.items():
47 | test_dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=DataCollector(tokenizer, padding="longest", max_length=max_length), num_workers=4, prefetch_factor=2)
48 | dataloaders[name] = test_dataloader
49 | return dataloaders
50 |
51 | @torch.no_grad()
52 | def run_generation(model:torch.nn.Module, tokenizer:PreTrainedTokenizerBase, name:str, test_dataloader:DataLoader, generation_config:GenerationConfig) -> List[str]:
53 | model.eval()
54 | generated_texts = []
55 | for inputs in tqdm(test_dataloader, desc=f"Generating texts for {name}"):
56 | if 'labels' in inputs:
57 | inputs.pop('labels')
58 | input_ids_shape = inputs['input_ids'].shape
59 | generated_token_batch = model.generate(**inputs, generation_config=generation_config)
60 | # generated_token_batch = generated_token_batch.cpu().numpy().tolist()
61 |
62 | # mask input_ids to get only the generated text
63 | mask = torch.cat(
64 | (torch.zeros(input_ids_shape), torch.ones(input_ids_shape[0], generated_token_batch.shape[1] - input_ids_shape[1])),
65 | dim=-1
66 | ).to(torch.int64).to(generated_token_batch.device)
67 | generated_token_batch = (generated_token_batch * mask).cpu().numpy().tolist()
68 | generated_texts.extend(tokenizer.batch_decode(generated_token_batch, skip_special_tokens=True))
69 | return generated_texts
70 |
71 | def get_meta_data(data_args:DataArguments, split="test")->Dict[str, List[Dict[str, str]]]:
72 | meta_datas = {}
73 | for name in data_args.dataset_names:
74 | full_path = os.path.join(data_args.data_path, name, f"{split}.json")
75 | assert os.path.exists(full_path), f"File {full_path} does not exist"
76 |
77 | with open(full_path, 'r') as f:
78 | data = json.load(f)
79 | meta_datas[name] = data
80 | return meta_datas
81 |
82 |
83 | def main():
84 | parser = HfArgumentParser((InferArguments, DataArguments))
85 | infer_args, data_args = parser.parse_args_into_dataclasses()
86 |
87 | # prepare model, tokenizer and dataloaders
88 | model = prepare_model_for_inference(
89 | model_name_or_path=infer_args.model_name_or_path,
90 | bnb_config=infer_args.bnb_config,
91 | peft_cfg_path=infer_args.peft_cfg_path,
92 | peft_weights_path=infer_args.peft_weights_path,
93 | )
94 |
95 | # tokenizer_config =
96 | tokenizer = AutoTokenizer.from_pretrained(infer_args.tokenizer_name_or_path)
97 | dataloaders = prepare_dataloader(data_args, tokenizer, batch_size=infer_args.infer_batch_size)
98 | logger.info(f"Model and data loaders prepared for {data_args.dataset_names}, starting generation")
99 |
100 | start = time.time()
101 | generated_texts_datasets = {}
102 | for name, dataloader in dataloaders.items():
103 | generated_texts = run_generation(
104 | model=model,
105 | tokenizer=tokenizer,
106 | name=name,
107 | test_dataloader=dataloader,
108 | generation_config=infer_args.generation_config
109 | )
110 | generated_texts_datasets[name] = generated_texts
111 | end = time.time()
112 |
113 | # run generation
114 | logger.info(f"Generation completed, using {((end-start)/60):.2f} seconds")
115 | meta_datas = get_meta_data(data_args, split="test")
116 | results = {}
117 | for i, (name, gen_texts) in enumerate(generated_texts_datasets.items()):
118 | results[name] = []
119 | assert len(gen_texts) == len(meta_datas[name]), f"Number of generated texts ({len(gen_texts)}) does not match the number of meta datas ({len(meta_datas[name])})"
120 | gen_texts: List[str]
121 | meta_datas: Dict[str, List[Dict[str, str]]]
122 | for j, text in enumerate(gen_texts):
123 | results[name].append(dict(
124 | prompt=meta_datas[name][j]['prompt'],
125 | answer=meta_datas[name][j]['answer'],
126 | generated_text=text,
127 | ))
128 |
129 | # save results
130 | base_path = os.path.split(infer_args.save_path)[0]
131 | if not os.path.exists(base_path):
132 | os.makedirs(base_path, exist_ok=True)
133 | with open(f"{infer_args.save_path}", 'w', encoding="utf-8") as f:
134 | json.dump(results, f, indent=4, ensure_ascii=False)
135 | logger.info(f"Results saved to {infer_args.save_path}")
136 |
137 | if __name__ == "__main__":
138 | logging.basicConfig(
139 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
140 | datefmt="%m/%d/%Y %H:%M:%S",
141 | level=logging.INFO
142 | )
143 | logger = logging.getLogger(__name__)
144 |
145 | main()
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding, BitsAndBytesConfig
3 | from peft import prepare_model_for_kbit_training
4 | from utils.arg_configs import get_args, CLArguments, TuningArguments, DataArguments
5 | from dataclasses import asdict
6 | from get_dataset import get_datasets, DataCollector
7 | from utils.functions import set_all_seed
8 | import warnings
9 | warnings.filterwarnings("ignore", category=UserWarning)
10 |
11 | def main():
12 | train_args, cl_args, tuning_args, data_args = get_args()
13 | set_all_seed(tuning_args.manual_seed)
14 | bnb_config = BitsAndBytesConfig(
15 | load_in_8bit=tuning_args.load_8bit,
16 | )
17 | model = AutoModelForCausalLM.from_pretrained(
18 | tuning_args.model_name_or_path,
19 | quantization_config=bnb_config,
20 | torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
21 | )
22 |
23 | tokenizer = AutoTokenizer.from_pretrained(tuning_args.model_name_or_path)
24 |
25 | if tokenizer.pad_token_id is None:
26 | tokenizer.pad_token_id = 0
27 |
28 | # prepare datasets
29 | train_datasets = get_datasets(**asdict(data_args), tokenizer=tokenizer, split='train')
30 | valid_datasets = get_datasets(**asdict(data_args), tokenizer=tokenizer, split='eval')
31 |
32 |
33 | train_args = dict(
34 | model=model if not tuning_args.load_8bit else prepare_model_for_kbit_training(model),
35 | args=train_args,
36 | train_dataset=train_datasets,
37 | eval_dataset=valid_datasets,
38 | tokenizer=tokenizer,
39 | cl_args=cl_args,
40 | tuning_args=tuning_args,
41 | data_args=data_args,
42 | data_collator=DataCollector(tokenizer, padding=True, max_length=data_args.max_length)
43 | )
44 |
45 | if cl_args.cl_method.lower() == 'seq':
46 | from method.BaseTrainerCL import BaseTrainerCL
47 | cl_trainer = BaseTrainerCL(**train_args)
48 | elif cl_args.cl_method.lower() == 'ewc':
49 | from method.EWC import EWCTrainer
50 | cl_trainer = EWCTrainer(**train_args)
51 | elif cl_args.cl_method.lower() == 'er':
52 | from method.ER import ERTrainer
53 | cl_trainer = ERTrainer(**train_args)
54 | elif cl_args.cl_method.lower() == 'gem':
55 | from method.GEM import GEMTrainer
56 | cl_trainer = GEMTrainer(**train_args)
57 | elif cl_args.cl_method.lower() == 'agem':
58 | from method.AGEM import AveGEMTrainer
59 | cl_trainer = AveGEMTrainer(**train_args)
60 | elif cl_args.cl_method.lower() == 'l2p':
61 | from method.L2P import L2PTrainer
62 | cl_trainer = L2PTrainer(**train_args)
63 | elif cl_args.cl_method.lower() == 'pp':
64 | from method.PP import PPTrainer
65 | cl_trainer = PPTrainer(**train_args)
66 | elif cl_args.cl_method.lower() == 'ilora':
67 | from method.ILORA import ILoRATrainer
68 | cl_trainer = ILoRATrainer(**train_args)
69 | elif cl_args.cl_method.lower() == 'mtl':
70 | from method.MTL import MTLTrainer
71 | cl_trainer = MTLTrainer(**train_args)
72 | elif cl_args.cl_method.lower() == 'one':
73 | from method.ONE import ONETrainer
74 | cl_trainer = ONETrainer(**train_args)
75 | else:
76 | ValueError(f"continual learning method: {cl_args.cl_method} not implement yet")
77 |
78 | cl_trainer.continual_learning()
79 |
80 | if __name__ == '__main__':
81 |
82 | main()
83 |
84 |
--------------------------------------------------------------------------------
/method/AGEM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
3 | from .BaseTrainerCL import BaseTrainerCL
4 | from peft import PeftModel
5 |
6 |
7 | class AveTaskGradientCallback(TrainerCallback):
8 | def __init__(self, **kwargs):
9 | super().__init__()
10 | self.model:PeftModel = kwargs.get('model')
11 | self.current_task_name = kwargs.get('current_task_name') # need update during training
12 | self.n_tasks = kwargs.get('n_tasks')
13 | self.grads = {}
14 | self.task_names = kwargs.get('task_names')
15 | self.init_grads()
16 |
17 | def init_grads(self):
18 | for n, p in self.model.named_parameters():
19 | if p.requires_grad:
20 | self.grads[n] = torch.ones([p.data.numel()], dtype=p.dtype, device=p.device)
21 |
22 | def store_grads(self):
23 | for n, p in self.model.named_parameters():
24 | if n in self.grads:
25 | self.ave_grads(n, p.grad.detach().clone().view(-1))
26 |
27 | def ave_grads(self, name, new_grads):
28 | self.grads[name] = (self.grads[name] * (self.task_names.index(self.current_task_name) + 1) + new_grads) / (self.task_names.index(self.current_task_name) + 2)
29 |
30 | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
31 | for n, p in self.model.named_parameters():
32 | if n in self.grads and p.requires_grad:
33 | p.grad = self.get_updated_grads(n, p.grad)
34 |
35 | def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
36 | self.store_grads()
37 |
38 | def get_updated_grads(self, name, grad, eps=1e-4):
39 | ori_shape = grad.shape
40 | grad = grad.view(-1)
41 | pre_grad = self.grads[name].cuda().to(torch.float32)
42 | grad, pre_grad = grad.unsqueeze(1), pre_grad.unsqueeze(1)
43 | dot_product = torch.mm(grad.t(), pre_grad)
44 |
45 | if (dot_product < 0) != 0:
46 | new_grad = grad - (torch.mm(grad.t(), pre_grad) + eps) / (torch.mm(pre_grad.t(), pre_grad) + eps) * pre_grad
47 | grad.copy_(new_grad)
48 |
49 | return grad.view(ori_shape)
50 |
51 | def update_current_task_name(self, name:str):
52 | self.current_task_name = name
53 |
54 |
55 |
56 | class AveGEMTrainer(BaseTrainerCL):
57 | def __init__(self, **kwargs):
58 | super().__init__(**kwargs)
59 |
60 | self.add_callback(AveTaskGradientCallback(
61 | model = self.model,
62 | current_task_name=self.current_task_name,
63 | n_tasks=self.num_tasks,
64 | task_names=self.task_names))
65 |
66 | self.gem_cb = None
67 | for cb in self.callback_handler.callbacks:
68 | if isinstance(cb, AveTaskGradientCallback):
69 | self.gem_cb = cb
70 | break
71 |
72 | def continual_learning(self):
73 | for i, name in enumerate(self.task_names):
74 | self.gem_cb.update_current_task_name(name)
75 | self.before_task_start(name)
76 | self.train()
77 | self.after_task_end(name)
--------------------------------------------------------------------------------
/method/BaseTrainerCL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pathlib import Path
4 | from typing import Callable, Tuple, Union, Optional, Dict, List, overload
5 | from torch.utils.data import Dataset
6 | from transformers.trainer_callback import TrainerControl, TrainerState
7 | from transformers.training_args import TrainingArguments
8 | from utils.arg_configs import CLArguments, TuningArguments, DataArguments
9 | from get_dataset import get_joint_datasets
10 | from peft import get_peft_model, PeftModel, LoraModel, PeftModelForCausalLM, get_peft_model_state_dict
11 | from transformers import Trainer, PreTrainedModel, TrainerCallback
12 |
13 |
14 | class BaseTrainerCL(Trainer):
15 | def __init__(self, **kwargs):
16 | self.cl_args = kwargs.pop("cl_args", None)
17 | self.tuning_args = kwargs.pop("tuning_args", None)
18 | self.data_args = kwargs.pop("data_args", None)
19 | kwargs['model'] = self.prepare_model_for_cl_traning(kwargs['model'], self.cl_args, self.tuning_args)
20 | self.continual_training_dataset, self.continual_evaluating_dataset = \
21 | self.prepare_dataset_for_cl_traininig(
22 | kwargs.get("train_dataset", None),
23 | kwargs.get("eval_dataset", None))
24 | self.task_names = list(self.continual_training_dataset.keys())
25 | self.num_tasks: int = len(self.continual_training_dataset)
26 | self.current_task_name: str = None
27 | super().__init__(**kwargs)
28 |
29 | def prepare_model_for_cl_traning(self, model: Union[PreTrainedModel, nn.Module], cl_args: CLArguments=None, tuning_args: TuningArguments=None) -> Union[PreTrainedModel, PeftModel]:
30 | peft_model = get_peft_model(
31 | model=model,
32 | peft_config=tuning_args.lora_config,
33 | )
34 | return peft_model
35 |
36 | def prepare_dataset_for_cl_traininig(self, train_dataset: Dict[str, Dataset], eval_dataset: Dict[str, Dataset]) -> Dict[str, Dataset]:
37 |
38 | if self.cl_args.cl_method == 'mtl':
39 | train_dataset = get_joint_datasets(train_dataset)
40 | eval_dataset = get_joint_datasets(eval_dataset)
41 | return train_dataset, eval_dataset
42 | return train_dataset, eval_dataset
43 |
44 | def continual_learning(self):
45 | for i, name in enumerate(self.task_names):
46 | self.before_task_start(name)
47 | self.train()
48 | self.after_task_end(name)
49 |
50 | def before_task_start(self, task_name: str):
51 | """ update training and evaluation dataset for the current task """
52 | if self.cl_args.cl_method == 'mtl':
53 | self.train_dataset = self.continual_evaluating_dataset
54 | self.eval_dataset = self.continual_evaluating_dataset
55 | self.current_task_name = "joint"
56 | return
57 |
58 | if task_name not in self.continual_training_dataset:
59 | raise ValueError(f"task name {task_name} not found in the dataset")
60 | self.current_task_name = task_name
61 | self.train_dataset, self.eval_dataset = self.continual_training_dataset[task_name], self.continual_evaluating_dataset[task_name]
62 |
63 | # update model for the current task
64 | if self.cl_args.cl_method == 'one':
65 | if isinstance(self.model, LoraModel):
66 | self.model = get_peft_model(
67 | model=self.model.model,
68 | peft_config=self.tuning_args.lora_config,
69 | )
70 |
71 | def after_task_end(self, *args, **kwargs):
72 | """ save the model after training the current task """
73 | assert args[0] == self.current_task_name, f"task name mismatch: {args[0]} != {self.current_task_name}"
74 | wrappered_model_class = kwargs.get("wrappered_model_class", None)
75 |
76 | if isinstance(self.model, PeftModelForCausalLM):
77 | lora_state_dict = get_peft_model_state_dict(self.model)
78 | lora_config = self.model.peft_config
79 | if self.args.local_rank == 0:
80 | print(f"*** Saving lora adapter for task: {self.current_task_name} ***")
81 | torch.save(lora_state_dict, Path(self.args.output_dir).joinpath(f"lora_{self.current_task_name}.pt"))
82 | for adapter_name, adapter_config in lora_config.items():
83 | if adapter_name == 'default':
84 | adapter_config.save_pretrained(Path(self.args.output_dir).joinpath(f"lora_{self.current_task_name}"))
85 | else:
86 | adapter_config.save_pretrained(Path(self.args.output_dir).joinpath(f"lora_{self.current_task_name}_{adapter_name}"))
87 | elif not wrappered_model_class and isinstance(self.model, wrappered_model_class):
88 | raise NotImplementedError("not implemented yet") # TODO: implement for PP model
89 |
90 | self.tokenizer.save_pretrained(Path(self.args.output_dir).joinpath(f"tokenizer_{self.current_task_name}"))
--------------------------------------------------------------------------------
/method/ER.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Tuple
4 | from .BaseTrainerCL import BaseTrainerCL
5 |
6 |
7 | def reservoir(num_seen_examples: int, buffer_size: int) -> int:
8 | if num_seen_examples < buffer_size:
9 | return num_seen_examples
10 |
11 | rand = np.random.randint(0, num_seen_examples + 1)
12 | if rand < buffer_size:
13 | return rand
14 | else:
15 | return -1
16 |
17 | def concat_inputs(input_ids:torch.Tensor, attention_mask:torch.Tensor, labels:torch.Tensor, buffer_input_ids:torch.Tensor, buffer_attention_mask:torch.Tensor, buffer_labels:torch.Tensor) -> Tuple:
18 | device = input_ids.device
19 | input_ids = torch.cat((input_ids, buffer_input_ids.to(device)), dim=0)
20 | attention_mask = torch.cat((attention_mask, buffer_attention_mask.to(device)), dim=0)
21 | labels = torch.cat((labels, buffer_labels.to(device)), dim=0)
22 | return input_ids, attention_mask, labels
23 |
24 | class Buffer:
25 | def __init__(self, buffer_size:int, device:str, pad_id:int=2, ignore_index:int=-100):
26 | self.buffer_size = buffer_size
27 | self.device = device
28 | self.num_seen_examples = 0
29 | self.attributes = ['input_ids', 'attention_mask', 'labels', 'logits', 'task_labels', 'activations']
30 | self.init_buffer()
31 | self.pad_id = pad_id
32 | self.ignore_index = ignore_index
33 |
34 | def init_buffer(self) -> None:
35 | for attr_str in self.attributes:
36 | setattr(self, attr_str, [None for _ in range(self.buffer_size)])
37 |
38 | def add_data(self, input_ids, attention_mask=None, labels=None, logits=None, task_labels=None, activations=None):
39 | n = input_ids.shape[0] if hasattr(input_ids, 'shape') else len(input_ids)
40 | for i in range(n):
41 | index = reservoir(self.num_seen_examples, self.buffer_size)
42 | self.num_seen_examples += 1
43 | if index >= 0:
44 | self.input_ids[index] = input_ids[i].detach().clone().to(self.device)
45 | if attention_mask is not None:
46 | self.attention_mask[index] = attention_mask[i].detach().clone().to(self.device)
47 | if labels is not None:
48 | self.labels[index] = labels[i].detach().clone().to(self.device)
49 | if logits is not None:
50 | self.logits[index] = logits[i].detach().clone().to(self.device)
51 | if task_labels is not None:
52 | self.task_labels[index] = task_labels[i].detach().clone().to(self.device)
53 | if activations is not None:
54 | self.activations[index] = activations[i].detach().clone().to(self.device)
55 |
56 | def get_data(self, size: int, pad_to:int) -> Tuple:
57 | n = len(self.input_ids)
58 | if size > min(self.num_seen_examples, n):
59 | size = min(self.num_seen_examples, n)
60 |
61 | choice = np.random.choice(min(self.num_seen_examples, n), size=size, replace=False)
62 | if len(choice) == 0:
63 | return None, None
64 | # for left padding
65 | input_ids = []
66 | attention_mask = []
67 | labels = []
68 |
69 | for i in choice:
70 |
71 | input_ids.append(torch.cat(
72 | (torch.full((pad_to - self.input_ids[i].shape[-1],), self.pad_id, dtype=torch.long).to(self.device),
73 | self.input_ids[i]), dim=-1)
74 | )
75 | if self.attention_mask[i] is not None:
76 | attention_mask.append(torch.cat(
77 | (torch.full((pad_to - self.attention_mask[i].shape[-1],), 0, dtype=torch.long).to(self.device),
78 | self.attention_mask[i]), dim=-1)
79 | )
80 | if self.labels[i] is not None:
81 | labels.append(torch.cat(
82 | (torch.full((pad_to - self.labels[i].shape[-1],), self.ignore_index, dtype=torch.long).to(self.device),
83 | self.labels[i]), dim=-1)
84 | )
85 |
86 | input_ids = torch.stack(input_ids)
87 | attention_mask = torch.stack(attention_mask)
88 | labels = torch.stack(labels)
89 | return input_ids, attention_mask, labels
90 |
91 | def is_empty(self) -> bool:
92 | if self.num_seen_examples == 0:
93 | return True
94 | else:
95 | return False
96 |
97 | def get_all_data(self) -> Tuple:
98 | ret_tuple = (torch.stack([ee.cpu()
99 | for ee in self.input_ids]).to(self.device),)
100 | for attr_str in self.attributes[1:]:
101 | if hasattr(self, attr_str):
102 | attr = getattr(self, attr_str)
103 | ret_tuple += (attr,)
104 | return ret_tuple
105 |
106 | def empty(self) -> None:
107 | for attr_str in self.attributes:
108 | if hasattr(self, attr_str):
109 | delattr(self, attr_str)
110 | self.num_seen_examples = 0
111 |
112 |
113 | class ERTrainer(BaseTrainerCL):
114 | def __init__(self, **kwargs):
115 | super().__init__(**kwargs)
116 | self.buffer_size = self.cl_args.cl_config.get('buffer_size', None)
117 | self.buffer = Buffer(self.buffer_size, 'cpu', pad_id=self.tokenizer.pad_token_id, ignore_index=-100)
118 |
119 | def compute_loss(self, model, inputs, return_outputs=False):
120 |
121 | if self.current_task_name == self.task_names[0]:
122 | self.buffer.add_data(inputs["input_ids"], inputs["attention_mask"], inputs["labels"])
123 | outputs = model(**inputs)
124 | else:
125 | buffer_inputs, buffer_attention_mask, buffer_labels = self.buffer.get_data(inputs["input_ids"].shape[0], inputs["input_ids"].shape[1])
126 | if buffer_inputs is not None and buffer_attention_mask is not None and buffer_labels is not None:
127 | inputs["input_ids"], inputs["attention_mask"], inputs["labels"] = concat_inputs(inputs["input_ids"], inputs["attention_mask"], inputs["labels"], buffer_inputs, buffer_attention_mask, buffer_labels)
128 | outputs = model(**inputs)
129 | self.buffer.add_data(inputs["input_ids"], inputs["attention_mask"], inputs["labels"])
130 |
131 | return (outputs.loss, outputs) if return_outputs else outputs.loss
132 |
133 |
134 | def continual_learning(self):
135 | for i, name in enumerate(self.task_names):
136 | self.before_task_start(name)
137 | self.train()
138 | self.after_task_end(name)
139 |
140 |
141 |
--------------------------------------------------------------------------------
/method/EWC.py:
--------------------------------------------------------------------------------
1 | from transformers import TrainerCallback, TrainingArguments, PreTrainedModel
2 | from transformers.trainer_callback import TrainerState, TrainerControl
3 | from typing import Tuple, Dict
4 | import os
5 | from .BaseTrainerCL import BaseTrainerCL
6 | from peft import PeftModel
7 |
8 | class GradientCallback(TrainerCallback):
9 | def __init__(self, *args, **kwargs):
10 | super().__init__()
11 | self.fisher = {}
12 | self.model: PeftModel = kwargs.get("model")
13 | self.previous_weights = {}
14 | self.trainable_params = {}
15 | self.init_fisher_and_weights()
16 | assert len(self.fisher) > 0 and len(self.trainable_params) > 0, "fisher and previous_weights should not be empty"
17 |
18 | def init_fisher_and_weights(self):
19 | print("init fisher and weights")
20 | for n, p in self.model.named_parameters():
21 | if p.requires_grad:
22 | self.fisher[n] = p.detach().clone().data.zero_()
23 | self.trainable_params[n] = p.detach().clone().data
24 |
25 | def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
26 | self.previous_weights = {n: p.detach().clone() for n, p in self.model.named_parameters() if
27 | n in self.trainable_params.keys()}
28 |
29 | def get_fisher_and_prior(self) -> Tuple[dict, dict]:
30 | return self.fisher, self.previous_weights
31 |
32 | def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
33 | for n, p in self.model.named_parameters():
34 | if n in self.trainable_params.keys() and p.grad is not None:
35 | self.fisher[n] += p.grad.detach().clone().data.pow(2) / state.global_step
36 | elif p.grad is None:
37 | Warning(f"parameter {n} has no gradient")
38 |
39 | class EWCTrainer(BaseTrainerCL):
40 | """
41 | https://arxiv.org/abs/1612.00796
42 | """
43 | def __init__(self, **kwargs):
44 | super().__init__(**kwargs)
45 | self.add_callback(GradientCallback(model=self.model))
46 | self.ewc_lambda = self.cl_args.ewc_lambda
47 | self.cb = None
48 | for cb in self.callback_handler.callbacks:
49 | if isinstance(cb, GradientCallback):
50 | self.cb = cb
51 | break
52 |
53 | def continual_learning(self):
54 | for i, name in enumerate(self.task_names):
55 | self.before_task_start(name)
56 | self.train()
57 | self.after_task_end(name)
58 |
59 | def compute_loss(self, model, inputs, return_outputs=False):
60 | outputs = model(**inputs)
61 | loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]
62 | ewc_loss = self.compute_ewc_loss(model)
63 | loss += ewc_loss
64 | return (loss, outputs) if return_outputs else loss
65 |
66 | def compute_ewc_loss(self, model):
67 | fisher, previous_weights = self.cb.get_fisher_and_prior()
68 | assert len(fisher) > 0 and len(previous_weights) > 0, "fisher and previous_weights should not be empty"
69 |
70 | ewc_loss = 0
71 | for n, p in model.named_parameters():
72 | if n in fisher:
73 | ewc_loss += (fisher[n] * (p - previous_weights[n]).pow(2)).sum() * self.ewc_lambda / 2
74 |
75 | if ewc_loss < 1e-5:
76 | Warning("EWC regularization loss is too small, please check the hyper-parameters")
77 | return ewc_loss
--------------------------------------------------------------------------------
/method/GEM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
3 | from qpth.qp import QPFunction
4 | from .BaseTrainerCL import BaseTrainerCL
5 | from peft import PeftModel
6 | from deepspeed.utils import safe_get_full_grad
7 | import torch.distributed as dist
8 | class TaskGradientCallback(TrainerCallback):
9 | def __init__(self, **kwargs):
10 | super().__init__()
11 | self.model:PeftModel = kwargs.get('model')
12 | self.current_task_name = kwargs.get('current_task_name') # need update during training
13 | self.n_tasks = kwargs.get('n_tasks')
14 | self.grads = {}
15 | self.task_names = kwargs.get('task_names')
16 | self.device = kwargs.get('device', 'cpu')
17 | self.init_grads()
18 |
19 | def init_grads(self):
20 | for n, p in self.model.named_parameters():
21 | if p.requires_grad: # reduce memory usage
22 | self.grads[n] = torch.ones([p.data.numel(), self.n_tasks], dtype=p.dtype, device=self.device)
23 |
24 | def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
25 | for n, p in self.model.named_parameters():
26 | if n in self.grads:
27 | # p_grad = safe_get_full_grad(p)
28 | # print('rank', dist.get_rank(), n, '->', p.device)
29 | # assert p.grad is not None, f"parameter {n} has no gradient"
30 | if p.grad is None:
31 | print(f"rank {dist.get_rank()} parameter {n} has no gradient in device {p.device}")
32 | p.grad = self.get_updated_grads(n, p.grad, self.task_names.index(self.current_task_name))
33 | grad_old = self.grads[n][:, self.task_names.index(self.current_task_name)].detach().clone()
34 | grad_new = (grad_old * state.step + p.grad.detach().clone().view(-1)) / (state.step + 1)
35 | self.grads[n][:, self.task_names.index(self.current_task_name)] = grad_new
36 |
37 | def get_updated_grads(self, name, grad, idx, margin=0.1, eps=1.0):
38 | ori_shape = grad.shape # None
39 | grad = grad.view(-1)
40 | pre_grad = self.grads[name].cuda()[:, :idx+1].to(torch.float32)
41 | dot_product = torch.mm(grad.unsqueeze(0), pre_grad)
42 | if (dot_product < 0).sum() != 0:
43 | pre_grad_cuda = pre_grad.t()
44 | grad_cuda = grad.contiguous().view(-1)
45 | t = pre_grad_cuda.shape[0]
46 | P = torch.matmul(pre_grad_cuda, pre_grad_cuda.t())
47 | P = 0.5 * (P + P.t())
48 |
49 | P[torch.isnan(P)] = 0.0
50 | eigenvalues = torch.linalg.eigvals(P)
51 | if not (eigenvalues.real > 0).all(): # due to the grad clip happens after the projection, the grad could be huge, refactor eps=1.0 is reasonable
52 | P += torch.eye(t).cuda() * eps
53 |
54 | q = torch.matmul(pre_grad_cuda, grad_cuda).t() * -1
55 |
56 | P = P.to(torch.float32)
57 | q = q.to(torch.float32)
58 | G = torch.eye(t).cuda() * -1
59 | h = torch.zeros(t).cuda() - margin
60 | e = torch.Tensor().cuda()
61 | v = QPFunction(verbose=False)(P, q, G, h, e, e)[0]
62 | v = v.to(torch.float32)
63 | x = torch.matmul(v, pre_grad_cuda) + grad_cuda
64 | grad.copy_(x.view(-1))
65 | return grad.view(ori_shape)
66 |
67 | def update_current_task_name(self, name:str):
68 | self.current_task_name = name
69 |
70 |
71 |
72 | class GEMTrainer(BaseTrainerCL):
73 | def __init__(self, **kwargs):
74 | super().__init__(**kwargs)
75 |
76 | self.add_callback(TaskGradientCallback(
77 | model = self.model,
78 | current_task_name=self.current_task_name,
79 | n_tasks=self.num_tasks,
80 | task_names=self.task_names))
81 |
82 | self.gem_cb = None
83 | for cb in self.callback_handler.callbacks:
84 | if isinstance(cb, TaskGradientCallback):
85 | self.gem_cb = cb
86 | break
87 |
88 | def continual_learning(self):
89 | for i, name in enumerate(self.task_names):
90 | self.gem_cb.update_current_task_name(name)
91 | self.before_task_start(name)
92 | self.train()
93 | self.after_task_end(name)
94 |
--------------------------------------------------------------------------------
/method/ILORA.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 | from typing import Optional, List, Union, Tuple
4 | import copy
5 | import numpy as np
6 | import wandb
7 | from peft import PeftModel, LoraModel, load_peft_weights, set_peft_model_state_dict
8 | from .BaseTrainerCL import BaseTrainerCL
9 | import torch
10 | from transformers import LlamaForCausalLM, PreTrainedModel, TrainerCallback
11 | from transformers.modeling_outputs import CausalLMOutputWithPast
12 | from torch import nn
13 | import torch.nn.functional as F
14 |
15 |
16 | # from method.BaseTrainerCL import BaseTrainerCL
17 |
18 |
19 | def reservoir(num_seen_examples: int, buffer_size: int) -> int:
20 | """
21 | Reservoir sampling algorithm.
22 | :param num_seen_examples: the number of seen examples
23 | :param buffer_size: the maximum buffer size
24 | :return: the target index if the current image is sampled, else -1
25 | """
26 | if num_seen_examples < buffer_size:
27 | return num_seen_examples
28 |
29 | rand = np.random.randint(0, num_seen_examples + 1)
30 | if rand < buffer_size:
31 | return rand
32 | else:
33 | return -1
34 |
35 |
36 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int:
37 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size
38 |
39 |
40 | class Buffer:
41 | """
42 | The memory buffer of rehearsal method.
43 | """
44 |
45 | def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir'):
46 | assert mode in ['ring', 'reservoir']
47 | self.buffer_size = buffer_size
48 | self.device = device
49 | self.num_seen_examples = 0
50 | self.functional_index = eval(mode)
51 | if mode == 'ring':
52 | assert n_tasks is not None
53 | self.task_number = n_tasks
54 | self.buffer_portion_size = buffer_size // n_tasks
55 | self.attributes = ['input_ids', 'labels', 'logits', 'task_labels', 'activations']
56 | self.init_buffer()
57 |
58 | def init_buffer(self) -> None:
59 | for attr_str in self.attributes:
60 | setattr(self, attr_str, [None for _ in range(self.buffer_size)])
61 |
62 | def add_data(self, input_ids, labels=None, logits=None, task_labels=None, activations=None):
63 | """
64 | Adds the data to the memory buffer according to the reservoir strategy.
65 | :param input_ids: tensor containing the images
66 | :param labels: tensor containing the labels
67 | :param logits: tensor containing the outputs of the network
68 | :param task_labels: tensor containing the task labels
69 | :param activations: tensor containing the activations of the network
70 | :return:
71 | """
72 | n = input_ids.shape[0] if hasattr(input_ids, 'shape') else len(input_ids)
73 | for i in range(n):
74 | index = reservoir(self.num_seen_examples, self.buffer_size)
75 | self.num_seen_examples += 1
76 | if index >= 0:
77 | self.input_ids[index] = input_ids[i].to(self.device)
78 | if labels is not None:
79 | self.labels[index] = labels[i].to(self.device)
80 | if logits is not None:
81 | self.logits[index] = logits[i].to(self.device)
82 | if task_labels is not None:
83 | self.task_labels[index] = task_labels[i].to(self.device)
84 | if activations is not None:
85 | self.activations[index] = activations[i].to(self.device)
86 |
87 | def get_data(self, size: int) -> Tuple:
88 | """
89 | Random samples a batch of size items.
90 | :param size: the number of requested items
91 | :return:
92 | """
93 | n = self.input_ids.shape[0] if hasattr(self.input_ids, 'shape') else len(self.input_ids)
94 | if size > min(self.num_seen_examples, n):
95 | size = min(self.num_seen_examples, n)
96 |
97 | choice = np.random.choice(min(self.num_seen_examples, n), size=size, replace=False)
98 | if len(choice) == 0:
99 | return None, None
100 | max_input_id_len = max([self.input_ids[c].shape[0] for c in choice])
101 | max_label_len = max([self.labels[c].shape[0] for c in choice])
102 | # for left padding
103 | input_ids = torch.stack(
104 | [torch.cat([torch.zeros(max_input_id_len - ee.shape[0]).long().to(ee.device), ee]) for ee in
105 | [self.input_ids[c] for c in choice]]).reshape(size, max_input_id_len)
106 | labels = torch.stack([torch.cat([torch.zeros(max_label_len - ee.shape[0]).long().to(ee.device), ee]) for ee in
107 | [self.labels[c] for c in choice]]).reshape(size, max_label_len)
108 | return input_ids, labels
109 |
110 | def is_empty(self) -> bool:
111 | """
112 | Returns true if the buffer is empty, false otherwise.
113 | """
114 | if self.num_seen_examples == 0:
115 | return True
116 | else:
117 | return False
118 |
119 | def get_all_data(self) -> Tuple:
120 | """
121 | Return all the items in the memory buffer.
122 | :return: a tuple with all the items in the memory buffer
123 | """
124 | ret_tuple = (torch.stack([ee.cpu()
125 | for ee in self.input_ids]).to(self.device),)
126 | for attr_str in self.attributes[1:]:
127 | if hasattr(self, attr_str):
128 | attr = getattr(self, attr_str)
129 | ret_tuple += (attr,)
130 | return ret_tuple
131 |
132 | def empty(self) -> None:
133 | """
134 | Set all the tensors to None.
135 | """
136 | for attr_str in self.attributes:
137 | if hasattr(self, attr_str):
138 | delattr(self, attr_str)
139 | self.num_seen_examples = 0
140 |
141 |
142 | class ILoRAModel(LlamaForCausalLM):
143 | def __init__(self, model: PeftModel, reg_decay: bool = False):
144 | super().__init__(model.config)
145 | self.model = model
146 | self.current_task_name = "C-STANCE"
147 | # regularization settings
148 | peft_cfg = model.peft_config['default']
149 | self.model.add_adapter('ema', peft_cfg)
150 | self.model.to(self.model.device)
151 |
152 | self.ema_alpha: float = 0.25
153 | self.reg_weight: float = 1.0
154 | self.ema_update_freq: float = 0.1
155 | self.consistency_loss = nn.MSELoss(reduction='none')
156 |
157 | self.buffer = Buffer(500, 'cuda') # same as ER
158 | self.l_cons = 0
159 | self.total = 0
160 | self.ori_loss = 0
161 |
162 | def update_ema_weights(self, step):
163 | alpha = min(1 - 1 / (step + 1), self.ema_alpha)
164 |
165 | self.model.set_adapter('default')
166 | model_state_dict = {n: p.detach().clone() for n, p in self.model.named_parameters() if p.requires_grad}
167 |
168 | self.model.set_adapter('ema')
169 | for name, param in self.model.named_parameters():
170 | if name in model_state_dict.keys():
171 | param.data.mul_(alpha).add_(torch.mul(model_state_dict[name].data, 1 - alpha))
172 | self.model.set_adapter('default')
173 |
174 | def update_reg_weight(self, step, decay_steps):
175 | if self.reg_decay:
176 | self.alpha = self.fix_reg_weight * self.decay_rate ** (step / decay_steps)
177 | else:
178 | self.alpha = self.fix_reg_weight
179 |
180 | def update_ema_model(self, path):
181 | adapter_weights = load_peft_weights(path)
182 | set_peft_model_state_dict(self.model, adapter_weights, adapter_name='ema')
183 |
184 | def concat_inputs(self, input_ids: torch.Tensor, labels: torch.Tensor, buffer_inputs_ids: torch.Tensor,
185 | buffer_labels: torch.Tensor) -> Tuple:
186 | if buffer_inputs_ids is None or buffer_labels is None:
187 | return input_ids, labels
188 |
189 | max_input_id_len = max(input_ids.shape[1], buffer_inputs_ids.shape[1])
190 | max_labels_len = max(labels.shape[1], buffer_labels.shape[1])
191 |
192 | extended_input_ids = torch.cat([input_ids, torch.zeros(input_ids.shape[0],
193 | max_input_id_len - input_ids.shape[1]).long().to(
194 | input_ids.device)], dim=1)
195 | extended_labels = torch.cat(
196 | [labels, torch.zeros(labels.shape[0], max_labels_len - labels.shape[1]).long().to(labels.device)], dim=1)
197 |
198 | extended_buffer_inputs_ids = torch.cat([buffer_inputs_ids, torch.zeros(buffer_inputs_ids.shape[0],
199 | max_input_id_len -
200 | buffer_inputs_ids.shape[1]).long().to(
201 | buffer_inputs_ids.device)], dim=1)
202 | extended_buffer_labels = torch.cat([buffer_labels, torch.zeros(buffer_labels.shape[0],
203 | max_labels_len - buffer_labels.shape[
204 | 1]).long().to(buffer_labels.device)], dim=1)
205 |
206 | input_ids = torch.cat([extended_input_ids, extended_buffer_inputs_ids], dim=0)
207 | labels = torch.cat([extended_labels, extended_buffer_labels], dim=0)
208 | return input_ids, labels
209 |
210 | def forward(
211 | self,
212 | input_ids: torch.LongTensor = None,
213 | attention_mask: Optional[torch.Tensor] = None,
214 | position_ids: Optional[torch.LongTensor] = None,
215 | past_key_values: Optional[List[torch.FloatTensor]] = None,
216 | inputs_embeds: Optional[torch.FloatTensor] = None,
217 | labels: Optional[torch.LongTensor] = None,
218 | use_cache: Optional[bool] = None,
219 | output_attentions: Optional[bool] = None,
220 | output_hidden_states: Optional[bool] = None,
221 | return_dict: Optional[bool] = None,
222 | ) -> Union[Tuple, CausalLMOutputWithPast]:
223 | buffer_inputs, buffer_labels = None, None
224 | if self.current_task_name != "C-STANCE":
225 | buffer_inputs, buffer_labels = self.buffer.get_data(input_ids.shape[0])
226 | l_cons = 0
227 | if labels is not None and buffer_inputs is not None and buffer_labels is not None:
228 | self.model.set_adapter('default')
229 | plastic_hiddn = self.model(
230 | buffer_inputs,
231 | labels=buffer_labels,
232 | output_hidden_states=True,
233 | return_dict=True)
234 |
235 | self.model.set_adapter('ema')
236 | with torch.no_grad():
237 | stable_hiddn = self.model(
238 | buffer_inputs,
239 | labels=buffer_labels,
240 | output_hidden_states=True,
241 | return_dict=True).hidden_states
242 | indexs = [inner_plastic > inner_stable for inner_plastic, inner_stable in
243 | zip(plastic_hiddn.hidden_states, stable_hiddn)]
244 | reg_hiddn = [torch.where(idx, inner_plastic, inner_stable) for idx, inner_plastic, inner_stable in
245 | zip(indexs, plastic_hiddn.hidden_states, stable_hiddn)]
246 |
247 | l_cons = torch.mean(
248 | torch.cat([self.consistency_loss(plastic, ema) for plastic, ema in
249 | zip(plastic_hiddn.hidden_states, reg_hiddn)], dim=0))
250 |
251 | self.l_cons = l_cons # for logging use
252 |
253 | self.model.set_adapter('default')
254 | ori_out = self.model(
255 | input_ids=input_ids,
256 | # attention_mask=attention_mask,
257 | labels=labels,
258 | return_dict=True,
259 | output_hidden_states=False)
260 | if labels is not None and buffer_inputs is not None and buffer_labels is not None:
261 | self.total_loss = (ori_out.loss + plastic_hiddn.loss) / 2 + self.reg_weight * l_cons
262 | else:
263 | self.total_loss = ori_out.loss + self.reg_weight * l_cons
264 | self.total = self.total_loss.item()
265 |
266 | return CausalLMOutputWithPast(
267 | loss=self.total_loss,
268 | past_key_values=ori_out.past_key_values,
269 | logits=ori_out.logits,
270 | hidden_states=ori_out.hidden_states
271 | )
272 |
273 |
274 | class ILoRATrainer(BaseTrainerCL):
275 | def __init__(self, **kwargs):
276 | super().__init__(**kwargs)
277 | self.model = ILoRAModel(self.model)
278 | self.add_callback(CLSCallback(self.model))
279 |
280 | def compute_loss(self, model, inputs, return_outputs=False):
281 | self.model.buffer.add_data(inputs['input_ids'], inputs['labels'])
282 | outputs = self.model(**inputs)
283 | loss = outputs.loss
284 | print("loss:", loss.item())
285 | return (loss, outputs) if return_outputs else loss
286 |
287 | def continual_learning(self):
288 | resume_from_checkpoint = "False"
289 | for task_name, dataset in self.continual_training_dataset.items():
290 | self.model.current_task_name = task_name
291 | self.current_task_name = task_name
292 | self.train_dataset = dataset
293 | self.train()
294 | resume_from_checkpoint = self.save_model(task_name)
295 | self.model.load_ema_model(resume_from_checkpoint)
296 | wandb.finish()
297 |
298 | def save_model(self, name) -> str:
299 | if self.args.output_dir is not None:
300 | output_dir = os.path.join(self.args.output_dir, f"{self.cl_method}_{self.adapter}_checkpoint_{name}")
301 | self.model.model.set_adapter('default')
302 | self.model.model.save_pretrained(output_dir)
303 | return output_dir
304 |
305 |
306 | class CLSCallback(TrainerCallback):
307 | def __init__(self, model: ILoRAModel):
308 | self.model = model
309 |
310 | def on_step_end(self, args, state, control, **kwargs):
311 | self.model.update_ema_weights(state.global_step)
312 | if wandb.run:
313 | wandb.log({
314 | "reg_weight": self.model.reg_weight,
315 | "consist_loss": self.model.l_cons,
316 | "total_loss": self.model.total,
317 | "ori_loss": self.model.ori_loss,
318 | })
--------------------------------------------------------------------------------
/method/L2P.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Union, Tuple
2 |
3 | import torch, os
4 | import torch.utils.data
5 | from torch import nn
6 | from transformers.modeling_outputs import CausalLMOutputWithPast
7 | from transformers import LlamaForCausalLM, PreTrainedModel, Trainer
8 | from transformers.modeling_outputs import CausalLMOutputWithPast
9 | from peft import PeftModel
10 | from .BaseTrainerCL import BaseTrainerCL
11 | import numpy as np
12 | from copy import deepcopy
13 |
14 |
15 | def l2_normalize(x, dim=None, epsilon=1e-12):
16 | square_norm = torch.sum(x ** 2, dim=dim, keepdim=True)
17 | x_inv_norm = torch.rsqrt(torch.maximum(square_norm, torch.tensor(epsilon, device=x.device)))
18 | return x * x_inv_norm
19 |
20 | class L2PModel(LlamaForCausalLM):
21 | def __init__(self, model:PreTrainedModel, pool_size:int=10, prompt_length:int=5, promt_init:str='random'):
22 | super().__init__(model.config)
23 | self.model = model
24 | self.embed_tokens = self.model.get_input_embeddings()
25 | self.embed_tk_shapes = self.embed_tokens.weight.shape
26 | self.prompt = None
27 | self.top_k = 3
28 | self.diversity_loss_weight = 0.5
29 | self.pool_size = pool_size
30 | self.prompt_length = prompt_length
31 | self.init_prompt(promt_init)
32 | self.embeding_key = 'mean'
33 | self.batchwise_prompt: bool = False
34 | self.current_task_name:str = None
35 |
36 | def init_prompt(self,promt_init):
37 | self.prompt = nn.Parameter(
38 | torch.tensor(
39 | self.create_prompt(self.pool_size, self.prompt_length, promt_init), requires_grad=True
40 | )
41 | ).to(self.device)
42 |
43 | def create_prompt(self, pool_size, prompt_length, promt_init='random'):
44 | N = self.embed_tk_shapes[0]
45 | p_weights = []
46 |
47 | for p in range(self.pool_size):
48 | p_w = []
49 | for i in range(self.prompt_length):
50 | with torch.no_grad():
51 | j = np.random.randint(N)
52 | w = deepcopy(self.embed_tokens.weight[j].detach().cpu().numpy())
53 | p_w.append(w)
54 | p_weights.append(p_w)
55 |
56 | return np.array(p_weights)
57 |
58 | def save_prompt_weights(self, path):
59 | state_dict = {"prompt_pool": self.prompt}
60 | torch.save(state_dict, os.path.join(path, f"prompt_weights_{self.current_task_name}.pt"))
61 |
62 | def load_prompt_weights(self, path, task_name="jecqa"):
63 | state_dict = torch.load(os.path.join(path, f"prompt_weights_{task_name}.pt"), map_location=self.device)
64 | self.prompt.data = state_dict["prompt_pool"].data
65 | print(f"Loaded prompt weights from {path}")
66 |
67 | def freeze_prompt(self):
68 | for n, p in self.named_parameters():
69 | p.requires_grad = False
70 |
71 | def forward(
72 | self,
73 | input_ids: torch.LongTensor = None,
74 | attention_mask: Optional[torch.Tensor] = None,
75 | position_ids: Optional[torch.LongTensor] = None,
76 | past_key_values: Optional[List[torch.FloatTensor]] = None,
77 | inputs_embeds: Optional[torch.FloatTensor] = None,
78 | labels: Optional[torch.LongTensor] = None,
79 | use_cache: Optional[bool] = None,
80 | output_attentions: Optional[bool] = None,
81 | output_hidden_states: Optional[bool] = None,
82 | return_dict: Optional[bool] = None,
83 | ) -> Union[Tuple, CausalLMOutputWithPast]:
84 |
85 | i_input_embeds = self.embed_tokens(input_ids)
86 | out = dict()
87 | if self.embeding_key == 'mean':
88 | i_input_embeds_mean = torch.mean(i_input_embeds, dim=1)
89 | elif self.embeding_key == 'max':
90 | i_input_embeds_mean = torch.max(i_input_embeds, dim=1)[0]
91 | elif self.embeding_key == 'mean_max':
92 | i_input_embeds_mean = torch.max(i_input_embeds, dim=1)[0] + 2 * torch.mean(i_input_embeds, dim=1)
93 | else:
94 | raise NotImplementedError("Not supported way of calculating embedding keys!")
95 |
96 | prompt_key = torch.mean(self.prompt, dim=1) # Pool_size, C
97 | prompt_norm = l2_normalize(prompt_key, dim=1).to("cuda")
98 | inputs_embeds_norm = l2_normalize(i_input_embeds_mean, dim=1)
99 | prompt_norm = prompt_norm.to(dtype=inputs_embeds_norm.dtype)
100 | similarity = torch.matmul(inputs_embeds_norm, prompt_norm.t()) # B, Pool_size
101 |
102 | _, idx = torch.topk(similarity, k=self.top_k, dim=1) # B, top_k
103 | if self.batchwise_prompt:
104 | prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True)
105 | if prompt_id.shape[0] < self.pool_size:
106 | prompt_id = torch.cat([prompt_id, torch.full((self.pool_size - prompt_id.shape[0],), torch.min(idx.flatten()), device=prompt_id.device)])
107 | id_counts = torch.cat([id_counts, torch.full((self.pool_size - id_counts.shape[0],), 0, device=id_counts.device)])
108 | _, major_idx = torch.topk(id_counts, k=self.top_k)
109 | major_prompt_id = prompt_id[major_idx]
110 | idx = major_prompt_id.expand(inputs_embeds.shape[0], -1)
111 |
112 | batched_prompt_raw = self.prompt[idx] # B, top_k, length, C
113 | batch_size, top_k, length, c = batched_prompt_raw.shape
114 | batched_prompt = batched_prompt_raw.reshape(batch_size, top_k * length, c)
115 | inputs_embeds = torch.cat([batched_prompt, i_input_embeds],axis=1)
116 |
117 | prefix_length = batched_prompt.shape[1]
118 | attn_masks = torch.concat((torch.tensor(1).to("cuda").repeat(batch_size,prefix_length),attention_mask), axis=1)
119 |
120 | if labels is None: # inference
121 | return self.model(inputs_embeds=inputs_embeds.half(), attention_mask=attn_masks, use_cache=False, return_dict=True)
122 |
123 | labels = torch.concat((labels[0][0].repeat(batch_size,inputs_embeds.shape[1]-labels.shape[1]),labels),axis=1)
124 | outs = self.model(inputs_embeds=inputs_embeds,labels=labels,attention_mask=attn_masks,use_cache=False)
125 | loss = outs[0]
126 | batched_key_norm = prompt_norm[idx]
127 | inputs_embeds_norm = inputs_embeds_norm.unsqueeze(1) # B, 1, C
128 | sim = batched_key_norm * inputs_embeds_norm # B, top_k, C
129 | reduce_sim = torch.sum(sim) / inputs_embeds.shape[0] # Scalar
130 |
131 | loss -= reduce_sim * self.diversity_loss_weight
132 | return loss
133 |
134 | class L2PTrainer(BaseTrainerCL):
135 | def __init__(self, **kwargs):
136 | super().__init__(**kwargs)
137 | self.model = L2PModel(self.model)
138 |
139 | def compute_loss(self, model:L2PModel, inputs, return_outputs=False):
140 | input_ids = inputs['input_ids']
141 | attn_masks = inputs['attention_mask']
142 | labels = inputs.pop('labels')
143 | outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attn_masks)
144 | return outputs # since we only calculate loss in model.forward(), we return outputs here
145 |
146 | def continual_learning(self):
147 | for i, name in enumerate(self.task_names):
148 | self.model.current_task_name = name
149 | self.before_task_start(name)
150 | self.train()
151 | self.after_task_end(name)
152 |
153 | def save_model(self, name) -> str:
154 | if self.args.output_dir is not None:
155 | output_dir = os.path.join(self.args.output_dir, f"{self.cl_method}_{self.adapter}_checkpoint_{name}")
156 | assert isinstance(self.model.model, PeftModel), "self.model.model is not a PeftModel"
157 | self.model.model.save_pretrained(output_dir)
158 | print(f"save task: {name} adapter to {self.args.output_dir}")
159 | return output_dir
--------------------------------------------------------------------------------
/method/MTL.py:
--------------------------------------------------------------------------------
1 | from .BaseTrainerCL import BaseTrainerCL
2 |
3 | class MTLTrainer(BaseTrainerCL):
4 | def __init__(self, **kwargs):
5 | super().__init__(**kwargs)
6 |
--------------------------------------------------------------------------------
/method/ONE.py:
--------------------------------------------------------------------------------
1 | from .BaseTrainerCL import BaseTrainerCL
2 |
3 | class ONETrainer(BaseTrainerCL):
4 | def __init__(self, **kwargs):
5 | super().__init__(**kwargs)
6 |
--------------------------------------------------------------------------------
/method/PP.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Union, Tuple
2 |
3 | import torch, os, wandb
4 | import torch.utils.data
5 | from torch import nn
6 | from transformers.modeling_outputs import CausalLMOutputWithPast
7 | from transformers import LlamaForCausalLM, PreTrainedModel, Trainer
8 | from transformers.modeling_outputs import CausalLMOutputWithPast
9 | from peft import PeftModel
10 | from .BaseTrainerCL import BaseTrainerCL
11 | import numpy as np
12 | from copy import deepcopy
13 |
14 | class ResMLP(torch.nn.Module):
15 | def __init__(self, hidden_dim, bottleneck_size, module_type='MLP1', residual=True):
16 | super().__init__()
17 | self.residual = residual
18 | if module_type=='MLP1':
19 | self.module = nn.Sequential(
20 | nn.Linear(hidden_dim, bottleneck_size),
21 | nn.ReLU(),
22 | nn.Linear(bottleneck_size, hidden_dim),
23 | )
24 |
25 | elif module_type=='MLP2':
26 | self.module = nn.Sequential(
27 | nn.Linear(hidden_dim, bottleneck_size),
28 | nn.ReLU(),
29 | nn.Linear(bottleneck_size, bottleneck_size),
30 | nn.Tanh(),
31 | nn.Linear(bottleneck_size, hidden_dim),
32 | )
33 |
34 | elif module_type=='transformer':
35 | device = 'cuda'
36 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=2, dropout=0.05).to(device)
37 | self.module = nn.TransformerEncoder(self.encoder_layer, num_layers=2).to(device)
38 |
39 | def forward(self, inputs):
40 | if self.residual:
41 | return self.module(inputs) + inputs
42 | else:
43 | return self.module(inputs)
44 |
45 | class PPModel(LlamaForCausalLM):
46 | def __init__(self,model:PreTrainedModel, prefix_len, task_names, prefix_path=None):
47 | super().__init__(model.config)
48 | self.model = model
49 | self.prefix_len:int = prefix_len
50 | self.prefix_pth:str = prefix_path
51 | self.task_names:List = task_names
52 | self.num_tasks:int = len(task_names)
53 | self.current_task_name:str = None
54 | self.embed_tokens = self.model.get_input_embeddings() # [vocab_size, embed_dim]
55 | self.embed_tokens_len = self.embed_tokens.weight.shape[0] # vocab_size
56 |
57 | self.prompt = nn.Parameter(
58 | torch.tensor(self.init_prompt(), requires_grad=True) # [prefix_len, embed_dim]
59 | ).to(self.device)
60 |
61 | self.previous_prompts = torch.zeros([0, self.prompt.shape[1]], requires_grad=False, dtype=torch.bfloat16).to(self.device) # [0, embed_dim]
62 | self.mlps = nn.ModuleList([ResMLP(model.config.hidden_size, 128, module_type='MLP1', residual=True) for _ in range(self.num_tasks)]).to(self.device)
63 |
64 | assert self.prefix_len > 0
65 |
66 | self.pp = True
67 | self.is_train = True
68 |
69 | def init_prompt(self):
70 | prompt_weights = []
71 | for i in range(self.prefix_len):
72 | with torch.no_grad():
73 | j = np.random.randint(self.embed_tokens_len)
74 | w = deepcopy(self.embed_tokens.weight[j].detach().cpu().numpy())
75 | prompt_weights.append(w / 100)
76 | return np.array(prompt_weights)
77 |
78 | def progressive_previous_prompt(self, task_name):
79 | """
80 | update previous prompt at end of each task
81 | """
82 | if task_name != None and self.mlps != None:
83 | with torch.no_grad():
84 | new_prompt = self.mlps[self.task_names.index(task_name)](self.prompt)
85 | self.previous_prompts = torch.cat([self.previous_prompts, new_prompt], axis=0)
86 | print(f'updated previous prompt to: {self.previous_prompts.shape}')
87 |
88 | def freeze_mlps(self, name:str, requires_grad=False):
89 | """
90 | Freeze or unfreeze all the MLPs except the one for the current task
91 | """
92 | for i, mlp in enumerate(self.mlps):
93 | if i != self.task_names.index(name):
94 | for param in mlp.parameters():
95 | if param.requires_grad != requires_grad:
96 | param.requires_grad = requires_grad
97 | else:
98 | for param in mlp.parameters():
99 | if param.requires_grad == requires_grad:
100 | param.requires_grad = not requires_grad
101 |
102 | def save_mlps_prompts(self, path):
103 | """
104 | Save all the MLPs and prompts and previous learned prompts
105 | """
106 | mlp_state_dict = {"mlps": self.mlps.state_dict()}
107 | prompt_state_dict = {"prompt": self.prompt}
108 | previous_prompt_state_dict = {"previous_prompt": self.previous_prompts}
109 | torch.save(mlp_state_dict, os.path.join(path, f"mlps_{self.current_task_name}.pt"))
110 | torch.save(prompt_state_dict, os.path.join(path, f"prompt_{self.current_task_name}.pt"))
111 | torch.save(previous_prompt_state_dict, os.path.join(path, f"previous_prompt_{self.current_task_name}.pt"))
112 |
113 | def load_mlps_prompts(self, path, task_name = None):
114 | """
115 | Load all the MLPs and prompts
116 | """
117 | mlp_path = os.path.join(path, f"mlps_{task_name}.pt")
118 | prompt_path = os.path.join(path, f"prompt_{task_name}.pt")
119 | previous_prompt_path = os.path.join(path, f"previous_prompt_{task_name}.pt")
120 | assert mlp_path and prompt_path and previous_prompt_path, "mlp_path or prompt_path is None"
121 |
122 | mlp_state_dict = torch.load(mlp_path, map_location=self.device)
123 | prompt_state_dict = torch.load(prompt_path, map_location=self.device)
124 | previous_prompt_state_dict = torch.load(previous_prompt_path, map_location=self.device)
125 | self.mlps.load_state_dict(mlp_state_dict["mlps"])
126 | self.prompt.data = prompt_state_dict["prompt"].data
127 | self.previous_prompts.data = previous_prompt_state_dict["previous_prompt"].data
128 | print(f"Loaded mlps and prompt from {mlp_path} and {prompt_path}")
129 |
130 | def freeze_all(self):
131 | for n, p in self.named_parameters():
132 | p.requires_grad = False
133 |
134 | def forward(
135 | self,
136 | input_ids: torch.LongTensor = None,
137 | attention_mask: Optional[torch.Tensor] = None,
138 | position_ids: Optional[torch.LongTensor] = None,
139 | past_key_values: Optional[List[torch.FloatTensor]] = None,
140 | inputs_embeds: Optional[torch.FloatTensor] = None,
141 | labels: Optional[torch.LongTensor] = None,
142 | use_cache: Optional[bool] = None,
143 | output_attentions: Optional[bool] = None,
144 | output_hidden_states: Optional[bool] = None,
145 | return_dict: Optional[bool] = None,
146 | ) -> Union[Tuple, CausalLMOutputWithPast]:
147 |
148 | inputs_embeds = self.embed_tokens(input_ids) # [batch_size, seq_len, embed_dim]
149 | k = inputs_embeds.shape[0] # batch_size
150 |
151 | mlp = self.mlps[self.task_names.index(self.current_task_name)]
152 | prompt = mlp(self.prompt) #[prefix_len, embed_dim]
153 |
154 | if self.pp:
155 | inputs_embeds = torch.cat([prompt.unsqueeze(0).repeat(k, 1, 1), # [batch_size, prefix_len, embed_dim]
156 | self.previous_prompts.unsqueeze(0).repeat(k, 1, 1),# [batch_size, len_of_learned_tasks, embed_dim]
157 | inputs_embeds], axis=1)# [batch_size, seq_len, embed_dim]
158 | full_prefix_len = prompt.shape[0] + self.previous_prompts.shape[0] # prefix_len + len_of_learned_tasks
159 |
160 | source_mask = torch.cat((torch.tensor(1).to('cuda').repeat(k, full_prefix_len),
161 | attention_mask), axis=1) # [batch_size, prefix_len + learned_tasks_len]
162 | if labels is not None:
163 | labels = torch.concat((labels[0][0].repeat(k, inputs_embeds.shape[1] - labels.shape[1]), labels),axis=1).detach()#[batch_size, prefix_len + learned_tasks_len, embed_dim]
164 | return self.model(
165 | inputs_embeds=inputs_embeds,
166 | labels=labels,
167 | attention_mask=source_mask
168 | )
169 | else:
170 | inputs_embeds = inputs_embeds.half()
171 | return self.model(inputs_embeds=inputs_embeds, attention_mask=source_mask, use_cache=False, return_dict=True)
172 |
173 | class PPTrainer(BaseTrainerCL):
174 |
175 | def __init__(self, **kwargs):
176 | super().__init__(**kwargs)
177 | self.model:PPModel = PPModel(model=self.model,
178 | prefix_len=20,
179 | task_names=list(self.continual_training_dataset.keys()))
180 |
181 | def compute_loss(self, model, inputs, return_outputs=False):
182 | outputs = self.model(**inputs)
183 | loss = outputs.loss
184 | return (loss, outputs) if return_outputs else loss
185 |
186 | def continual_learning(self):
187 | resume_from_checkpoint = "False"
188 | for name, train_set in self.continual_training_dataset.items():
189 | self.current_task_name = name
190 | self.model.current_task_name = name
191 | self.update_adapter_and_train_set(resume_from_checkpoint, train_set)
192 | self.model.freeze_mlps(name)
193 | self.train()
194 | self.model.progressive_previous_prompt(name)
195 | resume_from_checkpoint = self.save_model(name)
196 | self.model.save_mlps_prompts(self.args.output_dir)
197 | wandb.finish()
198 |
199 | def save_model(self, name) -> str:
200 | if self.args.output_dir is not None:
201 | output_dir = os.path.join(self.args.output_dir, f"{self.cl_method}_{self.adapter}_checkpoint_{name}")
202 | assert isinstance(self.model.model, PeftModel), "self.model.model is not a PeftModel"
203 | self.model.model.save_pretrained(output_dir)
204 | print(f"save task: {name} adapter to {self.args.output_dir}")
205 | return output_dir
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | bitsandbytes
3 | datasets
4 | fire
5 | peft
6 | git+https://github.com/2proveit/transformers.git@prepare_for_continual_training
7 | sentencepiece
8 | deepspeed
9 | numpy
10 | wandb
11 | qpth
12 | pandas
13 | tqdm
14 | evaluate
15 | rouge
16 | fuzzywuzzy
17 | nltk
18 | prettytable
19 | sacrebleu
20 | sacremoses
--------------------------------------------------------------------------------
/scripts/train_seq.sh:
--------------------------------------------------------------------------------
1 | cl_method="seq"
2 | deepspeed main.py \
3 | --deepspeed deepspeed_zero2_no_offload.json \
4 | --model_name_or_path "meta-llama/Llama2-7b-hf" \
5 | --load_in_8bits true \
6 | --data_path "~/Downloads/TRACE-Benchmark/LLM-CL-Benchmark_5000" \
7 | --dataset_names "20Minuten,FOMC,MeetingBank,ScienceQA" \
8 | --max_length 1024 \
9 | --train_on_inputs false \
10 | --cl_method $cl_method \
11 | --output_dir "outputs/$cl_method" \
12 | --per_device_train_batch_size 1 \
13 | --per_device_eval_batch_size 1 \
14 | --gradient_accumulation_steps 1 \
15 | --num_train_epochs 3 \
16 | --warmup_steps 500 \
17 | --learning_rate 1e-5 \
18 | --lr_scheduler_type "constant_with_warmup" \
19 | --load_best_model_at_end true \
20 | --save_total_limit 3 \
21 | --evaluation_strategy "steps" \
22 | --save_strategy "steps" \
23 | --save_steps 500 \
24 | --eval_steps 500 \
25 | --logging_steps 500 \
26 | --report_to "wandb" \
27 | --run_name "llama2-7b-hf" \
--------------------------------------------------------------------------------
/utils/arg_configs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from typing import List, Dict, Tuple, Union, Any, Optional
4 | from pathlib import Path
5 | from dataclasses import dataclass, field
6 | from transformers import HfArgumentParser, TrainingArguments, GenerationConfig, BitsAndBytesConfig
7 | from peft import LoraConfig, TaskType
8 | import warnings
9 | __all__ = ["get_args", "CLArguments", "TuningArguments", "DataArguments"]
10 |
11 | def list_strings(string: str) -> List[str]:
12 | return string.split(",")
13 | @dataclass
14 | class DataArguments:
15 | data_path: Union[str, Path] = "./TRACE-Benchmark/LLM-CL-Benchmark_5000"
16 | dataset_names: str = "20Minuten,FOMC,MeetingBank,NumGLUE-ds,ScienceQA,C-STANCE,NumGLUE-cm"
17 | max_length: int = 1024
18 | train_on_inputs: bool = False
19 | # truncation: bool = True
20 | # padding: bool = True
21 |
22 | def __post_init__(self):
23 | self.dataset_names = list_strings(self.dataset_names)
24 |
25 | @dataclass
26 | class CLArguments:
27 | cl_method: str = 'seq'
28 | ewc_lambda: float = 0.1
29 | er_buffer_size: int = 1000
30 | er_buffer_method: str = "random"
31 | gem_memory_strength: float = 0.5
32 | l2p_pool_size: int = 10
33 | l2p_prompt_length: int = 5
34 | l2p_prompt_init: str = "random"
35 | pp_prefix_length: int = 20
36 |
37 | cl_config: Union[Dict[str, Any], Any] = field(init=False)
38 | def __post_init__(self):
39 | assert self.cl_method in ["seq","ewc", "er", "gem", "agem", "l2p", "pp", "mtl", "one", "ilora"], f"cl_method '{self.cl_method}' not supported"
40 | if self.cl_method == 'seq':
41 | warnings.warn("Using 'seq' as cl_method, other cl configs are ignored")
42 | self.cl_config = {}
43 | if self.cl_method == 'ewc':
44 | warnings.warn("Using 'ewc' as cl_method, other cl configs are ignored")
45 | self.cl_config = {
46 | "lambda": self.ewc_lambda
47 | }
48 | if self.cl_method == 'er':
49 | warnings.warn("Using 'er' as cl_method, other cl configs are ignored")
50 | self.cl_config = {
51 | "buffer_size": self.er_buffer_size,
52 | "buffer_method": self.er_buffer_method
53 | }
54 | if self.cl_method == 'gem':
55 | warnings.warn("Using 'gem' as cl_method, other cl configs are ignored")
56 | self.cl_config = {
57 | "memory_strength": self.gem_memory_strength
58 | }
59 | if self.cl_method == 'l2p':
60 | warnings.warn("Using 'l2p' as cl_method, other cl configs are ignored")
61 | self.cl_config = {
62 | "pool_size": self.l2p_pool_size,
63 | "prompt_length": self.l2p_prompt_length,
64 | "prompt_init": self.l2p_prompt_init
65 | }
66 | if self.cl_method == 'pp':
67 | warnings.warn("Using 'pp' as cl_method, other cl configs are ignored")
68 | self.cl_config = {
69 | "prefix_length": self.pp_prefix_length
70 | }
71 | if self.cl_method == 'ilora':
72 | warnings.warn("Using 'ilora' as cl_method, other cl configs are ignored")
73 | self.cl_config = {}
74 | if self.cl_method == 'mtl':
75 | warnings.warn("Using 'mtl' as cl_method, other cl configs are ignored")
76 | self.cl_config = {}
77 | if self.cl_method == 'one':
78 | warnings.warn("Using 'one' as cl_method, other cl configs are ignored")
79 | self.cl_config = {}
80 |
81 | print(f"*** cl_config ***:\n\t{self.cl_config}")
82 |
83 | @dataclass
84 | class TuningArguments:
85 | # basic config
86 | model_name_or_path: Union[str, Path] = "meta-llama/Llama2-7b-hf"
87 | load_in_8bit: bool = False
88 | # lora config
89 | lora_r: int = 8
90 | lora_alpha: int = 32
91 | lora_dropout: float = 0.1
92 | target_modules: str = "q_proj,k_proj,v_proj,o_proj"
93 | load_8bit: bool = True
94 | lora_config: Union[LoraConfig, Any] = field(init=False)
95 | manual_seed: int = 37
96 |
97 | # redundant args
98 | config_file: Optional[str] = None
99 | def __post_init__(self):
100 | self.target_modules = list_strings(self.target_modules)
101 | self.lora_config = LoraConfig(
102 | r=self.lora_r,
103 | lora_alpha=self.lora_alpha,
104 | lora_dropout=self.lora_dropout,
105 | target_modules=self.target_modules,
106 | task_type=TaskType.CAUSAL_LM
107 | )
108 | @dataclass
109 | class InferArguments:
110 | cl_method: str = 'seq'
111 | model_name_or_path: Union[str, Path] = "meta-llama/Llama2-7b-hf"
112 | load_in_8bit: bool = True
113 | load_in_4bit: bool = False
114 | tokenizer_name_or_path: Union[str, Path] = "meta-llama/Llama2-7b-hf"
115 | peft_cfg_path: Optional[str] = None
116 | peft_weights_path: Optional[str] = None
117 | infer_batch_size: int = 4
118 | # generation config
119 | max_new_tokens: int = 128
120 | temperature: float = 0.1
121 | top_p: float = 0.75
122 | repetition_penalty: float = 1.15
123 | do_sample: bool = True
124 | generation_config: Union[GenerationConfig, Any] = field(init=False)
125 | bnb_config: BitsAndBytesConfig = field(init=False)
126 | save_path: Union[str, Path] = "./generated_texts.json"
127 |
128 | def __post_init__(self):
129 | self.generation_config = GenerationConfig(
130 | max_new_tokens=self.max_new_tokens,
131 | temperature=self.temperature,
132 | top_p=self.top_p,
133 | repetition_penalty=self.repetition_penalty,
134 | do_sample=self.do_sample
135 | )
136 | self.bnb_config = BitsAndBytesConfig(
137 | load_in_8bit=self.load_in_8bit,
138 | load_in_4bit=self.load_in_4bit
139 | )
140 | @dataclass
141 | class EvalArguments:
142 | cl_method: str = 'seq'
143 | json_dirs: str = "outputs/seq/20Minuten.json"
144 | increment_order: str = '20Minuten'
145 | save_path: str = "./eval_table.csv"
146 |
147 | def __post_init__(self):
148 | self.json_dirs = list_strings(self.json_dirs)
149 | self.increment_order = list_strings(self.increment_order)
150 |
151 | def get_args() -> Tuple[TrainingArguments, CLArguments, TuningArguments, DataArguments]:
152 | parser = HfArgumentParser((TrainingArguments, CLArguments, TuningArguments, DataArguments))
153 | train_args, cl_args, tuning_args, data_args = parser.parse_args_into_dataclasses()
154 | return train_args, cl_args, tuning_args, data_args
155 |
156 | if __name__ == "__main__":
157 | train_args, cl_args, tuning_args = get_args()
158 | print(train_args)
159 | print(cl_args)
160 | print(tuning_args)
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 | from rouge import Rouge
3 | from fuzzywuzzy import fuzz
4 | from datasets import load_metric
5 | from nltk.translate.bleu_score import sentence_bleu
6 | import evaluate
7 |
8 |
9 | ########################
10 | # BLEU
11 | ########################
12 | def tokenize(text):
13 | tokens = re.split(r'\s|\.', text)
14 | tokens = [t for t in tokens if len(t) > 0]
15 | return tokens
16 |
17 |
18 | def bleu_score(reference, hypothesis, gram):
19 | reference_tokens = tokenize(reference)
20 | hypothesis_tokens = tokenize(hypothesis)
21 |
22 | if gram == 1:
23 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1.,)) # BELU-1
24 | elif gram == 2:
25 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 2., 1. / 2.)) # BELU-2
26 | elif gram == 3:
27 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 3., 1. / 3., 1. / 3.)) # BELU-3
28 | elif gram == 4:
29 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 4., 1. / 4., 1. / 4., 1. / 4.)) # BELU-4
30 |
31 | return bleu
32 |
33 |
34 | def calculate_bleu(results, data, gram):
35 | bleus = []
36 | for output_id in range(len(results)):
37 | prediction = results[output_id]
38 | target = data[output_id]
39 | if prediction == "" or target == "":
40 | continue
41 | bleu = bleu_score(target, prediction, gram)
42 | bleus.append(bleu)
43 | avg_bleu = sum(bleus) / len(results)
44 | return avg_bleu
45 |
46 |
47 | ########################
48 | ## Rouge-L
49 | ########################
50 | def score_rouge(str1, str2):
51 | print(type(str1), type(str2))
52 | rouge = Rouge(metrics=["rouge-l"])
53 | scores = rouge.get_scores(str1, str2, avg=True)
54 | rouge_l = scores['rouge-l']['f']
55 | return rouge_l
56 |
57 |
58 | def calculate_rouge(results, data):
59 | rouges = []
60 | for output_id in range(len(results)):
61 | prediction = results[output_id]
62 | target = data[output_id]
63 | if prediction == "" or target == "":
64 | continue
65 | rouge = score_rouge(target, prediction)
66 | rouges.append(rouge)
67 | avg_rouge = sum(rouges) / len(results) if rouges else 0.0
68 | return avg_rouge
69 |
70 | def calculate_rouge_(results, data):
71 | rouge_evaluator = evaluate.load("rouge")
72 | eval_results = []
73 | eval_refs = []
74 | for output_id in range(len(results)):
75 | prediction = results[output_id]
76 | target = data[output_id]
77 | if prediction == "" or target == "":
78 | continue
79 | eval_results.append(prediction)
80 | eval_refs.append(target)
81 | rouge_scores = rouge_evaluator.compute(predictions=eval_results, references=eval_refs, rouge_types=["rougeL"])
82 | return rouge_scores['rougeL']
83 |
84 | def caculate_accuracy(results, data):
85 | scores = 0
86 | for output_id in range(len(results)):
87 | target = data[output_id]
88 | prediction = results[output_id]
89 | if prediction == "" or target == "":
90 | continue
91 | if prediction == target:
92 | scores += 1
93 | avg_score = scores / len(results)
94 | return avg_score
95 |
96 |
97 |
98 | def f1_score(list1, list2):
99 | # TP: item in list1 and list2
100 | # FP: item in list1 but not in list2
101 | # TN: item not in list1 and list2
102 | # FN: item in list2 but not in list1
103 | num_TP = 0
104 | for item1 in list1:
105 | for item2 in list2:
106 | if item1 == item2:
107 | num_TP += 1
108 | break
109 | precision = num_TP / len(list1)
110 | recall = num_TP / len(list2)
111 | if precision == 0 or recall == 0:
112 | return 0
113 | return 2 * (precision * recall / (precision + recall))
114 |
115 |
116 | def calculate_f1(results, data):
117 | scores = []
118 | for output_id in range(len(results)):
119 | prediction = results[output_id]
120 | target = data[output_id]
121 | if len(prediction) == 0 or len(target) == 0:
122 | continue
123 | score = f1_score(target, prediction)
124 | scores.append(score)
125 | avg_score = sum(scores) / len(results)
126 | return avg_score
127 |
128 |
129 |
130 | def calculate_sari(inputs, results, data):
131 | sari = load_metric("sari")
132 | result = sari.compute(sources=inputs, predictions=results, references=[[label] for label in data]), # one reference for each prediction
133 | return result
134 |
135 |
136 | def eval_20Minuten(input_sequences, predicted_sequences, ground_truths):
137 | sari = calculate_sari(input_sequences, predicted_sequences, ground_truths)
138 | evaluation_result = {"sari": sari}
139 | return evaluation_result
140 |
141 | def eval_medmcqa(predicted_sequences, ground_truths):
142 | predicted_sequences = postprocess_choice_acc(predicted_sequences)
143 | ground_truths = postprocess_choice_acc(ground_truths)
144 |
145 | accuracy = caculate_accuracy(predicted_sequences, ground_truths)
146 | evaluation_result = {"accuracy": accuracy}
147 | return evaluation_result
148 |
149 | def eval_jecqa(predicted_sequences, ground_truths):
150 | predicted_sequences = postprocess_choice_acc(predicted_sequences)
151 | ground_truths = postprocess_choice_acc(ground_truths)
152 |
153 | accuracy = caculate_accuracy(predicted_sequences, ground_truths)
154 | evaluation_result = {"accuracy": accuracy}
155 | return evaluation_result
156 |
157 | def eval_CStance(predicted_sequences, ground_truths):
158 | predicted_sequences = postprocess_choice_acc(predicted_sequences)
159 | ground_truths = postprocess_choice_acc(ground_truths)
160 |
161 | accuracy = caculate_accuracy(predicted_sequences, ground_truths)
162 | evaluation_result = {"accuracy": accuracy}
163 | return evaluation_result
164 |
165 |
166 | def eval_FOMC(predicted_sequences, ground_truths):
167 | predicted_sequences = postprocess_choice_acc(predicted_sequences)
168 | ground_truths = postprocess_choice_acc(ground_truths)
169 |
170 | accuracy = caculate_accuracy(predicted_sequences, ground_truths)
171 | evaluation_result = {"accuracy": accuracy}
172 | return evaluation_result
173 |
174 |
175 | def eval_MeetingBank(predicted_sequences, ground_truths):
176 | # bleu_1 = calculate_bleu(predicted_sequences, ground_truths, 1)
177 | # bleu_4 = calculate_bleu(predicted_sequences, ground_truths, 4)
178 | rouge = calculate_rouge_(predicted_sequences, ground_truths)
179 | # evaluation_result = {"bleu-1": bleu_1, "bleu-4": bleu_4, "rouge-L": rouge}
180 | evaluation_result = {"rouge-L": rouge}
181 | return evaluation_result
182 |
183 |
184 | def eval_NumGLUE(predicted_sequences, ground_truths):
185 | predicted_sequences = postprocess_choice_num(predicted_sequences)
186 | ground_truths = postprocess_choice_num(ground_truths)
187 |
188 | accuracy = caculate_accuracy(predicted_sequences, ground_truths)
189 | evaluation_result = {"accuracy": accuracy}
190 | return evaluation_result
191 |
192 |
193 | def resolve(dataset: list):
194 | keyword_list = []
195 | for datium in dataset:
196 | keyword_list.append(datium.split(" , "))
197 | return keyword_list
198 |
199 |
200 | def eval_PapyrusF(predicted_sequences, ground_truths):
201 | outputs = resolve(predicted_sequences)
202 | gts = resolve(ground_truths)
203 |
204 | f1 = calculate_f1(outputs, gts)
205 | evaluation_result = {"F1": f1}
206 | return evaluation_result
207 |
208 | def postprocess_choice_acc(predicted_sequences):
209 | outputs = []
210 | for output in predicted_sequences:
211 | if not output:
212 | outputs.append("")
213 | continue
214 | match = re.search(r"[A-D]", output)
215 | if match:
216 | outputs.append(match.group(0))
217 | else:
218 | outputs.append("")
219 | return outputs
220 |
221 | def postprocess_choice_num(predicted_sequences):
222 | outputs = []
223 | for output in predicted_sequences:
224 | if not output:
225 | outputs.append("")
226 | continue
227 | match = re.search(r"\d+\.?\d*", output)
228 | if match:
229 | outputs.append(match.group(0))
230 | else:
231 | outputs.append("")
232 | return outputs
233 |
234 |
235 | def resolve_sciQA(dataset: list):
236 | answers = []
237 | reasonings = []
238 | for datium in dataset:
239 | if len(datium) >= 3:
240 | answers.append(datium[0]) # the first char is the answer. e.g. A, B,...
241 | reasonings.append(datium[2:]) # A/nBecause...
242 | elif 1 <= len(datium) < 3:
243 | answers.append(datium[0])
244 | reasonings.append("")
245 | else:
246 | answers.append("")
247 | reasonings.append("")
248 | outputs = {"answers": answers, "reasonings": reasonings}
249 | return outputs
250 |
251 |
252 | def eval_SciQA(predicted_sequences, ground_truths):
253 | outputs = resolve_sciQA(predicted_sequences)
254 | gts = resolve_sciQA(ground_truths)
255 | outputs["answers"] = postprocess_choice_acc(outputs["answers"])
256 | gts["answers"] = postprocess_choice_acc(gts["answers"])
257 |
258 | # bleu_1 = calculate_bleu(outputs["reasonings"], gts["reasonings"], 1)
259 | # bleu_4 = calculate_bleu(outputs["reasonings"], gts["reasonings"], 4)
260 | # rouge = calculate_rouge_(outputs["reasonings"], gts["reasonings"])
261 | accuracy = caculate_accuracy(outputs["answers"], gts["answers"])
262 |
263 | # evaluation_result = {"bleu-1": bleu_1, "bleu-4": bleu_4,
264 | # # "rouge-L": rouge,
265 | # "accuracy": accuracy}
266 | evaluation_result = {"accuracy": accuracy}
267 | return evaluation_result
268 |
--------------------------------------------------------------------------------