├── figures
├── overview.png
└── speed_vs_accuracy_vs_cache.png
├── global_vars.py
├── configs
├── config_i2cl_transfer_learning.py
├── config_i2cl_infer.py
├── config_label_anchor.py
├── config_soft_prompt.py
├── config_task_vector.py
└── config_i2cl.py
├── my_datasets
├── __init__.py
├── rotten_tomatoes.py
├── sst5.py
├── agnews.py
├── emo.py
├── sst2.py
├── subj.py
├── trec.py
├── hate_speech18.py
├── dbpedia.py
└── basetask.py
├── LICENSE
├── README.md
├── requirements.txt
├── evaluator.py
├── run_soft_prompt.py
├── run_task_vector.py
├── run_i2cl.py
├── utils.py
├── run_i2cl_transfer_learning.py
├── run_i2cl_infer.py
├── run_label_anchor.py
└── wrapper.py
/figures/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LzVv123456/I2CL/HEAD/figures/overview.png
--------------------------------------------------------------------------------
/figures/speed_vs_accuracy_vs_cache.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LzVv123456/I2CL/HEAD/figures/speed_vs_accuracy_vs_cache.png
--------------------------------------------------------------------------------
/global_vars.py:
--------------------------------------------------------------------------------
1 | CUR_ATTN_MASK = None
2 | ATTN_MASK_START = None
3 | ATTN_MASK_END = None
4 | SUPPORT_MODEL = ['gpt2-xl', 'EleutherAI/gpt-j-6B', 'meta-llama/Llama-2-7b-hf']
5 |
--------------------------------------------------------------------------------
/configs/config_i2cl_transfer_learning.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import my_datasets as md
4 |
5 | config = {}
6 | config['gpus'] = ['0']
7 | config['exp_name'] = 'exps/i2cl_transfer_learning'
8 | config['models'] = ['meta-llama/Llama-2-7b-hf']
9 | config['datasets'] = list(md.target_datasets.keys())
10 |
11 | config['run_num'] = 1
12 |
13 | config['threshold'] = 0.8
14 | config['temp'] = 0.5
15 | config['target_path'] = 'exps/i2cl'
16 |
--------------------------------------------------------------------------------
/configs/config_i2cl_infer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import my_datasets as md
4 |
5 |
6 | config = {}
7 |
8 | config['gpus'] = ['0']
9 | config['exp_name'] = 'exps/i2cl_infer'
10 | config['models'] = ['meta-llama/Llama-2-7b-hf']
11 | config['datasets'] = list(md.target_datasets.keys())
12 | config['run_baseline'] = True
13 | config['downstream_datasets'] = None # None will use the same dataset as source, one can also specify a list of target downstream datasets
14 |
15 | config['target_path'] = 'exps/i2cl'
16 | config['use_new_demon'] = True # whether to use new demonstrations to generate context vectors
--------------------------------------------------------------------------------
/my_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .basetask import BaseTask
2 |
3 | from .sst2 import SST2
4 | from .dbpedia import DBPedia
5 | from .sst5 import SST5
6 | from .trec import TREC
7 | from .agnews import AGNews
8 | from .subj import Subj
9 | from .rotten_tomatoes import RottenTomatoes
10 | from .hate_speech18 import HateSpeech18
11 | from .emo import EMO
12 |
13 |
14 | target_datasets = {
15 | 'agnews': AGNews,
16 | 'dbpedia': DBPedia,
17 | 'sst5': SST5,
18 | 'trec': TREC,
19 | 'sst2': SST2,
20 | 'subj': Subj,
21 | 'mr': RottenTomatoes,
22 | 'hate_speech18': HateSpeech18,
23 | 'emo': EMO,
24 | }
25 |
26 | dataset_dict = {}
27 | dataset_dict.update(target_datasets)
28 |
29 | def get_dataset(dataset, *args, **kwargs) -> BaseTask:
30 | return dataset_dict[dataset](task_name=dataset, *args, **kwargs)
--------------------------------------------------------------------------------
/configs/config_label_anchor.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import my_datasets as md
4 |
5 |
6 | config = {}
7 | # general
8 | config['exp_name'] = 'exps/label_anchor'
9 | config['gpus'] = ['0']
10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] # 'gpt2-xl', 'meta-llama/Llama-2-7b-hf', 'EleutherAI/gpt-j-6B'
11 | config['datasets'] = list(md.dataset_classification.keys())
12 | config['seed'] = 42
13 | config['run_num'] = 5
14 | config['run_baseline'] = True
15 | config['metric'] = 'acc' # 'acc', 'macro_f1'
16 | config['bs'] = 2
17 | config['load_in_8bit'] = False
18 | config['use_cache'] = True
19 |
20 | # data
21 | config['shot_per_class'] = 5
22 | config['test_data_num'] = 500
23 | config['sample_method'] = 'uniform' # 'random', 'uniform'
24 | config['use_instruction'] = False
25 | config['add_extra_query'] = False
26 | config['example_separator'] = '\n'
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jack Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/configs/config_soft_prompt.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import my_datasets as md
4 |
5 |
6 | config = {}
7 | # general
8 | config['exp_name'] = 'exps/soft_prompt'
9 | config['gpus'] = ['0']
10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] # 'gpt2-xl', 'meta-llama/Llama-2-7b-hf', 'EleutherAI/gpt-j-6B'
11 | config['datasets'] = list(md.target_datasets.keys())
12 | config['seed'] = 42
13 | config['run_num'] = 5
14 | config['run_baseline'] = True
15 | config['metric'] = 'acc' # 'acc', 'macro_f1'
16 | config['bs'] = 2
17 | config['load_in_8bit'] = False
18 | config['use_cache'] = True
19 | config['example_separator'] = '\n'
20 |
21 | # data
22 | config['shot_per_class'] = 5
23 | config['test_data_num'] = 500
24 | config['sample_method'] = 'uniform' # 'random', 'uniform'
25 | config['add_extra_query'] = False
26 |
27 | # prompt_tuning
28 | pt_config = {}
29 | pt_config['task_type'] = 'CAUSAL_LM'
30 | pt_config['num_virtual_tokens'] = 1
31 | pt_config['num_layers'] = 28
32 | config['pt_config'] = pt_config
33 |
34 | # optimization
35 | config['epochs'] = 50
36 | config['optim'] = 'adamW' # 'adam', 'adamW', 'sgd'
37 | config['grad_bs'] = 4
38 | config['lr'] = 0.1
39 | config['wd'] = 1e-3
--------------------------------------------------------------------------------
/configs/config_task_vector.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import my_datasets as md
4 |
5 |
6 | config = {}
7 | # general
8 | config['exp_name'] = 'exps/task_vector'
9 | config['gpus'] = ['0']
10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] # 'gpt2-xl', 'meta-llama/Llama-2-7b-hf', 'EleutherAI/gpt-j-6B'
11 | config['datasets'] = list(md.target_datasets.keys())
12 | config['seed'] = 42
13 | config['run_num'] = 5
14 | config['run_baseline'] = True
15 | config['metric'] = 'acc' # 'acc', 'macro_f1'
16 | config['bs'] = 2
17 | config['load_in_8bit'] = False
18 | config['use_cache'] = True
19 |
20 | # context vector
21 | config['layer'] = 'all' # all, late, early, mid
22 | config['tok_pos'] = 'last'
23 | config['module'] = ['hidden'] # 'mlp', 'attn', 'hidden'
24 | config['gen_cv_method'] = 'context' # 'context', 'noise'
25 | config['post_fuse_method'] = 'mean' # 'mean', 'pca'
26 | config['split_demon'] = False # split demonstraiton into seperate examples
27 |
28 | # data
29 | config['shot_per_class'] = 5
30 | config['val_data_num'] = 32
31 | config['test_data_num'] = 500
32 | config['sample_method'] = 'uniform' # 'random', 'uniform'
33 | config['use_instruction'] = False
34 | config['add_extra_query'] = True
35 | config['example_separator'] = '\n'
--------------------------------------------------------------------------------
/configs/config_i2cl.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import my_datasets as md
4 |
5 |
6 | config = {}
7 | # general
8 | config['exp_name'] = 'exps/i2cl'
9 | config['gpus'] = ['0']
10 | config['models'] = ['meta-llama/Llama-2-7b-hf'] # 'gpt2-xl', 'EleutherAI/gpt-j-6B'
11 | config['datasets'] = list(md.target_datasets.keys())
12 | config['seed'] = 42
13 | config['run_num'] = 5 # number of runs
14 | config['run_baseline'] = True # whether run baseline
15 | config['metric'] = 'acc' # 'acc', 'macro_f1'
16 | config['bs'] = 2 # batch size
17 | config['load_in_8bit'] = False
18 | config['use_cache'] = True # whether use kv cache
19 | config['demo_sample_method'] = 'random' # 'random' or deficient
20 |
21 | # calibrate
22 | config['add_noise'] = True # whether add noise
23 | config['noise_scale'] = 0.001 # noise scale
24 | config['epochs'] = 100 # number of epochs
25 | config['optim'] = 'adamW' # 'adam', 'adamW', 'sgd'
26 | config['grad_bs'] = 2 # batch size for clibration
27 | config['lr'] = 0.01
28 | config['wd'] = 1e-3
29 | config['cali_example_method'] = 'normal' # 'normal', 'random_label'
30 |
31 | # context vector
32 | config['layer'] = 'all' # all, early, mid, late
33 | config['tok_pos'] = 'last' # 'random', 'first', 'last'
34 | config['inject_method'] = 'linear' # 'linear', 'constraint', 'add'
35 | config['inject_pos'] = 'all' # 'all', 'first', last', 'random'
36 | config['init_value'] = [0.1, 1.0] # linear and constraint: [0.1, 1.0], add: [0.1]
37 | config['module'] = ['mlp', 'attn'] # 'mlp', 'attn', 'hidden'
38 | config['gen_cv_method'] = 'context' # 'context', 'noise'
39 | config['post_fuse_method'] = 'mean' # 'mean', 'pca'
40 | config['split_demon'] = True # split demonstraiton into seperate examples
41 | config['gen_example_method'] = 'normal' # 'normal', 'random_label', 'no_template', 'random_order'
42 |
43 | # data
44 | config['shot_per_class'] = 5 # number of shots per class
45 | config['val_data_num'] = 32
46 | config['test_data_num'] = 500 # number of test data
47 | config['sample_method'] = 'uniform' # 'random', 'uniform'
48 | config['use_instruction'] = False
49 | config['add_extra_query'] = False
50 | config['example_separator'] = '\n'
--------------------------------------------------------------------------------
/my_datasets/rotten_tomatoes.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | # from basetask import BaseTask
6 | from datasets import load_dataset
7 |
8 |
9 |
10 | class RottenTomatoes(BaseTask):
11 | def __init__(self, split='train', *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | self.task_type = "classification"
14 | # set split name
15 | load_split = split
16 | print(f"Loading {load_split} data from rotten_tomatoes ...")
17 | # class_num
18 | self.class_num = 2
19 | # load dataset
20 | self.dataset = load_dataset('rotten_tomatoes', split=load_split, keep_in_memory=True)
21 | # get all data
22 | self.all_data = [data for data in self.dataset]
23 | # get all labels
24 | self.all_labels = self.get_all_labels()
25 | # random sample data
26 | if self.max_data_num is not None:
27 | self.random_sample_data(self.max_data_num)
28 | # print a few examples
29 | print(f'Dataset lengh is {len(self.all_data)}')
30 | print("Example data:")
31 | self.print_data([0])
32 |
33 | def get_dmonstration_template(self):
34 | template = {
35 | 'input': 'Review: {text}\nSentiment:',
36 | 'ans': '{answer}',
37 | 'options': ['negative', 'positive'],
38 | 'format': ['Review:', 'Sentiment:']
39 | }
40 | return template
41 |
42 | def get_task_instruction(self):
43 | task_instruction = "Classify the sentiment of the sentence into one of the categories: positive or negative.\n\n"
44 | return task_instruction
45 |
46 | def apply_template(self, data):
47 | """
48 | PS: label should always be an integer and can be used to index the options
49 | """
50 | template = self.get_dmonstration_template()
51 | input_template = template['input']
52 | ans_template = template['ans']
53 | options = template['options']
54 | input_str = input_template.replace("{text}", data["text"])
55 | # answers can have multiple options and is a list
56 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
57 | label = data["label"]
58 | return input_str, answer_str, label
59 |
60 | def get_all_labels(self):
61 | return [data["label"] for data in self.all_data]
--------------------------------------------------------------------------------
/my_datasets/sst5.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | # from basetask import BaseTask
6 | from datasets import load_dataset
7 |
8 |
9 | class SST5(BaseTask):
10 | def __init__(self, split='train', *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.task_type = "classification"
13 | # set split name
14 | load_split = split
15 | print(f"Loading {load_split} data from sst5 ...")
16 | # class_num
17 | self.class_num = 5
18 | # load dataset
19 | self.dataset = load_dataset('SetFit/sst5', split=load_split, keep_in_memory=True)
20 | # get all data
21 | self.all_data = [data for data in self.dataset]
22 | # get all labels
23 | self.all_labels = self.get_all_labels()
24 | # random sample data
25 | if self.max_data_num is not None:
26 | self.random_sample_data(self.max_data_num)
27 | # print a few examples
28 | print(f'Dataset lengh is {len(self.all_data)}')
29 | print("Example data:")
30 | self.print_data([0])
31 |
32 | def get_dmonstration_template(self):
33 | template = {
34 | 'input': 'Sentence: {text}\nSentiment:',
35 | 'ans': '{answer}',
36 | 'options': ["terrible", "negative", "neutral", "positive", "great"],
37 | 'format': ['Sentence:', 'Sentiment:']
38 | }
39 | return template
40 |
41 | def get_task_instruction(self):
42 | task_instruction = "Classify the sentiment of sentence into one of the categories: terrible, negative, neutral, positive, great.\n\n"
43 | return task_instruction
44 |
45 | def apply_template(self, data):
46 | """
47 | PS: label should always be an integer and can be used to index the options
48 | """
49 | template = self.get_dmonstration_template()
50 | input_template = template['input']
51 | ans_template = template['ans']
52 | options = template['options']
53 | input_str = input_template.replace("{text}", data["text"])
54 | # answers can have multiple options and is a list
55 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
56 | label = data["label"]
57 | return input_str, answer_str, label
58 |
59 | def get_all_labels(self):
60 | return [data["label"] for data in self.all_data]
--------------------------------------------------------------------------------
/my_datasets/agnews.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | from datasets import load_dataset
6 |
7 |
8 | class AGNews(BaseTask):
9 | def __init__(self, split='train', *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self.task_type = "classification"
12 | # set split name
13 | if split in ['train', 'validation']:
14 | load_split = 'train'
15 | else:
16 | load_split = 'test'
17 | print(f"Loading {load_split} data from AGNews ...")
18 | # class_num
19 | self.class_num = 4
20 | # load dataset
21 | self.dataset = load_dataset('ag_news', split=load_split, keep_in_memory=True)
22 | # get all data
23 | self.all_data = [data for data in self.dataset]
24 | # get all labels
25 | self.all_labels = self.get_all_labels()
26 | # random sample data
27 | if self.max_data_num is not None:
28 | self.random_sample_data(self.max_data_num)
29 | # print a few examples
30 | print(f'Dataset lengh is {len(self.all_data)}')
31 | print("Example data:")
32 | self.print_data([0])
33 |
34 | def get_dmonstration_template(self):
35 | template = {
36 | 'input': 'News: {text}\nType:',
37 | 'ans': '{answer}',
38 | 'options': ["World", "Sports", "Business", "Technology"],
39 | 'format': ['News:', 'Type:']
40 | }
41 | return template
42 |
43 | def get_task_instruction(self):
44 | task_instruction = "Classify the article into one of the categories: World, Sports, Business, Technology.\n\n"
45 | return task_instruction
46 |
47 | def apply_template(self, data):
48 | """
49 | PS: label should always be an integer and can be used to index the options
50 | """
51 | template = self.get_dmonstration_template()
52 | input_template = template['input']
53 | ans_template = template['ans']
54 | options = template['options']
55 | input_str = input_template.replace("{text}", data["text"])
56 | # answers can have multiple options and is a list
57 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
58 | label = data["label"]
59 | return input_str, answer_str, label
60 |
61 | def get_all_labels(self):
62 | return [data["label"] for data in self.all_data]
--------------------------------------------------------------------------------
/my_datasets/emo.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | # from basetask import BaseTask
6 | from datasets import load_dataset
7 |
8 |
9 | class EMO(BaseTask):
10 | def __init__(self, split='train', *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.task_type = "classification"
13 | # set split name
14 | if split in ['train', 'validation']:
15 | load_split = 'train'
16 | else:
17 | load_split = 'test'
18 | print(f"Loading {load_split} data from emo ...")
19 | # class_num
20 | self.class_num = 4
21 | # load dataset
22 | self.dataset = load_dataset('emo', split=load_split, keep_in_memory=True)
23 | # get all data
24 | self.all_data = [data for data in self.dataset]
25 | # get all labels
26 | self.all_labels = self.get_all_labels()
27 | # random sample data
28 | if self.max_data_num is not None:
29 | self.random_sample_data(self.max_data_num)
30 | # print a few examples
31 | print(f'Dataset lengh is {len(self.all_data)}')
32 | print("Example data:")
33 | self.print_data([0])
34 |
35 | def get_dmonstration_template(self):
36 | template = {
37 | 'input': 'Dialogue: {text}\nEmotion:',
38 | 'ans': '{answer}',
39 | 'options': ['others', 'happy', 'sad', 'angry'],
40 | 'format': ['Dialogue:', 'Emotion:']
41 | }
42 | return template
43 |
44 | def get_task_instruction(self):
45 | task_instruction = "Classify the emotion of the dialogue into one of the categories: Others, Happy, Sad, Angry.\n\n"
46 | return task_instruction
47 |
48 | def apply_template(self, data):
49 | """
50 | PS: label should always be an integer and can be used to index the options
51 | """
52 | template = self.get_dmonstration_template()
53 | input_template = template['input']
54 | ans_template = template['ans']
55 | options = template['options']
56 | input_str = input_template.replace("{text}", data["text"])
57 | # answers can have multiple options and is a list
58 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
59 | label = data["label"]
60 | return input_str, answer_str, label
61 |
62 | def get_all_labels(self):
63 | return [data["label"] for data in self.all_data]
--------------------------------------------------------------------------------
/my_datasets/sst2.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | # from basetask import BaseTask
6 | from datasets import load_dataset
7 |
8 |
9 | class SST2(BaseTask):
10 | def __init__(self, split='train', *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.task_type = "classification"
13 | # set split name
14 | if split in ['train', 'validation']:
15 | load_split = 'train'
16 | else:
17 | load_split = 'validation'
18 | print(f"Loading {load_split} data from sst2 ...")
19 | # class_num
20 | self.class_num = 2
21 | # load dataset
22 | self.dataset = load_dataset('glue', 'sst2', split=load_split, keep_in_memory=True)
23 | # get all data
24 | self.all_data = [data for data in self.dataset]
25 | # get all labels
26 | self.all_labels = self.get_all_labels()
27 | # random sample data
28 | if self.max_data_num is not None:
29 | self.random_sample_data(self.max_data_num)
30 | # print a few examples
31 | print(f'Dataset lengh is {len(self.all_data)}')
32 | print("Example data:")
33 | self.print_data([0])
34 |
35 | def get_dmonstration_template(self):
36 | template = {
37 | 'input': 'Review: {text}\nSentiment:',
38 | 'ans': '{answer}',
39 | 'options': ['negative', 'positive'],
40 | 'format': ['Review:', 'Sentiment:']
41 | }
42 | return template
43 |
44 | def get_task_instruction(self):
45 | task_instruction = "Classify the sentiment of the sentence into one of the categories: positive or negative.\n\n"
46 | return task_instruction
47 |
48 | def apply_template(self, data):
49 | """
50 | PS: label should always be an integer and can be used to index the options
51 | """
52 | template = self.get_dmonstration_template()
53 | input_template = template['input']
54 | ans_template = template['ans']
55 | options = template['options']
56 | input_str = input_template.replace("{text}", data["sentence"])
57 | # answers can have multiple options and is a list
58 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
59 | label = data["label"]
60 | return input_str, answer_str, label
61 |
62 | def get_all_labels(self):
63 | return [data["label"] for data in self.all_data]
64 |
--------------------------------------------------------------------------------
/my_datasets/subj.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | # from basetask import BaseTask
6 | from datasets import load_dataset
7 |
8 |
9 | class Subj(BaseTask):
10 | def __init__(self, split='train', *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.task_type = "classification"
13 | assert split in ['train', 'validation', 'test']
14 | # set split name
15 | if split in ['train', 'validation']:
16 | load_split = 'train'
17 | else:
18 | load_split = 'test'
19 | print(f"Loading {load_split} data from subj ...")
20 | # class_num
21 | self.class_num = 2
22 | # load dataset
23 | self.dataset = load_dataset('SetFit/subj', split=load_split, keep_in_memory=True)
24 | # get all data
25 | self.all_data = [data for data in self.dataset]
26 | # get all labels
27 | self.all_labels = self.get_all_labels()
28 | # random sample data
29 | if self.max_data_num is not None:
30 | self.random_sample_data(self.max_data_num)
31 | # print a few examples
32 | print(f'Dataset lengh is {len(self.all_data)}')
33 | print("Example data:")
34 | self.print_data([0])
35 |
36 | def get_dmonstration_template(self):
37 | template = {
38 | 'input': 'Sentence: {text}\nLabel:',
39 | 'ans': '{answer}',
40 | 'options': ['objective', 'subjective'],
41 | 'format': ['Sentence:', 'Label:']
42 | }
43 | return template
44 |
45 | def get_task_instruction(self):
46 | task_instruction = "Classify the sentiment of the sentence into one of the categories: positive or negative.\n\n"
47 | return task_instruction
48 |
49 | def apply_template(self, data):
50 | """
51 | PS: label should always be an integer and can be used to index the options
52 | """
53 | template = self.get_dmonstration_template()
54 | input_template = template['input']
55 | ans_template = template['ans']
56 | options = template['options']
57 | input_str = input_template.replace("{text}", data["text"])
58 | # answers can have multiple options and is a list
59 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
60 | label = data["label"]
61 | return input_str, answer_str, label
62 |
63 | def get_all_labels(self):
64 | return [data["label"] for data in self.all_data]
--------------------------------------------------------------------------------
/my_datasets/trec.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | from datasets import load_dataset
6 |
7 |
8 | class TREC(BaseTask):
9 | def __init__(self, split='train', *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self.task_type = "classification"
12 | # set split name
13 | if split in ['train', 'validation']:
14 | load_split = 'train'
15 | else:
16 | load_split = 'test'
17 | print(f"Loading {load_split} data from TREC ...")
18 | # class_num
19 | self.class_num = 6
20 | # load dataset
21 | self.dataset = load_dataset('trec', split=load_split, keep_in_memory=True)
22 | # get all data
23 | self.all_data = [data for data in self.dataset]
24 | # get all labels
25 | self.all_labels = self.get_all_labels()
26 | # random sample data
27 | if self.max_data_num is not None:
28 | self.random_sample_data(self.max_data_num)
29 | # print a few examples
30 | print(f'Dataset lengh is {len(self.all_data)}')
31 | print("Example data:")
32 | self.print_data([0])
33 |
34 | def get_dmonstration_template(self):
35 | template = {
36 | 'input': 'Question: {text}\nAnswer Type:',
37 | 'ans': '{answer}',
38 | 'options': ["Abbreviation", "Entity", "Description", "Person", "Location", "Number"],
39 | 'format': ['Question:', 'Category:']
40 | }
41 | return template
42 |
43 | def get_task_instruction(self):
44 | task_instruction = "Classify the questions based on whether their answer type is a Number, Location, Person, Description, Entity, or Abbreviation.\n\n"
45 | return task_instruction
46 |
47 | def apply_template(self, data):
48 | """
49 | PS: label should always be an integer and can be used to index the options
50 | """
51 | template = self.get_dmonstration_template()
52 | input_template = template['input']
53 | ans_template = template['ans']
54 | options = template['options']
55 | input_str = input_template.replace("{text}", data["text"])
56 | # answers can have multiple options and is a list
57 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
58 | label = data["coarse_label"]
59 | return input_str, answer_str, label
60 |
61 | def get_all_labels(self):
62 | return [data["coarse_label"] for data in self.all_data]
--------------------------------------------------------------------------------
/my_datasets/hate_speech18.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | # from basetask import BaseTask
6 | from datasets import load_dataset
7 |
8 |
9 | class HateSpeech18(BaseTask):
10 | def __init__(self, split='train', *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.task_type = "classification"
13 | print(f"Loading train data from hates_peach18 ...")
14 | # class_num
15 | self.class_num = 2
16 | # load dataset
17 | self.dataset = load_dataset('hate_speech18', split='train', keep_in_memory=True)
18 | # get all data
19 | self.all_data = [data for data in self.dataset if data['label'] in [0, 1]] # clean data to keep only hate and noHate
20 | if split in ['train', 'validation']:
21 | self.all_data = self.all_data[:len(self.all_data)//2]
22 | else:
23 | self.all_data = self.all_data[len(self.all_data)//2:]
24 | # get all labels
25 | self.all_labels = self.get_all_labels()
26 | # random sample data
27 | if self.max_data_num is not None:
28 | self.random_sample_data(self.max_data_num)
29 | # print a few examples
30 | print(f'Dataset length is {len(self.all_data)}')
31 | print("Example data:")
32 | self.print_data([0])
33 |
34 | def get_dmonstration_template(self):
35 | template = {
36 | 'input': 'Text: {text}\nLabel:',
37 | 'ans': '{answer}',
38 | 'options': ['neutral', 'hate'],
39 | 'format': ['Text:', 'Label:']
40 | }
41 | return template
42 |
43 | def get_task_instruction(self):
44 | task_instruction = "Classify the text into one of the categories: noHate or hate.\n\n"
45 | return task_instruction
46 |
47 | def apply_template(self, data):
48 | """
49 | PS: label should always be an integer and can be used to index the options
50 | """
51 | template = self.get_dmonstration_template()
52 | input_template = template['input']
53 | ans_template = template['ans']
54 | options = template['options']
55 | input_str = input_template.replace("{text}", data["text"])
56 | # answers can have multiple options and is a list
57 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
58 | label = data["label"]
59 | return input_str, answer_str, label
60 |
61 | def get_all_labels(self):
62 | return [data["label"] for data in self.all_data]
--------------------------------------------------------------------------------
/my_datasets/dbpedia.py:
--------------------------------------------------------------------------------
1 | try:
2 | from . import BaseTask
3 | except:
4 | from basetask import BaseTask
5 | # from basetask import BaseTask
6 | from datasets import load_dataset
7 |
8 |
9 | class DBPedia(BaseTask):
10 | def __init__(self, split='train', *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.task_type = "classification"
13 | # set split name
14 | if split in ['train', 'validation']:
15 | load_split = 'train'
16 | else:
17 | load_split = 'test'
18 |
19 | print(f"Loading {load_split} data from DBPedia ...")
20 | # class_num
21 | self.class_num = 14
22 | # load dataset
23 | self.dataset = load_dataset('fancyzhx/dbpedia_14', split=load_split, keep_in_memory=True)
24 | # get all data
25 | self.all_data = [data for data in self.dataset]
26 | # get all labels
27 | self.all_labels = self.get_all_labels()
28 | # random sample data
29 | if self.max_data_num is not None:
30 | self.random_sample_data(self.max_data_num)
31 | # print a few examples
32 | print(f'Dataset lengh is {len(self.all_data)}')
33 | print("Example data:")
34 | self.print_data([0])
35 |
36 | def get_dmonstration_template(self):
37 | template = {
38 | 'input': 'Input: {text}\nLabel:',
39 | 'ans': '{answer}',
40 | 'options': ['company', 'school', 'artist', 'athlete',
41 | 'politics', 'transportation', 'building',
42 | 'nature', 'village', 'animal', 'plant',
43 | 'album', 'film', 'book'],
44 | 'format': ['Input:', 'Label:']
45 | }
46 | return template
47 |
48 | def get_task_instruction(self):
49 | task_instruction = "Classify the sentiment of the sentence into one of the categories: positive or negative.\n\n"
50 | return task_instruction
51 |
52 | def apply_template(self, data):
53 | """
54 | PS: label should always be an integer and can be used to index the options
55 | """
56 | template = self.get_dmonstration_template()
57 | input_template = template['input']
58 | ans_template = template['ans']
59 | options = template['options']
60 | input_str = input_template.replace("{text}", data["content"])
61 | # answers can have multiple options and is a list
62 | answer_str = [ans_template.replace("{answer}", options[i]) for i in range(len(options))]
63 | label = data["label"]
64 | return input_str, answer_str, label
65 |
66 | def get_all_labels(self):
67 | return [data["label"] for data in self.all_data]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # I2CL: Implicit In-context Learning
2 | Please refer to the [paper](https://arxiv.org/pdf/2405.14660) for more details.
3 |
4 |
5 |
6 |  |
7 |  |
8 |
9 |
10 |
11 | ## What's New?
12 | ### 🌟 Introducing I2CL: a new paradigm to leverage demonstration examples:
13 | Implicit In-context Learning (I2CL) absorbs a tiny set of demonstration examples in the activation space, diverging from standard In-context Learning (ICL) that prefixes demonstration examples in the token space. As a result, I2CL bypasses the limitation of the context window size.
14 |
15 | ### 🌟 Efficiency
16 | I2CL is extremely efficient in terms of both computation and memory usage. It achieves few-shot (i.e., ICL) performance with approximately zero-shot cost at inference.
17 |
18 | ### 🌟 Robustness
19 | I2CL is robust against the selection and order of demonstration examples. It yields satisfying performance even under deficient demonstration examples which can severely degrade the performance of ICL.
20 |
21 | ### 🌟 Generation of "Task-ID"
22 | I2CL introduces a set of scalar values that act as task-IDs. These task-IDs effectively indicate the task similarity and can be leveraged to perform transfer learning across tasks.
23 |
24 | ### 🌟 Better Understanding of ICL
25 | I2CL substantiates a two-stage workflow. A context vector is first generated by condensing the demonstration examples independently. Then a linear combination of the context vector and activation from the query is injected back into the residual streams. The effectiveness of I2CL suggests a potential two-stage workflow for ICL.
26 |
27 | ## Installation
28 | 1. **Clone the Repository**:
29 |
30 | ```bash
31 | git clone https://github.com/LzVv123456/I2CL
32 | cd yourrepository
33 | ```
34 |
35 | 2. **Create a Conda Environment**:
36 |
37 | Create a new conda environment to avoid conflicts with existing packages.
38 |
39 | ```bash
40 | conda create --name i2cl_env python=3.8
41 | conda activate i2cl_env
42 | ```
43 |
44 | 3. **Install Dependencies**:
45 |
46 | Use `pip` to install the required libraries listed in the `requirements.txt` file.
47 |
48 | ```bash
49 | pip install -r requirements.txt
50 | ```
51 |
52 |
53 | ## Usage
54 |
55 | To use the code, follow these steps:
56 |
57 | 1. **Navigate to the I2CL Folder**:
58 |
59 | ```bash
60 | cd I2CL
61 | ```
62 |
63 | 2. **Run I2CL**:
64 |
65 | To run I2CL, execute the following command:
66 |
67 | ```bash
68 | python run_i2cl.py
69 | ```
70 |
71 | 3. **Run Comparable Methods**:
72 |
73 | To run other comparable methods, use the following commands:
74 |
75 | ```bash
76 | python run_soft_prompt.py
77 | python run_task_vector.py
78 | python run_label_anchor.py
79 | ```
80 |
81 | 4. **Apply I2CL to Unseen Demonstrations or Perform Transfer Learning**:
82 |
83 | First, run `run_i2cl.py` and specify the target path in the configuration files for `run_i2cl_infer.py` and `run_i2cl_transfer_learning.py` as the output result folder of `run_i2cl.py`.
84 |
85 | Then, execute the following commands:
86 |
87 | ```bash
88 | python run_i2cl_infer.py
89 | python run_i2cl_transfer_learning.py
90 | ```
91 |
92 | 5. **Configuration and Ablation Studies**
93 |
94 | For ablation studies and other configurations, please refer to `configs/config_i2cl.py` and the corresponding code files for more details.
95 |
96 | Please don't hesitate to drop an email at zhuowei.li@cs.rutgers.edu if you have question.
97 |
98 | ## License
99 |
100 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
101 |
102 | ## Acknowledgments
103 |
104 | I would also like to acknowledge the repositories that inspired and contributed to this work:
105 | - [label-words-are-anchors](https://github.com/lancopku/label-words-are-anchors)
106 | - [ICV](https://github.com/shengliu66/ICV)
107 | - [icl_task_vectors](https://github.com/roeehendel/icl_task_vectors)
108 |
109 | ## Citation
110 | If you found this work useful for your research, feel free to star ⭐ the repo or cite the following paper:
111 | ```
112 | @misc{li2024implicit,
113 | title={Implicit In-context Learning},
114 | author={Zhuowei Li and Zihao Xu and Ligong Han and Yunhe Gao and Song Wen and Di Liu and Hao Wang and Dimitris N. Metaxas},
115 | year={2024},
116 | eprint={2405.14660},
117 | archivePrefix={arXiv},
118 | primaryClass={cs.LG}
119 | }
120 | ```
121 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.26.1
2 | aiohttp==3.9.3
3 | aiosignal==1.3.1
4 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
5 | async-timeout==4.0.3
6 | attrs==23.2.0
7 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
8 | bitsandbytes==0.42.0
9 | Bottleneck @ file:///opt/conda/conda-bld/bottleneck_1657175564434/work
10 | Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work
11 | calflops==0.2.9
12 | certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi
13 | cffi @ file:///croot/cffi_1700254295673/work
14 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
15 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
16 | contourpy @ file:///croot/contourpy_1700583582875/work
17 | cryptography @ file:///croot/cryptography_1702070282333/work
18 | cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
19 | datasets==2.16.1
20 | debugpy @ file:///croot/debugpy_1690905042057/work
21 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
22 | dill==0.3.7
23 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
24 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
25 | filelock @ file:///croot/filelock_1700591183607/work
26 | fonttools==4.25.0
27 | frozenlist==1.4.1
28 | fsspec==2023.10.0
29 | gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645438755360/work
30 | huggingface-hub==0.20.3
31 | idna @ file:///croot/idna_1666125576474/work
32 | importlib-resources @ file:///croot/importlib_resources-suite_1704281845041/work
33 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1708996548741/work
34 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work
35 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
36 | Jinja2 @ file:///croot/jinja2_1706733616596/work
37 | joblib==1.3.2
38 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work
39 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257447442/work
40 | kiwisolver @ file:///croot/kiwisolver_1672387140495/work
41 | MarkupSafe @ file:///croot/markupsafe_1704205993651/work
42 | matplotlib==3.8.4
43 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
44 | mkl-fft @ file:///croot/mkl_fft_1695058164594/work
45 | mkl-random @ file:///croot/mkl_random_1695059800811/work
46 | mkl-service==2.4.0
47 | mpmath @ file:///croot/mpmath_1690848262763/work
48 | multidict==6.0.5
49 | multiprocess==0.70.15
50 | munkres==1.1.4
51 | nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
52 | networkx @ file:///croot/networkx_1690561992265/work
53 | numexpr @ file:///croot/numexpr_1696515281613/work
54 | numpy @ file:///croot/numpy_and_numpy_base_1704311704800/work/dist/numpy-1.26.3-cp39-cp39-linux_x86_64.whl#sha256=93e7b9e5e2090dd03810e7c1b02cb077d3ef49fc713f0af531e0667375d9decb
55 | nvidia-ml-py3==7.352.0
56 | packaging @ file:///croot/packaging_1693575174725/work
57 | pandas @ file:///croot/pandas_1702317985682/work/dist/pandas-2.1.4-cp39-cp39-linux_x86_64.whl#sha256=324122d54922a1f7a8103c955ecfc280ac90ea4d0a9ae7e365d5a477d5f0713e
58 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
59 | peft==0.8.2
60 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
61 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
62 | Pillow @ file:///croot/pillow_1696580024257/work
63 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1706713388748/work
64 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work
65 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
66 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
67 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
68 | pyarrow==15.0.0
69 | pyarrow-hotfix==0.6
70 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
71 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work
72 | pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work
73 | pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
74 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work
75 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
76 | pytz @ file:///croot/pytz_1695131579487/work
77 | PyYAML @ file:///croot/pyyaml_1698096049011/work
78 | pyzmq @ file:///croot/pyzmq_1705605076900/work
79 | regex==2023.12.25
80 | requests @ file:///croot/requests_1690400202158/work
81 | safetensors==0.4.2
82 | scikit-learn==1.4.0
83 | scipy==1.12.0
84 | seaborn==0.13.2
85 | six @ file:///tmp/build/80754af9/six_1644875935023/work
86 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
87 | sympy @ file:///croot/sympy_1701397643339/work
88 | threadpoolctl==3.2.0
89 | tokenizers==0.15.1
90 | torch==2.2.0
91 | torchaudio==2.2.0
92 | torchvision==0.17.0
93 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827245914/work
94 | tqdm==4.66.1
95 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1710254411456/work
96 | transformers==4.37.2
97 | triton==2.2.0
98 | typing_extensions @ file:///croot/typing_extensions_1705599297034/work
99 | tzdata @ file:///croot/python-tzdata_1690578112552/work
100 | urllib3 @ file:///croot/urllib3_1698257533958/work
101 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
102 | xxhash==3.4.1
103 | yarl==1.9.4
104 | zipp @ file:///croot/zipp_1704206909481/work
105 |
--------------------------------------------------------------------------------
/evaluator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import utils
5 | import global_vars as gv
6 | import my_datasets as md
7 |
8 |
9 | class Evaluator(nn.Module):
10 |
11 | def __init__(self, dataset, batch_size):
12 | super().__init__()
13 | self.dataset = dataset
14 | self.batch_size = batch_size
15 |
16 | def evaluate(self, model_wrapper, tokenizer, demonstration='', use_cache=False):
17 |
18 | return self._evaluate_text_classification_batch(model_wrapper, tokenizer,
19 | demonstration, use_cache=use_cache)
20 |
21 | def _evaluate_text_classification_batch(self, model_wrapper, tokenizer,
22 | demonstration, use_cache=False):
23 |
24 | model = model_wrapper.model
25 | # prepare label dict
26 | label_map = {}
27 | ans_txt_list = self.dataset.get_dmonstration_template()['options']
28 | for label, ans_txt in enumerate(ans_txt_list):
29 | if 'gpt' in tokenizer.__class__.__name__.lower():
30 | ans_txt = ' ' + ans_txt # add space to the beginning of answer
31 | ans_tok = tokenizer.encode(ans_txt, add_special_tokens=False)[0] # use the first token if more than one token
32 | print(f"ans_txt: {ans_txt}, ans_tok: {ans_tok}")
33 | label_map[ans_tok] = label # index is the label
34 | print(f"label_map: {label_map}")
35 |
36 | # prepare all data
37 | all_pred_labels = []
38 | all_inputs, all_labels = [], []
39 | for data in self.dataset.all_data:
40 | ques_str, _, label = self.dataset.apply_template(data)
41 | if use_cache or len(demonstration) == 0:
42 | context = ques_str
43 | else:
44 | context = demonstration + ques_str
45 | all_inputs.append(context)
46 | all_labels.append(label)
47 |
48 | # cache the demonstration
49 | if len(demonstration) > 0 and use_cache:
50 | demon_token = tokenizer(demonstration, return_tensors="pt", padding=True).to(model.device)
51 | with torch.no_grad():
52 | demon_outputs = model(**demon_token, use_cache=True)
53 | demon_past_key_values = demon_outputs.past_key_values
54 | demon_attn_mask = demon_token['attention_mask']
55 | demon_past_key_values = tuple(tuple(t.repeat(self.batch_size, 1, 1, 1) for
56 | t in tup) for tup in demon_past_key_values)
57 | demon_attn_mask = demon_attn_mask.repeat(self.batch_size, 1)
58 | if len(all_inputs) % self.batch_size != 0: # last batch
59 | sp_demon_past_key_values = tuple(tuple(t.repeat(len(all_inputs) % self.batch_size, 1, 1, 1)
60 | for t in tup) for tup in demon_outputs.past_key_values)
61 | sp_demon_attn_mask = demon_attn_mask[-(len(all_inputs) % self.batch_size):]
62 | use_cache = True
63 | else:
64 | demon_past_key_values = None
65 | sp_demon_past_key_values = None
66 | sp_demon_attn_mask = None
67 | use_cache = False
68 |
69 | # loop over all data
70 | with torch.no_grad():
71 | for i in range(0, len(all_inputs), self.batch_size):
72 | cur_inputs = all_inputs[i:i+self.batch_size]
73 | # accommodate for the last batch
74 | if len(cur_inputs) != self.batch_size:
75 | demon_past_key_values = sp_demon_past_key_values
76 | demon_attn_mask = sp_demon_attn_mask
77 | input_tok = tokenizer(cur_inputs, return_tensors="pt", padding=True)
78 | input_ids = input_tok['input_ids'].to(model.device)
79 | attn_mask = input_tok['attention_mask'].to(model.device)
80 | # get index for prediction logits, need to be applied before concatenating demon_attn_mask with attn_mask
81 | pred_loc = utils.last_one_indices(attn_mask).to(model.device)
82 | # set global variables
83 | gv.ATTN_MASK_START = torch.zeros_like(pred_loc)
84 | gv.ATTN_MASK_END = pred_loc
85 | if use_cache:
86 | attn_mask = torch.cat([demon_attn_mask, attn_mask], dim=1)
87 | output = model(input_ids=input_ids, attention_mask=attn_mask,
88 | past_key_values=demon_past_key_values, use_cache=use_cache)
89 | else:
90 | output = model(input_ids=input_ids, attention_mask=attn_mask)
91 | logits = output.logits
92 |
93 | # get prediction logits
94 | pred_logits = logits[torch.arange(logits.size(0)), pred_loc]
95 | # get prediction labels
96 | interest_index = list(label_map.keys())
97 | pred_logits = pred_logits[:, interest_index]
98 | probs = F.softmax(pred_logits, dim=-1)
99 | pred_labels = probs.argmax(dim=-1)
100 | # save results
101 | all_pred_labels.extend(pred_labels.cpu().numpy().tolist())
102 |
103 | assert len(all_pred_labels) == len(all_labels)
104 | # both all_results and all_labels are list containing label index, can you help me to calculate accuracy and macro f1?
105 | # initialize TP, FP, FN
106 | acc = []
107 | num_classes = self.dataset.class_num
108 | TP = [0] * num_classes
109 | FP = [0] * num_classes
110 | FN = [0] * num_classes
111 | for i, true_label in enumerate(all_labels):
112 | pred_label = all_pred_labels[i]
113 | pred = pred_label == true_label
114 | acc.append(pred)
115 | # Update TP, FP, FN
116 | if pred:
117 | TP[true_label] += 1
118 | else:
119 | FP[pred_label] += 1
120 | FN[true_label] += 1
121 | # Calculate precision, recall, F1 for each class and macro F1
122 | precision = [0] * num_classes
123 | recall = [0] * num_classes
124 | f1 = [0] * num_classes
125 | for i in range(num_classes):
126 | precision[i] = TP[i] / (TP[i] + FP[i]) if (TP[i] + FP[i]) > 0 else 0
127 | recall[i] = TP[i] / (TP[i] + FN[i]) if (TP[i] + FN[i]) > 0 else 0
128 | f1[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0
129 | macro_f1 = sum(f1) / num_classes
130 | acc = sum(acc) / len(acc)
131 | return {'acc': acc, 'macro_f1': macro_f1}
--------------------------------------------------------------------------------
/run_soft_prompt.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import copy
4 | import time
5 | import random
6 | import argparse
7 | import itertools
8 | import torch
9 | from multiprocessing import Process, Queue
10 |
11 | import utils
12 | import evaluator as ev
13 | import my_datasets as md
14 |
15 |
16 | def main(args):
17 | # set global seed
18 | utils.set_seed(args.config['seed'])
19 | # set device
20 | args.device = utils.set_device(args.gpu)
21 | # set comprare metric
22 | args.metric = args.config['metric']
23 | # get save dir
24 | utils.init_exp_path(args, args.config['exp_name'])
25 |
26 | # load tokenizer and model
27 | model, tokenizer, model_config = \
28 | utils.load_model_tokenizer(args.model_name, args.device)
29 |
30 | # get model_wrapper
31 | model_wrapper = utils.get_model_wrapper(args.model_name, model,
32 | tokenizer, model_config,
33 | args.device)
34 | # load datasets
35 | train_dataset = md.get_dataset(args.dataset_name, split='train', max_data_num=None)
36 | test_dataset = md.get_dataset(args.dataset_name, split='test',
37 | max_data_num=args.config['test_data_num'],
38 | sample_mode=args.config['sample_method'])
39 |
40 | # get max demonstration token length for each dataset
41 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer)
42 |
43 | # get shot_num
44 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia
45 | args.config['shot_per_class'] = 1
46 | args.config['bs'] = 1
47 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class'])
48 | # build evaluator
49 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs'])
50 |
51 | # init result_dict
52 | result_dict = {'demon': {},
53 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': []},
54 | 'time': {'calibrate': [], 'evaluate': []},
55 | }
56 |
57 | for run_id in range(args.config['run_num']):
58 | run_name = f'run_{run_id}'
59 | args.run_name = run_name
60 | print(f'Run time {run_name}')
61 | run_seed = args.config['seed'] + run_id
62 | utils.set_seed(run_seed)
63 |
64 | # zero-shot baseline
65 | if run_id == 0 and args.config['run_baseline']:
66 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
67 | use_cache=args.config['use_cache'])
68 | result_dict['test_result']['zero_shot'].append(test_zeroshot_result)
69 | print(f'Test zero-shot result: {test_zeroshot_result}\n')
70 |
71 | # sample demonstration
72 | demon, _, demon_data_index = \
73 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num,
74 | max_demonstration_tok_len=args.test_max_token,
75 | add_extra_query=args.config['add_extra_query'],
76 | example_separator=args.config['example_separator'],
77 | return_data_index=True, seed=random.randint(0, 1e6))
78 |
79 | if args.config['add_extra_query']:
80 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0]
81 | # remove all contents after the last first_format_anchor including the anchor
82 | if first_format_anchor in demon:
83 | baseline_demon = demon[:demon.rfind(first_format_anchor)]
84 | query_demon = demon[demon.rfind(first_format_anchor):]
85 | else:
86 | baseline_demon = demon
87 | query_demon = None
88 | print(f'Demonstration:\n{demon}\n')
89 | print(f'Baseline demonstration:\n{baseline_demon}\n')
90 | print(f'Query demonstration:\n{query_demon}\n')
91 |
92 | # few-shot baseline
93 | if args.config['run_baseline']:
94 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer,
95 | demonstration=baseline_demon,
96 | use_cache=args.config['use_cache'])
97 | result_dict['test_result']['few_shot'].append(test_fewshot_result)
98 | print(f'Test few-shot result: {test_fewshot_result}\n')
99 |
100 | # generate demon_list
101 | demon_list = [demon]
102 | # save demon_list
103 | result_dict['demon'][run_name] = demon_list
104 |
105 | # prepare peft_train_dataset
106 | cali_dataset = copy.deepcopy(train_dataset)
107 | cali_dataset.all_data = [train_dataset.all_data[i] for i in demon_data_index]
108 |
109 | # train softprompt
110 | s_t = time.time()
111 | model_wrapper.softprompt(args.config, cali_dataset, save_dir=args.save_dir, run_name=run_name)
112 | e_t = time.time()
113 | print(f'Calibration time: {e_t - s_t}')
114 | result_dict['time']['calibrate'].append(e_t - s_t)
115 |
116 | # evaluate softprompt
117 | s_t = time.time()
118 | test_ours_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
119 | use_cache=args.config['use_cache'])
120 | print(f'Test Soft Prompt result: {test_ours_result}\n')
121 | result_dict['test_result']['ours'].append(test_ours_result)
122 | e_t = time.time()
123 | print(f'Evaluate time: {e_t - s_t}')
124 | result_dict['time']['evaluate'].append(e_t - s_t)
125 |
126 | # save result_dict after each run
127 | with open(args.save_dir + '/result_dict.json', 'w') as f:
128 | json.dump(result_dict, f, indent=4)
129 |
130 | # reset model_wrapper to unadapted model
131 | del model_wrapper, model, tokenizer, model_config
132 | model, tokenizer, model_config = \
133 | utils.load_model_tokenizer(args.model_name, args.device,
134 | load_in_8bit=args.config['load_in_8bit'],
135 | output_hidden_states=False)
136 | # get model_wrapper
137 | model_wrapper = utils.get_model_wrapper(args.model_name, model,
138 | tokenizer, model_config,
139 | args.device)
140 |
141 | # delete all variables
142 | del model_wrapper, model, tokenizer, train_dataset, test_dataset, cali_dataset
143 | del test_evaluator, result_dict, demon_list
144 |
145 |
146 | # get args
147 | def get_args():
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument('--config_path', type=str, default='configs/config_soft_prompt.py', help='path to config file')
150 | return parser.parse_args()
151 |
152 |
153 | if __name__ == "__main__":
154 | # get args
155 | args = get_args()
156 | # load config
157 | config = utils.load_config(args.config_path)
158 | # Generate all combinations of models and datasets
159 | combinations = list(itertools.product(config['models'], config['datasets']))
160 | # Queue to hold tasks
161 | task_queue = Queue()
162 | for combine in combinations:
163 | task_queue.put(combine)
164 |
165 | def run_task(gpu_id, config):
166 | while not task_queue.empty():
167 | model_name, dataset_name = task_queue.get()
168 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}")
169 | input_args = argparse.Namespace()
170 | cur_config = copy.deepcopy(config)
171 | input_args.model_name = model_name
172 | input_args.dataset_name = dataset_name
173 | input_args.gpu = gpu_id
174 | input_args.config = cur_config
175 | try:
176 | main(input_args)
177 | finally:
178 | # Clean up CUDA memory after each task
179 | gc.collect()
180 | torch.cuda.empty_cache()
181 | print(f"CUDA memory cleared for GPU {gpu_id}")
182 | time.sleep(5)
183 |
184 | # Create a process for each GPU
185 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']]
186 | # Start all processes
187 | for p in processes:
188 | p.start()
189 | # Wait for all processes to finish
190 | for p in processes:
191 | p.join()
192 | print("All tasks completed.")
--------------------------------------------------------------------------------
/my_datasets/basetask.py:
--------------------------------------------------------------------------------
1 | import random
2 | import string
3 | from itertools import zip_longest
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class BaseTask(Dataset):
8 | def __init__(self, task_name, sample_mode='random', max_data_num=None, use_instruction=False, seed=0):
9 | super().__init__()
10 | self.task_name = task_name
11 | self.sample_mode = sample_mode
12 | self.max_data_num = max_data_num # maximum number of loaded data
13 | self.use_instruction = use_instruction # whether add task instruction as prefix
14 | self.task_type = None
15 | self.dataset = None # dataset
16 | self.all_data = None # all data
17 | self.all_labels = None # all labels
18 | self.seed = seed
19 |
20 | def random_sample_data(self, max_data_num, seed=0):
21 | """
22 | This function is used to randomly sample data from the dataset.
23 | """
24 | # set random seed
25 | random.seed(self.seed)
26 | assert self.all_data is not None, "Please load data first!"
27 | if max_data_num < len(self.all_data):
28 | if (self.all_labels is None) or (self.sample_mode == 'random'):
29 | # random sample data
30 | self.all_data = random.sample(self.all_data, max_data_num)
31 | else:
32 | # sample data uniformly from each class, if not possible, sample up to max_data_num
33 | label_dict = {label: [] for label in list(set(self.all_labels))}
34 | for index, label in enumerate(self.all_labels):
35 | label_dict[label].append(self.all_data[index])
36 | # print key and number of data with current key
37 | print("Number of data in each class:")
38 | for label, label_list in label_dict.items():
39 | print(f"{label}: {len(label_list)}")
40 | new_all_data = []
41 |
42 | key_num = len(label_dict.keys())
43 | for label in label_dict:
44 | # if number of data with current label is smaller than max_data_num / class_num, use all data and distibute remaining quota evenly to rest classes
45 | if len(label_dict[label]) < max_data_num // key_num:
46 | new_all_data.extend(label_dict[label])
47 | else:
48 | # random sample data from current label
49 | new_all_data.extend(random.sample(label_dict[label], max_data_num // key_num))
50 | # if length of new_all_data is smaller than max_data_num, randomly sample data from all data but check if the data is already in new_all_data
51 | while len(new_all_data) < max_data_num:
52 | tem_data = random.choice(self.all_data)
53 | if tem_data not in new_all_data:
54 | new_all_data.append(tem_data)
55 | self.all_data = new_all_data
56 | else:
57 | print(f"Warning: max_data_num {max_data_num} is larger than the dataset size {len(self.all_data)}!, use all data instead.")
58 |
59 |
60 | def get_max_demonstration_token_length(self, tokenizer):
61 | """
62 | This function is used to get the maximum token length of the example in the dataset.
63 | This is mainly used for test dataset to determine the maximum number of the demonstration
64 | tokens that can be used due to the limit of context window.
65 | """
66 | all_data_toks = []
67 | for data in self.all_data:
68 | input_str, ans_str, _ = self.apply_template(data)
69 | if self.use_instruction:
70 | instruct = self.get_task_instruction()
71 | instruct = instruct + '\n'
72 | else:
73 | instruct = ""
74 | demonstration_str = [(instruct + input_str + ' ' + ans_str[i]) for i in range(len(ans_str))]
75 | all_data_toks.extend(demonstration_str)
76 | # get maximum token length of the example in the dataset
77 | encoded_inputs = tokenizer.batch_encode_plus(all_data_toks, return_tensors="pt", padding=True, truncation=True)
78 | single_data_max_len = encoded_inputs['attention_mask'].sum(dim=1).max().item()
79 | if 'Llama' in tokenizer.name_or_path:
80 | if 'Llama-2' in tokenizer.name_or_path:
81 | cxt_max_len = 4096
82 | else:
83 | cxt_max_len = 2048
84 | else:
85 | cxt_max_len = tokenizer.model_max_length
86 | max_demonstration_len = cxt_max_len - single_data_max_len
87 | print(f"Max demonstration token length is : {cxt_max_len} - {single_data_max_len} = {max_demonstration_len}",)
88 | return max_demonstration_len
89 |
90 |
91 | def gen_few_shot_demonstration(self, tokenizer, shot_num, max_demonstration_tok_len=1e6,
92 | add_extra_query=False, example_separator='\n',
93 | return_data_index=False, gen_example_method='normal', seed=0):
94 |
95 | """
96 | This function is used to generate few-shot demonstration.
97 | """
98 | # set random seed
99 | random.seed(seed)
100 |
101 | assert self.all_data is not None, "Please load data first!"
102 | assert shot_num <= len(self.all_data), "Shot number should be smaller than the number of data!"
103 | if hasattr(self, 'class_num') and self.class_num is not None: # if class number is provided
104 | assert shot_num == -1 or shot_num == 0 or shot_num >= self.class_num, "Shot number should be at least larger than the number of classes!"
105 | class_num = self.class_num
106 | # get label dict
107 | label_dict = {label: [] for label in range(class_num)}
108 | for index, data in enumerate(self.all_data):
109 | label_dict[self.apply_template(data)[-1]].append(index)
110 | else:
111 | class_num = None
112 |
113 | # get task instruction
114 | instruct = self.get_task_instruction() if self.use_instruction else ""
115 | if len(instruct) > 0:
116 | demonstration_expample_list = [instruct]
117 | demonstration = instruct + example_separator
118 | else:
119 | demonstration_expample_list = []
120 | demonstration = ""
121 |
122 | if class_num is None: # random sample data
123 | sample_indexes = random.sample(range(len(self.all_data)), shot_num)
124 | else: # uniform sample data from each class
125 | sample_indexes = []
126 | # split sample number into each class equally, if not possible, sample as many as possible
127 | for label in label_dict:
128 | sample_indexes.extend(random.sample(label_dict[label], shot_num // class_num))
129 | # random shuffle the sample indexes
130 | random.shuffle(sample_indexes)
131 |
132 | for index in sample_indexes:
133 | input_str, ans_str, label = self.apply_template(self.all_data[index])
134 | ans = ans_str[label]
135 | new_example = input_str + ' ' + ans
136 | demonstration = demonstration + new_example + example_separator
137 | # collect single demonstration example
138 | if gen_example_method == 'normal':
139 | single_example = new_example
140 | elif gen_example_method == 'random_label':
141 | single_example = input_str + ' ' + random.choice(ans_str)
142 | elif gen_example_method == 'no_template':
143 | single_example = input_str.split(":")[1].split("\n")[0] + ' ' + ans
144 | elif gen_example_method == 'random_order':
145 | # random change the order of each word in the input string
146 | single_example = new_example
147 | words = single_example.split()
148 | random.shuffle(words)
149 | single_example = ' '.join(words)
150 | else:
151 | raise ValueError("Unknown demonstration example generation method!")
152 | single_example = single_example + example_separator
153 | demonstration_expample_list.append(single_example)
154 |
155 | if add_extra_query: # add a random query at the end of demonstration
156 | extra_qeury, _, _ = self.apply_template(self.all_data[random.randint(0, len(self.all_data))])
157 | demonstration += extra_qeury
158 | demonstration_expample_list.append(extra_qeury)
159 |
160 | # check length of demonstration token
161 | encoded_inputs = tokenizer(demonstration, return_tensors="pt", padding=True, truncation=False)
162 | assert len(encoded_inputs['input_ids'][0]) < max_demonstration_tok_len, "Demonstration token length should be smaller than the maximum demonstration token length!"
163 | print(f"Generated {shot_num}-shot demonstration.")
164 |
165 | if return_data_index:
166 | return demonstration, demonstration_expample_list, sample_indexes
167 | else:
168 | return demonstration, demonstration_expample_list
169 |
170 |
171 | def get_dmonstration_template(self):
172 | """
173 | This function is used to provide template for demonstration, need to be implemented for each task.
174 | """
175 | raise NotImplementedError("Please provide the template for demonstration!")
176 |
177 | def get_task_instruction(self):
178 | """
179 | This function is used to provide task instruction, need to be implemented for each task.
180 | """
181 | raise NotImplementedError("Please provide the task instruction!")
182 |
183 | def apply_template(self, data):
184 | """
185 | This function is used to apply template to a given data, need to be implemented for each task.
186 | """
187 | raise NotImplementedError("Please provide how to apply template!")
188 |
189 | def print_data(self, indices):
190 | """
191 | This function is used to print data given indices.
192 | """
193 | if isinstance(indices, int):
194 | indices = [indices]
195 | for index in indices:
196 | print(self.all_data[index])
197 |
198 | def __len__(self):
199 | return len(self.all_data)
200 |
201 | def __getitem__(self, index):
202 | return self.all_data[index]
--------------------------------------------------------------------------------
/run_task_vector.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import copy
4 | import random
5 | import time
6 | import argparse
7 | import itertools
8 | import torch
9 | from multiprocessing import Process, Queue
10 |
11 | import utils
12 | import my_datasets as md
13 | import evaluator as ev
14 |
15 |
16 | def target_layer_selection(args, model_wrapper, tokenizer, evaluator, context_vector_dict):
17 | num_layers = model_wrapper.num_layers
18 | with torch.no_grad():
19 | best_layer = 0
20 | best_result = 0
21 | for layer in range(num_layers):
22 | with model_wrapper.replace_latent(context_vector_dict, [layer], args.config):
23 | val_result = evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
24 | use_cache=args.config['use_cache'])
25 | print(f'Layer {layer} result: {val_result}\n')
26 | if val_result[args.metric] > best_result:
27 | best_result = val_result[args.metric]
28 | best_layer = layer
29 | print(f'Best layer: {best_layer}')
30 | return best_layer
31 |
32 |
33 | def main(args):
34 | # set global seed
35 | utils.set_seed(args.config['seed'])
36 | # set device
37 | args.device = utils.set_device(args.gpu)
38 | # set metric used
39 | args.metric = args.config['metric']
40 | # get save dir
41 | utils.init_exp_path(args, args.config['exp_name'])
42 |
43 | # load tokenizer and model
44 | model, tokenizer, model_config = \
45 | utils.load_model_tokenizer(args.model_name, args.device)
46 |
47 | # get model_wrapper
48 | model_wrapper = utils.get_model_wrapper(args.model_name, model,
49 | tokenizer, model_config,
50 | args.device)
51 | # load datasets
52 | train_dataset = md.get_dataset(args.dataset_name, split='train', max_data_num=None)
53 | val_dataset = md.get_dataset(args.dataset_name, split='validation',
54 | max_data_num=args.config['val_data_num'],
55 | sample_mode=args.config['sample_method'])
56 | test_dataset = md.get_dataset(args.dataset_name, split='test',
57 | max_data_num=args.config['test_data_num'],
58 | sample_mode=args.config['sample_method'])
59 |
60 | # get max demonstration token length for each dataset
61 | args.val_max_token = val_dataset.get_max_demonstration_token_length(tokenizer)
62 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer)
63 |
64 | # get shot_num
65 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia
66 | args.config['shot_per_class'] = 1
67 | args.config['bs'] = 1
68 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class'])
69 | # build evaluate
70 | val_evaluator = ev.Evaluator(val_dataset, batch_size=args.config['bs'])
71 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs'])
72 | # init result_dict
73 | result_dict = {'demon': {},
74 | 'split_demon': {},
75 | 'best_replace_layer': {},
76 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': []},
77 | 'val_result': {'zero_shot': [], 'few_shot': [], 'ours': []},
78 | 'time': {'calibrate': [], 'evaluate': []},
79 | }
80 |
81 | for run_id in range(args.config['run_num']):
82 | run_name = f'run_{run_id}'
83 | args.run_name = run_name
84 | print(f'Run time {run_name}')
85 | run_seed = args.config['seed'] + run_id
86 | utils.set_seed(run_seed)
87 |
88 | # zero-shot baseline
89 | if run_id == 0 and args.config['run_baseline']:
90 | val_zeroshot_result = val_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
91 | use_cache=args.config['use_cache'])
92 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
93 | use_cache=args.config['use_cache'])
94 | result_dict['val_result']['zero_shot'].append(val_zeroshot_result)
95 | result_dict['test_result']['zero_shot'].append(test_zeroshot_result)
96 | print(f'Validation zero-shot result: {val_zeroshot_result}\n')
97 | print(f'Test zero-shot result: {test_zeroshot_result}\n')
98 |
99 | # sample demonstration
100 | demon, _, _ = \
101 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num,
102 | max_demonstration_tok_len=min(args.val_max_token,
103 | args.test_max_token),
104 | add_extra_query=args.config['add_extra_query'],
105 | example_separator=args.config['example_separator'],
106 | return_data_index=True, seed=random.randint(0, 1e6)
107 | )
108 |
109 | if args.config['add_extra_query']:
110 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0]
111 | # remove all contents after the last first_format_anchor including the anchor
112 | if first_format_anchor in demon:
113 | baseline_demon = demon[:demon.rfind(first_format_anchor)]
114 | query_demon = demon[demon.rfind(first_format_anchor):]
115 | else:
116 | baseline_demon = demon
117 | query_demon = None
118 | print(f'Demonstration:\n{demon}\n')
119 | print(f'Baseline demonstration:\n{baseline_demon}\n')
120 | print(f'Query demonstration:\n{query_demon}\n')
121 |
122 | # few-shot baseline
123 | if args.config['run_baseline']:
124 | val_fewshot_result = val_evaluator.evaluate(model_wrapper, tokenizer,
125 | demonstration=baseline_demon,
126 | use_cache=args.config['use_cache'])
127 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer,
128 | demonstration=baseline_demon,
129 | use_cache=args.config['use_cache'])
130 | result_dict['val_result']['few_shot'].append(val_fewshot_result)
131 | result_dict['test_result']['few_shot'].append(test_fewshot_result)
132 | print(f'Validation few-shot result: {val_fewshot_result}\n')
133 | print(f'Test few-shot result: {test_fewshot_result}\n')
134 |
135 | # extract latents ======================================================================
136 | all_latent_dicts = []
137 | with torch.no_grad():
138 | with model_wrapper.extract_latent():
139 | demon_token = tokenizer(demon, return_tensors='pt').to(args.device)
140 | _ = model(**demon_token)
141 | all_latent_dicts.append(model_wrapper.latent_dict)
142 |
143 | # generate context vector ==============================================================
144 | context_vector_dict = model_wrapper.get_context_vector(all_latent_dicts, args.config)
145 | del all_latent_dicts
146 |
147 | # injection layer selection ============================================================
148 | best_replace_layer = target_layer_selection(args, model_wrapper, tokenizer,
149 | val_evaluator, context_vector_dict)
150 | result_dict['best_replace_layer'][run_name] = best_replace_layer
151 |
152 | # evaluate task_vector ========================================================================
153 | s_t = time.time()
154 | with model_wrapper.replace_latent(context_vector_dict, [best_replace_layer], args.config):
155 | val_ours_result = val_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
156 | use_cache=args.config['use_cache'])
157 | print(f'Validation task_vector result: {val_ours_result}\n')
158 | result_dict['val_result']['ours'].append(val_ours_result)
159 |
160 | test_ours_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
161 | use_cache=args.config['use_cache'])
162 | print(f'Test task_vector result: {test_ours_result}\n')
163 | result_dict['test_result']['ours'].append(test_ours_result)
164 | e_t = time.time()
165 |
166 | print(f'Evaluate time: {e_t - s_t}')
167 | result_dict['time']['evaluate'].append(e_t - s_t)
168 |
169 | # save result_dict after each run
170 | with open(args.save_dir + '/result_dict.json', 'w') as f:
171 | json.dump(result_dict, f, indent=4)
172 |
173 | # delete all variables
174 | del model_wrapper, model, tokenizer, train_dataset, val_dataset, test_dataset
175 | del val_evaluator, test_evaluator, result_dict, context_vector_dict,
176 |
177 |
178 | # get args
179 | def get_args():
180 | parser = argparse.ArgumentParser()
181 | parser.add_argument('--config_path', type=str, default='configs/config_task_vector.py', help='path to config file')
182 | return parser.parse_args()
183 |
184 |
185 | if __name__ == "__main__":
186 | # get args
187 | args = get_args()
188 | # load config
189 | config = utils.load_config(args.config_path)
190 | # Generate all combinations of models and datasets
191 | combinations = list(itertools.product(config['models'], config['datasets']))
192 | # Queue to hold tasks
193 | task_queue = Queue()
194 | for combine in combinations:
195 | task_queue.put(combine)
196 |
197 | def run_task(gpu_id, config):
198 | while not task_queue.empty():
199 | model_name, dataset_name = task_queue.get()
200 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}")
201 | input_args = argparse.Namespace()
202 | cur_config = copy.deepcopy(config)
203 | input_args.model_name = model_name
204 | input_args.dataset_name = dataset_name
205 | input_args.gpu = gpu_id
206 | input_args.config = cur_config
207 | try:
208 | main(input_args)
209 | finally:
210 | # Clean up CUDA memory after each task
211 | gc.collect()
212 | torch.cuda.empty_cache()
213 | print(f"CUDA memory cleared for GPU {gpu_id}")
214 | time.sleep(5)
215 |
216 | # Create a process for each GPU
217 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']]
218 | # Start all processes
219 | for p in processes:
220 | p.start()
221 | # Wait for all processes to finish
222 | for p in processes:
223 | p.join()
224 | print("All tasks completed.")
--------------------------------------------------------------------------------
/run_i2cl.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import copy
4 | import time
5 | import random
6 | import argparse
7 | import itertools
8 | import torch
9 | import numpy as np
10 | from multiprocessing import Process, Queue
11 |
12 | import utils
13 | import my_datasets as md
14 | import evaluator as ev
15 |
16 |
17 | def main(args):
18 | # set global seed
19 | utils.set_seed(args.config['seed'])
20 | # set device
21 | args.device = utils.set_device(args.gpu)
22 | # set metric used
23 | args.metric = args.config['metric']
24 | # get save dir
25 | utils.init_exp_path(args, args.config['exp_name'])
26 |
27 | # load tokenizer and model
28 | model, tokenizer, model_config = \
29 | utils.load_model_tokenizer(args.model_name, args.device)
30 |
31 | # get model_wrapper
32 | model_wrapper = utils.get_model_wrapper(args.model_name, model,
33 | tokenizer, model_config,
34 | args.device)
35 |
36 | # load datasets
37 | train_dataset = md.get_dataset(args.dataset_name, split='train',
38 | max_data_num=None, seed=args.config['seed'])
39 | holdout_dataset = md.get_dataset(args.dataset_name, split='validation',
40 | max_data_num=args.config['val_data_num'],
41 | sample_mode=args.config['sample_method'],
42 | seed=args.config['seed'])
43 | test_dataset = md.get_dataset(args.dataset_name, split='test',
44 | max_data_num=args.config['test_data_num'],
45 | sample_mode=args.config['sample_method'],
46 | seed=args.config['seed'])
47 |
48 | # get max demonstration token length for each dataset
49 | if args.config['split_demon']:
50 | args.test_max_token = 1e8
51 | else:
52 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer)
53 |
54 | # get shot_num
55 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia
56 | args.config['shot_per_class'] = 1
57 | args.config['bs'] = 1
58 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class'])
59 |
60 | # build evaluators
61 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs'])
62 | holdout_evaluator = ev.Evaluator(holdout_dataset, batch_size=args.config['bs'])
63 | # init result_dict
64 | result_dict = {'demon': {},
65 | 'split_demon': {},
66 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': []},
67 | 'linear_coef': {},
68 | 'time': {'calibrate': [], 'evaluate': []},
69 | }
70 | cv_save_dict = {}
71 |
72 | for run_id in range(args.config['run_num']):
73 | run_name = f'run_{run_id}'
74 | args.run_name = run_name
75 | print(f'Run time {run_name}')
76 | run_seed = args.config['seed'] + run_id
77 | utils.set_seed(run_seed)
78 |
79 | # zero-shot baseline
80 | if run_id == 0 and args.config['run_baseline']:
81 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
82 | use_cache=args.config['use_cache'])
83 | result_dict['test_result']['zero_shot'].append(test_zeroshot_result)
84 | print(f'Test zero-shot result: {test_zeroshot_result}\n')
85 |
86 | # sample demonstration
87 | count = 0
88 | temp_demon_list, temp_result_list = [], []
89 | while True:
90 | demon, split_demon, demon_data_index = \
91 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num,
92 | max_demonstration_tok_len=args.test_max_token,
93 | add_extra_query=args.config['add_extra_query'],
94 | example_separator=args.config['example_separator'],
95 | gen_example_method = args.config['gen_example_method'],
96 | return_data_index=True, seed=random.randint(0, 1e6))
97 | temp_demon_list.append((demon, split_demon, demon_data_index))
98 |
99 | if args.config['demo_sample_method'] == 'random':
100 | break
101 | else:
102 | tem_val_result = holdout_evaluator.evaluate(model_wrapper, tokenizer,
103 | demonstration=demon,
104 | use_cache=args.config['use_cache'])
105 | temp_result = tem_val_result[args.metric]
106 | temp_result_list.append(temp_result)
107 | if count > 20:
108 | if args.config['demo_sample_method'] == 'deficient':
109 | demon, split_demon, demon_data_index = temp_demon_list[np.argmin(temp_result_list)]
110 | else:
111 | raise ValueError('Invalid demo_sample_method!')
112 | break
113 | count += 1
114 |
115 | # build val_evaluator use demon_data_index
116 | cali_dataset = copy.deepcopy(train_dataset)
117 | cali_dataset.all_data = [train_dataset.all_data[i] for i in demon_data_index]
118 |
119 | # clean demonstration
120 | if args.config['add_extra_query']:
121 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0]
122 | # remove all contents after the last first_format_anchor including the anchor
123 | if first_format_anchor in demon:
124 | baseline_demon = demon[:demon.rfind(first_format_anchor)]
125 | query_demon = demon[demon.rfind(first_format_anchor):]
126 | else:
127 | baseline_demon = demon
128 | query_demon = None
129 | print(f'Demonstration:\n{demon}\n')
130 | print(f'Baseline demonstration:\n{baseline_demon}\n')
131 | print(f'Query demonstration:\n{query_demon}\n')
132 |
133 | # few-shot baseline
134 | if args.config['run_baseline']:
135 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer,
136 | demonstration=baseline_demon,
137 | use_cache=args.config['use_cache'])
138 | result_dict['test_result']['few_shot'].append(test_fewshot_result)
139 | print(f'Test few-shot result: {test_fewshot_result}\n')
140 |
141 | # generate demon_list
142 | demon_list = [demon]
143 | split_demon_list = split_demon
144 | result_dict['demon'][run_name] = demon_list
145 | result_dict['split_demon'][run_name] = split_demon_list
146 |
147 | # init strength_params
148 | model_wrapper.init_strength(args.config)
149 |
150 | # extract latents
151 | all_latent_dicts = []
152 | with torch.no_grad():
153 | if not args.config['split_demon']:
154 | target_demon_list = demon_list[0]
155 | else:
156 | target_demon_list = split_demon_list
157 | for cur_demon in target_demon_list:
158 | with model_wrapper.extract_latent():
159 | demon_token = tokenizer(cur_demon, return_tensors='pt').to(args.device)
160 | _ = model(**demon_token)
161 | all_latent_dicts.append(model_wrapper.latent_dict)
162 | model_wrapper.reset_latent_dict()
163 |
164 | # generate context vector
165 | context_vector_dict = model_wrapper.get_context_vector(all_latent_dicts, args.config)
166 | if args.config['gen_cv_method'] == 'noise':
167 | context_vector_dict = model_wrapper.init_noise_context_vector(context_vector_dict)
168 | del all_latent_dicts
169 |
170 | # calibrate context vector
171 | s_t = time.time()
172 | model_wrapper.calibrate_strength(context_vector_dict, cali_dataset,
173 | args.config, save_dir=args.save_dir,
174 | run_name=args.run_name)
175 | e_t = time.time()
176 | print(f'Calibration time: {e_t - s_t}')
177 | result_dict['time']['calibrate'].append(e_t - s_t)
178 |
179 | # save linear_coef
180 | result_dict['linear_coef'][run_name] = model_wrapper.linear_coef.tolist()
181 |
182 | # evaluate i2cl
183 | s_t = time.time()
184 | with torch.no_grad():
185 | with model_wrapper.inject_latent(context_vector_dict, args.config,
186 | model_wrapper.linear_coef):
187 | test_ours_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
188 | use_cache=args.config['use_cache'])
189 | print(f'Test I2CL result: {test_ours_result}\n')
190 | result_dict['test_result']['ours'].append(test_ours_result)
191 | e_t = time.time()
192 | print(f'Evaluate time: {e_t - s_t}')
193 | result_dict['time']['evaluate'].append(e_t - s_t)
194 |
195 | # save result_dict after each run
196 | with open(args.save_dir + '/result_dict.json', 'w') as f:
197 | json.dump(result_dict, f, indent=4)
198 |
199 | # save context vector dict
200 | for layer, subdict in context_vector_dict.items():
201 | for module, activation in subdict.items():
202 | context_vector_dict[layer][module] = activation.cpu().numpy().tolist()
203 | cv_save_dict[run_name] = context_vector_dict
204 |
205 | with open(args.save_dir + '/cv_save_dict.json', 'w') as f:
206 | json.dump(cv_save_dict, f, indent=4)
207 |
208 | # delete all variables
209 | del model_wrapper, model, tokenizer, train_dataset, cali_dataset, test_dataset, holdout_dataset
210 | del test_evaluator, holdout_evaluator
211 | del result_dict, context_vector_dict, demon_list
212 |
213 |
214 | # get args
215 | def get_args():
216 | parser = argparse.ArgumentParser()
217 | parser.add_argument('--config_path', type=str, default='configs/config_i2cl.py', help='path to config file')
218 | return parser.parse_args()
219 |
220 |
221 | if __name__ == "__main__":
222 | # get args
223 | args = get_args()
224 | # load config
225 | config = utils.load_config(args.config_path)
226 | # Generate all combinations of models and datasets
227 | combinations = list(itertools.product(config['models'], config['datasets']))
228 | # Queue to hold tasks
229 | task_queue = Queue()
230 | for combine in combinations:
231 | task_queue.put(combine)
232 |
233 | def run_task(gpu_id, config):
234 | while not task_queue.empty():
235 | model_name, dataset_name = task_queue.get()
236 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}")
237 | input_args = argparse.Namespace()
238 | cur_config = copy.deepcopy(config)
239 | input_args.model_name = model_name
240 | input_args.dataset_name = dataset_name
241 | input_args.gpu = gpu_id
242 | input_args.config = cur_config
243 | try:
244 | main(input_args)
245 | finally:
246 | # Clean up CUDA memory after each task
247 | gc.collect()
248 | torch.cuda.empty_cache()
249 | print(f"CUDA memory cleared for GPU {gpu_id}")
250 | time.sleep(5)
251 |
252 | # Create a process for each GPU
253 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']]
254 | # Start all processes
255 | for p in processes:
256 | p.start()
257 | # Wait for all processes to finish
258 | for p in processes:
259 | p.join()
260 | print("All tasks completed.")
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import random
5 | import functools
6 | import warnings
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from typing import Union, List, Optional
11 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
12 | import matplotlib.pyplot as plt
13 | import wrapper
14 | import my_datasets as md
15 |
16 |
17 | def set_seed(seed):
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 | torch.backends.cudnn.deterministic = True
23 |
24 |
25 | def set_device(gpu_id):
26 | device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
27 | torch.cuda.set_device(device)
28 | return device
29 |
30 |
31 | def init_exp_path(args, exp_name, separate_dataset=True):
32 | if separate_dataset:
33 | save_dir = os.path.join(exp_name, args.model_name, args.dataset_name)
34 | else:
35 | save_dir = os.path.join(exp_name, args.model_name)
36 | args.save_dir = save_dir
37 | if os.path.exists(save_dir) and 'debug' not in exp_name:
38 | raise ValueError(f"Experiment {exp_name} already exists! please delete it or change the name!")
39 | os.makedirs(save_dir, exist_ok=True)
40 | # save config_dict
41 | with open(f'{save_dir}/config.json', 'w') as f:
42 | json.dump(args.config, f, indent=4)
43 | # save args as txt file, I want to make it beautiful
44 | with open(f'{save_dir}/args.txt', 'w') as f:
45 | for key, value in vars(args).items():
46 | f.write(f'{key}: {value}\n')
47 |
48 |
49 | def load_model_tokenizer(model_name, device, output_hidden_states=True, load_in_8bit=False):
50 | # load tokenizer and model
51 | tokenizer = AutoTokenizer.from_pretrained(model_name)
52 | model = AutoModelForCausalLM.from_pretrained(model_name,
53 | output_hidden_states=output_hidden_states,
54 | load_in_8bit=load_in_8bit,
55 | torch_dtype=torch.float32)
56 | if not load_in_8bit:
57 | model = model.to(device)
58 | config = AutoConfig.from_pretrained(model_name)
59 | tokenizer.pad_token = tokenizer.eos_token
60 | tokenizer.padding_side = 'right'
61 | return model, tokenizer, config
62 |
63 |
64 | def get_model_wrapper(model_name, model, tokenizer, model_config, device):
65 | if 'llama' in model_name:
66 | model_wrapper = wrapper.LlamaWrapper(model, tokenizer, model_config, device)
67 | elif 'gpt' in model_name:
68 | model_wrapper = wrapper.GPTWrapper(model, tokenizer, model_config, device)
69 | else:
70 | raise ValueError("only support llama or gpt!")
71 | return model_wrapper
72 |
73 |
74 | def load_config(file_path):
75 | if not file_path:
76 | raise ValueError("No file path provided")
77 | file_dir = os.path.dirname(file_path)
78 | if file_dir not in sys.path:
79 | sys.path.append(file_dir)
80 | file_name = os.path.basename(file_path)
81 | module_name = os.path.splitext(file_name)[0]
82 | module = __import__(module_name)
83 | try:
84 | my_variable = getattr(module, 'config')
85 | print(my_variable)
86 | return my_variable
87 | except AttributeError:
88 | print(f"The module does not have a variable named 'config'")
89 |
90 |
91 | def get_shot_num(dataset, shot_per_class, shot_num=5):
92 | if hasattr(dataset, 'class_num') and dataset.class_num is not None:
93 | shot_num = dataset.class_num * shot_per_class
94 | else:
95 | shot_num = shot_num
96 | # if shot_num < 0, then use all data
97 | if shot_num < 0:
98 | shot_num = -1
99 | return shot_num
100 |
101 |
102 | def first_one_indices(tensor):
103 | """
104 | Finds the index of the first 1 in each row of a 2D tensor.
105 |
106 | Args:
107 | tensor (torch.Tensor): A 2D tensor of size (N, M) containing only 0 and 1 entries.
108 |
109 | Returns:
110 | torch.Tensor: A tensor of size N containing the index of the first 1 in each row.
111 | If a row contains only 0s, the index will be set to -1 (or a sentinel value of your choice).
112 | """
113 | # Check for rows containing only zeros.
114 | is_all_zero = tensor.sum(dim=1) == 0
115 | # Get the index of the first occurrence of the maximum value (1) along each row.
116 | indices = tensor.argmax(dim=1)
117 | # Handle rows with all zeros.
118 | indices[is_all_zero] = -1 # Set to -1 to indicate no '1' found in these rows
119 | return indices
120 |
121 |
122 | def last_one_indices(tensor):
123 | """
124 | Finds the index of the last 1 in each row of a 2D tensor.
125 |
126 | Args:
127 | tensor (torch.Tensor): A 2D tensor of size (N, M) containing only 0 and 1 entries.
128 |
129 | Returns:
130 | torch.Tensor: A tensor of size N containing the index of the last 1 in each row.
131 | If a row contains only 0s, the index will be set to -1 (or a sentinel value of your choice).
132 | """
133 | # Reverse each row to find the last occurrence of 1 (which becomes the first in the reversed row)
134 | reversed_tensor = torch.flip(tensor, [1])
135 | # Check for rows containing only zeros in the reversed tensor
136 | is_all_zero = reversed_tensor.sum(dim=1) == 0
137 | # Get the index of the first occurrence of the maximum value (1) along each row in the reversed tensor
138 | indices = reversed_tensor.argmax(dim=1)
139 | # Adjust the indices for the original order of each row
140 | indices = tensor.size(1) - 1 - indices
141 | # Handle rows with all zeros
142 | indices[is_all_zero] = -1 # Set to -1 to indicate no '1' found in these rows
143 | return indices
144 |
145 |
146 | def plot_loss_curve(loss_list, save_path):
147 | plt.plot(loss_list)
148 | plt.xlabel('epoch')
149 | plt.ylabel('loss')
150 | plt.savefig(save_path)
151 | plt.close()
152 |
153 |
154 | def svd_flip(u, v):
155 | # columns of u, rows of v
156 | max_abs_cols = torch.argmax(torch.abs(u), 0)
157 | i = torch.arange(u.shape[1]).to(u.device)
158 | signs = torch.sign(u[max_abs_cols, i])
159 | u *= signs
160 | v *= signs.view(-1, 1)
161 | return u, v
162 |
163 |
164 | class PCA(nn.Module):
165 | def __init__(self, n_components):
166 | super().__init__()
167 | self.n_components = n_components
168 |
169 | @torch.no_grad()
170 | def fit(self, X):
171 | n, d = X.size()
172 | if self.n_components is not None:
173 | d = min(self.n_components, d)
174 | self.register_buffer("mean_", X.mean(0, keepdim=True))
175 | Z = X - self.mean_ # center
176 | U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
177 | Vt = Vh
178 | U, Vt = svd_flip(U, Vt)
179 | self.register_buffer("components_", Vt[:d])
180 | return self
181 |
182 | def forward(self, X):
183 | return self.transform(X)
184 |
185 | def transform(self, X):
186 | assert hasattr(self, "components_"), "PCA must be fit before use."
187 | return torch.matmul(X - self.mean_, self.components_.t())
188 |
189 | def fit_transform(self, X):
190 | self.fit(X)
191 | return self.transform(X)
192 |
193 | def inverse_transform(self, Y):
194 | assert hasattr(self, "components_"), "PCA must be fit before use."
195 | return torch.matmul(Y, self.components_) + self.mean_
196 |
197 |
198 | class ContextSolver:
199 | def __init__(self, task_name, tokenizer=None):
200 | # assert task_name in ['sst2', 'trec', 'agnews', 'emo']
201 | self.task_name = task_name
202 | self.tokenizer = tokenizer
203 | self.task_dataset = md.get_dataset(task_name, split='train', max_data_num=10)
204 | self.format_s = self.task_dataset.get_dmonstration_template()['input']
205 | self.parse_format_s()
206 |
207 | def parse_format_s(self):
208 | self.X_prefix = self.format_s.split('\n')[0].split(':')[0] + ':'
209 | self.Y_prefix = self.format_s.split('\n')[1].split(':')[0] + ':'
210 |
211 | def get_empty_demo_context(self, context: str, only_demo_part=True):
212 | context = context.split('\n')
213 | for i, line in enumerate(context[:-2]):
214 | if self.X_prefix in line:
215 | line = self.X_prefix
216 | elif self.Y_prefix in line:
217 | line = line
218 | else:
219 | raise warnings.warn('Global prefix or other str exists!')
220 | context[i] = line
221 | if only_demo_part:
222 | context = context[:-2]
223 | context = '\n'.join(context)
224 | return context
225 |
226 | def get_mask_strings_and_match_before(self, context, input_ids, tokenizer=None):
227 | if tokenizer is None:
228 | tokenizer = self.tokenizer
229 | print('debug tokenizer name :', tokenizer.__class__.__name__)
230 | if 'Llama' in tokenizer.__class__.__name__:
231 | sap_token = tokenizer.encode('\n', add_special_tokens=False)[1]
232 | poss = torch.where(input_ids == sap_token)[0]
233 | else:
234 | sap_token = tokenizer.encode('\n', add_special_tokens=False)[0]
235 | poss = torch.where(input_ids == sap_token)[0]
236 | print('debug sap_token:', sap_token)
237 | print('debug poss:', poss)
238 | if len(poss) >= 2:
239 | match_before = poss[-2] + 1
240 | else:
241 | match_before = None
242 |
243 | list_s = []
244 | list_s.append(self.X_prefix)
245 | list_s.append('\n' + self.X_prefix)
246 | context = context.split('\n')
247 | for i, line in enumerate(context[:-2]):
248 | if self.X_prefix in line:
249 | pass
250 | elif self.Y_prefix in line:
251 | list_s.append('\n' + line)
252 | list_s.append('\n' + line + '\n')
253 | else:
254 | raise warnings.warn('Global prefix or other str exists!')
255 | return list_s, match_before
256 |
257 | def get_mask(self, input_ids, tokenizer=None):
258 | if isinstance(input_ids, list):
259 | input_ids = torch.tensor(input_ids)
260 | if len(input_ids.shape) == 2:
261 | assert input_ids.shape[0] == 1
262 | input_ids = input_ids[0]
263 | if tokenizer is None:
264 | tokenizer = self.tokenizer
265 | context = tokenizer.decode(input_ids)
266 | list_s, match_before = self.get_mask_strings_and_match_before(context, input_ids=input_ids,
267 | tokenizer=tokenizer)
268 | print('debug context:', context)
269 | print('debug list_s:', list_s)
270 | print('debug match_before:', match_before)
271 | tensor_str_finder = TensorStrFinder(tokenizer=tokenizer)
272 | mask = tensor_str_finder.get_strs_mask_in_tensor(list_s=list_s, t=input_ids,
273 | match_before=match_before)
274 | return mask
275 |
276 |
277 | class TensorStrFinder:
278 | def __init__(self, tokenizer):
279 | self.tokenizer = tokenizer
280 |
281 | def find_tensor_in_tensor(self, a_tensor: Union[torch.Tensor, list], b_tensor: torch.Tensor,
282 | return_mask=True, match_before: Optional[int] = None):
283 | if len(b_tensor.shape) == 2:
284 | assert b_tensor.shape[0] == 1
285 | b_tensor = b_tensor[0]
286 | if isinstance(a_tensor, list):
287 | a_tensor = torch.tensor(a_tensor)
288 | if a_tensor.device != b_tensor.device:
289 | a_tensor = a_tensor.to(b_tensor.device)
290 |
291 | window_size = len(a_tensor)
292 | b_windows = b_tensor.unfold(0, window_size, 1)
293 |
294 | matches = torch.all(b_windows == a_tensor, dim=1)
295 |
296 | positions = torch.nonzero(matches, as_tuple=True)[0]
297 |
298 | if return_mask:
299 | mask = torch.zeros_like(b_tensor, dtype=torch.bool)
300 | for pos in positions:
301 | if match_before is None or pos + window_size <= match_before:
302 | mask[pos:pos + window_size] = True
303 | return mask
304 |
305 | return positions
306 |
307 | def find_str_in_tensor(self, s: str, t: torch.Tensor, return_mask=True, match_before=None):
308 | s_tokens = self.tokenizer.encode(s, add_special_tokens=False)
309 | s_tensor = torch.LongTensor(s_tokens)
310 | return self.find_tensor_in_tensor(s_tensor, t, return_mask=return_mask,
311 | match_before=match_before)
312 |
313 | def get_strs_mask_in_tensor(self, list_s: List[str], t: torch.Tensor, match_before=None):
314 | list_s_tokens = [self.tokenizer.encode(s, add_special_tokens=False) for s in list_s]
315 | if 'Llama' in self.tokenizer.__class__.__name__:
316 | list_s_tokens = [s_tokens[1:] if s_tokens[0] == 29871 else s_tokens for s_tokens in list_s_tokens]
317 | list_s_tensor = [torch.LongTensor(s_tokens) for s_tokens in list_s_tokens]
318 | print('debug list_s_tensor:', list_s_tensor)
319 | mask_tensor_list = [
320 | self.find_tensor_in_tensor(s_tensor, t, return_mask=True, match_before=match_before) for
321 | s_tensor in list_s_tensor]
322 | mask_tensor = functools.reduce(torch.logical_or, mask_tensor_list)
323 | return mask_tensor
--------------------------------------------------------------------------------
/run_i2cl_transfer_learning.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import json
4 | import time
5 | import copy
6 | import random
7 | import argparse
8 | import itertools
9 | import torch
10 | import numpy as np
11 | from multiprocessing import Process, Queue
12 | from itertools import combinations
13 |
14 | import utils
15 | import my_datasets as md
16 | import evaluator as ev
17 |
18 |
19 | def main(args):
20 | # load config from target_path
21 | tar_exp_path = os.path.join(args.config['target_path'],
22 | args.model_name, args.dataset_name)
23 | tar_config_path = os.path.join(tar_exp_path, 'config.json')
24 | tar_result_path = os.path.join(tar_exp_path, 'result_dict.json')
25 | # reutrn if target_path does not exist
26 | if not os.path.exists(tar_config_path) or not os.path.exists(tar_result_path):
27 | print(f"target_path: {tar_exp_path} does not exist!")
28 | return
29 | # load config
30 | with open(tar_config_path, 'r') as f:
31 | config = json.load(f)
32 | # load result_dict
33 | with open(tar_result_path, 'r') as f:
34 | result_dict = json.load(f)
35 |
36 | # update config with args.config
37 | config.update(args.config)
38 | args.config = config
39 | args.result_dict = result_dict
40 |
41 | # set global seed
42 | utils.set_seed(args.config['seed'])
43 | # set device
44 | args.device = utils.set_device(args.gpu)
45 | # set metric used
46 | args.metric = args.config['metric']
47 | # get save dir
48 | utils.init_exp_path(args, args.config['exp_name'])
49 |
50 | # load tokenizer and model
51 | model, tokenizer, model_config = \
52 | utils.load_model_tokenizer(args.model_name, args.device)
53 |
54 | # get model_wrapper
55 | model_wrapper = utils.get_model_wrapper(args.model_name, model,
56 | tokenizer, model_config,
57 | args.device)
58 | # load datasets
59 | train_dataset = md.get_dataset(args.dataset_name, split='train',
60 | max_data_num=None, seed=args.config['seed'])
61 | test_dataset = md.get_dataset(args.dataset_name, split='test',
62 | max_data_num=args.config['test_data_num'],
63 | sample_mode=args.config['sample_method'],
64 | seed=args.config['seed'])
65 |
66 | # get max demonstration token length for each dataset
67 | if args.config['split_demon']:
68 | # when split demon, do not check max example token length
69 | args.test_max_token = 1e8
70 | else:
71 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer)
72 |
73 | # get shot_num
74 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia
75 | args.config['shot_per_class'] = 1
76 | args.config['bs'] = 1
77 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class'])
78 | # build evaluate
79 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs'])
80 | # init result_dict
81 | infer_result_dict = {'demon': {},
82 | 'split_demon': {},
83 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': [], 'ensemble': {}},
84 | 'linear_coef': {},
85 | 'time': {'calibrate': [], 'evaluate': []}
86 | }
87 |
88 | all_cv_dicts, all_coef = [], []
89 | for dataset_name in list(md.target_datasets.keys()):
90 | # collect calibrated coefficients
91 | coe_save_path = os.path.join(args.config['target_path'], args.model_name, dataset_name, 'result_dict.json')
92 | with open(coe_save_path, 'r') as f:
93 | cur_result_dict = json.load(f)
94 | tem_coef = []
95 | for run_id, coef in cur_result_dict['linear_coef'].items():
96 | tem_coef.append(torch.tensor(coef))
97 | # average strength params
98 | all_coef.append(torch.stack(tem_coef).mean(dim=0))
99 |
100 | # collect context vectors
101 | cv_save_path = os.path.join(args.config['target_path'], args.model_name, dataset_name, 'cv_save_dict.json')
102 | with open(cv_save_path, 'r') as f:
103 | cv_dict = json.load(f)
104 | cur_cv_dict = {}
105 | for _, cv_dict in cv_dict.items():
106 | for layer, sub_dict in cv_dict.items():
107 | if layer not in cur_cv_dict:
108 | cur_cv_dict[layer] = {}
109 | for module, activation in sub_dict.items():
110 | if module not in cur_cv_dict[layer]:
111 | cur_cv_dict[layer][module] = []
112 | cur_cv_dict[layer][module].append(torch.tensor(activation))
113 | # average context vector diict
114 | for layer, sub_dict in cur_cv_dict.items():
115 | for module, activation_list in sub_dict.items():
116 | cur_cv_dict[layer][module] = torch.stack(activation_list).mean(dim=0)
117 | all_cv_dicts.append(cur_cv_dict)
118 |
119 | for run_id in range(args.config['run_num']):
120 | run_name = f'run_{run_id}'
121 | args.run_name = run_name
122 | print(f'Run time {run_name}')
123 | run_seed = args.config['seed'] + run_id
124 | utils.set_seed(run_seed)
125 |
126 | # build val dataset
127 | _, split_demon_list, demon_data_index = \
128 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num,
129 | max_demonstration_tok_len=args.test_max_token,
130 | add_extra_query=args.config['add_extra_query'],
131 | example_separator=args.config['example_separator'],
132 | gen_example_method = args.config['gen_example_method'],
133 | return_data_index=True, seed=random.randint(0, 1e6))
134 | # build val_evaluator use demon_data_index
135 | val_dataset = copy.deepcopy(train_dataset)
136 | val_dataset.all_data = [train_dataset.all_data[i] for i in demon_data_index]
137 |
138 | # get demon
139 | demon_list = args.result_dict['demon'][run_name]
140 | assert split_demon_list == args.result_dict['split_demon'][run_name], \
141 | print(f'split_demon_list: {split_demon_list} != {args.result_dict["split_demon"][run_name]}')
142 |
143 | # save demon_list
144 | infer_result_dict['demon'][run_name] = demon_list
145 | infer_result_dict['split_demon'][run_name] = split_demon_list
146 |
147 | # cal task simlarity ===================================================================
148 | cur_coef = torch.tensor(args.result_dict['linear_coef'][run_name]).view(-1)
149 | ref_coef = torch.stack(all_coef)
150 | ref_coef = ref_coef.view(ref_coef.size(0), -1)
151 | # calculate cosine similarity between cur_coef and all_coef
152 | sim = torch.nn.functional.cosine_similarity(cur_coef, ref_coef, dim=1)
153 |
154 | # keep sim and its index whose similarity is larger than threshold
155 | tar_sim = sim[sim > args.config['threshold']]
156 | tar_idx = torch.nonzero(sim > args.config['threshold']).view(-1)
157 |
158 | print(f'Similarities: {tar_sim}')
159 | print(f'Similar tasks: {[list(md.target_datasets.keys())[idx] for idx in tar_idx]}')
160 |
161 | tar_sim = tar_sim.cpu().numpy()
162 | tar_cv_dicts = [all_cv_dicts[idx] for idx in tar_idx]
163 | tar_coef = [all_coef[idx] for idx in tar_idx]
164 | assert len(tar_cv_dicts) > 2, print('No enough transferable tasks!')
165 | infer_result_dict['similar_task_num'] = len(tar_cv_dicts)
166 |
167 | # change top_k_sim to probability distribution that sums to 1
168 | def softmax_with_temperature(logits, temperature=1.0):
169 | scaled_logits = logits / temperature
170 | exps = np.exp(scaled_logits - np.max(scaled_logits)) # For numerical stability
171 | softmax_outputs = exps / np.sum(exps)
172 | return softmax_outputs
173 |
174 | tar_prob = softmax_with_temperature(tar_sim, args.config['temp'])
175 | context_vector_dict, linear_coef = prepare_inject_dicts_params(tar_prob, tar_cv_dicts, tar_coef)
176 |
177 | # set strength params
178 | model_wrapper.init_strength(args.config)
179 |
180 | # calibrate context vector
181 | s_t = time.time()
182 | model_wrapper.calibrate_strength(context_vector_dict, val_dataset,
183 | args.config, save_dir=args.save_dir,
184 | run_name=args.run_name)
185 | e_t = time.time()
186 | print(f'Calibration time: {e_t - s_t}')
187 | infer_result_dict['time']['calibrate'].append(e_t - s_t)
188 |
189 | # save linear_coef
190 | infer_result_dict['linear_coef'][run_name] = model_wrapper.linear_coef.tolist()
191 |
192 | # evaluate i2cl ========================================================================
193 | s_t = time.time()
194 | with torch.no_grad():
195 | with model_wrapper.inject_latent(context_vector_dict, args.config,
196 | model_wrapper.linear_coef):
197 | test_ours_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
198 | use_cache=args.config['use_cache'])
199 | print(f'Test I2CL result: {test_ours_result}\n')
200 | infer_result_dict['test_result']['ours'].append(test_ours_result)
201 | e_t = time.time()
202 |
203 | print(f'Evaluate time: {e_t - s_t}')
204 | infer_result_dict['time']['evaluate'].append(e_t - s_t)
205 |
206 | # save result_dict after each run
207 | with open(args.save_dir + '/result_dict.json', 'w') as f:
208 | json.dump(infer_result_dict, f, indent=4)
209 |
210 | # delete all variables
211 | del model, tokenizer, model_config, model_wrapper, train_dataset, test_dataset, test_evaluator
212 | del all_cv_dicts, all_coef
213 | del context_vector_dict, linear_coef
214 | del infer_result_dict
215 |
216 |
217 | def prepare_inject_dicts_params(tar_prob, tar_cv_dicts, tar_coef):
218 | target_layers = list(tar_cv_dicts[0].keys())
219 | target_modules = list(tar_cv_dicts[0][target_layers[0]].keys())
220 | print(f'target_layers: {target_layers}')
221 | print(f'target_modules: {target_modules}')
222 | # init an empty ensemble_dict with the same structure as all_inject_dicts
223 | ensemble_cv_dict = {layer: {module: 0 for module in target_modules} for layer in target_layers}
224 | new_coef = torch.zeros(tar_coef[0].size())
225 | for idx, cv_dict in enumerate(tar_cv_dicts):
226 | for layer_idx, layer in enumerate(target_layers):
227 | for module_idx, module in enumerate(target_modules):
228 | cv = cv_dict[layer][module]
229 | coef = tar_coef[idx][layer_idx, module_idx, 0]
230 | ensemble_cv_dict[layer][module] += cv * coef * tar_prob[idx]
231 | new_coef += tar_coef[idx] * tar_prob[idx]
232 | # set the first strength param to 1 since coefficient of context vector has been included in the context vector
233 | new_coef[:, :, 0] = 1
234 | # set layer name in ensemble_cv_dict to int type
235 | ensemble_cv_dict = {int(layer): sub_dict for layer, sub_dict in ensemble_cv_dict.items()}
236 | return ensemble_cv_dict, new_coef
237 |
238 |
239 | # get args
240 | def get_args():
241 | parser = argparse.ArgumentParser()
242 | parser.add_argument('--config_path', type=str, default='configs/config_i2cl_transfer_learning.py', help='path to config file')
243 | return parser.parse_args()
244 |
245 |
246 | if __name__ == "__main__":
247 | # get args
248 | args = get_args()
249 | # load config
250 | config = utils.load_config(args.config_path)
251 | # Generate all combinations of models and datasets
252 | combinations = list(itertools.product(config['models'], config['datasets']))
253 | # Queue to hold tasks
254 | task_queue = Queue()
255 | for combine in combinations:
256 | task_queue.put(combine)
257 |
258 | def run_task(gpu_id, config):
259 | while not task_queue.empty():
260 | model_name, dataset_name = task_queue.get()
261 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}")
262 | input_args = argparse.Namespace()
263 | cur_config = copy.deepcopy(config)
264 | input_args.model_name = model_name
265 | input_args.dataset_name = dataset_name
266 | input_args.gpu = gpu_id
267 | input_args.config = cur_config
268 | try:
269 | main(input_args)
270 | finally:
271 | # Clean up CUDA memory after each task
272 | gc.collect()
273 | torch.cuda.empty_cache()
274 | print(f"CUDA memory cleared for GPU {gpu_id}")
275 | time.sleep(5)
276 |
277 | # Create a process for each GPU
278 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']]
279 | # Start all processes
280 | for p in processes:
281 | p.start()
282 | # Wait for all processes to finish
283 | for p in processes:
284 | p.join()
285 | print("All tasks completed.")
--------------------------------------------------------------------------------
/run_i2cl_infer.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import json
4 | import time
5 | import copy
6 | import random
7 | import argparse
8 | import itertools
9 | import torch
10 | import numpy as np
11 | from multiprocessing import Process, Queue
12 | from itertools import combinations
13 |
14 | import utils
15 | import my_datasets as md
16 | import evaluator as ev
17 |
18 |
19 | def main(args):
20 | # load config from target_path
21 | tar_exp_path = os.path.join(args.config['target_path'],
22 | args.model_name, args.dataset_name)
23 | tar_config_path = os.path.join(tar_exp_path, 'config.json')
24 | tar_result_path = os.path.join(tar_exp_path, 'result_dict.json')
25 | # reutrn if target_path does not exist
26 | if not os.path.exists(tar_config_path) or not os.path.exists(tar_result_path):
27 | print(f"target_path: {tar_exp_path} does not exist!")
28 | return
29 | # load config
30 | with open(tar_config_path, 'r') as f:
31 | config = json.load(f)
32 | # load result_dict
33 | with open(tar_result_path, 'r') as f:
34 | result_dict = json.load(f)
35 | # update config with args.config
36 | config.update(args.config)
37 | args.config = config
38 | args.result_dict = result_dict
39 |
40 | # set global seed
41 | utils.set_seed(args.config['seed'])
42 | # set device
43 | args.device = utils.set_device(args.gpu)
44 | # set metric used
45 | args.metric = args.config['metric']
46 | # get save dir
47 | utils.init_exp_path(args, args.config['exp_name'])
48 |
49 | # load tokenizer and model
50 | model, tokenizer, model_config = \
51 | utils.load_model_tokenizer(args.model_name, args.device)
52 |
53 | # get model_wrapper
54 | model_wrapper = utils.get_model_wrapper(args.model_name, model,
55 | tokenizer, model_config,
56 | args.device)
57 | # load datasets
58 | train_dataset = md.get_dataset(args.dataset_name, split='train',
59 | max_data_num=None, seed = args.config['seed'])
60 | holdout_dataset = md.get_dataset(args.dataset_name, split='validation',
61 | max_data_num=args.config['val_data_num'],
62 | sample_mode=args.config['sample_method'],
63 | seed=args.config['seed'])
64 | test_dataset = md.get_dataset(args.dataset_name, split='test',
65 | max_data_num=args.config['test_data_num'],
66 | sample_mode=args.config['sample_method'],
67 | seed=args.config['seed'])
68 |
69 | # get shot_num
70 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia
71 | args.config['shot_per_class'] = 1
72 | args.config['bs'] = 1
73 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class'])
74 | # build evaluate
75 | holdout_evaluator = ev.Evaluator(holdout_dataset, batch_size=args.config['bs'])
76 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs'])
77 | # init result_dict
78 | infer_result_dict = {'demon': {},
79 | 'split_demon': {},
80 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': {}},
81 | 'linear_coef': {},
82 | 'time': {'calibrate': [], 'evaluate': []}
83 | }
84 |
85 |
86 | all_context_vector_dicts, all_linear_coefs = [], []
87 | for run_id in range(args.config['run_num']):
88 | run_name = f'run_{run_id}'
89 | args.run_name = run_name
90 | print(f'Run time {run_name}')
91 | run_seed = args.config['seed'] + run_id
92 | utils.set_seed(run_seed)
93 |
94 | # zero-shot baseline
95 | if run_id == 0 and args.config['run_baseline']:
96 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
97 | use_cache=args.config['use_cache'])
98 | infer_result_dict['test_result']['zero_shot'].append(test_zeroshot_result)
99 | print(f'Test zero-shot result: {test_zeroshot_result}\n')
100 |
101 | if args.config['use_new_demon']:
102 | # sample demonstration
103 | count = 0
104 | temp_demon_list, temp_result_list = [], []
105 | while True:
106 | demon, split_demon, demon_data_index = \
107 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num,
108 | max_demonstration_tok_len=1e8,
109 | add_extra_query=args.config['add_extra_query'],
110 | example_separator=args.config['example_separator'],
111 | return_data_index=True, seed=random.randint(0, 1e6) + run_seed)
112 | temp_demon_list.append((demon, split_demon, demon_data_index))
113 |
114 | if args.config['demo_sample_method'] == 'random':
115 | break
116 | else:
117 | tem_val_result = holdout_evaluator.evaluate(model_wrapper, tokenizer,
118 | demonstration=demon,
119 | use_cache=args.config['use_cache'])
120 | temp_result = tem_val_result[args.metric]
121 | temp_result_list.append(temp_result)
122 | if count > 20:
123 | if args.config['demo_sample_method'] == 'deficient':
124 | demon, split_demon, demon_data_index = temp_demon_list[np.argmin(temp_result_list)]
125 | else:
126 | raise ValueError('Invalid demon_sample_method')
127 | break
128 | count += 1
129 |
130 | if args.config['add_extra_query']:
131 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0]
132 | # remove all contents after the last first_format_anchor including the anchor
133 | if first_format_anchor in demon:
134 | baseline_demon = demon[:demon.rfind(first_format_anchor)]
135 | query_demon = demon[demon.rfind(first_format_anchor):]
136 | else:
137 | baseline_demon = demon
138 | query_demon = None
139 | print(f'Demonstration:\n{demon}\n')
140 | print(f'Baseline demonstration:\n{baseline_demon}\n')
141 | print(f'Query demonstration:\n{query_demon}\n')
142 | demon_list = [demon]
143 | split_demon_list = split_demon
144 | else:
145 | demon_list = args.result_dict['demon'][run_name]
146 | demon = demon_list[0]
147 | baseline_demon = demon
148 | try:
149 | split_demon_list = args.result_dict['split_demon'][run_name]
150 | except KeyError:
151 | raise ValueError('split_demon_list not found in result_dict!')
152 |
153 | # save demon_list
154 | infer_result_dict['demon'][run_name] = demon_list
155 | infer_result_dict['split_demon'][run_name] = split_demon_list
156 |
157 | # few-shot baseline
158 | if args.config['run_baseline']:
159 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer,
160 | demonstration=baseline_demon,
161 | use_cache=args.config['use_cache'])
162 | infer_result_dict['test_result']['few_shot'].append(test_fewshot_result)
163 | print(f'Test few-shot result: {test_fewshot_result}\n')
164 |
165 | # extract latents ======================================================================
166 | all_latent_dicts = []
167 | with torch.no_grad():
168 | if not args.config['split_demon']:
169 | target_demon_list = demon_list[0]
170 | else:
171 | target_demon_list = split_demon_list
172 | for cur_demon in target_demon_list:
173 | with model_wrapper.extract_latent():
174 | demon_token = tokenizer(cur_demon, return_tensors='pt').to(args.device)
175 | _ = model(**demon_token)
176 | all_latent_dicts.append(model_wrapper.latent_dict)
177 | model_wrapper.reset_latent_dict()
178 |
179 | # generate context vector ==============================================================
180 | context_vector_dict = model_wrapper.get_context_vector(all_latent_dicts, args.config)
181 | del all_latent_dicts
182 |
183 | # get strength params ===================================================================
184 | model_wrapper.init_strength(args.config)
185 | del model_wrapper.linear_coef
186 | model_wrapper.linear_coef = torch.tensor(args.result_dict['linear_coef'][run_name])
187 |
188 | # save context_vector_dict and linear_coef for ensemble
189 | all_context_vector_dicts.append(context_vector_dict)
190 | all_linear_coefs.append(model_wrapper.linear_coef)
191 |
192 | # prepare downstream tasks
193 | if args.config['downstream_datasets'] is None:
194 | downstream_datasets = [args.dataset_name]
195 | else:
196 | downstream_datasets = args.config['downstream_datasets']
197 |
198 | for target_task_name in downstream_datasets:
199 | # init saving structure
200 | if target_task_name not in infer_result_dict['test_result']['ours']:
201 | infer_result_dict['test_result']['ours'][target_task_name] = []
202 |
203 | # prepare target dataset
204 | target_dataset = md.get_dataset(target_task_name, split='test',
205 | max_data_num=args.config['test_data_num'],
206 | sample_mode=args.config['sample_method'],
207 | seed=args.config['seed'])
208 | target_evaluator = ev.Evaluator(target_dataset, batch_size=args.config['bs'])
209 |
210 | # evaluate i2cl ========================================================================
211 | s_t = time.time()
212 | with torch.no_grad():
213 | with model_wrapper.inject_latent(context_vector_dict, args.config,
214 | model_wrapper.linear_coef):
215 | test_ours_result = target_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
216 | use_cache=args.config['use_cache'])
217 | print(f'Test I2CL result: {test_ours_result}\n')
218 | infer_result_dict['test_result']['ours'][target_task_name].append(test_ours_result)
219 | e_t = time.time()
220 |
221 | print(f'Evaluate time: {e_t - s_t}')
222 | infer_result_dict['time']['evaluate'].append(e_t - s_t)
223 |
224 | # save result_dict after each run
225 | with open(args.save_dir + '/result_dict.json', 'w') as f:
226 | json.dump(infer_result_dict, f, indent=4)
227 |
228 | # delete all variables
229 | del model, tokenizer, model_config, model_wrapper, train_dataset, test_dataset, holdout_dataset
230 | del test_evaluator, holdout_evaluator
231 | del all_context_vector_dicts, all_linear_coefs
232 | del infer_result_dict
233 |
234 |
235 | # get args
236 | def get_args():
237 | parser = argparse.ArgumentParser()
238 | parser.add_argument('--config_path', type=str, default='configs/config_i2cl_infer.py', help='path to config file')
239 | return parser.parse_args()
240 |
241 |
242 | if __name__ == "__main__":
243 | # get args
244 | args = get_args()
245 | # load config
246 | config = utils.load_config(args.config_path)
247 | # Generate all combinations of models and datasets
248 | combinations = list(itertools.product(config['models'], config['datasets']))
249 | # Queue to hold tasks
250 | task_queue = Queue()
251 | for combine in combinations:
252 | task_queue.put(combine)
253 |
254 | def run_task(gpu_id, config):
255 | while not task_queue.empty():
256 | model_name, dataset_name = task_queue.get()
257 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}")
258 | input_args = argparse.Namespace()
259 | cur_config = copy.deepcopy(config)
260 | input_args.model_name = model_name
261 | input_args.dataset_name = dataset_name
262 | input_args.gpu = gpu_id
263 | input_args.config = cur_config
264 | try:
265 | main(input_args)
266 | finally:
267 | # Clean up CUDA memory after each task
268 | gc.collect()
269 | torch.cuda.empty_cache()
270 | print(f"CUDA memory cleared for GPU {gpu_id}")
271 | time.sleep(5)
272 |
273 | # Create a process for each GPU
274 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']]
275 | # Start all processes
276 | for p in processes:
277 | p.start()
278 | # Wait for all processes to finish
279 | for p in processes:
280 | p.join()
281 | print("All tasks completed.")
--------------------------------------------------------------------------------
/run_label_anchor.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import copy
4 | import time
5 | import random
6 | import argparse
7 | import itertools
8 | import torch
9 | import torch.nn.functional as F
10 | from multiprocessing import Process, Queue
11 |
12 | import utils
13 | import my_datasets as md
14 | import evaluator as ev
15 |
16 |
17 | def cached_evaluation(config, dataset, model, tokenizer, compressed_past_key_values,
18 | compressed_attn_mask, position_offset=None):
19 | batch_size = config['bs']
20 | # prepare label dict
21 | label_map = {}
22 | ans_txt_list = dataset.get_dmonstration_template()['options']
23 | for label, ans_txt in enumerate(ans_txt_list):
24 | if 'gpt' in tokenizer.__class__.__name__.lower():
25 | ans_txt = ' ' + ans_txt # add space to the beginning of answer
26 | ans_tok = tokenizer.encode(ans_txt, add_special_tokens=False)[0] # use the first token if more than one token
27 | print(f"ans_txt: {ans_txt}, ans_tok: {ans_tok}")
28 | label_map[ans_tok] = label # index is the label
29 | print(f"label_map: {label_map}")
30 |
31 | # prepare all data
32 | all_pred_labels = []
33 | all_inputs, all_labels = [], []
34 | for data in dataset.all_data:
35 | ques_str, _, label = dataset.apply_template(data)
36 | context = ques_str
37 | all_inputs.append(context)
38 | all_labels.append(label)
39 |
40 | # prepare cached data
41 | cached_past_key_values = tuple(tuple(t.repeat(batch_size, 1, 1, 1) for t in tup)
42 | for tup in compressed_past_key_values)
43 | cached_attn_mask = compressed_attn_mask.repeat(batch_size, 1)
44 |
45 | # loop over all data
46 | with torch.no_grad():
47 | for i in range(0, len(all_inputs), batch_size):
48 | cur_inputs = all_inputs[i:i+batch_size]
49 | input_tok = tokenizer(cur_inputs, return_tensors="pt", padding=True)
50 | input_ids = input_tok['input_ids'].to(model.device)
51 | attn_mask = input_tok['attention_mask'].to(model.device)
52 |
53 | # get index for prediction logits, need to be applied before concatenating demon_attn_mask with attn_mask
54 | pred_loc = utils.last_one_indices(attn_mask).to(model.device)
55 |
56 | # get logits
57 | if position_offset is not None:
58 | position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=model.device).repeat(input_ids.size(0), 1)
59 | position_ids = position_ids + position_offset
60 | else:
61 | position_ids = None
62 |
63 | attn_mask = torch.cat([cached_attn_mask, attn_mask], dim=1)
64 | output = model(input_ids=input_ids, attention_mask=attn_mask,
65 | past_key_values=cached_past_key_values,
66 | position_ids=position_ids)
67 | # get prediction logits
68 | logits = output.logits
69 | pred_logits = logits[torch.arange(logits.size(0)), pred_loc]
70 | # get prediction labels
71 | interest_index = list(label_map.keys())
72 | pred_logits = pred_logits[:, interest_index]
73 | probs = F.softmax(pred_logits, dim=-1)
74 | pred_labels = probs.argmax(dim=-1)
75 | # save results
76 | all_pred_labels.extend(pred_labels.cpu().numpy().tolist())
77 |
78 | assert len(all_pred_labels) == len(all_labels)
79 | # both all_results and all_labels are list containing label index, can you help me to calculate accuracy and macro f1?
80 | # initialize TP, FP, FN
81 | acc = []
82 | num_classes = dataset.class_num
83 | TP = [0] * num_classes
84 | FP = [0] * num_classes
85 | FN = [0] * num_classes
86 | for i, true_label in enumerate(all_labels):
87 | pred_label = all_pred_labels[i]
88 | pred = pred_label == true_label
89 | acc.append(pred)
90 | # Update TP, FP, FN
91 | if pred:
92 | TP[true_label] += 1
93 | else:
94 | FP[pred_label] += 1
95 | FN[true_label] += 1
96 | # Calculate precision, recall, F1 for each class and macro F1
97 | precision = [0] * num_classes
98 | recall = [0] * num_classes
99 | f1 = [0] * num_classes
100 | for i in range(num_classes):
101 | precision[i] = TP[i] / (TP[i] + FP[i]) if (TP[i] + FP[i]) > 0 else 0
102 | recall[i] = TP[i] / (TP[i] + FN[i]) if (TP[i] + FN[i]) > 0 else 0
103 | f1[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0
104 | macro_f1 = sum(f1) / num_classes
105 | acc = sum(acc) / len(acc)
106 | return {'acc': acc, 'macro_f1': macro_f1}
107 |
108 |
109 | def main(args):
110 | # set global seed
111 | utils.set_seed(args.config['seed'])
112 | # set device
113 | args.device = utils.set_device(args.gpu)
114 | # set metric used
115 | args.metric = args.config['metric']
116 | # get save dir
117 | utils.init_exp_path(args, args.config['exp_name'])
118 |
119 | # load tokenizer and model
120 | model, tokenizer, model_config = \
121 | utils.load_model_tokenizer(args.model_name, args.device)
122 |
123 | # get model_wrapper
124 | model_wrapper = utils.get_model_wrapper(args.model_name, model,
125 | tokenizer, model_config,
126 | args.device)
127 | # load datasets
128 | train_dataset = md.get_dataset(args.dataset_name, split='train',
129 | max_data_num=None,
130 | seed=args.config['seed'])
131 | test_dataset = md.get_dataset(args.dataset_name, split='test',
132 | max_data_num=args.config['test_data_num'],
133 | sample_mode=args.config['sample_method'],
134 | seed=args.config['seed'])
135 |
136 | # get max demonstration token length for each dataset
137 | args.test_max_token = test_dataset.get_max_demonstration_token_length(tokenizer)
138 |
139 | # get shot_num
140 | if args.dataset_name == 'dbpedia': # always use 1-shot for dbpedia
141 | args.config['shot_per_class'] = 1
142 | args.config['bs'] = 1
143 | args.shot_num = utils.get_shot_num(train_dataset, args.config['shot_per_class'])
144 | # build evaluate
145 | test_evaluator = ev.Evaluator(test_dataset, batch_size=args.config['bs'])
146 | # init result_dict
147 | result_dict = {'demon': {},
148 | 'split_demon': {},
149 | 'test_result': {'zero_shot': [], 'few_shot': [], 'ours': []},
150 | 'time': {'calibrate': [], 'evaluate': []},
151 | }
152 |
153 | for run_id in range(args.config['run_num']):
154 | run_name = f'run_{run_id}'
155 | args.run_name = run_name
156 | print(f'Run time {run_name}')
157 | run_seed = args.config['seed'] + run_id
158 | utils.set_seed(run_seed)
159 |
160 | # zero-shot baseline
161 | if run_id == 0 and args.config['run_baseline']:
162 | test_zeroshot_result = test_evaluator.evaluate(model_wrapper, tokenizer, demonstration='',
163 | use_cache=args.config['use_cache'])
164 | result_dict['test_result']['zero_shot'].append(test_zeroshot_result)
165 | print(f'Test zero-shot result: {test_zeroshot_result}\n')
166 |
167 | # sample demonstration
168 | demon, _ = \
169 | train_dataset.gen_few_shot_demonstration(tokenizer=tokenizer, shot_num=args.shot_num,
170 | max_demonstration_tok_len=args.test_max_token,
171 | add_extra_query=args.config['add_extra_query'],
172 | example_separator=args.config['example_separator'],
173 | seed=random.randint(0, 1e6))
174 |
175 | if args.config['add_extra_query']:
176 | first_format_anchor = train_dataset.get_dmonstration_template()['format'][0]
177 | # remove all contents after the last first_format_anchor including the anchor
178 | if first_format_anchor in demon:
179 | baseline_demon = demon[:demon.rfind(first_format_anchor)]
180 | query_demon = demon[demon.rfind(first_format_anchor):]
181 | else:
182 | baseline_demon = demon
183 | query_demon = None
184 | print(f'Demonstration:\n{demon}\n')
185 | print(f'Baseline demonstration:\n{baseline_demon}\n')
186 | print(f'Query demonstration:\n{query_demon}\n')
187 |
188 | # few-shot baseline
189 | if args.config['run_baseline']:
190 | test_fewshot_result = test_evaluator.evaluate(model_wrapper, tokenizer,
191 | demonstration=baseline_demon,
192 | use_cache=args.config['use_cache'])
193 | result_dict['test_result']['few_shot'].append(test_fewshot_result)
194 | print(f'Test few-shot result: {test_fewshot_result}\n')
195 |
196 | # apply label anchor
197 | context_solver = utils.ContextSolver(task_name=args.dataset_name, tokenizer=tokenizer)
198 | demon_token = tokenizer(demon, return_tensors='pt').to(args.device)
199 | compress_attn_mask = context_solver.get_mask(demon_token['input_ids'])
200 | print(f'Compress_attn_mask: {compress_attn_mask}\n')
201 | # compressed_text
202 | compressed_token_id = copy.deepcopy(demon_token['input_ids'][0])
203 | mask = copy.deepcopy(compress_attn_mask).cpu().detach().numpy()
204 | compressed_token_id = compressed_token_id.cpu().detach().numpy()
205 | compressed_text = tokenizer.decode(list(compressed_token_id[mask]))
206 | print(f'Compressed_text: {compressed_text}\n')
207 | with torch.no_grad():
208 | demon_outputs = model(**demon_token, use_cache=True)
209 | past_key_values = demon_outputs.past_key_values
210 |
211 | if args.model_name == 'meta-llama/Llama-2-7b-hf':
212 | mask_end_idx = torch.where(compress_attn_mask)[0][-1] + 1
213 | cached_past_key_values = tuple(tuple(t[:, :, :mask_end_idx, :] for t in tup) for tup in past_key_values)
214 | cached_attn_mask = copy.deepcopy(compress_attn_mask)[:mask_end_idx].unsqueeze(0)
215 | else:
216 | cached_past_key_values = tuple(tuple(t[:, :, compress_attn_mask, :] for t in tup) for tup in past_key_values)
217 | cached_attn_mask = torch.ones(1, compress_attn_mask.sum(), dtype=torch.bool).to(args.device)
218 | print(f'Cached_attn_mask: {cached_attn_mask}\n')
219 |
220 | if args.model_name == 'gpt2-xl':
221 | position_offset = 0
222 | elif args.model_name == 'EleutherAI/gpt-j-6B':
223 | position_offset = torch.where(compress_attn_mask)[0][-1] + 1
224 | elif args.model_name == 'meta-llama/Llama-2-7b-hf':
225 | position_offset = None
226 | else:
227 | raise ValueError('model not supported')
228 |
229 | # evaluate with label anchor
230 | s_t = time.time()
231 | test_result = cached_evaluation(args.config, test_dataset, model, tokenizer,
232 | cached_past_key_values, cached_attn_mask, position_offset)
233 | print(f'Test label_anchor result: {test_result}\n')
234 | result_dict['test_result']['ours'].append(test_result)
235 | e_t = time.time()
236 | result_dict['time']['evaluate'].append(e_t - s_t)
237 |
238 | # save result_dict after each run
239 | with open(args.save_dir + '/result_dict.json', 'w') as f:
240 | json.dump(result_dict, f, indent=4)
241 |
242 | # delete all variables
243 | del model_wrapper, model, tokenizer, train_dataset, test_dataset
244 | del test_evaluator
245 | del result_dict
246 |
247 |
248 | # get args
249 | def get_args():
250 | parser = argparse.ArgumentParser()
251 | parser.add_argument('--config_path', type=str, default='configs/config_label_anchor.py', help='path to config file')
252 | return parser.parse_args()
253 |
254 |
255 | if __name__ == "__main__":
256 | # get args
257 | args = get_args()
258 | # load config
259 | config = utils.load_config(args.config_path)
260 | # Generate all combinations of models and datasets
261 | combinations = list(itertools.product(config['models'], config['datasets']))
262 | # Queue to hold tasks
263 | task_queue = Queue()
264 | for combine in combinations:
265 | task_queue.put(combine)
266 |
267 | def run_task(gpu_id, config):
268 | while not task_queue.empty():
269 | model_name, dataset_name = task_queue.get()
270 | print(f"Running {model_name} on {dataset_name} with GPU {gpu_id}")
271 | input_args = argparse.Namespace()
272 | cur_config = copy.deepcopy(config)
273 | input_args.model_name = model_name
274 | input_args.dataset_name = dataset_name
275 | input_args.gpu = gpu_id
276 | input_args.config = cur_config
277 | try:
278 | main(input_args)
279 | finally:
280 | # Clean up CUDA memory after each task
281 | gc.collect()
282 | torch.cuda.empty_cache()
283 | print(f"CUDA memory cleared for GPU {gpu_id}")
284 | time.sleep(5)
285 |
286 | # Create a process for each GPU
287 | processes = [Process(target=run_task, args=(gpu_id, config)) for gpu_id in config['gpus']]
288 | # Start all processes
289 | for p in processes:
290 | p.start()
291 | # Wait for all processes to finish
292 | for p in processes:
293 | p.join()
294 | print("All tasks completed.")
--------------------------------------------------------------------------------
/wrapper.py:
--------------------------------------------------------------------------------
1 | import math
2 | import string
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from contextlib import contextmanager
8 | from functools import reduce
9 | import numpy as np
10 | import utils
11 | import global_vars as gv
12 | from peft import get_peft_model, PromptTuningConfig
13 |
14 |
15 | class ModelWrapper(nn.Module):
16 | def __init__(self, model, tokenizer, model_config, device):
17 | super().__init__()
18 | self.model = model.eval()
19 | self.tokenizer = tokenizer
20 | self.model_config = model_config
21 | self.device = device
22 | self.num_layers = self._get_layer_num()
23 | self.latent_dict = {}
24 | self.linear_coef = None
25 | self.inject_layers = None
26 | print(f"The model has {self.num_layers} layers:")
27 |
28 | def reset_latent_dict(self):
29 | self.latent_dict = {}
30 |
31 | @contextmanager
32 | def extract_latent(self):
33 | handles = []
34 | try:
35 | # attach hook
36 | for layer_idx in range(self.num_layers):
37 | handles.append(
38 | self._get_nested_attr(self._get_arribute_path(layer_idx, 'attn')).register_forward_hook(
39 | self.extract_hook_func(layer_idx, 'attn')))
40 | handles.append(
41 | self._get_nested_attr(self._get_arribute_path(layer_idx, 'mlp')).register_forward_hook(
42 | self.extract_hook_func(layer_idx, 'mlp')))
43 | handles.append(
44 | self._get_nested_attr(self._get_arribute_path(layer_idx, 'hidden')).register_forward_hook(
45 | self.extract_hook_func(layer_idx, 'hidden')))
46 | yield
47 | finally:
48 | # remove hook
49 | for handle in handles:
50 | handle.remove()
51 |
52 | def extract_hook_func(self, layer_idx, target_module):
53 | if layer_idx not in self.latent_dict:
54 | self.latent_dict[layer_idx] = {}
55 | def hook_func(module, inputs, outputs):
56 | if type(outputs) is tuple:
57 | outputs = outputs[0]
58 | self.latent_dict[layer_idx][target_module] = outputs.detach().cpu()
59 | return hook_func
60 |
61 | @contextmanager
62 | def inject_latent(self, context_vector_dict, config, linear_coef, train_mode=False):
63 | handles = []
64 | assert self.inject_layers is not None, "inject_layers is not set!"
65 | inject_method = config['inject_method']
66 | inject_pos = config['inject_pos']
67 | add_noise = config['add_noise']
68 | noise_scale = config['noise_scale']
69 | try:
70 | # attach hook
71 | for layer_idx, layer in enumerate(self.inject_layers):
72 | for module_idx, module in enumerate(config['module']):
73 | context_vector_container = [context_vector_dict[layer][module].to(self.device)]
74 | strength = linear_coef[layer_idx, module_idx, :]
75 | inject_func = self.inject_hook_func(context_vector_container, strength,
76 | inject_method, add_noise, noise_scale,
77 | inject_pos, train_mode)
78 | handles.append(
79 | self._get_nested_attr(self._get_arribute_path(layer, module)).
80 | register_forward_hook(inject_func)
81 | )
82 | yield
83 | finally:
84 | # remove hook
85 | print(f"Removing {len(handles)} hooks...")
86 | for handle in handles:
87 | handle.remove()
88 |
89 | def inject_hook_func(self, context_vector_container, strength, inject_method,
90 | add_noise, noise_scale, inject_pos, train_mode=False):
91 |
92 | def hook_func(module, inputs, outputs):
93 | if type(outputs) is tuple:
94 | output = outputs[0]
95 | else:
96 | output = outputs
97 | # init context_vector
98 | context_vector = context_vector_container[0]
99 | # expand inject_value to match output size (b, seq_len, d)
100 | context_vector = context_vector.expand(output.size(0), output.size(1), context_vector.size(-1))
101 |
102 | if inject_method == 'add':
103 | output = output + F.relu(strength) * context_vector
104 | elif inject_method == 'linear':
105 | if inject_pos == 'all':
106 | output = strength[1] * output + strength[0] * context_vector
107 | else:
108 | if inject_pos == 'last':
109 | for i in range(output.size(0)):
110 | end_idx = gv.ATTN_MASK_END[i] - 1
111 | content = strength[1] * output[i, end_idx, :].clone().detach() + strength[0] * context_vector[i, end_idx, :]
112 | output[i, end_idx, :] = content
113 | elif inject_pos == 'first':
114 | content = strength[1] * output[:, 0, :].clone().detach() + strength[0] * context_vector[:, 0, :]
115 | output[:, 0, :] = content
116 | elif inject_pos == 'random':
117 | for i in range(output.size(0)):
118 | end_idx = gv.ATTN_MASK_END[i]
119 | random_idx = random.randint(0, end_idx)
120 | content = strength[1] * output[i, random_idx, :].clone().detach() + strength[0] * context_vector[i, random_idx, :]
121 | output[i, random_idx, :] = content
122 | else:
123 | raise ValueError("only support all, last, first or random!")
124 |
125 | elif inject_method == 'balance':
126 | a, b = strength[0], strength[1]
127 | output = ((1.0 - a) * output + a * context_vector) * b
128 | else:
129 | raise ValueError("only support add, linear or balance!")
130 |
131 | if add_noise and train_mode:
132 | # get l2_norm of output and use it as a scalar to scale noise, make sure no gradient
133 | output_norm = torch.norm(output, p=2, dim=-1).detach().unsqueeze(-1)
134 | noise = torch.randn_like(output).detach()
135 | output += noise * output_norm * noise_scale
136 |
137 | if type(outputs) is tuple:
138 | outputs = list(outputs)
139 | outputs[0] = output
140 | outputs = tuple(outputs)
141 | else:
142 | outputs = output
143 | return outputs
144 | return hook_func
145 |
146 |
147 | @contextmanager
148 | def replace_latent(self, context_vector_dict, target_layers, config):
149 | handles = []
150 | try:
151 | # attach hook
152 | for _, layer in enumerate(target_layers):
153 | for _, module in enumerate(config['module']):
154 | context_vector_container = [context_vector_dict[layer][module].to(self.device)]
155 | inject_func = self.replace_hook_func(context_vector_container)
156 | handles.append(
157 | self._get_nested_attr(self._get_arribute_path(layer, module)).
158 | register_forward_hook(inject_func))
159 | yield
160 | finally:
161 | # remove hook
162 | print(f"Removing {len(handles)} hooks...")
163 | for handle in handles:
164 | handle.remove()
165 |
166 | def replace_hook_func(self, context_vector_container):
167 | def hook_func(module, inputs, outputs):
168 | if type(outputs) is tuple:
169 | output = outputs[0]
170 | else:
171 | output = outputs
172 | # init context_vector
173 | context_vector = context_vector_container[0]
174 | # replace hidden states of last token position with context_vector
175 | for i in range(output.size(0)):
176 | end_idx = gv.ATTN_MASK_END[i]
177 | output[i, end_idx, :] = context_vector
178 |
179 | if type(outputs) is tuple:
180 | outputs = list(outputs)
181 | outputs[0] = output
182 | outputs = tuple(outputs)
183 | else:
184 | outputs = output
185 | return outputs
186 | return hook_func
187 |
188 |
189 | def get_context_vector(self, all_latent_dicts, config):
190 | if len(all_latent_dicts) == 1:
191 | latent_dict = all_latent_dicts[0]
192 | output_dict = {}
193 | for layer, sub_dict in latent_dict.items():
194 | output_dict[layer] = {}
195 | for module in config['module']:
196 | latent_value = sub_dict[module]
197 | if config['tok_pos'] == 'last':
198 | latent_value = latent_value[:, -1, :].squeeze()
199 | elif config['tok_pos'] == 'first':
200 | latent_value = latent_value[:, 0, :].squeeze()
201 | elif config['tok_pos'] == 'random':
202 | latent_value = latent_value[:, random.randint(0, latent_value.size(1)), :].squeeze()
203 | else:
204 | raise ValueError("only support last, first or random!")
205 | output_dict[layer][module] = latent_value.detach().to('cpu')
206 | else:
207 | # concatenate context vector for each module
208 | ensemble_dict = {module:[] for module in config['module']} # {module_name: []}
209 | for _, latent_dict in enumerate(all_latent_dicts):
210 | cur_dict = {module:[] for module in config['module']} # {module_name: []}
211 | for layer, sub_dict in latent_dict.items():
212 | for module in config['module']:
213 | latent_value = sub_dict[module] # (b, seq_len, d)
214 | if config['tok_pos'] == 'last':
215 | latent_value = latent_value[:, -1, :].squeeze()
216 | elif config['tok_pos'] == 'first':
217 | latent_value = latent_value[:, 0, :].squeeze()
218 | elif config['tok_pos'] == 'random':
219 | latent_value = latent_value[:, random.randint(0, latent_value.size(1)), :].squeeze()
220 | else:
221 | raise ValueError("only support last, first or random!")
222 | cur_dict[module].append(latent_value)
223 |
224 | for module, latent_list in cur_dict.items():
225 | cur_latent = torch.stack(latent_list, dim=0) # (layer_num, d)
226 | ensemble_dict[module].append(cur_latent)
227 |
228 | for module, latent_list in ensemble_dict.items():
229 | if config['post_fuse_method'] == 'mean':
230 | context_vector = torch.stack(latent_list, dim=0).mean(dim=0) # (layer_num, d)
231 | ensemble_dict[module] = context_vector
232 | elif config['post_fuse_method'] == 'pca':
233 | latents = torch.stack(latent_list, dim=0) # (ensemble_num, layer_num, d)
234 | ensemble_num, layer_num, d = latents.size()
235 | latents = latents.view(ensemble_num, -1) # (ensemble_num*layer_num, d)
236 | # apply pca
237 | pca = utils.PCA(n_components=1).to(latents.device).fit(latents.float())
238 | context_vector = (pca.components_.sum(dim=0, keepdim=True) + pca.mean_).mean(0)
239 | ensemble_dict[module] = context_vector.view(layer_num, d) # (layer_num, d)
240 | else:
241 | raise ValueError("Unsupported ensemble method!")
242 | # reorganize ensemble_dict into layers
243 | layers = list(all_latent_dicts[0].keys())
244 | output_dict = {layer: {} for layer in layers}
245 | for module, context_vector in ensemble_dict.items():
246 | for layer_idx, layer in enumerate(layers):
247 | output_dict[layer][module] = context_vector[layer_idx, :].detach().to('cpu') # (d)
248 |
249 | return output_dict
250 |
251 |
252 | def calibrate_strength(self, context_vector_dict, dataset, config,
253 | save_dir=None, run_name=None):
254 | # prepare label dict
255 | label_map = {}
256 | ans_txt_list = dataset.get_dmonstration_template()['options']
257 | for label, ans_txt in enumerate(ans_txt_list):
258 | if 'gpt' in self.tokenizer.__class__.__name__.lower():
259 | ans_txt = ' ' + ans_txt # add space to the beginning of answer
260 | ans_tok = self.tokenizer.encode(ans_txt, add_special_tokens=False)[0] # use the first token if more than one token
261 | print(f"ans_txt: {ans_txt}, ans_tok: {ans_tok}")
262 | label_map[label] = ans_tok # index is the label
263 | print(f"label_map: {label_map}")
264 |
265 | # frozen all parameters
266 | for param in self.model.parameters():
267 | param.requires_grad = False
268 |
269 | # init optimizer
270 | optim_paramters = [{'params': self.linear_coef}]
271 | if config['optim'] == 'sgd':
272 | optimizer = torch.optim.SGD(optim_paramters, lr=config['lr'],
273 | weight_decay=config['wd'])
274 | elif config['optim'] == 'adamW':
275 | optimizer = torch.optim.AdamW(optim_paramters, config['lr'],
276 | weight_decay=config['wd'])
277 | elif config['optim'] == 'adam':
278 | optimizer = torch.optim.Adam(optim_paramters, config['lr'])
279 | else:
280 | raise ValueError('optim must be sgd, adamW or adam!')
281 |
282 | # get all_data
283 | all_data = dataset.all_data
284 |
285 | # init lr_scheduler
286 | epochs, batch_size = config['epochs'], config['grad_bs']
287 | total_steps = epochs * len(all_data) // batch_size
288 | warmup_steps = int((0.05*epochs) * (len(all_data) // batch_size))
289 | lr_lambda = lambda step: min(1.0, step / warmup_steps) * (1 + math.cos(math.pi * step / total_steps)) / 2 \
290 | if step > warmup_steps else step / warmup_steps
291 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
292 |
293 | # train
294 | print('Calibrating strength params...')
295 | with self.inject_latent(context_vector_dict, config,
296 | self.linear_coef, train_mode=True):
297 | loss_list = []
298 | all_data_index = list(range(len(all_data)))
299 | epoch_iter = len(all_data) // batch_size
300 | for _ in range(epochs):
301 | epoch_loss = []
302 | for i in range(epoch_iter):
303 | np.random.shuffle(all_data_index)
304 | batch_index = all_data_index[:batch_size]
305 | batch_data = [all_data[idx] for idx in batch_index]
306 | batch_input, batch_label = [], []
307 | for data in batch_data:
308 | input_str, ans_list, label = dataset.apply_template(data)
309 |
310 | # collect single demonstration example
311 | if config['cali_example_method'] == 'normal':
312 | pass
313 | elif config['cali_example_method'] == 'random_label':
314 | label = random.choice(list(range(len(ans_list))))
315 | else:
316 | raise ValueError("only support normal or random_label!")
317 |
318 | batch_input.append(input_str)
319 | batch_label.append(label)
320 |
321 | input_tok = self.tokenizer(batch_input, return_tensors='pt', padding=True)
322 | input_ids = input_tok['input_ids'].to(self.device)
323 | attn_mask = input_tok['attention_mask'].to(self.device)
324 | pred_loc = utils.last_one_indices(attn_mask).to(self.device)
325 | # set global vars
326 | gv.ATTN_MASK_END = pred_loc
327 | gv.ATTN_MASK_START = torch.zeros_like(pred_loc)
328 | # forward
329 | logits = self.model(input_ids=input_ids, attention_mask=attn_mask).logits
330 | # get prediction logits
331 | pred_logits = logits[torch.arange(logits.size(0)), pred_loc]
332 | # get loss
333 | gt_label = torch.tensor([label_map[label] for label in batch_label]).to(self.device)
334 | loss = F.cross_entropy(pred_logits, gt_label, reduction='mean')
335 | epoch_loss.append(loss.item())
336 | # update strength params
337 | optimizer.zero_grad()
338 | loss.backward()
339 | optimizer.step()
340 | scheduler.step()
341 | cur_lr = optimizer.param_groups[0]['lr']
342 | print(f'Epoch {_+1}/{epochs}, batch {i//batch_size+1}/{len(all_data)//batch_size+1}, loss: {loss.item()}, lr: {cur_lr}')
343 | epoch_loss = np.mean(epoch_loss)
344 | loss_list.append(epoch_loss)
345 |
346 | # fronzen all learnable strength params
347 | self.linear_coef.requires_grad = False
348 | # set model to eval mode
349 | self.model.eval()
350 | # plot loss curve and save it
351 | utils.plot_loss_curve(loss_list, save_dir + f'/{run_name}_loss_curve.png')
352 |
353 |
354 | def softprompt(self, config, dataset, save_dir=None, run_name=None):
355 | pt_config = PromptTuningConfig(**config['pt_config'])
356 | peft_model = get_peft_model(self.model, pt_config)
357 |
358 | # prepare label dict
359 | label_map = {}
360 | ans_txt_list = dataset.get_dmonstration_template()['options']
361 | for label, ans_txt in enumerate(ans_txt_list):
362 | if 'gpt' in self.tokenizer.__class__.__name__.lower():
363 | ans_txt = ' ' + ans_txt # add space to the beginning of answer
364 | ans_tok = self.tokenizer.encode(ans_txt, add_special_tokens=False)[0] # use the first token if more than one token
365 | print(f"ans_txt: {ans_txt}, ans_tok: {ans_tok}")
366 | label_map[label] = ans_tok # index is the label
367 | print(f"label_map: {label_map}")
368 |
369 | # print trainable parameters
370 | peft_model.print_trainable_parameters()
371 | print(f'PEFT model:\n {peft_model}')
372 | # set model to peft model
373 | self.model = peft_model
374 |
375 | # init optimizer
376 | optim_paramters = [{'params': self.model.parameters()}]
377 | if config['optim'] == 'sgd':
378 | optimizer = torch.optim.SGD(optim_paramters, lr=config['lr'],
379 | weight_decay=config['wd'])
380 | elif config['optim'] == 'adamW':
381 | optimizer = torch.optim.AdamW(optim_paramters, config['lr'],
382 | weight_decay=config['wd'])
383 | elif config['optim'] == 'adam':
384 | optimizer = torch.optim.Adam(optim_paramters, config['lr'])
385 | else:
386 | raise ValueError('optim must be sgd, adamW or adam!')
387 |
388 | # get all data
389 | all_data = dataset.all_data
390 |
391 | # init lr_scheduler
392 | epochs, batch_size = config['epochs'], config['grad_bs']
393 | total_steps = epochs * len(all_data) // batch_size
394 | warmup_steps = int((0.05*epochs) * (len(all_data) // batch_size))
395 | lr_lambda = lambda step: min(1.0, step / warmup_steps) * (1 + math.cos(math.pi * step / total_steps)) / 2 \
396 | if step > warmup_steps else step / warmup_steps
397 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
398 |
399 | # train
400 | loss_list = []
401 | all_data_index = list(range(len(all_data)))
402 | for _ in range(epochs):
403 | epoch_loss = []
404 | np.random.shuffle(all_data_index)
405 | for i in range(0, len(all_data), batch_size):
406 | batch_index = all_data_index[i: i + batch_size]
407 | batch_data = [all_data[idx] for idx in batch_index]
408 | batch_input, batch_label = [], []
409 | for data in batch_data:
410 | input_str, _, label = dataset.apply_template(data)
411 | batch_input.append(input_str)
412 | batch_label.append(label)
413 |
414 | input_tok = self.tokenizer(batch_input, return_tensors='pt', padding=True)
415 | input_ids = input_tok['input_ids'].to(self.device)
416 | attn_mask = input_tok['attention_mask'].to(self.device)
417 | pred_loc = utils.last_one_indices(attn_mask).to(self.device)
418 | # forward
419 | logits = self.model(input_ids=input_ids, attention_mask=attn_mask).logits
420 | # get prediction logits
421 | pred_logits = logits[torch.arange(logits.size(0)), pred_loc]
422 | # get loss
423 | gt_label = torch.tensor([label_map[label] for label in batch_label]).to(self.device)
424 | loss = F.cross_entropy(pred_logits, gt_label, reduction='mean')
425 | epoch_loss.append(loss.item())
426 | # update strength params
427 | optimizer.zero_grad()
428 | loss.backward()
429 | optimizer.step()
430 | scheduler.step()
431 | epoch_loss = np.mean(epoch_loss)
432 | loss_list.append(epoch_loss)
433 |
434 | # fronzen all learnable strength params
435 | for param in self.model.parameters():
436 | param.requires_grad = False
437 | # set model to eval mode
438 | self.model.eval()
439 | # plot loss curve and save it
440 | utils.plot_loss_curve(loss_list, save_dir + f'/{run_name}_loss_curve.png')
441 |
442 |
443 | def init_strength(self, config):
444 | # get linear_coef size
445 | if type(config['layer']) == str:
446 | if config['layer'] == 'all':
447 | layers = list(range(self.num_layers))
448 | layer_dim = len(layers)
449 | elif config['layer'] == 'late':
450 | layers = list(range((self.num_layers*2)//3, self.num_layers))
451 | layer_dim = len(layers)
452 | elif config['layer'] == 'early':
453 | layers = list(range(self.num_layers//3))
454 | layer_dim = len(layers)
455 | elif config['layer'] == 'mid':
456 | layers = list(range(self.num_layers//3, (self.num_layers*2)//3))
457 | layer_dim = len(layers)
458 | elif type(config['layer']) == list:
459 | layers = config['layer']
460 | layer_dim = len(layers)
461 | else:
462 | raise ValueError("layer must be all, late, early, mid or a list of layer index!")
463 |
464 | if config['inject_method'] == 'add':
465 | param_size = (layer_dim, len(config['module']), 1) # (layer_num, module_num, 1)
466 | elif config['inject_method'] in ['linear', 'balance']:
467 | param_size = (layer_dim, len(config['module']), 2) # (layer_num, module_num, 2)
468 | else:
469 | raise ValueError("only support add, linear or balance!")
470 | # set inject_layers
471 | self.inject_layers = layers
472 | # init linear_coef
473 | linear_coef = torch.zeros(param_size, device=self.device)
474 | linear_coef += torch.tensor(config['init_value'], device=self.device)
475 | self.linear_coef = nn.Parameter(linear_coef)
476 | print(f"linear_coef shape: {self.linear_coef.shape}\n")
477 | if not self.linear_coef.is_leaf:
478 | raise ValueError("linear_coef is not a leaf tensor, which is required for optimization.")
479 |
480 |
481 | def init_noise_context_vector(self, context_vector_dict):
482 | # init learnable context_vector
483 | for layer, sub_dict in context_vector_dict.items():
484 | for module, latent in sub_dict.items():
485 | noise_vector = torch.randn_like(latent).detach().cpu()
486 | context_vector_dict[layer][module] = noise_vector
487 | return context_vector_dict
488 |
489 |
490 | def _get_nested_attr(self, attr_path):
491 | """
492 | Accesses nested attributes of an object based on a dot-separated string path.
493 |
494 | :param obj: The object (e.g., a model).
495 | :param attr_path: A dot-separated string representing the path to the nested attribute.
496 | For example, 'transformer.h' or 'model.layers'.
497 | :return: The attribute at the specified path.
498 | """
499 | try:
500 | return reduce(getattr, attr_path.split('.'), self.model)
501 | except AttributeError:
502 | raise AttributeError(f"Attribute path '{attr_path}' not found.")
503 |
504 | def _get_layer_num(self):
505 | raise NotImplementedError("Please implement get_layer_num function for each model!")
506 |
507 | def _get_arribute_path(self, layer_idx, target_module):
508 | raise NotImplementedError("Please implement get_arribute_path function for each model!")
509 |
510 |
511 | class LlamaWrapper(ModelWrapper):
512 | def __init__(self, model, tokenizer, model_config, device):
513 | super().__init__(model, tokenizer, model_config, device)
514 | self.embed_matrix = self.model.model.embed_tokens.weight.data
515 | self.embed_dim = self.model_config.hidden_size
516 | self.last_norm = self.model.model.norm
517 |
518 | def _get_layer_num(self):
519 | return len(self.model.model.layers)
520 |
521 | def _get_arribute_path(self, layer_idx, target_module):
522 | if target_module == "attn":
523 | return f"model.layers.{layer_idx}.self_attn"
524 | elif target_module == "mlp":
525 | return f"model.layers.{layer_idx}.mlp"
526 | elif target_module == "hidden":
527 | return f"model.layers.{layer_idx}"
528 | else:
529 | raise ValueError("only support att or mlp!")
530 |
531 |
532 | class GPTWrapper(ModelWrapper):
533 | def __init__(self, model, tokenizer, model_config, device):
534 | super().__init__(model, tokenizer, model_config, device)
535 | self.embed_matrix = self.model.transformer.wte.weight.data
536 | self.embed_dim = self.embed_matrix.size(-1)
537 | self.last_norm = self.model.transformer.ln_f
538 |
539 | def _get_layer_num(self):
540 | return len(self.model.transformer.h)
541 |
542 | def _get_arribute_path(self, layer_idx, target_module):
543 | if target_module == "attn":
544 | return f"transformer.h.{layer_idx}.attn"
545 | elif target_module == "mlp":
546 | return f"transformer.h.{layer_idx}.mlp"
547 | elif target_module == "hidden":
548 | return f"transformer.h.{layer_idx}"
549 | else:
550 | raise ValueError("only support att or mlp!")
--------------------------------------------------------------------------------