├── figures ├── overview.png └── speed_vs_accuracy_vs_cache.png ├── global_vars.py ├── configs ├── config_i2cl_transfer_learning.py ├── config_i2cl_infer.py ├── config_label_anchor.py ├── config_soft_prompt.py ├── config_task_vector.py └── config_i2cl.py ├── my_datasets ├── __init__.py ├── rotten_tomatoes.py ├── sst5.py ├── agnews.py ├── emo.py ├── sst2.py ├── subj.py ├── trec.py ├── hate_speech18.py ├── dbpedia.py └── basetask.py ├── LICENSE ├── README.md ├── requirements.txt ├── evaluator.py ├── run_soft_prompt.py ├── run_task_vector.py ├── run_i2cl.py ├── utils.py ├── run_i2cl_transfer_learning.py ├── run_i2cl_infer.py ├── run_label_anchor.py └── wrapper.py /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LzVv123456/I2CL/HEAD/figures/overview.png -------------------------------------------------------------------------------- /figures/speed_vs_accuracy_vs_cache.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LzVv123456/I2CL/HEAD/figures/speed_vs_accuracy_vs_cache.png -------------------------------------------------------------------------------- /global_vars.py: -------------------------------------------------------------------------------- 1 | CUR_ATTN_MASK = None 2 | ATTN_MASK_START = None 3 | ATTN_MASK_END = None 4 | SUPPORT_MODEL = ['gpt2-xl', 'EleutherAI/gpt-j-6B', 'meta-llama/Llama-2-7b-hf'] 5 | -------------------------------------------------------------------------------- /configs/config_i2cl_transfer_learning.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import my_datasets as md 4 | 5 | config = {} 6 | config['gpus'] = ['0'] 7 | config['exp_name'] = 'exps/i2cl_transfer_learning' 8 | config['models'] = ['meta-llama/Llama-2-7b-hf'] 9 | config['datasets'] = list(md.target_datasets.keys()) 10 | 11 | config['run_num'] = 1 12 | 13 | config['threshold'] = 0.8 14 | config['temp'] = 0.5 15 | config['target_path'] = 'exps/i2cl' 16 | -------------------------------------------------------------------------------- /configs/config_i2cl_infer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import my_datasets as md 4 | 5 | 6 | config = {} 7 | 8 | config['gpus'] = ['0'] 9 | config['exp_name'] = 'exps/i2cl_infer' 10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] 11 | config['datasets'] = list(md.target_datasets.keys()) 12 | config['run_baseline'] = True 13 | config['downstream_datasets'] = None # None will use the same dataset as source, one can also specify a list of target downstream datasets 14 | 15 | config['target_path'] = 'exps/i2cl' 16 | config['use_new_demon'] = True # whether to use new demonstrations to generate context vectors -------------------------------------------------------------------------------- /my_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .basetask import BaseTask 2 | 3 | from .sst2 import SST2 4 | from .dbpedia import DBPedia 5 | from .sst5 import SST5 6 | from .trec import TREC 7 | from .agnews import AGNews 8 | from .subj import Subj 9 | from .rotten_tomatoes import RottenTomatoes 10 | from .hate_speech18 import HateSpeech18 11 | from .emo import EMO 12 | 13 | 14 | target_datasets = { 15 | 'agnews': AGNews, 16 | 'dbpedia': DBPedia, 17 | 'sst5': SST5, 18 | 'trec': TREC, 19 | 'sst2': SST2, 20 | 'subj': Subj, 21 | 'mr': RottenTomatoes, 22 | 'hate_speech18': HateSpeech18, 23 | 'emo': EMO, 24 | } 25 | 26 | dataset_dict = {} 27 | dataset_dict.update(target_datasets) 28 | 29 | def get_dataset(dataset, *args, **kwargs) -> BaseTask: 30 | return dataset_dict[dataset](task_name=dataset, *args, **kwargs) -------------------------------------------------------------------------------- /configs/config_label_anchor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import my_datasets as md 4 | 5 | 6 | config = {} 7 | # general 8 | config['exp_name'] = 'exps/label_anchor' 9 | config['gpus'] = ['0'] 10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] # 'gpt2-xl', 'meta-llama/Llama-2-7b-hf', 'EleutherAI/gpt-j-6B' 11 | config['datasets'] = list(md.dataset_classification.keys()) 12 | config['seed'] = 42 13 | config['run_num'] = 5 14 | config['run_baseline'] = True 15 | config['metric'] = 'acc' # 'acc', 'macro_f1' 16 | config['bs'] = 2 17 | config['load_in_8bit'] = False 18 | config['use_cache'] = True 19 | 20 | # data 21 | config['shot_per_class'] = 5 22 | config['test_data_num'] = 500 23 | config['sample_method'] = 'uniform' # 'random', 'uniform' 24 | config['use_instruction'] = False 25 | config['add_extra_query'] = False 26 | config['example_separator'] = '\n' -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jack Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/config_soft_prompt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import my_datasets as md 4 | 5 | 6 | config = {} 7 | # general 8 | config['exp_name'] = 'exps/soft_prompt' 9 | config['gpus'] = ['0'] 10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] # 'gpt2-xl', 'meta-llama/Llama-2-7b-hf', 'EleutherAI/gpt-j-6B' 11 | config['datasets'] = list(md.target_datasets.keys()) 12 | config['seed'] = 42 13 | config['run_num'] = 5 14 | config['run_baseline'] = True 15 | config['metric'] = 'acc' # 'acc', 'macro_f1' 16 | config['bs'] = 2 17 | config['load_in_8bit'] = False 18 | config['use_cache'] = True 19 | config['example_separator'] = '\n' 20 | 21 | # data 22 | config['shot_per_class'] = 5 23 | config['test_data_num'] = 500 24 | config['sample_method'] = 'uniform' # 'random', 'uniform' 25 | config['add_extra_query'] = False 26 | 27 | # prompt_tuning 28 | pt_config = {} 29 | pt_config['task_type'] = 'CAUSAL_LM' 30 | pt_config['num_virtual_tokens'] = 1 31 | pt_config['num_layers'] = 28 32 | config['pt_config'] = pt_config 33 | 34 | # optimization 35 | config['epochs'] = 50 36 | config['optim'] = 'adamW' # 'adam', 'adamW', 'sgd' 37 | config['grad_bs'] = 4 38 | config['lr'] = 0.1 39 | config['wd'] = 1e-3 -------------------------------------------------------------------------------- /configs/config_task_vector.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import my_datasets as md 4 | 5 | 6 | config = {} 7 | # general 8 | config['exp_name'] = 'exps/task_vector' 9 | config['gpus'] = ['0'] 10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] # 'gpt2-xl', 'meta-llama/Llama-2-7b-hf', 'EleutherAI/gpt-j-6B' 11 | config['datasets'] = list(md.target_datasets.keys()) 12 | config['seed'] = 42 13 | config['run_num'] = 5 14 | config['run_baseline'] = True 15 | config['metric'] = 'acc' # 'acc', 'macro_f1' 16 | config['bs'] = 2 17 | config['load_in_8bit'] = False 18 | config['use_cache'] = True 19 | 20 | # context vector 21 | config['layer'] = 'all' # all, late, early, mid 22 | config['tok_pos'] = 'last' 23 | config['module'] = ['hidden'] # 'mlp', 'attn', 'hidden' 24 | config['gen_cv_method'] = 'context' # 'context', 'noise' 25 | config['post_fuse_method'] = 'mean' # 'mean', 'pca' 26 | config['split_demon'] = False # split demonstraiton into seperate examples 27 | 28 | # data 29 | config['shot_per_class'] = 5 30 | config['val_data_num'] = 32 31 | config['test_data_num'] = 500 32 | config['sample_method'] = 'uniform' # 'random', 'uniform' 33 | config['use_instruction'] = False 34 | config['add_extra_query'] = True 35 | config['example_separator'] = '\n' -------------------------------------------------------------------------------- /configs/config_i2cl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import my_datasets as md 4 | 5 | 6 | config = {} 7 | # general 8 | config['exp_name'] = 'exps/i2cl' 9 | config['gpus'] = ['0'] 10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] # 'gpt2-xl', 'EleutherAI/gpt-j-6B' 11 | config['datasets'] = list(md.target_datasets.keys()) 12 | config['seed'] = 42 13 | config['run_num'] = 5 # number of runs 14 | config['run_baseline'] = True # whether run baseline 15 | config['metric'] = 'acc' # 'acc', 'macro_f1' 16 | config['bs'] = 2 # batch size 17 | config['load_in_8bit'] = False 18 | config['use_cache'] = True # whether use kv cache 19 | config['demo_sample_method'] = 'random' # 'random' or deficient 20 | 21 | # calibrate 22 | config['add_noise'] = True # whether add noise 23 | config['noise_scale'] = 0.001 # noise scale 24 | config['epochs'] = 100 # number of epochs 25 | config['optim'] = 'adamW' # 'adam', 'adamW', 'sgd' 26 | config['grad_bs'] = 2 # batch size for clibration 27 | config['lr'] = 0.01 28 | config['wd'] = 1e-3 29 | config['cali_example_method'] = 'normal' # 'normal', 'random_label' 30 | 31 | # context vector 32 | config['layer'] = 'all' # all, early, mid, late 33 | config['tok_pos'] = 'last' # 'random', 'first', 'last' 34 | config['inject_method'] = 'linear' # 'linear', 'constraint', 'add' 35 | config['inject_pos'] = 'all' # 'all', 'first', last', 'random' 36 | config['init_value'] = [0.1, 1.0] # linear and constraint: [0.1, 1.0], add: [0.1] 37 | config['module'] = ['mlp', 'attn'] # 'mlp', 'attn', 'hidden' 38 | config['gen_cv_method'] = 'context' # 'context', 'noise' 39 | config['post_fuse_method'] = 'mean' # 'mean', 'pca' 40 | config['split_demon'] = True # split demonstraiton into seperate examples 41 | config['gen_example_method'] = 'normal' # 'normal', 'random_label', 'no_template', 'random_order' 42 | 43 | # data 44 | config['shot_per_class'] = 5 # number of shots per class 45 | config['val_data_num'] = 32 46 | config['test_data_num'] = 500 # number of test data 47 | config['sample_method'] = 'uniform' # 'random', 'uniform' 48 | config['use_instruction'] = False 49 | config['add_extra_query'] = False 50 | config['example_separator'] = '\n' -------------------------------------------------------------------------------- /my_datasets/rotten_tomatoes.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | # from basetask import BaseTask 6 | from datasets import load_dataset 7 | 8 | 9 | 10 | class RottenTomatoes(BaseTask): 11 | def __init__(self, split='train', *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.task_type = "classification" 14 | # set split name 15 | load_split = split 16 | print(f"Loading {load_split} data from rotten_tomatoes ...") 17 | # class_num 18 | self.class_num = 2 19 | # load dataset 20 | self.dataset = load_dataset('rotten_tomatoes', split=load_split, keep_in_memory=True) 21 | # get all data 22 | self.all_data = [data for data in self.dataset] 23 | # get all labels 24 | self.all_labels = self.get_all_labels() 25 | # random sample data 26 | if self.max_data_num is not None: 27 | self.random_sample_data(self.max_data_num) 28 | # print a few examples 29 | print(f'Dataset lengh is {len(self.all_data)}') 30 | print("Example data:") 31 | self.print_data([0]) 32 | 33 | def get_dmonstration_template(self): 34 | template = { 35 | 'input': 'Review: {text}\nSentiment:', 36 | 'ans': '{answer}', 37 | 'options': ['negative', 'positive'], 38 | 'format': ['Review:', 'Sentiment:'] 39 | } 40 | return template 41 | 42 | def get_task_instruction(self): 43 | task_instruction = "Classify the sentiment of the sentence into one of the categories: positive or negative.\n\n" 44 | return task_instruction 45 | 46 | def apply_template(self, data): 47 | """ 48 | PS: label should always be an integer and can be used to index the options 49 | """ 50 | template = self.get_dmonstration_template() 51 | input_template = template['input'] 52 | ans_template = template['ans'] 53 | options = template['options'] 54 | input_str = input_template.replace("{text}", data["text"]) 55 | # answers can have multiple options and is a list 56 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 57 | label = data["label"] 58 | return input_str, answer_str, label 59 | 60 | def get_all_labels(self): 61 | return [data["label"] for data in self.all_data] -------------------------------------------------------------------------------- /my_datasets/sst5.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | # from basetask import BaseTask 6 | from datasets import load_dataset 7 | 8 | 9 | class SST5(BaseTask): 10 | def __init__(self, split='train', *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.task_type = "classification" 13 | # set split name 14 | load_split = split 15 | print(f"Loading {load_split} data from sst5 ...") 16 | # class_num 17 | self.class_num = 5 18 | # load dataset 19 | self.dataset = load_dataset('SetFit/sst5', split=load_split, keep_in_memory=True) 20 | # get all data 21 | self.all_data = [data for data in self.dataset] 22 | # get all labels 23 | self.all_labels = self.get_all_labels() 24 | # random sample data 25 | if self.max_data_num is not None: 26 | self.random_sample_data(self.max_data_num) 27 | # print a few examples 28 | print(f'Dataset lengh is {len(self.all_data)}') 29 | print("Example data:") 30 | self.print_data([0]) 31 | 32 | def get_dmonstration_template(self): 33 | template = { 34 | 'input': 'Sentence: {text}\nSentiment:', 35 | 'ans': '{answer}', 36 | 'options': ["terrible", "negative", "neutral", "positive", "great"], 37 | 'format': ['Sentence:', 'Sentiment:'] 38 | } 39 | return template 40 | 41 | def get_task_instruction(self): 42 | task_instruction = "Classify the sentiment of sentence into one of the categories: terrible, negative, neutral, positive, great.\n\n" 43 | return task_instruction 44 | 45 | def apply_template(self, data): 46 | """ 47 | PS: label should always be an integer and can be used to index the options 48 | """ 49 | template = self.get_dmonstration_template() 50 | input_template = template['input'] 51 | ans_template = template['ans'] 52 | options = template['options'] 53 | input_str = input_template.replace("{text}", data["text"]) 54 | # answers can have multiple options and is a list 55 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 56 | label = data["label"] 57 | return input_str, answer_str, label 58 | 59 | def get_all_labels(self): 60 | return [data["label"] for data in self.all_data] -------------------------------------------------------------------------------- /my_datasets/agnews.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | from datasets import load_dataset 6 | 7 | 8 | class AGNews(BaseTask): 9 | def __init__(self, split='train', *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.task_type = "classification" 12 | # set split name 13 | if split in ['train', 'validation']: 14 | load_split = 'train' 15 | else: 16 | load_split = 'test' 17 | print(f"Loading {load_split} data from AGNews ...") 18 | # class_num 19 | self.class_num = 4 20 | # load dataset 21 | self.dataset = load_dataset('ag_news', split=load_split, keep_in_memory=True) 22 | # get all data 23 | self.all_data = [data for data in self.dataset] 24 | # get all labels 25 | self.all_labels = self.get_all_labels() 26 | # random sample data 27 | if self.max_data_num is not None: 28 | self.random_sample_data(self.max_data_num) 29 | # print a few examples 30 | print(f'Dataset lengh is {len(self.all_data)}') 31 | print("Example data:") 32 | self.print_data([0]) 33 | 34 | def get_dmonstration_template(self): 35 | template = { 36 | 'input': 'News: {text}\nType:', 37 | 'ans': '{answer}', 38 | 'options': ["World", "Sports", "Business", "Technology"], 39 | 'format': ['News:', 'Type:'] 40 | } 41 | return template 42 | 43 | def get_task_instruction(self): 44 | task_instruction = "Classify the article into one of the categories: World, Sports, Business, Technology.\n\n" 45 | return task_instruction 46 | 47 | def apply_template(self, data): 48 | """ 49 | PS: label should always be an integer and can be used to index the options 50 | """ 51 | template = self.get_dmonstration_template() 52 | input_template = template['input'] 53 | ans_template = template['ans'] 54 | options = template['options'] 55 | input_str = input_template.replace("{text}", data["text"]) 56 | # answers can have multiple options and is a list 57 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 58 | label = data["label"] 59 | return input_str, answer_str, label 60 | 61 | def get_all_labels(self): 62 | return [data["label"] for data in self.all_data] -------------------------------------------------------------------------------- /my_datasets/emo.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | # from basetask import BaseTask 6 | from datasets import load_dataset 7 | 8 | 9 | class EMO(BaseTask): 10 | def __init__(self, split='train', *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.task_type = "classification" 13 | # set split name 14 | if split in ['train', 'validation']: 15 | load_split = 'train' 16 | else: 17 | load_split = 'test' 18 | print(f"Loading {load_split} data from emo ...") 19 | # class_num 20 | self.class_num = 4 21 | # load dataset 22 | self.dataset = load_dataset('emo', split=load_split, keep_in_memory=True) 23 | # get all data 24 | self.all_data = [data for data in self.dataset] 25 | # get all labels 26 | self.all_labels = self.get_all_labels() 27 | # random sample data 28 | if self.max_data_num is not None: 29 | self.random_sample_data(self.max_data_num) 30 | # print a few examples 31 | print(f'Dataset lengh is {len(self.all_data)}') 32 | print("Example data:") 33 | self.print_data([0]) 34 | 35 | def get_dmonstration_template(self): 36 | template = { 37 | 'input': 'Dialogue: {text}\nEmotion:', 38 | 'ans': '{answer}', 39 | 'options': ['others', 'happy', 'sad', 'angry'], 40 | 'format': ['Dialogue:', 'Emotion:'] 41 | } 42 | return template 43 | 44 | def get_task_instruction(self): 45 | task_instruction = "Classify the emotion of the dialogue into one of the categories: Others, Happy, Sad, Angry.\n\n" 46 | return task_instruction 47 | 48 | def apply_template(self, data): 49 | """ 50 | PS: label should always be an integer and can be used to index the options 51 | """ 52 | template = self.get_dmonstration_template() 53 | input_template = template['input'] 54 | ans_template = template['ans'] 55 | options = template['options'] 56 | input_str = input_template.replace("{text}", data["text"]) 57 | # answers can have multiple options and is a list 58 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 59 | label = data["label"] 60 | return input_str, answer_str, label 61 | 62 | def get_all_labels(self): 63 | return [data["label"] for data in self.all_data] -------------------------------------------------------------------------------- /my_datasets/sst2.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | # from basetask import BaseTask 6 | from datasets import load_dataset 7 | 8 | 9 | class SST2(BaseTask): 10 | def __init__(self, split='train', *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.task_type = "classification" 13 | # set split name 14 | if split in ['train', 'validation']: 15 | load_split = 'train' 16 | else: 17 | load_split = 'validation' 18 | print(f"Loading {load_split} data from sst2 ...") 19 | # class_num 20 | self.class_num = 2 21 | # load dataset 22 | self.dataset = load_dataset('glue', 'sst2', split=load_split, keep_in_memory=True) 23 | # get all data 24 | self.all_data = [data for data in self.dataset] 25 | # get all labels 26 | self.all_labels = self.get_all_labels() 27 | # random sample data 28 | if self.max_data_num is not None: 29 | self.random_sample_data(self.max_data_num) 30 | # print a few examples 31 | print(f'Dataset lengh is {len(self.all_data)}') 32 | print("Example data:") 33 | self.print_data([0]) 34 | 35 | def get_dmonstration_template(self): 36 | template = { 37 | 'input': 'Review: {text}\nSentiment:', 38 | 'ans': '{answer}', 39 | 'options': ['negative', 'positive'], 40 | 'format': ['Review:', 'Sentiment:'] 41 | } 42 | return template 43 | 44 | def get_task_instruction(self): 45 | task_instruction = "Classify the sentiment of the sentence into one of the categories: positive or negative.\n\n" 46 | return task_instruction 47 | 48 | def apply_template(self, data): 49 | """ 50 | PS: label should always be an integer and can be used to index the options 51 | """ 52 | template = self.get_dmonstration_template() 53 | input_template = template['input'] 54 | ans_template = template['ans'] 55 | options = template['options'] 56 | input_str = input_template.replace("{text}", data["sentence"]) 57 | # answers can have multiple options and is a list 58 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 59 | label = data["label"] 60 | return input_str, answer_str, label 61 | 62 | def get_all_labels(self): 63 | return [data["label"] for data in self.all_data] 64 | -------------------------------------------------------------------------------- /my_datasets/subj.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | # from basetask import BaseTask 6 | from datasets import load_dataset 7 | 8 | 9 | class Subj(BaseTask): 10 | def __init__(self, split='train', *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.task_type = "classification" 13 | assert split in ['train', 'validation', 'test'] 14 | # set split name 15 | if split in ['train', 'validation']: 16 | load_split = 'train' 17 | else: 18 | load_split = 'test' 19 | print(f"Loading {load_split} data from subj ...") 20 | # class_num 21 | self.class_num = 2 22 | # load dataset 23 | self.dataset = load_dataset('SetFit/subj', split=load_split, keep_in_memory=True) 24 | # get all data 25 | self.all_data = [data for data in self.dataset] 26 | # get all labels 27 | self.all_labels = self.get_all_labels() 28 | # random sample data 29 | if self.max_data_num is not None: 30 | self.random_sample_data(self.max_data_num) 31 | # print a few examples 32 | print(f'Dataset lengh is {len(self.all_data)}') 33 | print("Example data:") 34 | self.print_data([0]) 35 | 36 | def get_dmonstration_template(self): 37 | template = { 38 | 'input': 'Sentence: {text}\nLabel:', 39 | 'ans': '{answer}', 40 | 'options': ['objective', 'subjective'], 41 | 'format': ['Sentence:', 'Label:'] 42 | } 43 | return template 44 | 45 | def get_task_instruction(self): 46 | task_instruction = "Classify the sentiment of the sentence into one of the categories: positive or negative.\n\n" 47 | return task_instruction 48 | 49 | def apply_template(self, data): 50 | """ 51 | PS: label should always be an integer and can be used to index the options 52 | """ 53 | template = self.get_dmonstration_template() 54 | input_template = template['input'] 55 | ans_template = template['ans'] 56 | options = template['options'] 57 | input_str = input_template.replace("{text}", data["text"]) 58 | # answers can have multiple options and is a list 59 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 60 | label = data["label"] 61 | return input_str, answer_str, label 62 | 63 | def get_all_labels(self): 64 | return [data["label"] for data in self.all_data] -------------------------------------------------------------------------------- /my_datasets/trec.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | from datasets import load_dataset 6 | 7 | 8 | class TREC(BaseTask): 9 | def __init__(self, split='train', *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.task_type = "classification" 12 | # set split name 13 | if split in ['train', 'validation']: 14 | load_split = 'train' 15 | else: 16 | load_split = 'test' 17 | print(f"Loading {load_split} data from TREC ...") 18 | # class_num 19 | self.class_num = 6 20 | # load dataset 21 | self.dataset = load_dataset('trec', split=load_split, keep_in_memory=True) 22 | # get all data 23 | self.all_data = [data for data in self.dataset] 24 | # get all labels 25 | self.all_labels = self.get_all_labels() 26 | # random sample data 27 | if self.max_data_num is not None: 28 | self.random_sample_data(self.max_data_num) 29 | # print a few examples 30 | print(f'Dataset lengh is {len(self.all_data)}') 31 | print("Example data:") 32 | self.print_data([0]) 33 | 34 | def get_dmonstration_template(self): 35 | template = { 36 | 'input': 'Question: {text}\nAnswer Type:', 37 | 'ans': '{answer}', 38 | 'options': ["Abbreviation", "Entity", "Description", "Person", "Location", "Number"], 39 | 'format': ['Question:', 'Category:'] 40 | } 41 | return template 42 | 43 | def get_task_instruction(self): 44 | task_instruction = "Classify the questions based on whether their answer type is a Number, Location, Person, Description, Entity, or Abbreviation.\n\n" 45 | return task_instruction 46 | 47 | def apply_template(self, data): 48 | """ 49 | PS: label should always be an integer and can be used to index the options 50 | """ 51 | template = self.get_dmonstration_template() 52 | input_template = template['input'] 53 | ans_template = template['ans'] 54 | options = template['options'] 55 | input_str = input_template.replace("{text}", data["text"]) 56 | # answers can have multiple options and is a list 57 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 58 | label = data["coarse_label"] 59 | return input_str, answer_str, label 60 | 61 | def get_all_labels(self): 62 | return [data["coarse_label"] for data in self.all_data] -------------------------------------------------------------------------------- /my_datasets/hate_speech18.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | # from basetask import BaseTask 6 | from datasets import load_dataset 7 | 8 | 9 | class HateSpeech18(BaseTask): 10 | def __init__(self, split='train', *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.task_type = "classification" 13 | print(f"Loading train data from hates_peach18 ...") 14 | # class_num 15 | self.class_num = 2 16 | # load dataset 17 | self.dataset = load_dataset('hate_speech18', split='train', keep_in_memory=True) 18 | # get all data 19 | self.all_data = [data for data in self.dataset if data['label'] in [0, 1]] # clean data to keep only hate and noHate 20 | if split in ['train', 'validation']: 21 | self.all_data = self.all_data[:len(self.all_data)//2] 22 | else: 23 | self.all_data = self.all_data[len(self.all_data)//2:] 24 | # get all labels 25 | self.all_labels = self.get_all_labels() 26 | # random sample data 27 | if self.max_data_num is not None: 28 | self.random_sample_data(self.max_data_num) 29 | # print a few examples 30 | print(f'Dataset length is {len(self.all_data)}') 31 | print("Example data:") 32 | self.print_data([0]) 33 | 34 | def get_dmonstration_template(self): 35 | template = { 36 | 'input': 'Text: {text}\nLabel:', 37 | 'ans': '{answer}', 38 | 'options': ['neutral', 'hate'], 39 | 'format': ['Text:', 'Label:'] 40 | } 41 | return template 42 | 43 | def get_task_instruction(self): 44 | task_instruction = "Classify the text into one of the categories: noHate or hate.\n\n" 45 | return task_instruction 46 | 47 | def apply_template(self, data): 48 | """ 49 | PS: label should always be an integer and can be used to index the options 50 | """ 51 | template = self.get_dmonstration_template() 52 | input_template = template['input'] 53 | ans_template = template['ans'] 54 | options = template['options'] 55 | input_str = input_template.replace("{text}", data["text"]) 56 | # answers can have multiple options and is a list 57 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 58 | label = data["label"] 59 | return input_str, answer_str, label 60 | 61 | def get_all_labels(self): 62 | return [data["label"] for data in self.all_data] -------------------------------------------------------------------------------- /my_datasets/dbpedia.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import BaseTask 3 | except: 4 | from basetask import BaseTask 5 | # from basetask import BaseTask 6 | from datasets import load_dataset 7 | 8 | 9 | class DBPedia(BaseTask): 10 | def __init__(self, split='train', *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.task_type = "classification" 13 | # set split name 14 | if split in ['train', 'validation']: 15 | load_split = 'train' 16 | else: 17 | load_split = 'test' 18 | 19 | print(f"Loading {load_split} data from DBPedia ...") 20 | # class_num 21 | self.class_num = 14 22 | # load dataset 23 | self.dataset = load_dataset('fancyzhx/dbpedia_14', split=load_split, keep_in_memory=True) 24 | # get all data 25 | self.all_data = [data for data in self.dataset] 26 | # get all labels 27 | self.all_labels = self.get_all_labels() 28 | # random sample data 29 | if self.max_data_num is not None: 30 | self.random_sample_data(self.max_data_num) 31 | # print a few examples 32 | print(f'Dataset lengh is {len(self.all_data)}') 33 | print("Example data:") 34 | self.print_data([0]) 35 | 36 | def get_dmonstration_template(self): 37 | template = { 38 | 'input': 'Input: {text}\nLabel:', 39 | 'ans': '{answer}', 40 | 'options': ['company', 'school', 'artist', 'athlete', 41 | 'politics', 'transportation', 'building', 42 | 'nature', 'village', 'animal', 'plant', 43 | 'album', 'film', 'book'], 44 | 'format': ['Input:', 'Label:'] 45 | } 46 | return template 47 | 48 | def get_task_instruction(self): 49 | task_instruction = "Classify the sentiment of the sentence into one of the categories: positive or negative.\n\n" 50 | return task_instruction 51 | 52 | def apply_template(self, data): 53 | """ 54 | PS: label should always be an integer and can be used to index the options 55 | """ 56 | template = self.get_dmonstration_template() 57 | input_template = template['input'] 58 | ans_template = template['ans'] 59 | options = template['options'] 60 | input_str = input_template.replace("{text}", data["content"]) 61 | # answers can have multiple options and is a list 62 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))] 63 | label = data["label"] 64 | return input_str, answer_str, label 65 | 66 | def get_all_labels(self): 67 | return [data["label"] for data in self.all_data] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # I2CL: Implicit In-context Learning 2 | Please refer to the [paper](https://arxiv.org/pdf/2405.14660) for more details. 3 | 4 | 5 | 6 | 7 | 8 | 9 |
OverviewSpeed vs Accuracy vs Cache
10 | 11 | ## What's New? 12 | ### 🌟 Introducing I2CL: a new paradigm to leverage demonstration examples: 13 | Implicit In-context Learning (I2CL) absorbs a tiny set of demonstration examples in the activation space, diverging from standard In-context Learning (ICL) that prefixes demonstration examples in the token space. As a result, I2CL bypasses the limitation of the context window size. 14 | 15 | ### 🌟 Efficiency 16 | I2CL is extremely efficient in terms of both computation and memory usage. It achieves few-shot (i.e., ICL) performance with approximately zero-shot cost at inference. 17 | 18 | ### 🌟 Robustness 19 | I2CL is robust against the selection and order of demonstration examples. It yields satisfying performance even under deficient demonstration examples which can severely degrade the performance of ICL. 20 | 21 | ### 🌟 Generation of "Task-ID" 22 | I2CL introduces a set of scalar values that act as task-IDs. These task-IDs effectively indicate the task similarity and can be leveraged to perform transfer learning across tasks. 23 | 24 | ### 🌟 Better Understanding of ICL 25 | I2CL substantiates a two-stage workflow. A context vector is first generated by condensing the demonstration examples independently. Then a linear combination of the context vector and activation from the query is injected back into the residual streams. The effectiveness of I2CL suggests a potential two-stage workflow for ICL. 26 | 27 | ## Installation 28 | 1. **Clone the Repository**: 29 | 30 | ```bash 31 | git clone https://github.com/LzVv123456/I2CL 32 | cd yourrepository 33 | ``` 34 | 35 | 2. **Create a Conda Environment**: 36 | 37 | Create a new conda environment to avoid conflicts with existing packages. 38 | 39 | ```bash 40 | conda create --name i2cl_env python=3.8 41 | conda activate i2cl_env 42 | ``` 43 | 44 | 3. **Install Dependencies**: 45 | 46 | Use `pip` to install the required libraries listed in the `requirements.txt` file. 47 | 48 | ```bash 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | 53 | ## Usage 54 | 55 | To use the code, follow these steps: 56 | 57 | 1. **Navigate to the I2CL Folder**: 58 | 59 | ```bash 60 | cd I2CL 61 | ``` 62 | 63 | 2. **Run I2CL**: 64 | 65 | To run I2CL, execute the following command: 66 | 67 | ```bash 68 | python run_i2cl.py 69 | ``` 70 | 71 | 3. **Run Comparable Methods**: 72 | 73 | To run other comparable methods, use the following commands: 74 | 75 | ```bash 76 | python run_soft_prompt.py 77 | python run_task_vector.py 78 | python run_label_anchor.py 79 | ``` 80 | 81 | 4. **Apply I2CL to Unseen Demonstrations or Perform Transfer Learning**: 82 | 83 | First, run `run_i2cl.py` and specify the target path in the configuration files for `run_i2cl_infer.py` and `run_i2cl_transfer_learning.py` as the output result folder of `run_i2cl.py`. 84 | 85 | Then, execute the following commands: 86 | 87 | ```bash 88 | python run_i2cl_infer.py 89 | python run_i2cl_transfer_learning.py 90 | ``` 91 | 92 | 5. **Configuration and Ablation Studies** 93 | 94 | For ablation studies and other configurations, please refer to `configs/config_i2cl.py` and the corresponding code files for more details. 95 | 96 | Please don't hesitate to drop an email at zhuowei.li@cs.rutgers.edu if you have question. 97 | 98 | ## License 99 | 100 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 101 | 102 | ## Acknowledgments 103 | 104 | I would also like to acknowledge the repositories that inspired and contributed to this work: 105 | - [label-words-are-anchors](https://github.com/lancopku/label-words-are-anchors) 106 | - [ICV](https://github.com/shengliu66/ICV) 107 | - [icl_task_vectors](https://github.com/roeehendel/icl_task_vectors) 108 | 109 | ## Citation 110 | If you found this work useful for your research, feel free to star ⭐ the repo or cite the following paper: 111 | ``` 112 | @misc{li2024implicit, 113 | title={Implicit In-context Learning}, 114 | author={Zhuowei Li and Zihao Xu and Ligong Han and Yunhe Gao and Song Wen and Di Liu and Hao Wang and Dimitris N. Metaxas}, 115 | year={2024}, 116 | eprint={2405.14660}, 117 | archivePrefix={arXiv}, 118 | primaryClass={cs.LG} 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.26.1 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work 5 | async-timeout==4.0.3 6 | attrs==23.2.0 7 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 8 | bitsandbytes==0.42.0 9 | Bottleneck @ file:///opt/conda/conda-bld/bottleneck_1657175564434/work 10 | Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work 11 | calflops==0.2.9 12 | certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi 13 | cffi @ file:///croot/cffi_1700254295673/work 14 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 15 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work 16 | contourpy @ file:///croot/contourpy_1700583582875/work 17 | cryptography @ file:///croot/cryptography_1702070282333/work 18 | cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work 19 | datasets==2.16.1 20 | debugpy @ file:///croot/debugpy_1690905042057/work 21 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work 22 | dill==0.3.7 23 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work 24 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work 25 | filelock @ file:///croot/filelock_1700591183607/work 26 | fonttools==4.25.0 27 | frozenlist==1.4.1 28 | fsspec==2023.10.0 29 | gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645438755360/work 30 | huggingface-hub==0.20.3 31 | idna @ file:///croot/idna_1666125576474/work 32 | importlib-resources @ file:///croot/importlib_resources-suite_1704281845041/work 33 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1708996548741/work 34 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work 35 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work 36 | Jinja2 @ file:///croot/jinja2_1706733616596/work 37 | joblib==1.3.2 38 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work 39 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257447442/work 40 | kiwisolver @ file:///croot/kiwisolver_1672387140495/work 41 | MarkupSafe @ file:///croot/markupsafe_1704205993651/work 42 | matplotlib==3.8.4 43 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work 44 | mkl-fft @ file:///croot/mkl_fft_1695058164594/work 45 | mkl-random @ file:///croot/mkl_random_1695059800811/work 46 | mkl-service==2.4.0 47 | mpmath @ file:///croot/mpmath_1690848262763/work 48 | multidict==6.0.5 49 | multiprocess==0.70.15 50 | munkres==1.1.4 51 | nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work 52 | networkx @ file:///croot/networkx_1690561992265/work 53 | numexpr @ file:///croot/numexpr_1696515281613/work 54 | numpy @ file:///croot/numpy_and_numpy_base_1704311704800/work/dist/numpy-1.26.3-cp39-cp39-linux_x86_64.whl#sha256=93e7b9e5e2090dd03810e7c1b02cb077d3ef49fc713f0af531e0667375d9decb 55 | nvidia-ml-py3==7.352.0 56 | packaging @ file:///croot/packaging_1693575174725/work 57 | pandas @ file:///croot/pandas_1702317985682/work/dist/pandas-2.1.4-cp39-cp39-linux_x86_64.whl#sha256=324122d54922a1f7a8103c955ecfc280ac90ea4d0a9ae7e365d5a477d5f0713e 58 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work 59 | peft==0.8.2 60 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work 61 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 62 | Pillow @ file:///croot/pillow_1696580024257/work 63 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1706713388748/work 64 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work 65 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work 66 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 67 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 68 | pyarrow==15.0.0 69 | pyarrow-hotfix==0.6 70 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 71 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work 72 | pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work 73 | pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work 74 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work 75 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work 76 | pytz @ file:///croot/pytz_1695131579487/work 77 | PyYAML @ file:///croot/pyyaml_1698096049011/work 78 | pyzmq @ file:///croot/pyzmq_1705605076900/work 79 | regex==2023.12.25 80 | requests @ file:///croot/requests_1690400202158/work 81 | safetensors==0.4.2 82 | scikit-learn==1.4.0 83 | scipy==1.12.0 84 | seaborn==0.13.2 85 | six @ file:///tmp/build/80754af9/six_1644875935023/work 86 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work 87 | sympy @ file:///croot/sympy_1701397643339/work 88 | threadpoolctl==3.2.0 89 | tokenizers==0.15.1 90 | torch==2.2.0 91 | torchaudio==2.2.0 92 | torchvision==0.17.0 93 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827245914/work 94 | tqdm==4.66.1 95 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1710254411456/work 96 | transformers==4.37.2 97 | triton==2.2.0 98 | typing_extensions @ file:///croot/typing_extensions_1705599297034/work 99 | tzdata @ file:///croot/python-tzdata_1690578112552/work 100 | urllib3 @ file:///croot/urllib3_1698257533958/work 101 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work 102 | xxhash==3.4.1 103 | yarl==1.9.4 104 | zipp @ file:///croot/zipp_1704206909481/work 105 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import utils 5 | import global_vars as gv 6 | import my_datasets as md 7 | 8 | 9 | class Evaluator(nn.Module): 10 | 11 | def __init__(self, dataset, batch_size): 12 | super().__init__() 13 | self.dataset = dataset 14 | self.batch_size = batch_size 15 | 16 | def evaluate(self, model_wrapper, tokenizer, demonstration='', use_cache=False): 17 | 18 | return self._evaluate_text_classification_batch(model_wrapper, tokenizer, 19 | demonstration, use_cache=use_cache) 20 | 21 | def _evaluate_text_classification_batch(self, model_wrapper, tokenizer, 22 | demonstration, use_cache=False): 23 | 24 | model = model_wrapper.model 25 | # prepare label dict 26 | label_map = {} 27 | ans_txt_list = self.dataset.get_dmonstration_template()['options'] 28 | for label, ans_txt in enumerate(ans_txt_list): 29 | if 'gpt' in tokenizer.__class__.__name__.lower(): 30 | ans_txt = ' ' + ans_txt # add space to the beginning of answer 31 | ans_tok = tokenizer.encode(ans_txt, add_special_tokens=False)[0] # use the first token if more than one token 32 | print(f"ans_txt: {ans_txt}, ans_tok: {ans_tok}") 33 | label_map[ans_tok] = label # index is the label 34 | print(f"label_map: {label_map}") 35 | 36 | # prepare all data 37 | all_pred_labels = [] 38 | all_inputs, all_labels = [], [] 39 | for data in self.dataset.all_data: 40 | ques_str, _, label = self.dataset.apply_template(data) 41 | if use_cache or len(demonstration) == 0: 42 | context = ques_str 43 | else: 44 | context = demonstration + ques_str 45 | all_inputs.append(context) 46 | all_labels.append(label) 47 | 48 | # cache the demonstration 49 | if len(demonstration) > 0 and use_cache: 50 | demon_token = tokenizer(demonstration, return_tensors="pt", padding=True).to(model.device) 51 | with torch.no_grad(): 52 | demon_outputs = model(**demon_token, use_cache=True) 53 | demon_past_key_values = demon_outputs.past_key_values 54 | demon_attn_mask = demon_token['attention_mask'] 55 | demon_past_key_values = tuple(tuple(t.repeat(self.batch_size, 1, 1, 1) for 56 | t in tup) for tup in demon_past_key_values) 57 | demon_attn_mask = demon_attn_mask.repeat(self.batch_size, 1) 58 | if len(all_inputs) % self.batch_size != 0: # last batch 59 | sp_demon_past_key_values = tuple(tuple(t.repeat(len(all_inputs) % self.batch_size, 1, 1, 1) 60 | for t in tup) for tup in demon_outputs.past_key_values) 61 | sp_demon_attn_mask = demon_attn_mask[-(len(all_inputs) % self.batch_size):] 62 | use_cache = True 63 | else: 64 | demon_past_key_values = None 65 | sp_demon_past_key_values = None 66 | sp_demon_attn_mask = None 67 | use_cache = False 68 | 69 | # loop over all data 70 | with torch.no_grad(): 71 | for i in range(0, len(all_inputs), self.batch_size): 72 | cur_inputs = all_inputs[i:i+self.batch_size] 73 | # accommodate for the last batch 74 | if len(cur_inputs) != self.batch_size: 75 | demon_past_key_values = sp_demon_past_key_values 76 | demon_attn_mask = sp_demon_attn_mask 77 | input_tok = tokenizer(cur_inputs, return_tensors="pt", padding=True) 78 | input_ids = input_tok['input_ids'].to(model.device) 79 | attn_mask = input_tok['attention_mask'].to(model.device) 80 | # get index for prediction logits, need to be applied before concatenating demon_attn_mask with attn_mask 81 | pred_loc = utils.last_one_indices(attn_mask).to(model.device) 82 | # set global variables 83 | gv.ATTN_MASK_START = torch.zeros_like(pred_loc) 84 | gv.ATTN_MASK_END = pred_loc 85 | if use_cache: 86 | attn_mask = torch.cat([demon_attn_mask, attn_mask], dim=1) 87 | output = model(input_ids=input_ids, attention_mask=attn_mask, 88 | past_key_values=demon_past_key_values, use_cache=use_cache) 89 | else: 90 | output = model(input_ids=input_ids, attention_mask=attn_mask) 91 | logits = output.logits 92 | 93 | # get prediction logits 94 | pred_logits = logits[torch.arange(logits.size(0)), pred_loc] 95 | # get prediction labels 96 | interest_index = list(label_map.keys()) 97 | pred_logits = pred_logits[:, interest_index] 98 | probs = F.softmax(pred_logits, dim=-1) 99 | pred_labels = probs.argmax(dim=-1) 100 | # save results 101 | all_pred_labels.extend(pred_labels.cpu().numpy().tolist()) 102 | 103 | assert len(all_pred_labels) == len(all_labels) 104 | # both all_results and all_labels are list containing label index, can you help me to calculate accuracy and macro f1? 105 | # initialize TP, FP, FN 106 | acc = [] 107 | num_classes = self.dataset.class_num 108 | TP = [0] * num_classes 109 | FP = [0] * num_classes 110 | FN = [0] * num_classes 111 | for i, true_label in enumerate(all_labels): 112 | pred_label = all_pred_labels[i] 113 | pred = pred_label == true_label 114 | acc.append(pred) 115 | # Update TP, FP, FN 116 | if pred: 117 | TP[true_label] += 1 118 | else: 119 | FP[pred_label] += 1 120 | FN[true_label] += 1 121 | # Calculate precision, recall, F1 for each class and macro F1 122 | precision = [0] * num_classes 123 | recall = [0] * num_classes 124 | f1 = [0] * num_classes 125 | for i in range(num_classes): 126 | precision[i] = TP[i] / (TP[i] + FP[i]) if (TP[i] + FP[i]) > 0 else 0 127 | recall[i] = TP[i] / (TP[i] + FN[i]) if (TP[i] + FN[i]) > 0 else 0 128 | f1[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0 129 | macro_f1 = sum(f1) / num_classes 130 | acc = sum(acc) / len(acc) 131 | return {'acc': acc, 'macro_f1': macro_f1} -------------------------------------------------------------------------------- /run_soft_prompt.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import copy 4 | import time 5 | import random 6 | import argparse 7 | import itertools 8 | import torch 9 | from multiprocessing import Process, Queue 10 | 11 | import utils 12 | import evaluator as ev 13 | import my_datasets as md 14 | 15 | 16 | def main(args): 17 | # set global seed 18 | utils.set_seed(args.config['seed']) 19 | # set device 20 | args.device = utils.set_device(args.gpu) 21 | # set comprare metric 22 | args.metric = args.config['metric'] 23 | # get save dir 24 | utils.init_exp_path(args, args.config['exp_name']) 25 | 26 | # load tokenizer and model 27 | model, tokenizer, model_config = \ 28 | utils.load_model_tokenizer(args.model_name, args.device) 29 | 30 | # get model_wrapper 31 | model_wrapper = utils.get_model_wrapper(args.model_name, model, 32 | tokenizer, model_config, 33 | args.device) 34 | # load datasets 35 | train_dataset = md.get_dataset(args.dataset_name, split='train', max_data_num=None) 36 | test_dataset = md.get_dataset(args.dataset_name, split='test', 37 | max_data_num=args.config['test_data_num'], 38 | sample_mode=args.config['sample_method']) 39 | 40 | # get max demonstration token length for each dataset 41 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer) 42 | 43 | # get shot_num 44 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia 45 | args.config['shot_per_class'] = 1 46 | args.config['bs'] = 1 47 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class']) 48 | # build evaluator 49 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs']) 50 | 51 | # init result_dict 52 | result_dict = {'demon': {}, 53 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': []}, 54 | 'time': {'calibrate': [], 'evaluate': []}, 55 | } 56 | 57 | for run_id in range(args.config['run_num']): 58 | run_name = f'run_{run_id}' 59 | args.run_name = run_name 60 | print(f'Run time {run_name}') 61 | run_seed = args.config['seed'] + run_id 62 | utils.set_seed(run_seed) 63 | 64 | # zero-shot baseline 65 | if run_id == 0 and args.config['run_baseline']: 66 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 67 | use_cache=args.config['use_cache']) 68 | result_dict['test_result']['zero_shot'].append(test_zeroshot_result) 69 | print(f'Test zero-shot result: {test_zeroshot_result}\n') 70 | 71 | # sample demonstration 72 | demon, _, demon_data_index = \ 73 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num, 74 | max_demonstration_tok_len=args.test_max_token, 75 | add_extra_query=args.config['add_extra_query'], 76 | example_separator=args.config['example_separator'], 77 | return_data_index=True, seed=random.randint(0, 1e6)) 78 | 79 | if args.config['add_extra_query']: 80 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0] 81 | # remove all contents after the last first_format_anchor including the anchor 82 | if first_format_anchor in demon: 83 | baseline_demon = demon[:demon.rfind(first_format_anchor)] 84 | query_demon = demon[demon.rfind(first_format_anchor):] 85 | else: 86 | baseline_demon = demon 87 | query_demon = None 88 | print(f'Demonstration:\n{demon}\n') 89 | print(f'Baseline demonstration:\n{baseline_demon}\n') 90 | print(f'Query demonstration:\n{query_demon}\n') 91 | 92 | # few-shot baseline 93 | if args.config['run_baseline']: 94 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, 95 | demonstration=baseline_demon, 96 | use_cache=args.config['use_cache']) 97 | result_dict['test_result']['few_shot'].append(test_fewshot_result) 98 | print(f'Test few-shot result: {test_fewshot_result}\n') 99 | 100 | # generate demon_list 101 | demon_list = [demon] 102 | # save demon_list 103 | result_dict['demon'][run_name] = demon_list 104 | 105 | # prepare peft_train_dataset 106 | cali_dataset = copy.deepcopy(train_dataset) 107 | cali_dataset.all_data = [train_dataset.all_data[i] for i in demon_data_index] 108 | 109 | # train softprompt 110 | s_t = time.time() 111 | model_wrapper.softprompt(args.config, cali_dataset, save_dir=args.save_dir, run_name=run_name) 112 | e_t = time.time() 113 | print(f'Calibration time: {e_t - s_t}') 114 | result_dict['time']['calibrate'].append(e_t - s_t) 115 | 116 | # evaluate softprompt 117 | s_t = time.time() 118 | test_ours_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 119 | use_cache=args.config['use_cache']) 120 | print(f'Test Soft Prompt result: {test_ours_result}\n') 121 | result_dict['test_result']['ours'].append(test_ours_result) 122 | e_t = time.time() 123 | print(f'Evaluate time: {e_t - s_t}') 124 | result_dict['time']['evaluate'].append(e_t - s_t) 125 | 126 | # save result_dict after each run 127 | with open(args.save_dir + '/result_dict.json', 'w') as f: 128 | json.dump(result_dict, f, indent=4) 129 | 130 | # reset model_wrapper to unadapted model 131 | del model_wrapper, model, tokenizer, model_config 132 | model, tokenizer, model_config = \ 133 | utils.load_model_tokenizer(args.model_name, args.device, 134 | load_in_8bit=args.config['load_in_8bit'], 135 | output_hidden_states=False) 136 | # get model_wrapper 137 | model_wrapper = utils.get_model_wrapper(args.model_name, model, 138 | tokenizer, model_config, 139 | args.device) 140 | 141 | # delete all variables 142 | del model_wrapper, model, tokenizer, train_dataset, test_dataset, cali_dataset 143 | del test_evaluator, result_dict, demon_list 144 | 145 | 146 | # get args 147 | def get_args(): 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--config_path', type=str, default='configs/config_soft_prompt.py', help='path to config file') 150 | return parser.parse_args() 151 | 152 | 153 | if __name__ == "__main__": 154 | # get args 155 | args = get_args() 156 | # load config 157 | config = utils.load_config(args.config_path) 158 | # Generate all combinations of models and datasets 159 | combinations = list(itertools.product(config['models'], config['datasets'])) 160 | # Queue to hold tasks 161 | task_queue = Queue() 162 | for combine in combinations: 163 | task_queue.put(combine) 164 | 165 | def run_task(gpu_id, config): 166 | while not task_queue.empty(): 167 | model_name, dataset_name = task_queue.get() 168 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}") 169 | input_args = argparse.Namespace() 170 | cur_config = copy.deepcopy(config) 171 | input_args.model_name = model_name 172 | input_args.dataset_name = dataset_name 173 | input_args.gpu = gpu_id 174 | input_args.config = cur_config 175 | try: 176 | main(input_args) 177 | finally: 178 | # Clean up CUDA memory after each task 179 | gc.collect() 180 | torch.cuda.empty_cache() 181 | print(f"CUDA memory cleared for GPU {gpu_id}") 182 | time.sleep(5) 183 | 184 | # Create a process for each GPU 185 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']] 186 | # Start all processes 187 | for p in processes: 188 | p.start() 189 | # Wait for all processes to finish 190 | for p in processes: 191 | p.join() 192 | print("All tasks completed.") -------------------------------------------------------------------------------- /my_datasets/basetask.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from itertools import zip_longest 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class BaseTask(Dataset): 8 | def __init__(self, task_name, sample_mode='random', max_data_num=None, use_instruction=False, seed=0): 9 | super().__init__() 10 | self.task_name = task_name 11 | self.sample_mode = sample_mode 12 | self.max_data_num = max_data_num # maximum number of loaded data 13 | self.use_instruction = use_instruction # whether add task instruction as prefix 14 | self.task_type = None 15 | self.dataset = None # dataset 16 | self.all_data = None # all data 17 | self.all_labels = None # all labels 18 | self.seed = seed 19 | 20 | def random_sample_data(self, max_data_num, seed=0): 21 | """ 22 | This function is used to randomly sample data from the dataset. 23 | """ 24 | # set random seed 25 | random.seed(self.seed) 26 | assert self.all_data is not None, "Please load data first!" 27 | if max_data_num < len(self.all_data): 28 | if (self.all_labels is None) or (self.sample_mode == 'random'): 29 | # random sample data 30 | self.all_data = random.sample(self.all_data, max_data_num) 31 | else: 32 | # sample data uniformly from each class, if not possible, sample up to max_data_num 33 | label_dict = {label: [] for label in list(set(self.all_labels))} 34 | for index, label in enumerate(self.all_labels): 35 | label_dict[label].append(self.all_data[index]) 36 | # print key and number of data with current key 37 | print("Number of data in each class:") 38 | for label, label_list in label_dict.items(): 39 | print(f"{label}: {len(label_list)}") 40 | new_all_data = [] 41 | 42 | key_num = len(label_dict.keys()) 43 | for label in label_dict: 44 | # if number of data with current label is smaller than max_data_num / class_num, use all data and distibute remaining quota evenly to rest classes 45 | if len(label_dict[label]) < max_data_num // key_num: 46 | new_all_data.extend(label_dict[label]) 47 | else: 48 | # random sample data from current label 49 | new_all_data.extend(random.sample(label_dict[label], max_data_num // key_num)) 50 | # if length of new_all_data is smaller than max_data_num, randomly sample data from all data but check if the data is already in new_all_data 51 | while len(new_all_data) < max_data_num: 52 | tem_data = random.choice(self.all_data) 53 | if tem_data not in new_all_data: 54 | new_all_data.append(tem_data) 55 | self.all_data = new_all_data 56 | else: 57 | print(f"Warning: max_data_num {max_data_num} is larger than the dataset size {len(self.all_data)}!, use all data instead.") 58 | 59 | 60 | def get_max_demonstration_token_length(self, tokenizer): 61 | """ 62 | This function is used to get the maximum token length of the example in the dataset. 63 | This is mainly used for test dataset to determine the maximum number of the demonstration 64 | tokens that can be used due to the limit of context window. 65 | """ 66 | all_data_toks = [] 67 | for data in self.all_data: 68 | input_str, ans_str, _ = self.apply_template(data) 69 | if self.use_instruction: 70 | instruct = self.get_task_instruction() 71 | instruct = instruct + '\n' 72 | else: 73 | instruct = "" 74 | demonstration_str = [(instruct + input_str + ' ' + ans_str[i]) for i in range(len(ans_str))] 75 | all_data_toks.extend(demonstration_str) 76 | # get maximum token length of the example in the dataset 77 | encoded_inputs = tokenizer.batch_encode_plus(all_data_toks, return_tensors="pt", padding=True, truncation=True) 78 | single_data_max_len = encoded_inputs['attention_mask'].sum(dim=1).max().item() 79 | if 'Llama' in tokenizer.name_or_path: 80 | if 'Llama-2' in tokenizer.name_or_path: 81 | cxt_max_len = 4096 82 | else: 83 | cxt_max_len = 2048 84 | else: 85 | cxt_max_len = tokenizer.model_max_length 86 | max_demonstration_len = cxt_max_len - single_data_max_len 87 | print(f"Max demonstration token length is : {cxt_max_len} - {single_data_max_len} = {max_demonstration_len}",) 88 | return max_demonstration_len 89 | 90 | 91 | def gen_few_shot_demonstration(self, tokenizer, shot_num, max_demonstration_tok_len=1e6, 92 | add_extra_query=False, example_separator='\n', 93 | return_data_index=False, gen_example_method='normal', seed=0): 94 | 95 | """ 96 | This function is used to generate few-shot demonstration. 97 | """ 98 | # set random seed 99 | random.seed(seed) 100 | 101 | assert self.all_data is not None, "Please load data first!" 102 | assert shot_num <= len(self.all_data), "Shot number should be smaller than the number of data!" 103 | if hasattr(self, 'class_num') and self.class_num is not None: # if class number is provided 104 | assert shot_num == -1 or shot_num == 0 or shot_num >= self.class_num, "Shot number should be at least larger than the number of classes!" 105 | class_num = self.class_num 106 | # get label dict 107 | label_dict = {label: [] for label in range(class_num)} 108 | for index, data in enumerate(self.all_data): 109 | label_dict[self.apply_template(data)[-1]].append(index) 110 | else: 111 | class_num = None 112 | 113 | # get task instruction 114 | instruct = self.get_task_instruction() if self.use_instruction else "" 115 | if len(instruct) > 0: 116 | demonstration_expample_list = [instruct] 117 | demonstration = instruct + example_separator 118 | else: 119 | demonstration_expample_list = [] 120 | demonstration = "" 121 | 122 | if class_num is None: # random sample data 123 | sample_indexes = random.sample(range(len(self.all_data)), shot_num) 124 | else: # uniform sample data from each class 125 | sample_indexes = [] 126 | # split sample number into each class equally, if not possible, sample as many as possible 127 | for label in label_dict: 128 | sample_indexes.extend(random.sample(label_dict[label], shot_num // class_num)) 129 | # random shuffle the sample indexes 130 | random.shuffle(sample_indexes) 131 | 132 | for index in sample_indexes: 133 | input_str, ans_str, label = self.apply_template(self.all_data[index]) 134 | ans = ans_str[label] 135 | new_example = input_str + ' ' + ans 136 | demonstration = demonstration + new_example + example_separator 137 | # collect single demonstration example 138 | if gen_example_method == 'normal': 139 | single_example = new_example 140 | elif gen_example_method == 'random_label': 141 | single_example = input_str + ' ' + random.choice(ans_str) 142 | elif gen_example_method == 'no_template': 143 | single_example = input_str.split(":")[1].split("\n")[0] + ' ' + ans 144 | elif gen_example_method == 'random_order': 145 | # random change the order of each word in the input string 146 | single_example = new_example 147 | words = single_example.split() 148 | random.shuffle(words) 149 | single_example = ' '.join(words) 150 | else: 151 | raise ValueError("Unknown demonstration example generation method!") 152 | single_example = single_example + example_separator 153 | demonstration_expample_list.append(single_example) 154 | 155 | if add_extra_query: # add a random query at the end of demonstration 156 | extra_qeury, _, _ = self.apply_template(self.all_data[random.randint(0, len(self.all_data))]) 157 | demonstration += extra_qeury 158 | demonstration_expample_list.append(extra_qeury) 159 | 160 | # check length of demonstration token 161 | encoded_inputs = tokenizer(demonstration, return_tensors="pt", padding=True, truncation=False) 162 | assert len(encoded_inputs['input_ids'][0]) < max_demonstration_tok_len, "Demonstration token length should be smaller than the maximum demonstration token length!" 163 | print(f"Generated {shot_num}-shot demonstration.") 164 | 165 | if return_data_index: 166 | return demonstration, demonstration_expample_list, sample_indexes 167 | else: 168 | return demonstration, demonstration_expample_list 169 | 170 | 171 | def get_dmonstration_template(self): 172 | """ 173 | This function is used to provide template for demonstration, need to be implemented for each task. 174 | """ 175 | raise NotImplementedError("Please provide the template for demonstration!") 176 | 177 | def get_task_instruction(self): 178 | """ 179 | This function is used to provide task instruction, need to be implemented for each task. 180 | """ 181 | raise NotImplementedError("Please provide the task instruction!") 182 | 183 | def apply_template(self, data): 184 | """ 185 | This function is used to apply template to a given data, need to be implemented for each task. 186 | """ 187 | raise NotImplementedError("Please provide how to apply template!") 188 | 189 | def print_data(self, indices): 190 | """ 191 | This function is used to print data given indices. 192 | """ 193 | if isinstance(indices, int): 194 | indices = [indices] 195 | for index in indices: 196 | print(self.all_data[index]) 197 | 198 | def __len__(self): 199 | return len(self.all_data) 200 | 201 | def __getitem__(self, index): 202 | return self.all_data[index] -------------------------------------------------------------------------------- /run_task_vector.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import copy 4 | import random 5 | import time 6 | import argparse 7 | import itertools 8 | import torch 9 | from multiprocessing import Process, Queue 10 | 11 | import utils 12 | import my_datasets as md 13 | import evaluator as ev 14 | 15 | 16 | def target_layer_selection(args, model_wrapper, tokenizer, evaluator, context_vector_dict): 17 | num_layers = model_wrapper.num_layers 18 | with torch.no_grad(): 19 | best_layer = 0 20 | best_result = 0 21 | for layer in range(num_layers): 22 | with model_wrapper.replace_latent(context_vector_dict, [layer], args.config): 23 | val_result = evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 24 | use_cache=args.config['use_cache']) 25 | print(f'Layer {layer} result: {val_result}\n') 26 | if val_result[args.metric] > best_result: 27 | best_result = val_result[args.metric] 28 | best_layer = layer 29 | print(f'Best layer: {best_layer}') 30 | return best_layer 31 | 32 | 33 | def main(args): 34 | # set global seed 35 | utils.set_seed(args.config['seed']) 36 | # set device 37 | args.device = utils.set_device(args.gpu) 38 | # set metric used 39 | args.metric = args.config['metric'] 40 | # get save dir 41 | utils.init_exp_path(args, args.config['exp_name']) 42 | 43 | # load tokenizer and model 44 | model, tokenizer, model_config = \ 45 | utils.load_model_tokenizer(args.model_name, args.device) 46 | 47 | # get model_wrapper 48 | model_wrapper = utils.get_model_wrapper(args.model_name, model, 49 | tokenizer, model_config, 50 | args.device) 51 | # load datasets 52 | train_dataset = md.get_dataset(args.dataset_name, split='train', max_data_num=None) 53 | val_dataset = md.get_dataset(args.dataset_name, split='validation', 54 | max_data_num=args.config['val_data_num'], 55 | sample_mode=args.config['sample_method']) 56 | test_dataset = md.get_dataset(args.dataset_name, split='test', 57 | max_data_num=args.config['test_data_num'], 58 | sample_mode=args.config['sample_method']) 59 | 60 | # get max demonstration token length for each dataset 61 | args.val_max_token = val_dataset.get_max_demonstration_token_length(tokenizer) 62 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer) 63 | 64 | # get shot_num 65 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia 66 | args.config['shot_per_class'] = 1 67 | args.config['bs'] = 1 68 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class']) 69 | # build evaluate 70 | val_evaluator = ev.Evaluator(val_dataset, batch_size=args.config['bs']) 71 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs']) 72 | # init result_dict 73 | result_dict = {'demon': {}, 74 | 'split_demon': {}, 75 | 'best_replace_layer': {}, 76 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': []}, 77 | 'val_result': {'zero_shot': [], 'few_shot': [], 'ours': []}, 78 | 'time': {'calibrate': [], 'evaluate': []}, 79 | } 80 | 81 | for run_id in range(args.config['run_num']): 82 | run_name = f'run_{run_id}' 83 | args.run_name = run_name 84 | print(f'Run time {run_name}') 85 | run_seed = args.config['seed'] + run_id 86 | utils.set_seed(run_seed) 87 | 88 | # zero-shot baseline 89 | if run_id == 0 and args.config['run_baseline']: 90 | val_zeroshot_result = val_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 91 | use_cache=args.config['use_cache']) 92 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 93 | use_cache=args.config['use_cache']) 94 | result_dict['val_result']['zero_shot'].append(val_zeroshot_result) 95 | result_dict['test_result']['zero_shot'].append(test_zeroshot_result) 96 | print(f'Validation zero-shot result: {val_zeroshot_result}\n') 97 | print(f'Test zero-shot result: {test_zeroshot_result}\n') 98 | 99 | # sample demonstration 100 | demon, _, _ = \ 101 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num, 102 | max_demonstration_tok_len=min(args.val_max_token, 103 | args.test_max_token), 104 | add_extra_query=args.config['add_extra_query'], 105 | example_separator=args.config['example_separator'], 106 | return_data_index=True, seed=random.randint(0, 1e6) 107 | ) 108 | 109 | if args.config['add_extra_query']: 110 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0] 111 | # remove all contents after the last first_format_anchor including the anchor 112 | if first_format_anchor in demon: 113 | baseline_demon = demon[:demon.rfind(first_format_anchor)] 114 | query_demon = demon[demon.rfind(first_format_anchor):] 115 | else: 116 | baseline_demon = demon 117 | query_demon = None 118 | print(f'Demonstration:\n{demon}\n') 119 | print(f'Baseline demonstration:\n{baseline_demon}\n') 120 | print(f'Query demonstration:\n{query_demon}\n') 121 | 122 | # few-shot baseline 123 | if args.config['run_baseline']: 124 | val_fewshot_result = val_evaluator.evaluate(model_wrapper, tokenizer, 125 | demonstration=baseline_demon, 126 | use_cache=args.config['use_cache']) 127 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, 128 | demonstration=baseline_demon, 129 | use_cache=args.config['use_cache']) 130 | result_dict['val_result']['few_shot'].append(val_fewshot_result) 131 | result_dict['test_result']['few_shot'].append(test_fewshot_result) 132 | print(f'Validation few-shot result: {val_fewshot_result}\n') 133 | print(f'Test few-shot result: {test_fewshot_result}\n') 134 | 135 | # extract latents ====================================================================== 136 | all_latent_dicts = [] 137 | with torch.no_grad(): 138 | with model_wrapper.extract_latent(): 139 | demon_token = tokenizer(demon, return_tensors='pt').to(args.device) 140 | _ = model(**demon_token) 141 | all_latent_dicts.append(model_wrapper.latent_dict) 142 | 143 | # generate context vector ============================================================== 144 | context_vector_dict = model_wrapper.get_context_vector(all_latent_dicts, args.config) 145 | del all_latent_dicts 146 | 147 | # injection layer selection ============================================================ 148 | best_replace_layer = target_layer_selection(args, model_wrapper, tokenizer, 149 | val_evaluator, context_vector_dict) 150 | result_dict['best_replace_layer'][run_name] = best_replace_layer 151 | 152 | # evaluate task_vector ======================================================================== 153 | s_t = time.time() 154 | with model_wrapper.replace_latent(context_vector_dict, [best_replace_layer], args.config): 155 | val_ours_result = val_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 156 | use_cache=args.config['use_cache']) 157 | print(f'Validation task_vector result: {val_ours_result}\n') 158 | result_dict['val_result']['ours'].append(val_ours_result) 159 | 160 | test_ours_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 161 | use_cache=args.config['use_cache']) 162 | print(f'Test task_vector result: {test_ours_result}\n') 163 | result_dict['test_result']['ours'].append(test_ours_result) 164 | e_t = time.time() 165 | 166 | print(f'Evaluate time: {e_t - s_t}') 167 | result_dict['time']['evaluate'].append(e_t - s_t) 168 | 169 | # save result_dict after each run 170 | with open(args.save_dir + '/result_dict.json', 'w') as f: 171 | json.dump(result_dict, f, indent=4) 172 | 173 | # delete all variables 174 | del model_wrapper, model, tokenizer, train_dataset, val_dataset, test_dataset 175 | del val_evaluator, test_evaluator, result_dict, context_vector_dict, 176 | 177 | 178 | # get args 179 | def get_args(): 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument('--config_path', type=str, default='configs/config_task_vector.py', help='path to config file') 182 | return parser.parse_args() 183 | 184 | 185 | if __name__ == "__main__": 186 | # get args 187 | args = get_args() 188 | # load config 189 | config = utils.load_config(args.config_path) 190 | # Generate all combinations of models and datasets 191 | combinations = list(itertools.product(config['models'], config['datasets'])) 192 | # Queue to hold tasks 193 | task_queue = Queue() 194 | for combine in combinations: 195 | task_queue.put(combine) 196 | 197 | def run_task(gpu_id, config): 198 | while not task_queue.empty(): 199 | model_name, dataset_name = task_queue.get() 200 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}") 201 | input_args = argparse.Namespace() 202 | cur_config = copy.deepcopy(config) 203 | input_args.model_name = model_name 204 | input_args.dataset_name = dataset_name 205 | input_args.gpu = gpu_id 206 | input_args.config = cur_config 207 | try: 208 | main(input_args) 209 | finally: 210 | # Clean up CUDA memory after each task 211 | gc.collect() 212 | torch.cuda.empty_cache() 213 | print(f"CUDA memory cleared for GPU {gpu_id}") 214 | time.sleep(5) 215 | 216 | # Create a process for each GPU 217 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']] 218 | # Start all processes 219 | for p in processes: 220 | p.start() 221 | # Wait for all processes to finish 222 | for p in processes: 223 | p.join() 224 | print("All tasks completed.") -------------------------------------------------------------------------------- /run_i2cl.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import copy 4 | import time 5 | import random 6 | import argparse 7 | import itertools 8 | import torch 9 | import numpy as np 10 | from multiprocessing import Process, Queue 11 | 12 | import utils 13 | import my_datasets as md 14 | import evaluator as ev 15 | 16 | 17 | def main(args): 18 | # set global seed 19 | utils.set_seed(args.config['seed']) 20 | # set device 21 | args.device = utils.set_device(args.gpu) 22 | # set metric used 23 | args.metric = args.config['metric'] 24 | # get save dir 25 | utils.init_exp_path(args, args.config['exp_name']) 26 | 27 | # load tokenizer and model 28 | model, tokenizer, model_config = \ 29 | utils.load_model_tokenizer(args.model_name, args.device) 30 | 31 | # get model_wrapper 32 | model_wrapper = utils.get_model_wrapper(args.model_name, model, 33 | tokenizer, model_config, 34 | args.device) 35 | 36 | # load datasets 37 | train_dataset = md.get_dataset(args.dataset_name, split='train', 38 | max_data_num=None, seed=args.config['seed']) 39 | holdout_dataset = md.get_dataset(args.dataset_name, split='validation', 40 | max_data_num=args.config['val_data_num'], 41 | sample_mode=args.config['sample_method'], 42 | seed=args.config['seed']) 43 | test_dataset = md.get_dataset(args.dataset_name, split='test', 44 | max_data_num=args.config['test_data_num'], 45 | sample_mode=args.config['sample_method'], 46 | seed=args.config['seed']) 47 | 48 | # get max demonstration token length for each dataset 49 | if args.config['split_demon']: 50 | args.test_max_token = 1e8 51 | else: 52 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer) 53 | 54 | # get shot_num 55 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia 56 | args.config['shot_per_class'] = 1 57 | args.config['bs'] = 1 58 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class']) 59 | 60 | # build evaluators 61 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs']) 62 | holdout_evaluator = ev.Evaluator(holdout_dataset, batch_size=args.config['bs']) 63 | # init result_dict 64 | result_dict = {'demon': {}, 65 | 'split_demon': {}, 66 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': []}, 67 | 'linear_coef': {}, 68 | 'time': {'calibrate': [], 'evaluate': []}, 69 | } 70 | cv_save_dict = {} 71 | 72 | for run_id in range(args.config['run_num']): 73 | run_name = f'run_{run_id}' 74 | args.run_name = run_name 75 | print(f'Run time {run_name}') 76 | run_seed = args.config['seed'] + run_id 77 | utils.set_seed(run_seed) 78 | 79 | # zero-shot baseline 80 | if run_id == 0 and args.config['run_baseline']: 81 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 82 | use_cache=args.config['use_cache']) 83 | result_dict['test_result']['zero_shot'].append(test_zeroshot_result) 84 | print(f'Test zero-shot result: {test_zeroshot_result}\n') 85 | 86 | # sample demonstration 87 | count = 0 88 | temp_demon_list, temp_result_list = [], [] 89 | while True: 90 | demon, split_demon, demon_data_index = \ 91 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num, 92 | max_demonstration_tok_len=args.test_max_token, 93 | add_extra_query=args.config['add_extra_query'], 94 | example_separator=args.config['example_separator'], 95 | gen_example_method = args.config['gen_example_method'], 96 | return_data_index=True, seed=random.randint(0, 1e6)) 97 | temp_demon_list.append((demon, split_demon, demon_data_index)) 98 | 99 | if args.config['demo_sample_method'] == 'random': 100 | break 101 | else: 102 | tem_val_result = holdout_evaluator.evaluate(model_wrapper, tokenizer, 103 | demonstration=demon, 104 | use_cache=args.config['use_cache']) 105 | temp_result = tem_val_result[args.metric] 106 | temp_result_list.append(temp_result) 107 | if count > 20: 108 | if args.config['demo_sample_method'] == 'deficient': 109 | demon, split_demon, demon_data_index = temp_demon_list[np.argmin(temp_result_list)] 110 | else: 111 | raise ValueError('Invalid demo_sample_method!') 112 | break 113 | count += 1 114 | 115 | # build val_evaluator use demon_data_index 116 | cali_dataset = copy.deepcopy(train_dataset) 117 | cali_dataset.all_data = [train_dataset.all_data[i] for i in demon_data_index] 118 | 119 | # clean demonstration 120 | if args.config['add_extra_query']: 121 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0] 122 | # remove all contents after the last first_format_anchor including the anchor 123 | if first_format_anchor in demon: 124 | baseline_demon = demon[:demon.rfind(first_format_anchor)] 125 | query_demon = demon[demon.rfind(first_format_anchor):] 126 | else: 127 | baseline_demon = demon 128 | query_demon = None 129 | print(f'Demonstration:\n{demon}\n') 130 | print(f'Baseline demonstration:\n{baseline_demon}\n') 131 | print(f'Query demonstration:\n{query_demon}\n') 132 | 133 | # few-shot baseline 134 | if args.config['run_baseline']: 135 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, 136 | demonstration=baseline_demon, 137 | use_cache=args.config['use_cache']) 138 | result_dict['test_result']['few_shot'].append(test_fewshot_result) 139 | print(f'Test few-shot result: {test_fewshot_result}\n') 140 | 141 | # generate demon_list 142 | demon_list = [demon] 143 | split_demon_list = split_demon 144 | result_dict['demon'][run_name] = demon_list 145 | result_dict['split_demon'][run_name] = split_demon_list 146 | 147 | # init strength_params 148 | model_wrapper.init_strength(args.config) 149 | 150 | # extract latents 151 | all_latent_dicts = [] 152 | with torch.no_grad(): 153 | if not args.config['split_demon']: 154 | target_demon_list = demon_list[0] 155 | else: 156 | target_demon_list = split_demon_list 157 | for cur_demon in target_demon_list: 158 | with model_wrapper.extract_latent(): 159 | demon_token = tokenizer(cur_demon, return_tensors='pt').to(args.device) 160 | _ = model(**demon_token) 161 | all_latent_dicts.append(model_wrapper.latent_dict) 162 | model_wrapper.reset_latent_dict() 163 | 164 | # generate context vector 165 | context_vector_dict = model_wrapper.get_context_vector(all_latent_dicts, args.config) 166 | if args.config['gen_cv_method'] == 'noise': 167 | context_vector_dict = model_wrapper.init_noise_context_vector(context_vector_dict) 168 | del all_latent_dicts 169 | 170 | # calibrate context vector 171 | s_t = time.time() 172 | model_wrapper.calibrate_strength(context_vector_dict, cali_dataset, 173 | args.config, save_dir=args.save_dir, 174 | run_name=args.run_name) 175 | e_t = time.time() 176 | print(f'Calibration time: {e_t - s_t}') 177 | result_dict['time']['calibrate'].append(e_t - s_t) 178 | 179 | # save linear_coef 180 | result_dict['linear_coef'][run_name] = model_wrapper.linear_coef.tolist() 181 | 182 | # evaluate i2cl 183 | s_t = time.time() 184 | with torch.no_grad(): 185 | with model_wrapper.inject_latent(context_vector_dict, args.config, 186 | model_wrapper.linear_coef): 187 | test_ours_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 188 | use_cache=args.config['use_cache']) 189 | print(f'Test I2CL result: {test_ours_result}\n') 190 | result_dict['test_result']['ours'].append(test_ours_result) 191 | e_t = time.time() 192 | print(f'Evaluate time: {e_t - s_t}') 193 | result_dict['time']['evaluate'].append(e_t - s_t) 194 | 195 | # save result_dict after each run 196 | with open(args.save_dir + '/result_dict.json', 'w') as f: 197 | json.dump(result_dict, f, indent=4) 198 | 199 | # save context vector dict 200 | for layer, subdict in context_vector_dict.items(): 201 | for module, activation in subdict.items(): 202 | context_vector_dict[layer][module] = activation.cpu().numpy().tolist() 203 | cv_save_dict[run_name] = context_vector_dict 204 | 205 | with open(args.save_dir + '/cv_save_dict.json', 'w') as f: 206 | json.dump(cv_save_dict, f, indent=4) 207 | 208 | # delete all variables 209 | del model_wrapper, model, tokenizer, train_dataset, cali_dataset, test_dataset, holdout_dataset 210 | del test_evaluator, holdout_evaluator 211 | del result_dict, context_vector_dict, demon_list 212 | 213 | 214 | # get args 215 | def get_args(): 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument('--config_path', type=str, default='configs/config_i2cl.py', help='path to config file') 218 | return parser.parse_args() 219 | 220 | 221 | if __name__ == "__main__": 222 | # get args 223 | args = get_args() 224 | # load config 225 | config = utils.load_config(args.config_path) 226 | # Generate all combinations of models and datasets 227 | combinations = list(itertools.product(config['models'], config['datasets'])) 228 | # Queue to hold tasks 229 | task_queue = Queue() 230 | for combine in combinations: 231 | task_queue.put(combine) 232 | 233 | def run_task(gpu_id, config): 234 | while not task_queue.empty(): 235 | model_name, dataset_name = task_queue.get() 236 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}") 237 | input_args = argparse.Namespace() 238 | cur_config = copy.deepcopy(config) 239 | input_args.model_name = model_name 240 | input_args.dataset_name = dataset_name 241 | input_args.gpu = gpu_id 242 | input_args.config = cur_config 243 | try: 244 | main(input_args) 245 | finally: 246 | # Clean up CUDA memory after each task 247 | gc.collect() 248 | torch.cuda.empty_cache() 249 | print(f"CUDA memory cleared for GPU {gpu_id}") 250 | time.sleep(5) 251 | 252 | # Create a process for each GPU 253 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']] 254 | # Start all processes 255 | for p in processes: 256 | p.start() 257 | # Wait for all processes to finish 258 | for p in processes: 259 | p.join() 260 | print("All tasks completed.") -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import random 5 | import functools 6 | import warnings 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from typing import Union, List, Optional 11 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 12 | import matplotlib.pyplot as plt 13 | import wrapper 14 | import my_datasets as md 15 | 16 | 17 | def set_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.backends.cudnn.deterministic = True 23 | 24 | 25 | def set_device(gpu_id): 26 | device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') 27 | torch.cuda.set_device(device) 28 | return device 29 | 30 | 31 | def init_exp_path(args, exp_name, separate_dataset=True): 32 | if separate_dataset: 33 | save_dir = os.path.join(exp_name, args.model_name, args.dataset_name) 34 | else: 35 | save_dir = os.path.join(exp_name, args.model_name) 36 | args.save_dir = save_dir 37 | if os.path.exists(save_dir) and 'debug' not in exp_name: 38 | raise ValueError(f"Experiment {exp_name} already exists! please delete it or change the name!") 39 | os.makedirs(save_dir, exist_ok=True) 40 | # save config_dict 41 | with open(f'{save_dir}/config.json', 'w') as f: 42 | json.dump(args.config, f, indent=4) 43 | # save args as txt file, I want to make it beautiful 44 | with open(f'{save_dir}/args.txt', 'w') as f: 45 | for key, value in vars(args).items(): 46 | f.write(f'{key}: {value}\n') 47 | 48 | 49 | def load_model_tokenizer(model_name, device, output_hidden_states=True, load_in_8bit=False): 50 | # load tokenizer and model 51 | tokenizer = AutoTokenizer.from_pretrained(model_name) 52 | model = AutoModelForCausalLM.from_pretrained(model_name, 53 | output_hidden_states=output_hidden_states, 54 | load_in_8bit=load_in_8bit, 55 | torch_dtype=torch.float32) 56 | if not load_in_8bit: 57 | model = model.to(device) 58 | config = AutoConfig.from_pretrained(model_name) 59 | tokenizer.pad_token = tokenizer.eos_token 60 | tokenizer.padding_side = 'right' 61 | return model, tokenizer, config 62 | 63 | 64 | def get_model_wrapper(model_name, model, tokenizer, model_config, device): 65 | if 'llama' in model_name: 66 | model_wrapper = wrapper.LlamaWrapper(model, tokenizer, model_config, device) 67 | elif 'gpt' in model_name: 68 | model_wrapper = wrapper.GPTWrapper(model, tokenizer, model_config, device) 69 | else: 70 | raise ValueError("only support llama or gpt!") 71 | return model_wrapper 72 | 73 | 74 | def load_config(file_path): 75 | if not file_path: 76 | raise ValueError("No file path provided") 77 | file_dir = os.path.dirname(file_path) 78 | if file_dir not in sys.path: 79 | sys.path.append(file_dir) 80 | file_name = os.path.basename(file_path) 81 | module_name = os.path.splitext(file_name)[0] 82 | module = __import__(module_name) 83 | try: 84 | my_variable = getattr(module, 'config') 85 | print(my_variable) 86 | return my_variable 87 | except AttributeError: 88 | print(f"The module does not have a variable named 'config'") 89 | 90 | 91 | def get_shot_num(dataset, shot_per_class, shot_num=5): 92 | if hasattr(dataset, 'class_num') and dataset.class_num is not None: 93 | shot_num = dataset.class_num * shot_per_class 94 | else: 95 | shot_num = shot_num 96 | # if shot_num < 0, then use all data 97 | if shot_num < 0: 98 | shot_num = -1 99 | return shot_num 100 | 101 | 102 | def first_one_indices(tensor): 103 | """ 104 | Finds the index of the first 1 in each row of a 2D tensor. 105 | 106 | Args: 107 | tensor (torch.Tensor): A 2D tensor of size (N, M) containing only 0 and 1 entries. 108 | 109 | Returns: 110 | torch.Tensor: A tensor of size N containing the index of the first 1 in each row. 111 | If a row contains only 0s, the index will be set to -1 (or a sentinel value of your choice). 112 | """ 113 | # Check for rows containing only zeros. 114 | is_all_zero = tensor.sum(dim=1) == 0 115 | # Get the index of the first occurrence of the maximum value (1) along each row. 116 | indices = tensor.argmax(dim=1) 117 | # Handle rows with all zeros. 118 | indices[is_all_zero] = -1 # Set to -1 to indicate no '1' found in these rows 119 | return indices 120 | 121 | 122 | def last_one_indices(tensor): 123 | """ 124 | Finds the index of the last 1 in each row of a 2D tensor. 125 | 126 | Args: 127 | tensor (torch.Tensor): A 2D tensor of size (N, M) containing only 0 and 1 entries. 128 | 129 | Returns: 130 | torch.Tensor: A tensor of size N containing the index of the last 1 in each row. 131 | If a row contains only 0s, the index will be set to -1 (or a sentinel value of your choice). 132 | """ 133 | # Reverse each row to find the last occurrence of 1 (which becomes the first in the reversed row) 134 | reversed_tensor = torch.flip(tensor, [1]) 135 | # Check for rows containing only zeros in the reversed tensor 136 | is_all_zero = reversed_tensor.sum(dim=1) == 0 137 | # Get the index of the first occurrence of the maximum value (1) along each row in the reversed tensor 138 | indices = reversed_tensor.argmax(dim=1) 139 | # Adjust the indices for the original order of each row 140 | indices = tensor.size(1) - 1 - indices 141 | # Handle rows with all zeros 142 | indices[is_all_zero] = -1 # Set to -1 to indicate no '1' found in these rows 143 | return indices 144 | 145 | 146 | def plot_loss_curve(loss_list, save_path): 147 | plt.plot(loss_list) 148 | plt.xlabel('epoch') 149 | plt.ylabel('loss') 150 | plt.savefig(save_path) 151 | plt.close() 152 | 153 | 154 | def svd_flip(u, v): 155 | # columns of u, rows of v 156 | max_abs_cols = torch.argmax(torch.abs(u), 0) 157 | i = torch.arange(u.shape[1]).to(u.device) 158 | signs = torch.sign(u[max_abs_cols, i]) 159 | u *= signs 160 | v *= signs.view(-1, 1) 161 | return u, v 162 | 163 | 164 | class PCA(nn.Module): 165 | def __init__(self, n_components): 166 | super().__init__() 167 | self.n_components = n_components 168 | 169 | @torch.no_grad() 170 | def fit(self, X): 171 | n, d = X.size() 172 | if self.n_components is not None: 173 | d = min(self.n_components, d) 174 | self.register_buffer("mean_", X.mean(0, keepdim=True)) 175 | Z = X - self.mean_ # center 176 | U, S, Vh = torch.linalg.svd(Z, full_matrices=False) 177 | Vt = Vh 178 | U, Vt = svd_flip(U, Vt) 179 | self.register_buffer("components_", Vt[:d]) 180 | return self 181 | 182 | def forward(self, X): 183 | return self.transform(X) 184 | 185 | def transform(self, X): 186 | assert hasattr(self, "components_"), "PCA must be fit before use." 187 | return torch.matmul(X - self.mean_, self.components_.t()) 188 | 189 | def fit_transform(self, X): 190 | self.fit(X) 191 | return self.transform(X) 192 | 193 | def inverse_transform(self, Y): 194 | assert hasattr(self, "components_"), "PCA must be fit before use." 195 | return torch.matmul(Y, self.components_) + self.mean_ 196 | 197 | 198 | class ContextSolver: 199 | def __init__(self, task_name, tokenizer=None): 200 | # assert task_name in ['sst2', 'trec', 'agnews', 'emo'] 201 | self.task_name = task_name 202 | self.tokenizer = tokenizer 203 | self.task_dataset = md.get_dataset(task_name, split='train', max_data_num=10) 204 | self.format_s = self.task_dataset.get_dmonstration_template()['input'] 205 | self.parse_format_s() 206 | 207 | def parse_format_s(self): 208 | self.X_prefix = self.format_s.split('\n')[0].split(':')[0] + ':' 209 | self.Y_prefix = self.format_s.split('\n')[1].split(':')[0] + ':' 210 | 211 | def get_empty_demo_context(self, context: str, only_demo_part=True): 212 | context = context.split('\n') 213 | for i, line in enumerate(context[:-2]): 214 | if self.X_prefix in line: 215 | line = self.X_prefix 216 | elif self.Y_prefix in line: 217 | line = line 218 | else: 219 | raise warnings.warn('Global prefix or other str exists!') 220 | context[i] = line 221 | if only_demo_part: 222 | context = context[:-2] 223 | context = '\n'.join(context) 224 | return context 225 | 226 | def get_mask_strings_and_match_before(self, context, input_ids, tokenizer=None): 227 | if tokenizer is None: 228 | tokenizer = self.tokenizer 229 | print('debug tokenizer name :', tokenizer.__class__.__name__) 230 | if 'Llama' in tokenizer.__class__.__name__: 231 | sap_token = tokenizer.encode('\n', add_special_tokens=False)[1] 232 | poss = torch.where(input_ids == sap_token)[0] 233 | else: 234 | sap_token = tokenizer.encode('\n', add_special_tokens=False)[0] 235 | poss = torch.where(input_ids == sap_token)[0] 236 | print('debug sap_token:', sap_token) 237 | print('debug poss:', poss) 238 | if len(poss) >= 2: 239 | match_before = poss[-2] + 1 240 | else: 241 | match_before = None 242 | 243 | list_s = [] 244 | list_s.append(self.X_prefix) 245 | list_s.append('\n' + self.X_prefix) 246 | context = context.split('\n') 247 | for i, line in enumerate(context[:-2]): 248 | if self.X_prefix in line: 249 | pass 250 | elif self.Y_prefix in line: 251 | list_s.append('\n' + line) 252 | list_s.append('\n' + line + '\n') 253 | else: 254 | raise warnings.warn('Global prefix or other str exists!') 255 | return list_s, match_before 256 | 257 | def get_mask(self, input_ids, tokenizer=None): 258 | if isinstance(input_ids, list): 259 | input_ids = torch.tensor(input_ids) 260 | if len(input_ids.shape) == 2: 261 | assert input_ids.shape[0] == 1 262 | input_ids = input_ids[0] 263 | if tokenizer is None: 264 | tokenizer = self.tokenizer 265 | context = tokenizer.decode(input_ids) 266 | list_s, match_before = self.get_mask_strings_and_match_before(context, input_ids=input_ids, 267 | tokenizer=tokenizer) 268 | print('debug context:', context) 269 | print('debug list_s:', list_s) 270 | print('debug match_before:', match_before) 271 | tensor_str_finder = TensorStrFinder(tokenizer=tokenizer) 272 | mask = tensor_str_finder.get_strs_mask_in_tensor(list_s=list_s, t=input_ids, 273 | match_before=match_before) 274 | return mask 275 | 276 | 277 | class TensorStrFinder: 278 | def __init__(self, tokenizer): 279 | self.tokenizer = tokenizer 280 | 281 | def find_tensor_in_tensor(self, a_tensor: Union[torch.Tensor, list], b_tensor: torch.Tensor, 282 | return_mask=True, match_before: Optional[int] = None): 283 | if len(b_tensor.shape) == 2: 284 | assert b_tensor.shape[0] == 1 285 | b_tensor = b_tensor[0] 286 | if isinstance(a_tensor, list): 287 | a_tensor = torch.tensor(a_tensor) 288 | if a_tensor.device != b_tensor.device: 289 | a_tensor = a_tensor.to(b_tensor.device) 290 | 291 | window_size = len(a_tensor) 292 | b_windows = b_tensor.unfold(0, window_size, 1) 293 | 294 | matches = torch.all(b_windows == a_tensor, dim=1) 295 | 296 | positions = torch.nonzero(matches, as_tuple=True)[0] 297 | 298 | if return_mask: 299 | mask = torch.zeros_like(b_tensor, dtype=torch.bool) 300 | for pos in positions: 301 | if match_before is None or pos + window_size <= match_before: 302 | mask[pos:pos + window_size] = True 303 | return mask 304 | 305 | return positions 306 | 307 | def find_str_in_tensor(self, s: str, t: torch.Tensor, return_mask=True, match_before=None): 308 | s_tokens = self.tokenizer.encode(s, add_special_tokens=False) 309 | s_tensor = torch.LongTensor(s_tokens) 310 | return self.find_tensor_in_tensor(s_tensor, t, return_mask=return_mask, 311 | match_before=match_before) 312 | 313 | def get_strs_mask_in_tensor(self, list_s: List[str], t: torch.Tensor, match_before=None): 314 | list_s_tokens = [self.tokenizer.encode(s, add_special_tokens=False) for s in list_s] 315 | if 'Llama' in self.tokenizer.__class__.__name__: 316 | list_s_tokens = [s_tokens[1:] if s_tokens[0] == 29871 else s_tokens for s_tokens in list_s_tokens] 317 | list_s_tensor = [torch.LongTensor(s_tokens) for s_tokens in list_s_tokens] 318 | print('debug list_s_tensor:', list_s_tensor) 319 | mask_tensor_list = [ 320 | self.find_tensor_in_tensor(s_tensor, t, return_mask=True, match_before=match_before) for 321 | s_tensor in list_s_tensor] 322 | mask_tensor = functools.reduce(torch.logical_or, mask_tensor_list) 323 | return mask_tensor -------------------------------------------------------------------------------- /run_i2cl_transfer_learning.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import json 4 | import time 5 | import copy 6 | import random 7 | import argparse 8 | import itertools 9 | import torch 10 | import numpy as np 11 | from multiprocessing import Process, Queue 12 | from itertools import combinations 13 | 14 | import utils 15 | import my_datasets as md 16 | import evaluator as ev 17 | 18 | 19 | def main(args): 20 | # load config from target_path 21 | tar_exp_path = os.path.join(args.config['target_path'], 22 | args.model_name, args.dataset_name) 23 | tar_config_path = os.path.join(tar_exp_path, 'config.json') 24 | tar_result_path = os.path.join(tar_exp_path, 'result_dict.json') 25 | # reutrn if target_path does not exist 26 | if not os.path.exists(tar_config_path) or not os.path.exists(tar_result_path): 27 | print(f"target_path: {tar_exp_path} does not exist!") 28 | return 29 | # load config 30 | with open(tar_config_path, 'r') as f: 31 | config = json.load(f) 32 | # load result_dict 33 | with open(tar_result_path, 'r') as f: 34 | result_dict = json.load(f) 35 | 36 | # update config with args.config 37 | config.update(args.config) 38 | args.config = config 39 | args.result_dict = result_dict 40 | 41 | # set global seed 42 | utils.set_seed(args.config['seed']) 43 | # set device 44 | args.device = utils.set_device(args.gpu) 45 | # set metric used 46 | args.metric = args.config['metric'] 47 | # get save dir 48 | utils.init_exp_path(args, args.config['exp_name']) 49 | 50 | # load tokenizer and model 51 | model, tokenizer, model_config = \ 52 | utils.load_model_tokenizer(args.model_name, args.device) 53 | 54 | # get model_wrapper 55 | model_wrapper = utils.get_model_wrapper(args.model_name, model, 56 | tokenizer, model_config, 57 | args.device) 58 | # load datasets 59 | train_dataset = md.get_dataset(args.dataset_name, split='train', 60 | max_data_num=None, seed=args.config['seed']) 61 | test_dataset = md.get_dataset(args.dataset_name, split='test', 62 | max_data_num=args.config['test_data_num'], 63 | sample_mode=args.config['sample_method'], 64 | seed=args.config['seed']) 65 | 66 | # get max demonstration token length for each dataset 67 | if args.config['split_demon']: 68 | # when split demon, do not check max example token length 69 | args.test_max_token = 1e8 70 | else: 71 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer) 72 | 73 | # get shot_num 74 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia 75 | args.config['shot_per_class'] = 1 76 | args.config['bs'] = 1 77 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class']) 78 | # build evaluate 79 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs']) 80 | # init result_dict 81 | infer_result_dict = {'demon': {}, 82 | 'split_demon': {}, 83 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': [], 'ensemble': {}}, 84 | 'linear_coef': {}, 85 | 'time': {'calibrate': [], 'evaluate': []} 86 | } 87 | 88 | all_cv_dicts, all_coef = [], [] 89 | for dataset_name in list(md.target_datasets.keys()): 90 | # collect calibrated coefficients 91 | coe_save_path = os.path.join(args.config['target_path'], args.model_name, dataset_name, 'result_dict.json') 92 | with open(coe_save_path, 'r') as f: 93 | cur_result_dict = json.load(f) 94 | tem_coef = [] 95 | for run_id, coef in cur_result_dict['linear_coef'].items(): 96 | tem_coef.append(torch.tensor(coef)) 97 | # average strength params 98 | all_coef.append(torch.stack(tem_coef).mean(dim=0)) 99 | 100 | # collect context vectors 101 | cv_save_path = os.path.join(args.config['target_path'], args.model_name, dataset_name, 'cv_save_dict.json') 102 | with open(cv_save_path, 'r') as f: 103 | cv_dict = json.load(f) 104 | cur_cv_dict = {} 105 | for _, cv_dict in cv_dict.items(): 106 | for layer, sub_dict in cv_dict.items(): 107 | if layer not in cur_cv_dict: 108 | cur_cv_dict[layer] = {} 109 | for module, activation in sub_dict.items(): 110 | if module not in cur_cv_dict[layer]: 111 | cur_cv_dict[layer][module] = [] 112 | cur_cv_dict[layer][module].append(torch.tensor(activation)) 113 | # average context vector diict 114 | for layer, sub_dict in cur_cv_dict.items(): 115 | for module, activation_list in sub_dict.items(): 116 | cur_cv_dict[layer][module] = torch.stack(activation_list).mean(dim=0) 117 | all_cv_dicts.append(cur_cv_dict) 118 | 119 | for run_id in range(args.config['run_num']): 120 | run_name = f'run_{run_id}' 121 | args.run_name = run_name 122 | print(f'Run time {run_name}') 123 | run_seed = args.config['seed'] + run_id 124 | utils.set_seed(run_seed) 125 | 126 | # build val dataset 127 | _, split_demon_list, demon_data_index = \ 128 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num, 129 | max_demonstration_tok_len=args.test_max_token, 130 | add_extra_query=args.config['add_extra_query'], 131 | example_separator=args.config['example_separator'], 132 | gen_example_method = args.config['gen_example_method'], 133 | return_data_index=True, seed=random.randint(0, 1e6)) 134 | # build val_evaluator use demon_data_index 135 | val_dataset = copy.deepcopy(train_dataset) 136 | val_dataset.all_data = [train_dataset.all_data[i] for i in demon_data_index] 137 | 138 | # get demon 139 | demon_list = args.result_dict['demon'][run_name] 140 | assert split_demon_list == args.result_dict['split_demon'][run_name], \ 141 | print(f'split_demon_list: {split_demon_list} != {args.result_dict["split_demon"][run_name]}') 142 | 143 | # save demon_list 144 | infer_result_dict['demon'][run_name] = demon_list 145 | infer_result_dict['split_demon'][run_name] = split_demon_list 146 | 147 | # cal task simlarity =================================================================== 148 | cur_coef = torch.tensor(args.result_dict['linear_coef'][run_name]).view(-1) 149 | ref_coef = torch.stack(all_coef) 150 | ref_coef = ref_coef.view(ref_coef.size(0), -1) 151 | # calculate cosine similarity between cur_coef and all_coef 152 | sim = torch.nn.functional.cosine_similarity(cur_coef, ref_coef, dim=1) 153 | 154 | # keep sim and its index whose similarity is larger than threshold 155 | tar_sim = sim[sim > args.config['threshold']] 156 | tar_idx = torch.nonzero(sim > args.config['threshold']).view(-1) 157 | 158 | print(f'Similarities: {tar_sim}') 159 | print(f'Similar tasks: {[list(md.target_datasets.keys())[idx] for idx in tar_idx]}') 160 | 161 | tar_sim = tar_sim.cpu().numpy() 162 | tar_cv_dicts = [all_cv_dicts[idx] for idx in tar_idx] 163 | tar_coef = [all_coef[idx] for idx in tar_idx] 164 | assert len(tar_cv_dicts) > 2, print('No enough transferable tasks!') 165 | infer_result_dict['similar_task_num'] = len(tar_cv_dicts) 166 | 167 | # change top_k_sim to probability distribution that sums to 1 168 | def softmax_with_temperature(logits, temperature=1.0): 169 | scaled_logits = logits / temperature 170 | exps = np.exp(scaled_logits - np.max(scaled_logits)) # For numerical stability 171 | softmax_outputs = exps / np.sum(exps) 172 | return softmax_outputs 173 | 174 | tar_prob = softmax_with_temperature(tar_sim, args.config['temp']) 175 | context_vector_dict, linear_coef = prepare_inject_dicts_params(tar_prob, tar_cv_dicts, tar_coef) 176 | 177 | # set strength params 178 | model_wrapper.init_strength(args.config) 179 | 180 | # calibrate context vector 181 | s_t = time.time() 182 | model_wrapper.calibrate_strength(context_vector_dict, val_dataset, 183 | args.config, save_dir=args.save_dir, 184 | run_name=args.run_name) 185 | e_t = time.time() 186 | print(f'Calibration time: {e_t - s_t}') 187 | infer_result_dict['time']['calibrate'].append(e_t - s_t) 188 | 189 | # save linear_coef 190 | infer_result_dict['linear_coef'][run_name] = model_wrapper.linear_coef.tolist() 191 | 192 | # evaluate i2cl ======================================================================== 193 | s_t = time.time() 194 | with torch.no_grad(): 195 | with model_wrapper.inject_latent(context_vector_dict, args.config, 196 | model_wrapper.linear_coef): 197 | test_ours_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 198 | use_cache=args.config['use_cache']) 199 | print(f'Test I2CL result: {test_ours_result}\n') 200 | infer_result_dict['test_result']['ours'].append(test_ours_result) 201 | e_t = time.time() 202 | 203 | print(f'Evaluate time: {e_t - s_t}') 204 | infer_result_dict['time']['evaluate'].append(e_t - s_t) 205 | 206 | # save result_dict after each run 207 | with open(args.save_dir + '/result_dict.json', 'w') as f: 208 | json.dump(infer_result_dict, f, indent=4) 209 | 210 | # delete all variables 211 | del model, tokenizer, model_config, model_wrapper, train_dataset, test_dataset, test_evaluator 212 | del all_cv_dicts, all_coef 213 | del context_vector_dict, linear_coef 214 | del infer_result_dict 215 | 216 | 217 | def prepare_inject_dicts_params(tar_prob, tar_cv_dicts, tar_coef): 218 | target_layers = list(tar_cv_dicts[0].keys()) 219 | target_modules = list(tar_cv_dicts[0][target_layers[0]].keys()) 220 | print(f'target_layers: {target_layers}') 221 | print(f'target_modules: {target_modules}') 222 | # init an empty ensemble_dict with the same structure as all_inject_dicts 223 | ensemble_cv_dict = {layer: {module: 0 for module in target_modules} for layer in target_layers} 224 | new_coef = torch.zeros(tar_coef[0].size()) 225 | for idx, cv_dict in enumerate(tar_cv_dicts): 226 | for layer_idx, layer in enumerate(target_layers): 227 | for module_idx, module in enumerate(target_modules): 228 | cv = cv_dict[layer][module] 229 | coef = tar_coef[idx][layer_idx, module_idx, 0] 230 | ensemble_cv_dict[layer][module] += cv * coef * tar_prob[idx] 231 | new_coef += tar_coef[idx] * tar_prob[idx] 232 | # set the first strength param to 1 since coefficient of context vector has been included in the context vector 233 | new_coef[:, :, 0] = 1 234 | # set layer name in ensemble_cv_dict to int type 235 | ensemble_cv_dict = {int(layer): sub_dict for layer, sub_dict in ensemble_cv_dict.items()} 236 | return ensemble_cv_dict, new_coef 237 | 238 | 239 | # get args 240 | def get_args(): 241 | parser = argparse.ArgumentParser() 242 | parser.add_argument('--config_path', type=str, default='configs/config_i2cl_transfer_learning.py', help='path to config file') 243 | return parser.parse_args() 244 | 245 | 246 | if __name__ == "__main__": 247 | # get args 248 | args = get_args() 249 | # load config 250 | config = utils.load_config(args.config_path) 251 | # Generate all combinations of models and datasets 252 | combinations = list(itertools.product(config['models'], config['datasets'])) 253 | # Queue to hold tasks 254 | task_queue = Queue() 255 | for combine in combinations: 256 | task_queue.put(combine) 257 | 258 | def run_task(gpu_id, config): 259 | while not task_queue.empty(): 260 | model_name, dataset_name = task_queue.get() 261 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}") 262 | input_args = argparse.Namespace() 263 | cur_config = copy.deepcopy(config) 264 | input_args.model_name = model_name 265 | input_args.dataset_name = dataset_name 266 | input_args.gpu = gpu_id 267 | input_args.config = cur_config 268 | try: 269 | main(input_args) 270 | finally: 271 | # Clean up CUDA memory after each task 272 | gc.collect() 273 | torch.cuda.empty_cache() 274 | print(f"CUDA memory cleared for GPU {gpu_id}") 275 | time.sleep(5) 276 | 277 | # Create a process for each GPU 278 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']] 279 | # Start all processes 280 | for p in processes: 281 | p.start() 282 | # Wait for all processes to finish 283 | for p in processes: 284 | p.join() 285 | print("All tasks completed.") -------------------------------------------------------------------------------- /run_i2cl_infer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import json 4 | import time 5 | import copy 6 | import random 7 | import argparse 8 | import itertools 9 | import torch 10 | import numpy as np 11 | from multiprocessing import Process, Queue 12 | from itertools import combinations 13 | 14 | import utils 15 | import my_datasets as md 16 | import evaluator as ev 17 | 18 | 19 | def main(args): 20 | # load config from target_path 21 | tar_exp_path = os.path.join(args.config['target_path'], 22 | args.model_name, args.dataset_name) 23 | tar_config_path = os.path.join(tar_exp_path, 'config.json') 24 | tar_result_path = os.path.join(tar_exp_path, 'result_dict.json') 25 | # reutrn if target_path does not exist 26 | if not os.path.exists(tar_config_path) or not os.path.exists(tar_result_path): 27 | print(f"target_path: {tar_exp_path} does not exist!") 28 | return 29 | # load config 30 | with open(tar_config_path, 'r') as f: 31 | config = json.load(f) 32 | # load result_dict 33 | with open(tar_result_path, 'r') as f: 34 | result_dict = json.load(f) 35 | # update config with args.config 36 | config.update(args.config) 37 | args.config = config 38 | args.result_dict = result_dict 39 | 40 | # set global seed 41 | utils.set_seed(args.config['seed']) 42 | # set device 43 | args.device = utils.set_device(args.gpu) 44 | # set metric used 45 | args.metric = args.config['metric'] 46 | # get save dir 47 | utils.init_exp_path(args, args.config['exp_name']) 48 | 49 | # load tokenizer and model 50 | model, tokenizer, model_config = \ 51 | utils.load_model_tokenizer(args.model_name, args.device) 52 | 53 | # get model_wrapper 54 | model_wrapper = utils.get_model_wrapper(args.model_name, model, 55 | tokenizer, model_config, 56 | args.device) 57 | # load datasets 58 | train_dataset = md.get_dataset(args.dataset_name, split='train', 59 | max_data_num=None, seed = args.config['seed']) 60 | holdout_dataset = md.get_dataset(args.dataset_name, split='validation', 61 | max_data_num=args.config['val_data_num'], 62 | sample_mode=args.config['sample_method'], 63 | seed=args.config['seed']) 64 | test_dataset = md.get_dataset(args.dataset_name, split='test', 65 | max_data_num=args.config['test_data_num'], 66 | sample_mode=args.config['sample_method'], 67 | seed=args.config['seed']) 68 | 69 | # get shot_num 70 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia 71 | args.config['shot_per_class'] = 1 72 | args.config['bs'] = 1 73 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class']) 74 | # build evaluate 75 | holdout_evaluator = ev.Evaluator(holdout_dataset, batch_size=args.config['bs']) 76 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs']) 77 | # init result_dict 78 | infer_result_dict = {'demon': {}, 79 | 'split_demon': {}, 80 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': {}}, 81 | 'linear_coef': {}, 82 | 'time': {'calibrate': [], 'evaluate': []} 83 | } 84 | 85 | 86 | all_context_vector_dicts, all_linear_coefs = [], [] 87 | for run_id in range(args.config['run_num']): 88 | run_name = f'run_{run_id}' 89 | args.run_name = run_name 90 | print(f'Run time {run_name}') 91 | run_seed = args.config['seed'] + run_id 92 | utils.set_seed(run_seed) 93 | 94 | # zero-shot baseline 95 | if run_id == 0 and args.config['run_baseline']: 96 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 97 | use_cache=args.config['use_cache']) 98 | infer_result_dict['test_result']['zero_shot'].append(test_zeroshot_result) 99 | print(f'Test zero-shot result: {test_zeroshot_result}\n') 100 | 101 | if args.config['use_new_demon']: 102 | # sample demonstration 103 | count = 0 104 | temp_demon_list, temp_result_list = [], [] 105 | while True: 106 | demon, split_demon, demon_data_index = \ 107 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num, 108 | max_demonstration_tok_len=1e8, 109 | add_extra_query=args.config['add_extra_query'], 110 | example_separator=args.config['example_separator'], 111 | return_data_index=True, seed=random.randint(0, 1e6) + run_seed) 112 | temp_demon_list.append((demon, split_demon, demon_data_index)) 113 | 114 | if args.config['demo_sample_method'] == 'random': 115 | break 116 | else: 117 | tem_val_result = holdout_evaluator.evaluate(model_wrapper, tokenizer, 118 | demonstration=demon, 119 | use_cache=args.config['use_cache']) 120 | temp_result = tem_val_result[args.metric] 121 | temp_result_list.append(temp_result) 122 | if count > 20: 123 | if args.config['demo_sample_method'] == 'deficient': 124 | demon, split_demon, demon_data_index = temp_demon_list[np.argmin(temp_result_list)] 125 | else: 126 | raise ValueError('Invalid demon_sample_method') 127 | break 128 | count += 1 129 | 130 | if args.config['add_extra_query']: 131 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0] 132 | # remove all contents after the last first_format_anchor including the anchor 133 | if first_format_anchor in demon: 134 | baseline_demon = demon[:demon.rfind(first_format_anchor)] 135 | query_demon = demon[demon.rfind(first_format_anchor):] 136 | else: 137 | baseline_demon = demon 138 | query_demon = None 139 | print(f'Demonstration:\n{demon}\n') 140 | print(f'Baseline demonstration:\n{baseline_demon}\n') 141 | print(f'Query demonstration:\n{query_demon}\n') 142 | demon_list = [demon] 143 | split_demon_list = split_demon 144 | else: 145 | demon_list = args.result_dict['demon'][run_name] 146 | demon = demon_list[0] 147 | baseline_demon = demon 148 | try: 149 | split_demon_list = args.result_dict['split_demon'][run_name] 150 | except KeyError: 151 | raise ValueError('split_demon_list not found in result_dict!') 152 | 153 | # save demon_list 154 | infer_result_dict['demon'][run_name] = demon_list 155 | infer_result_dict['split_demon'][run_name] = split_demon_list 156 | 157 | # few-shot baseline 158 | if args.config['run_baseline']: 159 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, 160 | demonstration=baseline_demon, 161 | use_cache=args.config['use_cache']) 162 | infer_result_dict['test_result']['few_shot'].append(test_fewshot_result) 163 | print(f'Test few-shot result: {test_fewshot_result}\n') 164 | 165 | # extract latents ====================================================================== 166 | all_latent_dicts = [] 167 | with torch.no_grad(): 168 | if not args.config['split_demon']: 169 | target_demon_list = demon_list[0] 170 | else: 171 | target_demon_list = split_demon_list 172 | for cur_demon in target_demon_list: 173 | with model_wrapper.extract_latent(): 174 | demon_token = tokenizer(cur_demon, return_tensors='pt').to(args.device) 175 | _ = model(**demon_token) 176 | all_latent_dicts.append(model_wrapper.latent_dict) 177 | model_wrapper.reset_latent_dict() 178 | 179 | # generate context vector ============================================================== 180 | context_vector_dict = model_wrapper.get_context_vector(all_latent_dicts, args.config) 181 | del all_latent_dicts 182 | 183 | # get strength params =================================================================== 184 | model_wrapper.init_strength(args.config) 185 | del model_wrapper.linear_coef 186 | model_wrapper.linear_coef = torch.tensor(args.result_dict['linear_coef'][run_name]) 187 | 188 | # save context_vector_dict and linear_coef for ensemble 189 | all_context_vector_dicts.append(context_vector_dict) 190 | all_linear_coefs.append(model_wrapper.linear_coef) 191 | 192 | # prepare downstream tasks 193 | if args.config['downstream_datasets'] is None: 194 | downstream_datasets = [args.dataset_name] 195 | else: 196 | downstream_datasets = args.config['downstream_datasets'] 197 | 198 | for target_task_name in downstream_datasets: 199 | # init saving structure 200 | if target_task_name not in infer_result_dict['test_result']['ours']: 201 | infer_result_dict['test_result']['ours'][target_task_name] = [] 202 | 203 | # prepare target dataset 204 | target_dataset = md.get_dataset(target_task_name, split='test', 205 | max_data_num=args.config['test_data_num'], 206 | sample_mode=args.config['sample_method'], 207 | seed=args.config['seed']) 208 | target_evaluator = ev.Evaluator(target_dataset, batch_size=args.config['bs']) 209 | 210 | # evaluate i2cl ======================================================================== 211 | s_t = time.time() 212 | with torch.no_grad(): 213 | with model_wrapper.inject_latent(context_vector_dict, args.config, 214 | model_wrapper.linear_coef): 215 | test_ours_result = target_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 216 | use_cache=args.config['use_cache']) 217 | print(f'Test I2CL result: {test_ours_result}\n') 218 | infer_result_dict['test_result']['ours'][target_task_name].append(test_ours_result) 219 | e_t = time.time() 220 | 221 | print(f'Evaluate time: {e_t - s_t}') 222 | infer_result_dict['time']['evaluate'].append(e_t - s_t) 223 | 224 | # save result_dict after each run 225 | with open(args.save_dir + '/result_dict.json', 'w') as f: 226 | json.dump(infer_result_dict, f, indent=4) 227 | 228 | # delete all variables 229 | del model, tokenizer, model_config, model_wrapper, train_dataset, test_dataset, holdout_dataset 230 | del test_evaluator, holdout_evaluator 231 | del all_context_vector_dicts, all_linear_coefs 232 | del infer_result_dict 233 | 234 | 235 | # get args 236 | def get_args(): 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument('--config_path', type=str, default='configs/config_i2cl_infer.py', help='path to config file') 239 | return parser.parse_args() 240 | 241 | 242 | if __name__ == "__main__": 243 | # get args 244 | args = get_args() 245 | # load config 246 | config = utils.load_config(args.config_path) 247 | # Generate all combinations of models and datasets 248 | combinations = list(itertools.product(config['models'], config['datasets'])) 249 | # Queue to hold tasks 250 | task_queue = Queue() 251 | for combine in combinations: 252 | task_queue.put(combine) 253 | 254 | def run_task(gpu_id, config): 255 | while not task_queue.empty(): 256 | model_name, dataset_name = task_queue.get() 257 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}") 258 | input_args = argparse.Namespace() 259 | cur_config = copy.deepcopy(config) 260 | input_args.model_name = model_name 261 | input_args.dataset_name = dataset_name 262 | input_args.gpu = gpu_id 263 | input_args.config = cur_config 264 | try: 265 | main(input_args) 266 | finally: 267 | # Clean up CUDA memory after each task 268 | gc.collect() 269 | torch.cuda.empty_cache() 270 | print(f"CUDA memory cleared for GPU {gpu_id}") 271 | time.sleep(5) 272 | 273 | # Create a process for each GPU 274 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']] 275 | # Start all processes 276 | for p in processes: 277 | p.start() 278 | # Wait for all processes to finish 279 | for p in processes: 280 | p.join() 281 | print("All tasks completed.") -------------------------------------------------------------------------------- /run_label_anchor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import copy 4 | import time 5 | import random 6 | import argparse 7 | import itertools 8 | import torch 9 | import torch.nn.functional as F 10 | from multiprocessing import Process, Queue 11 | 12 | import utils 13 | import my_datasets as md 14 | import evaluator as ev 15 | 16 | 17 | def cached_evaluation(config, dataset, model, tokenizer, compressed_past_key_values, 18 | compressed_attn_mask, position_offset=None): 19 | batch_size = config['bs'] 20 | # prepare label dict 21 | label_map = {} 22 | ans_txt_list = dataset.get_dmonstration_template()['options'] 23 | for label, ans_txt in enumerate(ans_txt_list): 24 | if 'gpt' in tokenizer.__class__.__name__.lower(): 25 | ans_txt = ' ' + ans_txt # add space to the beginning of answer 26 | ans_tok = tokenizer.encode(ans_txt, add_special_tokens=False)[0] # use the first token if more than one token 27 | print(f"ans_txt: {ans_txt}, ans_tok: {ans_tok}") 28 | label_map[ans_tok] = label # index is the label 29 | print(f"label_map: {label_map}") 30 | 31 | # prepare all data 32 | all_pred_labels = [] 33 | all_inputs, all_labels = [], [] 34 | for data in dataset.all_data: 35 | ques_str, _, label = dataset.apply_template(data) 36 | context = ques_str 37 | all_inputs.append(context) 38 | all_labels.append(label) 39 | 40 | # prepare cached data 41 | cached_past_key_values = tuple(tuple(t.repeat(batch_size, 1, 1, 1) for t in tup) 42 | for tup in compressed_past_key_values) 43 | cached_attn_mask = compressed_attn_mask.repeat(batch_size, 1) 44 | 45 | # loop over all data 46 | with torch.no_grad(): 47 | for i in range(0, len(all_inputs), batch_size): 48 | cur_inputs = all_inputs[i:i+batch_size] 49 | input_tok = tokenizer(cur_inputs, return_tensors="pt", padding=True) 50 | input_ids = input_tok['input_ids'].to(model.device) 51 | attn_mask = input_tok['attention_mask'].to(model.device) 52 | 53 | # get index for prediction logits, need to be applied before concatenating demon_attn_mask with attn_mask 54 | pred_loc = utils.last_one_indices(attn_mask).to(model.device) 55 | 56 | # get logits 57 | if position_offset is not None: 58 | position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=model.device).repeat(input_ids.size(0), 1) 59 | position_ids = position_ids + position_offset 60 | else: 61 | position_ids = None 62 | 63 | attn_mask = torch.cat([cached_attn_mask, attn_mask], dim=1) 64 | output = model(input_ids=input_ids, attention_mask=attn_mask, 65 | past_key_values=cached_past_key_values, 66 | position_ids=position_ids) 67 | # get prediction logits 68 | logits = output.logits 69 | pred_logits = logits[torch.arange(logits.size(0)), pred_loc] 70 | # get prediction labels 71 | interest_index = list(label_map.keys()) 72 | pred_logits = pred_logits[:, interest_index] 73 | probs = F.softmax(pred_logits, dim=-1) 74 | pred_labels = probs.argmax(dim=-1) 75 | # save results 76 | all_pred_labels.extend(pred_labels.cpu().numpy().tolist()) 77 | 78 | assert len(all_pred_labels) == len(all_labels) 79 | # both all_results and all_labels are list containing label index, can you help me to calculate accuracy and macro f1? 80 | # initialize TP, FP, FN 81 | acc = [] 82 | num_classes = dataset.class_num 83 | TP = [0] * num_classes 84 | FP = [0] * num_classes 85 | FN = [0] * num_classes 86 | for i, true_label in enumerate(all_labels): 87 | pred_label = all_pred_labels[i] 88 | pred = pred_label == true_label 89 | acc.append(pred) 90 | # Update TP, FP, FN 91 | if pred: 92 | TP[true_label] += 1 93 | else: 94 | FP[pred_label] += 1 95 | FN[true_label] += 1 96 | # Calculate precision, recall, F1 for each class and macro F1 97 | precision = [0] * num_classes 98 | recall = [0] * num_classes 99 | f1 = [0] * num_classes 100 | for i in range(num_classes): 101 | precision[i] = TP[i] / (TP[i] + FP[i]) if (TP[i] + FP[i]) > 0 else 0 102 | recall[i] = TP[i] / (TP[i] + FN[i]) if (TP[i] + FN[i]) > 0 else 0 103 | f1[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0 104 | macro_f1 = sum(f1) / num_classes 105 | acc = sum(acc) / len(acc) 106 | return {'acc': acc, 'macro_f1': macro_f1} 107 | 108 | 109 | def main(args): 110 | # set global seed 111 | utils.set_seed(args.config['seed']) 112 | # set device 113 | args.device = utils.set_device(args.gpu) 114 | # set metric used 115 | args.metric = args.config['metric'] 116 | # get save dir 117 | utils.init_exp_path(args, args.config['exp_name']) 118 | 119 | # load tokenizer and model 120 | model, tokenizer, model_config = \ 121 | utils.load_model_tokenizer(args.model_name, args.device) 122 | 123 | # get model_wrapper 124 | model_wrapper = utils.get_model_wrapper(args.model_name, model, 125 | tokenizer, model_config, 126 | args.device) 127 | # load datasets 128 | train_dataset = md.get_dataset(args.dataset_name, split='train', 129 | max_data_num=None, 130 | seed=args.config['seed']) 131 | test_dataset = md.get_dataset(args.dataset_name, split='test', 132 | max_data_num=args.config['test_data_num'], 133 | sample_mode=args.config['sample_method'], 134 | seed=args.config['seed']) 135 | 136 | # get max demonstration token length for each dataset 137 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer) 138 | 139 | # get shot_num 140 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia 141 | args.config['shot_per_class'] = 1 142 | args.config['bs'] = 1 143 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class']) 144 | # build evaluate 145 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs']) 146 | # init result_dict 147 | result_dict = {'demon': {}, 148 | 'split_demon': {}, 149 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': []}, 150 | 'time': {'calibrate': [], 'evaluate': []}, 151 | } 152 | 153 | for run_id in range(args.config['run_num']): 154 | run_name = f'run_{run_id}' 155 | args.run_name = run_name 156 | print(f'Run time {run_name}') 157 | run_seed = args.config['seed'] + run_id 158 | utils.set_seed(run_seed) 159 | 160 | # zero-shot baseline 161 | if run_id == 0 and args.config['run_baseline']: 162 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='', 163 | use_cache=args.config['use_cache']) 164 | result_dict['test_result']['zero_shot'].append(test_zeroshot_result) 165 | print(f'Test zero-shot result: {test_zeroshot_result}\n') 166 | 167 | # sample demonstration 168 | demon, _ = \ 169 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num, 170 | max_demonstration_tok_len=args.test_max_token, 171 | add_extra_query=args.config['add_extra_query'], 172 | example_separator=args.config['example_separator'], 173 | seed=random.randint(0, 1e6)) 174 | 175 | if args.config['add_extra_query']: 176 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0] 177 | # remove all contents after the last first_format_anchor including the anchor 178 | if first_format_anchor in demon: 179 | baseline_demon = demon[:demon.rfind(first_format_anchor)] 180 | query_demon = demon[demon.rfind(first_format_anchor):] 181 | else: 182 | baseline_demon = demon 183 | query_demon = None 184 | print(f'Demonstration:\n{demon}\n') 185 | print(f'Baseline demonstration:\n{baseline_demon}\n') 186 | print(f'Query demonstration:\n{query_demon}\n') 187 | 188 | # few-shot baseline 189 | if args.config['run_baseline']: 190 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, 191 | demonstration=baseline_demon, 192 | use_cache=args.config['use_cache']) 193 | result_dict['test_result']['few_shot'].append(test_fewshot_result) 194 | print(f'Test few-shot result: {test_fewshot_result}\n') 195 | 196 | # apply label anchor 197 | context_solver = utils.ContextSolver(task_name=args.dataset_name, tokenizer=tokenizer) 198 | demon_token = tokenizer(demon, return_tensors='pt').to(args.device) 199 | compress_attn_mask = context_solver.get_mask(demon_token['input_ids']) 200 | print(f'Compress_attn_mask: {compress_attn_mask}\n') 201 | # compressed_text 202 | compressed_token_id = copy.deepcopy(demon_token['input_ids'][0]) 203 | mask = copy.deepcopy(compress_attn_mask).cpu().detach().numpy() 204 | compressed_token_id = compressed_token_id.cpu().detach().numpy() 205 | compressed_text = tokenizer.decode(list(compressed_token_id[mask])) 206 | print(f'Compressed_text: {compressed_text}\n') 207 | with torch.no_grad(): 208 | demon_outputs = model(**demon_token, use_cache=True) 209 | past_key_values = demon_outputs.past_key_values 210 | 211 | if args.model_name == 'meta-llama/Llama-2-7b-hf': 212 | mask_end_idx = torch.where(compress_attn_mask)[0][-1] + 1 213 | cached_past_key_values = tuple(tuple(t[:, :, :mask_end_idx, :] for t in tup) for tup in past_key_values) 214 | cached_attn_mask = copy.deepcopy(compress_attn_mask)[:mask_end_idx].unsqueeze(0) 215 | else: 216 | cached_past_key_values = tuple(tuple(t[:, :, compress_attn_mask, :] for t in tup) for tup in past_key_values) 217 | cached_attn_mask = torch.ones(1, compress_attn_mask.sum(), dtype=torch.bool).to(args.device) 218 | print(f'Cached_attn_mask: {cached_attn_mask}\n') 219 | 220 | if args.model_name == 'gpt2-xl': 221 | position_offset = 0 222 | elif args.model_name == 'EleutherAI/gpt-j-6B': 223 | position_offset = torch.where(compress_attn_mask)[0][-1] + 1 224 | elif args.model_name == 'meta-llama/Llama-2-7b-hf': 225 | position_offset = None 226 | else: 227 | raise ValueError('model not supported') 228 | 229 | # evaluate with label anchor 230 | s_t = time.time() 231 | test_result = cached_evaluation(args.config, test_dataset, model, tokenizer, 232 | cached_past_key_values, cached_attn_mask, position_offset) 233 | print(f'Test label_anchor result: {test_result}\n') 234 | result_dict['test_result']['ours'].append(test_result) 235 | e_t = time.time() 236 | result_dict['time']['evaluate'].append(e_t - s_t) 237 | 238 | # save result_dict after each run 239 | with open(args.save_dir + '/result_dict.json', 'w') as f: 240 | json.dump(result_dict, f, indent=4) 241 | 242 | # delete all variables 243 | del model_wrapper, model, tokenizer, train_dataset, test_dataset 244 | del test_evaluator 245 | del result_dict 246 | 247 | 248 | # get args 249 | def get_args(): 250 | parser = argparse.ArgumentParser() 251 | parser.add_argument('--config_path', type=str, default='configs/config_label_anchor.py', help='path to config file') 252 | return parser.parse_args() 253 | 254 | 255 | if __name__ == "__main__": 256 | # get args 257 | args = get_args() 258 | # load config 259 | config = utils.load_config(args.config_path) 260 | # Generate all combinations of models and datasets 261 | combinations = list(itertools.product(config['models'], config['datasets'])) 262 | # Queue to hold tasks 263 | task_queue = Queue() 264 | for combine in combinations: 265 | task_queue.put(combine) 266 | 267 | def run_task(gpu_id, config): 268 | while not task_queue.empty(): 269 | model_name, dataset_name = task_queue.get() 270 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}") 271 | input_args = argparse.Namespace() 272 | cur_config = copy.deepcopy(config) 273 | input_args.model_name = model_name 274 | input_args.dataset_name = dataset_name 275 | input_args.gpu = gpu_id 276 | input_args.config = cur_config 277 | try: 278 | main(input_args) 279 | finally: 280 | # Clean up CUDA memory after each task 281 | gc.collect() 282 | torch.cuda.empty_cache() 283 | print(f"CUDA memory cleared for GPU {gpu_id}") 284 | time.sleep(5) 285 | 286 | # Create a process for each GPU 287 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']] 288 | # Start all processes 289 | for p in processes: 290 | p.start() 291 | # Wait for all processes to finish 292 | for p in processes: 293 | p.join() 294 | print("All tasks completed.") -------------------------------------------------------------------------------- /wrapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | import string 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from contextlib import contextmanager 8 | from functools import reduce 9 | import numpy as np 10 | import utils 11 | import global_vars as gv 12 | from peft import get_peft_model, PromptTuningConfig 13 | 14 | 15 | class ModelWrapper(nn.Module): 16 | def __init__(self, model, tokenizer, model_config, device): 17 | super().__init__() 18 | self.model = model.eval() 19 | self.tokenizer = tokenizer 20 | self.model_config = model_config 21 | self.device = device 22 | self.num_layers = self._get_layer_num() 23 | self.latent_dict = {} 24 | self.linear_coef = None 25 | self.inject_layers = None 26 | print(f"The model has {self.num_layers} layers:") 27 | 28 | def reset_latent_dict(self): 29 | self.latent_dict = {} 30 | 31 | @contextmanager 32 | def extract_latent(self): 33 | handles = [] 34 | try: 35 | # attach hook 36 | for layer_idx in range(self.num_layers): 37 | handles.append( 38 | self._get_nested_attr(self._get_arribute_path(layer_idx, 'attn')).register_forward_hook( 39 | self.extract_hook_func(layer_idx, 'attn'))) 40 | handles.append( 41 | self._get_nested_attr(self._get_arribute_path(layer_idx, 'mlp')).register_forward_hook( 42 | self.extract_hook_func(layer_idx, 'mlp'))) 43 | handles.append( 44 | self._get_nested_attr(self._get_arribute_path(layer_idx, 'hidden')).register_forward_hook( 45 | self.extract_hook_func(layer_idx, 'hidden'))) 46 | yield 47 | finally: 48 | # remove hook 49 | for handle in handles: 50 | handle.remove() 51 | 52 | def extract_hook_func(self, layer_idx, target_module): 53 | if layer_idx not in self.latent_dict: 54 | self.latent_dict[layer_idx] = {} 55 | def hook_func(module, inputs, outputs): 56 | if type(outputs) is tuple: 57 | outputs = outputs[0] 58 | self.latent_dict[layer_idx][target_module] = outputs.detach().cpu() 59 | return hook_func 60 | 61 | @contextmanager 62 | def inject_latent(self, context_vector_dict, config, linear_coef, train_mode=False): 63 | handles = [] 64 | assert self.inject_layers is not None, "inject_layers is not set!" 65 | inject_method = config['inject_method'] 66 | inject_pos = config['inject_pos'] 67 | add_noise = config['add_noise'] 68 | noise_scale = config['noise_scale'] 69 | try: 70 | # attach hook 71 | for layer_idx, layer in enumerate(self.inject_layers): 72 | for module_idx, module in enumerate(config['module']): 73 | context_vector_container = [context_vector_dict[layer][module].to(self.device)] 74 | strength = linear_coef[layer_idx, module_idx, :] 75 | inject_func = self.inject_hook_func(context_vector_container, strength, 76 | inject_method, add_noise, noise_scale, 77 | inject_pos, train_mode) 78 | handles.append( 79 | self._get_nested_attr(self._get_arribute_path(layer, module)). 80 | register_forward_hook(inject_func) 81 | ) 82 | yield 83 | finally: 84 | # remove hook 85 | print(f"Removing {len(handles)} hooks...") 86 | for handle in handles: 87 | handle.remove() 88 | 89 | def inject_hook_func(self, context_vector_container, strength, inject_method, 90 | add_noise, noise_scale, inject_pos, train_mode=False): 91 | 92 | def hook_func(module, inputs, outputs): 93 | if type(outputs) is tuple: 94 | output = outputs[0] 95 | else: 96 | output = outputs 97 | # init context_vector 98 | context_vector = context_vector_container[0] 99 | # expand inject_value to match output size (b, seq_len, d) 100 | context_vector = context_vector.expand(output.size(0), output.size(1), context_vector.size(-1)) 101 | 102 | if inject_method == 'add': 103 | output = output + F.relu(strength) * context_vector 104 | elif inject_method == 'linear': 105 | if inject_pos == 'all': 106 | output = strength[1] * output + strength[0] * context_vector 107 | else: 108 | if inject_pos == 'last': 109 | for i in range(output.size(0)): 110 | end_idx = gv.ATTN_MASK_END[i] - 1 111 | content = strength[1] * output[i, end_idx, :].clone().detach() + strength[0] * context_vector[i, end_idx, :] 112 | output[i, end_idx, :] = content 113 | elif inject_pos == 'first': 114 | content = strength[1] * output[:, 0, :].clone().detach() + strength[0] * context_vector[:, 0, :] 115 | output[:, 0, :] = content 116 | elif inject_pos == 'random': 117 | for i in range(output.size(0)): 118 | end_idx = gv.ATTN_MASK_END[i] 119 | random_idx = random.randint(0, end_idx) 120 | content = strength[1] * output[i, random_idx, :].clone().detach() + strength[0] * context_vector[i, random_idx, :] 121 | output[i, random_idx, :] = content 122 | else: 123 | raise ValueError("only support all, last, first or random!") 124 | 125 | elif inject_method == 'balance': 126 | a, b = strength[0], strength[1] 127 | output = ((1.0 - a) * output + a * context_vector) * b 128 | else: 129 | raise ValueError("only support add, linear or balance!") 130 | 131 | if add_noise and train_mode: 132 | # get l2_norm of output and use it as a scalar to scale noise, make sure no gradient 133 | output_norm = torch.norm(output, p=2, dim=-1).detach().unsqueeze(-1) 134 | noise = torch.randn_like(output).detach() 135 | output += noise * output_norm * noise_scale 136 | 137 | if type(outputs) is tuple: 138 | outputs = list(outputs) 139 | outputs[0] = output 140 | outputs = tuple(outputs) 141 | else: 142 | outputs = output 143 | return outputs 144 | return hook_func 145 | 146 | 147 | @contextmanager 148 | def replace_latent(self, context_vector_dict, target_layers, config): 149 | handles = [] 150 | try: 151 | # attach hook 152 | for _, layer in enumerate(target_layers): 153 | for _, module in enumerate(config['module']): 154 | context_vector_container = [context_vector_dict[layer][module].to(self.device)] 155 | inject_func = self.replace_hook_func(context_vector_container) 156 | handles.append( 157 | self._get_nested_attr(self._get_arribute_path(layer, module)). 158 | register_forward_hook(inject_func)) 159 | yield 160 | finally: 161 | # remove hook 162 | print(f"Removing {len(handles)} hooks...") 163 | for handle in handles: 164 | handle.remove() 165 | 166 | def replace_hook_func(self, context_vector_container): 167 | def hook_func(module, inputs, outputs): 168 | if type(outputs) is tuple: 169 | output = outputs[0] 170 | else: 171 | output = outputs 172 | # init context_vector 173 | context_vector = context_vector_container[0] 174 | # replace hidden states of last token position with context_vector 175 | for i in range(output.size(0)): 176 | end_idx = gv.ATTN_MASK_END[i] 177 | output[i, end_idx, :] = context_vector 178 | 179 | if type(outputs) is tuple: 180 | outputs = list(outputs) 181 | outputs[0] = output 182 | outputs = tuple(outputs) 183 | else: 184 | outputs = output 185 | return outputs 186 | return hook_func 187 | 188 | 189 | def get_context_vector(self, all_latent_dicts, config): 190 | if len(all_latent_dicts) == 1: 191 | latent_dict = all_latent_dicts[0] 192 | output_dict = {} 193 | for layer, sub_dict in latent_dict.items(): 194 | output_dict[layer] = {} 195 | for module in config['module']: 196 | latent_value = sub_dict[module] 197 | if config['tok_pos'] == 'last': 198 | latent_value = latent_value[:, -1, :].squeeze() 199 | elif config['tok_pos'] == 'first': 200 | latent_value = latent_value[:, 0, :].squeeze() 201 | elif config['tok_pos'] == 'random': 202 | latent_value = latent_value[:, random.randint(0, latent_value.size(1)), :].squeeze() 203 | else: 204 | raise ValueError("only support last, first or random!") 205 | output_dict[layer][module] = latent_value.detach().to('cpu') 206 | else: 207 | # concatenate context vector for each module 208 | ensemble_dict = {module:[] for module in config['module']} # {module_name: []} 209 | for _, latent_dict in enumerate(all_latent_dicts): 210 | cur_dict = {module:[] for module in config['module']} # {module_name: []} 211 | for layer, sub_dict in latent_dict.items(): 212 | for module in config['module']: 213 | latent_value = sub_dict[module] # (b, seq_len, d) 214 | if config['tok_pos'] == 'last': 215 | latent_value = latent_value[:, -1, :].squeeze() 216 | elif config['tok_pos'] == 'first': 217 | latent_value = latent_value[:, 0, :].squeeze() 218 | elif config['tok_pos'] == 'random': 219 | latent_value = latent_value[:, random.randint(0, latent_value.size(1)), :].squeeze() 220 | else: 221 | raise ValueError("only support last, first or random!") 222 | cur_dict[module].append(latent_value) 223 | 224 | for module, latent_list in cur_dict.items(): 225 | cur_latent = torch.stack(latent_list, dim=0) # (layer_num, d) 226 | ensemble_dict[module].append(cur_latent) 227 | 228 | for module, latent_list in ensemble_dict.items(): 229 | if config['post_fuse_method'] == 'mean': 230 | context_vector = torch.stack(latent_list, dim=0).mean(dim=0) # (layer_num, d) 231 | ensemble_dict[module] = context_vector 232 | elif config['post_fuse_method'] == 'pca': 233 | latents = torch.stack(latent_list, dim=0) # (ensemble_num, layer_num, d) 234 | ensemble_num, layer_num, d = latents.size() 235 | latents = latents.view(ensemble_num, -1) # (ensemble_num*layer_num, d) 236 | # apply pca 237 | pca = utils.PCA(n_components=1).to(latents.device).fit(latents.float()) 238 | context_vector = (pca.components_.sum(dim=0, keepdim=True) + pca.mean_).mean(0) 239 | ensemble_dict[module] = context_vector.view(layer_num, d) # (layer_num, d) 240 | else: 241 | raise ValueError("Unsupported ensemble method!") 242 | # reorganize ensemble_dict into layers 243 | layers = list(all_latent_dicts[0].keys()) 244 | output_dict = {layer: {} for layer in layers} 245 | for module, context_vector in ensemble_dict.items(): 246 | for layer_idx, layer in enumerate(layers): 247 | output_dict[layer][module] = context_vector[layer_idx, :].detach().to('cpu') # (d) 248 | 249 | return output_dict 250 | 251 | 252 | def calibrate_strength(self, context_vector_dict, dataset, config, 253 | save_dir=None, run_name=None): 254 | # prepare label dict 255 | label_map = {} 256 | ans_txt_list = dataset.get_dmonstration_template()['options'] 257 | for label, ans_txt in enumerate(ans_txt_list): 258 | if 'gpt' in self.tokenizer.__class__.__name__.lower(): 259 | ans_txt = ' ' + ans_txt # add space to the beginning of answer 260 | ans_tok = self.tokenizer.encode(ans_txt, add_special_tokens=False)[0] # use the first token if more than one token 261 | print(f"ans_txt: {ans_txt}, ans_tok: {ans_tok}") 262 | label_map[label] = ans_tok # index is the label 263 | print(f"label_map: {label_map}") 264 | 265 | # frozen all parameters 266 | for param in self.model.parameters(): 267 | param.requires_grad = False 268 | 269 | # init optimizer 270 | optim_paramters = [{'params': self.linear_coef}] 271 | if config['optim'] == 'sgd': 272 | optimizer = torch.optim.SGD(optim_paramters, lr=config['lr'], 273 | weight_decay=config['wd']) 274 | elif config['optim'] == 'adamW': 275 | optimizer = torch.optim.AdamW(optim_paramters, config['lr'], 276 | weight_decay=config['wd']) 277 | elif config['optim'] == 'adam': 278 | optimizer = torch.optim.Adam(optim_paramters, config['lr']) 279 | else: 280 | raise ValueError('optim must be sgd, adamW or adam!') 281 | 282 | # get all_data 283 | all_data = dataset.all_data 284 | 285 | # init lr_scheduler 286 | epochs, batch_size = config['epochs'], config['grad_bs'] 287 | total_steps = epochs * len(all_data) // batch_size 288 | warmup_steps = int((0.05*epochs) * (len(all_data) // batch_size)) 289 | lr_lambda = lambda step: min(1.0, step / warmup_steps) * (1 + math.cos(math.pi * step / total_steps)) / 2 \ 290 | if step > warmup_steps else step / warmup_steps 291 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 292 | 293 | # train 294 | print('Calibrating strength params...') 295 | with self.inject_latent(context_vector_dict, config, 296 | self.linear_coef, train_mode=True): 297 | loss_list = [] 298 | all_data_index = list(range(len(all_data))) 299 | epoch_iter = len(all_data) // batch_size 300 | for _ in range(epochs): 301 | epoch_loss = [] 302 | for i in range(epoch_iter): 303 | np.random.shuffle(all_data_index) 304 | batch_index = all_data_index[:batch_size] 305 | batch_data = [all_data[idx] for idx in batch_index] 306 | batch_input, batch_label = [], [] 307 | for data in batch_data: 308 | input_str, ans_list, label = dataset.apply_template(data) 309 | 310 | # collect single demonstration example 311 | if config['cali_example_method'] == 'normal': 312 | pass 313 | elif config['cali_example_method'] == 'random_label': 314 | label = random.choice(list(range(len(ans_list)))) 315 | else: 316 | raise ValueError("only support normal or random_label!") 317 | 318 | batch_input.append(input_str) 319 | batch_label.append(label) 320 | 321 | input_tok = self.tokenizer(batch_input, return_tensors='pt', padding=True) 322 | input_ids = input_tok['input_ids'].to(self.device) 323 | attn_mask = input_tok['attention_mask'].to(self.device) 324 | pred_loc = utils.last_one_indices(attn_mask).to(self.device) 325 | # set global vars 326 | gv.ATTN_MASK_END = pred_loc 327 | gv.ATTN_MASK_START = torch.zeros_like(pred_loc) 328 | # forward 329 | logits = self.model(input_ids=input_ids, attention_mask=attn_mask).logits 330 | # get prediction logits 331 | pred_logits = logits[torch.arange(logits.size(0)), pred_loc] 332 | # get loss 333 | gt_label = torch.tensor([label_map[label] for label in batch_label]).to(self.device) 334 | loss = F.cross_entropy(pred_logits, gt_label, reduction='mean') 335 | epoch_loss.append(loss.item()) 336 | # update strength params 337 | optimizer.zero_grad() 338 | loss.backward() 339 | optimizer.step() 340 | scheduler.step() 341 | cur_lr = optimizer.param_groups[0]['lr'] 342 | print(f'Epoch {_+1}/{epochs}, batch {i//batch_size+1}/{len(all_data)//batch_size+1}, loss: {loss.item()}, lr: {cur_lr}') 343 | epoch_loss = np.mean(epoch_loss) 344 | loss_list.append(epoch_loss) 345 | 346 | # fronzen all learnable strength params 347 | self.linear_coef.requires_grad = False 348 | # set model to eval mode 349 | self.model.eval() 350 | # plot loss curve and save it 351 | utils.plot_loss_curve(loss_list, save_dir + f'/{run_name}_loss_curve.png') 352 | 353 | 354 | def softprompt(self, config, dataset, save_dir=None, run_name=None): 355 | pt_config = PromptTuningConfig(**config['pt_config']) 356 | peft_model = get_peft_model(self.model, pt_config) 357 | 358 | # prepare label dict 359 | label_map = {} 360 | ans_txt_list = dataset.get_dmonstration_template()['options'] 361 | for label, ans_txt in enumerate(ans_txt_list): 362 | if 'gpt' in self.tokenizer.__class__.__name__.lower(): 363 | ans_txt = ' ' + ans_txt # add space to the beginning of answer 364 | ans_tok = self.tokenizer.encode(ans_txt, add_special_tokens=False)[0] # use the first token if more than one token 365 | print(f"ans_txt: {ans_txt}, ans_tok: {ans_tok}") 366 | label_map[label] = ans_tok # index is the label 367 | print(f"label_map: {label_map}") 368 | 369 | # print trainable parameters 370 | peft_model.print_trainable_parameters() 371 | print(f'PEFT model:\n {peft_model}') 372 | # set model to peft model 373 | self.model = peft_model 374 | 375 | # init optimizer 376 | optim_paramters = [{'params': self.model.parameters()}] 377 | if config['optim'] == 'sgd': 378 | optimizer = torch.optim.SGD(optim_paramters, lr=config['lr'], 379 | weight_decay=config['wd']) 380 | elif config['optim'] == 'adamW': 381 | optimizer = torch.optim.AdamW(optim_paramters, config['lr'], 382 | weight_decay=config['wd']) 383 | elif config['optim'] == 'adam': 384 | optimizer = torch.optim.Adam(optim_paramters, config['lr']) 385 | else: 386 | raise ValueError('optim must be sgd, adamW or adam!') 387 | 388 | # get all data 389 | all_data = dataset.all_data 390 | 391 | # init lr_scheduler 392 | epochs, batch_size = config['epochs'], config['grad_bs'] 393 | total_steps = epochs * len(all_data) // batch_size 394 | warmup_steps = int((0.05*epochs) * (len(all_data) // batch_size)) 395 | lr_lambda = lambda step: min(1.0, step / warmup_steps) * (1 + math.cos(math.pi * step / total_steps)) / 2 \ 396 | if step > warmup_steps else step / warmup_steps 397 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 398 | 399 | # train 400 | loss_list = [] 401 | all_data_index = list(range(len(all_data))) 402 | for _ in range(epochs): 403 | epoch_loss = [] 404 | np.random.shuffle(all_data_index) 405 | for i in range(0, len(all_data), batch_size): 406 | batch_index = all_data_index[i: i + batch_size] 407 | batch_data = [all_data[idx] for idx in batch_index] 408 | batch_input, batch_label = [], [] 409 | for data in batch_data: 410 | input_str, _, label = dataset.apply_template(data) 411 | batch_input.append(input_str) 412 | batch_label.append(label) 413 | 414 | input_tok = self.tokenizer(batch_input, return_tensors='pt', padding=True) 415 | input_ids = input_tok['input_ids'].to(self.device) 416 | attn_mask = input_tok['attention_mask'].to(self.device) 417 | pred_loc = utils.last_one_indices(attn_mask).to(self.device) 418 | # forward 419 | logits = self.model(input_ids=input_ids, attention_mask=attn_mask).logits 420 | # get prediction logits 421 | pred_logits = logits[torch.arange(logits.size(0)), pred_loc] 422 | # get loss 423 | gt_label = torch.tensor([label_map[label] for label in batch_label]).to(self.device) 424 | loss = F.cross_entropy(pred_logits, gt_label, reduction='mean') 425 | epoch_loss.append(loss.item()) 426 | # update strength params 427 | optimizer.zero_grad() 428 | loss.backward() 429 | optimizer.step() 430 | scheduler.step() 431 | epoch_loss = np.mean(epoch_loss) 432 | loss_list.append(epoch_loss) 433 | 434 | # fronzen all learnable strength params 435 | for param in self.model.parameters(): 436 | param.requires_grad = False 437 | # set model to eval mode 438 | self.model.eval() 439 | # plot loss curve and save it 440 | utils.plot_loss_curve(loss_list, save_dir + f'/{run_name}_loss_curve.png') 441 | 442 | 443 | def init_strength(self, config): 444 | # get linear_coef size 445 | if type(config['layer']) == str: 446 | if config['layer'] == 'all': 447 | layers = list(range(self.num_layers)) 448 | layer_dim = len(layers) 449 | elif config['layer'] == 'late': 450 | layers = list(range((self.num_layers*2)//3, self.num_layers)) 451 | layer_dim = len(layers) 452 | elif config['layer'] == 'early': 453 | layers = list(range(self.num_layers//3)) 454 | layer_dim = len(layers) 455 | elif config['layer'] == 'mid': 456 | layers = list(range(self.num_layers//3, (self.num_layers*2)//3)) 457 | layer_dim = len(layers) 458 | elif type(config['layer']) == list: 459 | layers = config['layer'] 460 | layer_dim = len(layers) 461 | else: 462 | raise ValueError("layer must be all, late, early, mid or a list of layer index!") 463 | 464 | if config['inject_method'] == 'add': 465 | param_size = (layer_dim, len(config['module']), 1) # (layer_num, module_num, 1) 466 | elif config['inject_method'] in ['linear', 'balance']: 467 | param_size = (layer_dim, len(config['module']), 2) # (layer_num, module_num, 2) 468 | else: 469 | raise ValueError("only support add, linear or balance!") 470 | # set inject_layers 471 | self.inject_layers = layers 472 | # init linear_coef 473 | linear_coef = torch.zeros(param_size, device=self.device) 474 | linear_coef += torch.tensor(config['init_value'], device=self.device) 475 | self.linear_coef = nn.Parameter(linear_coef) 476 | print(f"linear_coef shape: {self.linear_coef.shape}\n") 477 | if not self.linear_coef.is_leaf: 478 | raise ValueError("linear_coef is not a leaf tensor, which is required for optimization.") 479 | 480 | 481 | def init_noise_context_vector(self, context_vector_dict): 482 | # init learnable context_vector 483 | for layer, sub_dict in context_vector_dict.items(): 484 | for module, latent in sub_dict.items(): 485 | noise_vector = torch.randn_like(latent).detach().cpu() 486 | context_vector_dict[layer][module] = noise_vector 487 | return context_vector_dict 488 | 489 | 490 | def _get_nested_attr(self, attr_path): 491 | """ 492 | Accesses nested attributes of an object based on a dot-separated string path. 493 | 494 | :param obj: The object (e.g., a model). 495 | :param attr_path: A dot-separated string representing the path to the nested attribute. 496 | For example, 'transformer.h' or 'model.layers'. 497 | :return: The attribute at the specified path. 498 | """ 499 | try: 500 | return reduce(getattr, attr_path.split('.'), self.model) 501 | except AttributeError: 502 | raise AttributeError(f"Attribute path '{attr_path}' not found.") 503 | 504 | def _get_layer_num(self): 505 | raise NotImplementedError("Please implement get_layer_num function for each model!") 506 | 507 | def _get_arribute_path(self, layer_idx, target_module): 508 | raise NotImplementedError("Please implement get_arribute_path function for each model!") 509 | 510 | 511 | class LlamaWrapper(ModelWrapper): 512 | def __init__(self, model, tokenizer, model_config, device): 513 | super().__init__(model, tokenizer, model_config, device) 514 | self.embed_matrix = self.model.model.embed_tokens.weight.data 515 | self.embed_dim = self.model_config.hidden_size 516 | self.last_norm = self.model.model.norm 517 | 518 | def _get_layer_num(self): 519 | return len(self.model.model.layers) 520 | 521 | def _get_arribute_path(self, layer_idx, target_module): 522 | if target_module == "attn": 523 | return f"model.layers.{layer_idx}.self_attn" 524 | elif target_module == "mlp": 525 | return f"model.layers.{layer_idx}.mlp" 526 | elif target_module == "hidden": 527 | return f"model.layers.{layer_idx}" 528 | else: 529 | raise ValueError("only support att or mlp!") 530 | 531 | 532 | class GPTWrapper(ModelWrapper): 533 | def __init__(self, model, tokenizer, model_config, device): 534 | super().__init__(model, tokenizer, model_config, device) 535 | self.embed_matrix = self.model.transformer.wte.weight.data 536 | self.embed_dim = self.embed_matrix.size(-1) 537 | self.last_norm = self.model.transformer.ln_f 538 | 539 | def _get_layer_num(self): 540 | return len(self.model.transformer.h) 541 | 542 | def _get_arribute_path(self, layer_idx, target_module): 543 | if target_module == "attn": 544 | return f"transformer.h.{layer_idx}.attn" 545 | elif target_module == "mlp": 546 | return f"transformer.h.{layer_idx}.mlp" 547 | elif target_module == "hidden": 548 | return f"transformer.h.{layer_idx}" 549 | else: 550 | raise ValueError("only support att or mlp!") --------------------------------------------------------------------------------