├── .DS_Store
├── OPPU.py
├── README.md
├── __pycache__
└── utils.cpython-310.pyc
├── asset
├── overview.png
└── teaser.png
├── data
└── .DS_Store
├── eval
├── __init__.py
├── __pycache__
│ ├── evaluation.cpython-310.pyc
│ └── evaluation.cpython-39.pyc
├── eval_all.py
├── eval_task.py
└── evaluation.py
├── gen_profile.py
├── prompt
├── prompt.json
└── prompt_profile.json
├── requirements.txt
├── task_LoRA.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/.DS_Store
--------------------------------------------------------------------------------
/OPPU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import bitsandbytes as bnb
4 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
5 | # from transformers import pipeline, BitsAndBytesConfig
6 | import argparse
7 | from rank_bm25 import BM25Okapi
8 | # from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
9 | import transformers
10 | from utils import split_batch, get_first_k_tokens, print_trainable_parameters, name2taskid
11 | from utils import extract_citation_title, extract_option, extract_movie, extract_news_cat, extract_news_headline, extract_product_review, extract_scholarly_title, extract_tweet_paraphrasing
12 | import json
13 | from tqdm import tqdm
14 | from peft import LoraConfig, get_peft_model, PeftModel
15 |
16 |
17 | parser = argparse.ArgumentParser(description="Parser for LoRA")
18 | parser.add_argument('--model_name', type=str, default='meta-llama/Llama-2-7b-hf')
19 | parser.add_argument('--batch_size', type=int, default=16)
20 | parser.add_argument('--k', type=int, default=0)
21 | parser.add_argument('--max_step', type=int, default=5000)
22 | parser.add_argument('--cut_off', type=int, default=2048)
23 | parser.add_argument('--max_epoch', type=int, default=2)
24 | parser.add_argument('--temperature', type=float, default=0.1)
25 | parser.add_argument('--task_name', type=str, default='movie_tagging')
26 | parser.add_argument('--add_profile', action='store_true')
27 | parser.add_argument('--task_lora', type=str, default='./ckpt/movie_tagging/k1-movie_tagging-Llama-2-7b-hf-task_LoRA_ckpt')
28 | parser.add_argument('--access_token', type=str, default=None)
29 |
30 | args = parser.parse_args()
31 | model_name = args.model_name
32 | task_name = args.task_name
33 | batch_size = args.batch_size
34 | k = args.k
35 | # max_step = args.max_step
36 | cutoff_len = args.cut_off
37 | add_eos_token = False
38 | max_epoch = args.max_epoch
39 |
40 | # # 4 bit quantization inference
41 | # bnb_config = BitsAndBytesConfig(
42 | # load_in_4bit=True,
43 | # bnb_4bit_quant_type="nf4",
44 | # bnb_4bit_compute_dtype=torch.float16,
45 | # bnb_4bit_use_double_quant=True,
46 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
47 | # )
48 |
49 | # 8-bit quantization inference
50 | # bnb_config = BitsAndBytesConfig(
51 | # load_in_8bit=True,
52 | # bnb_8bit_quant_type="nf8",
53 | # bnb_8bit_compute_dtype=torch.float16,
54 | # bnb_8bit_use_double_quant=True,
55 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
56 | # )
57 |
58 | # 16-bit quantization inference
59 | # bnb_config = BitsAndBytesConfig(
60 | # load_in_16bit=True,
61 | # bnb_16bit_quant_type="bf16",
62 | # bnb_16bit_compute_dtype=torch.bfloat16,
63 | # bnb_16bit_use_double_quant=True,
64 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
65 | # )
66 |
67 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=args.access_token)
68 | tokenizer.eos_token = ""
69 | tokenizer.pad_token = '[PAD]'
70 | # tokenizer.pad_token = tokenizer.eos_token
71 | tokenizer.pad_token_id = tokenizer.eos_token_id
72 |
73 |
74 | base_model = AutoModelForCausalLM.from_pretrained(
75 | model_name,
76 | # quantization_config=bnb_config,
77 | local_files_only=False,
78 | device_map='auto',
79 | trust_remote_code=True,
80 | torch_dtype=torch.bfloat16
81 | )
82 |
83 | base_model.config.use_cache = False
84 | base_model.config.pad_token_id = tokenizer.pad_token_id
85 | base_model.config.eos_token_id = tokenizer.eos_token_id
86 | base_model.config.bos_token_id = tokenizer.bos_token_id
87 |
88 |
89 | from peft import prepare_model_for_kbit_training
90 |
91 | base_model.gradient_checkpointing_enable()
92 | base_model = prepare_model_for_kbit_training(base_model)
93 |
94 |
95 |
96 | from peft import LoraConfig, get_peft_model
97 |
98 | peft_config = LoraConfig(
99 | r=8,
100 | lora_alpha=8,
101 | target_modules=["q_proj", "v_proj"], # , "k_proj", "out_proj"
102 | lora_dropout=0.05,
103 | bias="none",
104 | task_type="CAUSAL_LM",
105 | )
106 |
107 | training_arguments = transformers.TrainingArguments(
108 | output_dir='outputs/',
109 | per_device_train_batch_size=batch_size,
110 | gradient_accumulation_steps=1,
111 | optim='adamw_torch',
112 | num_train_epochs=max_epoch,
113 | save_steps=1e9,
114 | logging_steps=50,
115 | learning_rate=1e-4,
116 | weight_decay=1e-2,
117 | bf16=True,
118 | max_grad_norm=0.3,
119 | # max_steps=max_step,
120 | warmup_ratio=0.1,
121 | group_by_length=True,
122 | lr_scheduler_type='linear',
123 | report_to='none',
124 | )
125 |
126 |
127 | with open(f"./data/{task_name}/user_top_100_history.json", 'r') as f:
128 | test_data = json.load(f)
129 |
130 | format_flag = False
131 | if args.task_name == "movie_tagging":
132 | extract_article = extract_movie
133 | format_flag = True
134 | elif args.task_name == "news_categorize":
135 | extract_article = extract_news_cat
136 | format_flag = True
137 | elif args.task_name == "news_headline":
138 | extract_article = extract_news_headline
139 | format_flag = True
140 | elif args.task_name == "product_rating":
141 | extract_article = extrat_product_review
142 | format_flag = True
143 | elif args.task_name == "scholarly_title":
144 | extract_article = extract_scholarly_title
145 | format_flag = True
146 | elif args.task_name == "tweet_paraphrase":
147 | extract_article = extrat_tweet_paraphrasing
148 |
149 |
150 | with open('./prompt/prompt.json', 'r') as f:
151 | prompt_template = json.load(f)
152 |
153 |
154 | if args.add_profile:
155 | with open(f'./data/{task_name}/profile_user_100.json', 'r') as f:
156 | test_profile = json.load(f)
157 |
158 |
159 | def tokenize(prompt, add_eos_token=True):
160 | # there's probably a way to do this with the tokenizer settings
161 | # but again, gotta move fast
162 | result = tokenizer(
163 | prompt,
164 | truncation=True,
165 | max_length=cutoff_len,
166 | padding=False,
167 | return_tensors=None,
168 | )
169 | if (
170 | result["input_ids"][-1] != tokenizer.eos_token_id
171 | and len(result["input_ids"]) < cutoff_len
172 | and add_eos_token
173 | ):
174 | result["input_ids"].append(tokenizer.eos_token_id)
175 | result["attention_mask"].append(1)
176 |
177 | result["labels"] = result["input_ids"].copy()
178 |
179 | return result
180 |
181 |
182 | def generate_and_tokenize_prompt(data_point):
183 | full_prompt = data_point['full_prompt']
184 | tokenized_full_prompt = tokenize(full_prompt)
185 | # if not train_on_inputs:
186 | user_prompt = data_point['prompt']
187 |
188 | tokenized_user_prompt = tokenize(
189 | user_prompt, add_eos_token=add_eos_token
190 | )
191 | user_prompt_len = len(tokenized_user_prompt["input_ids"])
192 |
193 | if add_eos_token:
194 | user_prompt_len -= 1
195 |
196 | tokenized_full_prompt["labels"] = [
197 | -100
198 | ] * user_prompt_len + tokenized_full_prompt["labels"][
199 | user_prompt_len:
200 | ] # could be sped up, probably
201 | return tokenized_full_prompt
202 |
203 |
204 |
205 | # training
206 | from datasets import load_dataset, Dataset
207 | model = PeftModel.from_pretrained(model=base_model, model_id=args.task_lora, is_trainable=False)
208 | base_model = model.merge_and_unload()
209 | print_trainable_parameters(model)
210 |
211 |
212 | pred_all = []
213 | actual = []
214 | train_data = []
215 |
216 | for i in tqdm(range(len(test_data))):
217 | model = get_peft_model(base_model, peft_config)
218 | print_trainable_parameters(model)
219 |
220 | if args.add_profile:
221 | profile = test_profile[i]['output']
222 |
223 | for idx, q in enumerate(test_data[i]['profile']):
224 | for key, value in q.items():
225 | q[key] = get_first_k_tokens(q[key], 768)
226 |
227 | prompt = prompt_template[args.task_name]['OPPU_input'].format(**q)
228 | full_prompt = prompt_template[args.task_name]['OPPU_full'].format(**q)
229 |
230 | if k > 0 and idx != 0 and format_flag==True:
231 | visible_history_list = test_data[i]['profile'][:idx]
232 |
233 | for p in visible_history_list:
234 | for key, value in p.items():
235 | p[key] = get_first_k_tokens(p[key], 768)
236 |
237 | history_list = [prompt_template[args.task_name]['retrieval_history'].format(**p) for p in visible_history_list]
238 | tokenized_corpus = [doc.split(" ") for doc in history_list]
239 | bm25 = BM25Okapi(tokenized_corpus)
240 |
241 | tokenized_query = prompt_template[args.task_name]["retrieval_query"].format(**q).split(' ')
242 | retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=args.k)
243 |
244 | history_string = "".join(retrieved_history)
245 | prompt = history_string + "\n" + prompt
246 | full_prompt = history_string + "\n" + full_prompt
247 |
248 | if args.add_profile and format_flag == True:
249 | prompt = profile + "\n" + prompt
250 | full_prompt = profile + "\n" + full_prompt
251 |
252 | train_data.append(
253 | {
254 | "prompt": prompt,
255 | "full_prompt": full_prompt
256 | }
257 | )
258 |
259 | # print(train_data)
260 |
261 | train_dataset = Dataset.from_list(train_data)
262 | train_dataset = train_dataset.map(generate_and_tokenize_prompt).shuffle()
263 |
264 | trainer = transformers.Trainer(
265 | model=model,
266 | train_dataset=train_dataset,
267 | args=training_arguments,
268 | data_collator=transformers.DataCollatorForSeq2Seq(
269 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
270 | ),
271 | )
272 |
273 | for name, module in trainer.model.named_modules():
274 | if "norm" in name:
275 | module = module.to(torch.float32)
276 |
277 |
278 | model.config.use_cache = False # silence the warnings. Please re-enable for inference!
279 | trainer.train()
280 |
281 | output_name = "./ckpt/{}/k{}-{}-{}-OPPU_LoRA-{}".format(args.task_name, args.k, args.task_name, model_name.split('/')[-1], i)
282 | model.save_pretrained(output_name)
283 |
284 | model.eval()
285 | model.config.use_cache = True # silence the warnings. Please re-enable for inference!
286 |
287 | # test inference
288 | if args.add_profile:
289 | profile = test_profile[i]['output']
290 |
291 | if k > 0:
292 | visible_history_list = test_data[i]['profile']
293 | for p in visible_history_list:
294 | for key, value in p.items():
295 | p[key] = get_first_k_tokens(p[key], 368)
296 |
297 | history_list = [prompt_template[args.task_name]['retrieval_history'].format(**p) for p in visible_history_list]
298 |
299 | tokenized_corpus = [doc.split(" ") for doc in history_list]
300 | bm25 = BM25Okapi(tokenized_corpus)
301 |
302 | test_question_list = []
303 | question_id_list = []
304 |
305 | for q in test_data[i]['query']:
306 |
307 | if args.task_name == 'citation':
308 | test_question = q['input']
309 | test_article = extract_citation_title(test_question)
310 | option1, option2 = extract_option(test_question, 1), extract_option(test_question, 2)
311 | test_prompt = prompt_template[args.task_name]['prompt'].format(test_article, option1, option2)
312 |
313 | else:
314 | test_question = q['input']
315 | test_article = extract_article(test_question)
316 | test_prompt = prompt_template[args.task_name]['prompt'].format(test_article)
317 |
318 | if k > 0:
319 | tokenized_query = prompt_template[args.task_name]['retrieval_query_wokey'].format(test_article).split(" ")
320 | retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=args.k)
321 |
322 | history_string = "".join(retrieved_history)
323 | test_prompt = history_string + "\n" + test_prompt
324 |
325 | if args.add_profile:
326 | test_prompt = profile + "\n" + test_prompt
327 |
328 | test_question_list.append(test_prompt)
329 | question_id_list.append(q['id'])
330 |
331 | test_batch_list = split_batch(test_question_list, 1)
332 | out_list = []
333 |
334 | with torch.no_grad():
335 | for batch_idx, batch in tqdm(enumerate(test_batch_list), total=len(test_batch_list)):
336 | # try:
337 | sentences = batch
338 | inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False)
339 | inputs = inputs.to(model.device)
340 |
341 | with torch.autocast(device_type="cuda"):
342 | outputs = model.generate(
343 | **inputs,
344 | do_sample=True,
345 | top_k=10,
346 | temperature=args.temperature,
347 | top_p=0.9,
348 | eos_token_id=tokenizer.eos_token_id,
349 | max_new_tokens=200
350 | )
351 |
352 | out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
353 | out_list += out_sentence
354 | # except:
355 | # out_list += ['']
356 |
357 | for i in range(len(out_list)):
358 | output = out_list[i].replace(test_question_list[i], '')
359 | pred_all.append({
360 | "id": question_id_list[i],
361 | "output": output
362 | })
363 |
364 | print(output)
365 |
366 | output_file = {
367 | 'task': name2taskid[args.task_name],
368 | 'golds': pred_all,
369 | 'model': model_name,
370 | }
371 |
372 | if args.add_profile:
373 | with open('./output/{}/output-OPPU-k{}-{}-{}-profile.json'.format(args.k, args.task_name, args.task_name, model_name.split('/')[-1]), 'w') as f:
374 | json.dump(output_file, f, indent=4)
375 | else:
376 | with open('./output/{}/output-OPPU-k{}-{}-{}.json'.format(args.k, args.task_name, args.task_name, model_name.split('/')[-1]), 'w') as f:
377 | json.dump(output_file, f, indent=4)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Democratizing Large Language Models via Personalized Parameter-Efficient Fine-tuning
2 |
3 |
4 | This is source code of our EMNLP 2024 paper
5 |
6 | [**Democratizing Large Language Models via Personalized Parameter-Efficient Fine-tuning**](https://arxiv.org/abs/2402.04401).
7 |
8 | by
9 | [Zhaoxuan Tan](https://zhaoxuan.info/),
10 | [Qingkai Zeng](https://qingkaizeng.github.io/),
11 | [Yijun Tian](http://tianyijun.com/),
12 | [Zheyuan Liu](https://franciscoliu.github.io/),
13 | [Bing Yin](https://scholar.google.com/citations?user=qSOxydEAAAAJ&hl=en),
14 | [Meng Jiang](http://www.meng-jiang.com/).
15 |
16 |
17 |
18 | ## Overview ##
19 |
20 | * **Ownership**: Existing methods are processed centralized, where user history is encoded in a personalized prompt and processed by centralized LLMs. This paradigm limits the model's customization and ability to provide deep, personalized experiences tailored to individual users. Moreover, when using a centralized model, users often have to share personal data with the service provider, which raises concerns about how user data are stored, used, and protected.
21 |
22 | * **Behavior Pattern Generalization**: As is revealed by existing research, LLMs can be easily distracted by irrelevant context information that retrieval can hardly avoid. In LLM personalization, where the retrieval corpus is confined to a specific user's behaviors, retrieval augmentation might underperform, especially when the user's past behaviors do not closely mirror the patterns needed for the query at hand.
23 |
24 |
25 |

26 |
27 |
28 | Personalization in large language models (LLMs) is increasingly important, aiming to align the LLMs' interactions, content, and recommendations with individual user preferences. Recent advances have highlighted effective prompt design by enriching user queries with non-parametric knowledge through behavior history retrieval and textual profiles. However, these methods faced limitations due to a lack of model ownership, resulting in constrained customization and privacy issues, and often failed to capture complex, dynamic user behavior patterns. To address these shortcomings, we introduce One PEFT Per User (OPPU), employing personalized parameter-efficient fine-tuning (PEFT) modules to store user-specific behavior patterns and preferences. By plugging in personal PEFT parameters, users can own and use their LLMs individually. OPPU integrates parametric user knowledge in the personal PEFT parameters with non-parametric knowledge from retrieval and profiles, adapting LLMs to user behavior shifts. Experimental results demonstrate that OPPU significantly outperforms existing prompt-based methods across seven diverse tasks in the LaMP benchmark. Further studies reveal OPPU's enhanced capabilities in handling user behavior shifts, modeling users at different activity levels, maintaining robustness across various user history formats, and displaying versatility with different PEFT methods.
29 |
30 |
31 |

32 |
33 |
34 | ## Dataset ##
35 |
36 | We use publicly available data from the [LaMP](https://arxiv.org/abs/2304.11406) benchmark. You can download the our processed data [here](https://drive.google.com/file/d/1bJ3Rh_sqrw3suwwweFbra5CTV7GVjgxF/view?usp=sharing), unzip it, and place it under the ```./data``` folder
37 |
38 |
39 | ## Installation ##
40 | Please install the dependencies via conda, using the following command:
41 |
42 | ```bash
43 | pip install -r requirements.txt
44 | ```
45 |
46 | ## Experiment ##
47 | ```task_name``` can be selected from ```[citation, movie_tagging, news_categorize, news_headline, product_rating, scholarly_title, tweet_paraphrase]```. Here, we take ```movie_tagging``` as an example.
48 |
49 | ### OPPU
50 | #### 1. Base LLM Task Adaption
51 |
52 | ```bash
53 | CUDA_VISIBLE_DEVICES=0 python task_LoRA.py --k 0 --task_name movie_tagging
54 | ```
55 |
56 | #### 2. Train One PEFT Per User
57 | ```bash
58 | CUDA_VISIBLE_DEVICES=0 python OPPU.py --k 0 --task_name movie_tagging --task_lora ./ckpt/movie_tagging/k0-movie_tagging-Llama-2-7b-hf-task_LoRA_ckpt
59 | ```
60 |
61 | ### OPPU + RAG
62 |
63 | #### 1. Base LLM Task Adaption
64 |
65 | ```bash
66 | CUDA_VISIBLE_DEVICES=0 python task_LoRA.py --k 1 --task_name movie_tagging
67 | ```
68 |
69 | #### 2. Train One PEFT Per User
70 | ```bash
71 | CUDA_VISIBLE_DEVICES=0 python OPPU.py --k 1 --task_name movie_tagging --task_lora ./ckpt/movie_tagging/k1-movie_tagging-Llama-2-7b-hf-task_LoRA_ckpt
72 | ```
73 | ----
74 |
75 | ### OPPU + PAG
76 | #### 1. Base LLM Task Adaption
77 |
78 | ```bash
79 | CUDA_VISIBLE_DEVICES=0 python task_LoRA.py --k 1 --task_name movie_tagging --add_profile
80 | ```
81 |
82 | #### 2. Train One PEFT Per User
83 | ```bash
84 | CUDA_VISIBLE_DEVICES=0 python OPPU.py --k 1 --task_name movie_tagging --task_lora ./ckpt/movie_tagging/k1-movie_tagging-Llama-2-7b-hf-profile-task_LoRA_ckpt --add_profile
85 | ```
86 |
87 | ## Evaluation ##
88 | ```TASK_ID``` is the corresponding ID selected from ```["LaMP_1", "LaMP_2N", "LaMP_2M", "LaMP_3", "LaMP_4", "LaMP_5", "LaMP_7"]```
89 |
90 | ```bash
91 | python ./eval/eval_task.py \
92 | --golds_json {PATH_TO_LABEL_JSON_FILE} \
93 | --preds_json {PATH_TO_PREDICTION_JSON_FILE} \
94 | --task_name {TASK_ID} \
95 | --output_file {RESULT_JSON_PATH}
96 | ```
97 |
98 | ## Citation ##
99 | If you find this paper or codebase useful in your research, please kindly cite the following paper.
100 |
101 | ```bibtex
102 | @article{tan2024democratizing,
103 | title={Democratizing Large Language Models via Personalized Parameter-Efficient Fine-tuning},
104 | author={Tan, Zhaoxuan and Zeng, Qingkai and Tian, Yijun and Liu, Zheyuan and Yin, Bing and Jiang, Meng},
105 | journal={arXiv preprint arXiv:2402.04401},
106 | year={2024}
107 | }
108 | ```
109 |
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/asset/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/asset/overview.png
--------------------------------------------------------------------------------
/asset/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/asset/teaser.png
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/data/.DS_Store
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
1 | from evaluation import LaMPEvaluation
--------------------------------------------------------------------------------
/eval/__pycache__/evaluation.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/eval/__pycache__/evaluation.cpython-310.pyc
--------------------------------------------------------------------------------
/eval/__pycache__/evaluation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/eval/__pycache__/evaluation.cpython-39.pyc
--------------------------------------------------------------------------------
/eval/eval_all.py:
--------------------------------------------------------------------------------
1 | from evaluation import LaMPEvaluation
2 | import argparse
3 | import json
4 |
5 | parser = argparse.ArgumentParser()
6 |
7 | parser.add_argument("--golds_zip", required=True, help="Address to all gold labels for all tasks zipped in a file")
8 | parser.add_argument("--preds_zip", required=True, help="Address to all predictions for all tasks zipped in a file")
9 | parser.add_argument("--temp_dir", required=False, help="Address to a temp dir for extracting files", default="./tmp")
10 | parser.add_argument("--output_file", required=True, help="Address to the results file")
11 |
12 | if __name__ == "__main__":
13 |
14 | opts = parser.parse_args()
15 |
16 | evaluator = LaMPEvaluation(all_golds_zip_file_addr=opts.golds_zip, extract_addr=opts.temp_dir)
17 | results = evaluator.evaluate_all(opts.preds_zip)
18 | with open(opts.output_file, "w") as file:
19 | json.dump(results, file)
20 |
--------------------------------------------------------------------------------
/eval/eval_task.py:
--------------------------------------------------------------------------------
1 | from evaluation import LaMPEvaluation
2 | import argparse
3 | import json
4 | import warnings
5 |
6 | warnings.filterwarnings('ignore')
7 |
8 | parser = argparse.ArgumentParser()
9 |
10 | parser.add_argument("--golds_json", required=True, help="Address to all gold labels for the task as a json file")
11 | parser.add_argument("--preds_json", required=True, help="Address to all predictions for the task as a json file")
12 | parser.add_argument("--task_name", required=True, help="[LaMP_1, LaMP_2, LaMP_3, LaMP_4, LaMP_5, LaMP_6, LaMP_7]")
13 | parser.add_argument("--output_file", required=True, help="Address to the results file")
14 |
15 | if __name__ == "__main__":
16 |
17 | opts = parser.parse_args()
18 |
19 | evaluator = LaMPEvaluation(single_gold_json_file_addr=opts.golds_json)
20 | results = evaluator.evaluate_task(opts.preds_json, opts.task_name)
21 | with open(opts.output_file, "w") as file:
22 | json.dump(results, file)
23 |
--------------------------------------------------------------------------------
/eval/evaluation.py:
--------------------------------------------------------------------------------
1 | import json
2 | import zipfile
3 | import glob
4 | import os
5 | import shutil
6 |
7 | import os
8 |
9 | os.environ['HF_EVALUATE_OFFLINE'] = '1'
10 |
11 | import evaluate
12 |
13 | def postprocess_text_classification(preds, labels):
14 | preds = [str(pred).strip() for pred in preds]
15 | labels = [str(label).strip() for label in labels]
16 | return preds, labels
17 |
18 | def postprocess_text_generation(preds, labels):
19 | preds = [pred.strip() for pred in preds]
20 | labels = [[label.strip()] for label in labels]
21 |
22 | return preds, labels
23 |
24 | def create_metric_f1_accuracy(all_labels):
25 | f1_metric = evaluate.load("f1")
26 | accuracy_metric = evaluate.load("accuracy")
27 | def create_mapping(x):
28 | try:
29 | return all_labels.index(x)
30 | except:
31 | return -1
32 | def compute_metrics(decoded_preds, decoded_labels):
33 | decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels)
34 | decoded_preds = [create_mapping(x) for x in decoded_preds]
35 | decoded_labels = [create_mapping(x) for x in decoded_labels]
36 | result_acc = accuracy_metric.compute(predictions=decoded_preds, references=decoded_labels)
37 | result_f1 = f1_metric.compute(predictions=decoded_preds, references=decoded_labels, labels=list(range(len(all_labels))), average = "macro")
38 | result = {"accuracy" : result_acc["accuracy"], "f1" : result_f1["f1"]}
39 | return result
40 | return compute_metrics
41 |
42 | def create_metric_mae_rmse():
43 | mse_metric = evaluate.load("mse")
44 | mae_metric = evaluate.load("mae")
45 | def create_mapping(x, y):
46 | try:
47 | return float(x)
48 | except:
49 | print(x)
50 | y = float(y)
51 | if abs(1 - y) > abs(5 - y):
52 | return 1.0
53 | else:
54 | return 5.0
55 | def compute_metrics(decoded_preds, decoded_labels):
56 | decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels)
57 | decoded_preds = [create_mapping(x,y) for x,y in zip(decoded_preds, decoded_labels)]
58 | decoded_labels = [create_mapping(x,x) for x in decoded_labels]
59 | result_mae = mae_metric.compute(predictions=decoded_preds, references=decoded_labels)
60 | result_rmse = mse_metric.compute(predictions=decoded_preds, references=decoded_labels, squared = False)
61 | result = {"MAE" : result_mae["mae"], "RMSE" : result_rmse["mse"]}
62 | return result
63 | return compute_metrics
64 |
65 | def create_metric_rouge():
66 | rouge_metric = evaluate.load('rouge')
67 | def compute_metrics(decoded_preds, decoded_labels):
68 | decoded_preds, decoded_labels = postprocess_text_generation(decoded_preds, decoded_labels)
69 | result_rouge = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
70 | result = {"rouge-1" : result_rouge["rouge1"], "rouge-L" : result_rouge["rougeL"]}
71 | return result
72 | return compute_metrics
73 |
74 | class LaMPEvaluation(object):
75 |
76 | def __init__(self, all_golds_zip_file_addr = None, single_gold_json_file_addr = None, extract_addr = "./tmp") -> None:
77 | assert all_golds_zip_file_addr or single_gold_json_file_addr, "The golds should be provided for all datasets or at least one."
78 | assert not (all_golds_zip_file_addr and single_gold_json_file_addr), "The golds should be provided using zip file or json file not both."
79 | self.tasks_golds = dict()
80 | self.extract_addr = extract_addr
81 | self.evaluate_all_is_possible = False
82 | if all_golds_zip_file_addr:
83 | os.makedirs(self.extract_addr, exist_ok=True)
84 | with zipfile.ZipFile(all_golds_zip_file_addr, 'r') as zobj:
85 | zobj.extractall(path = extract_addr)
86 | for file_addr in glob.glob(os.path.join(self.extract_addr, "**/*.json"), recursive=True):
87 | with open(file_addr) as file:
88 | task = json.load(file)
89 | self.tasks_golds[task['task']] = task['golds']
90 | self._empty_dir(self.extract_addr)
91 | self.evaluate_all_is_possible = True
92 | if single_gold_json_file_addr:
93 | with open(single_gold_json_file_addr) as file:
94 | task = json.load(file)
95 | self.tasks_golds[task['task']] = task['golds']
96 |
97 | def _empty_dir(self, directory_path):
98 | for filename in os.listdir(directory_path):
99 | file_path = os.path.join(directory_path, filename)
100 | try:
101 | if os.path.isfile(file_path):
102 | os.unlink(file_path)
103 | elif os.path.isdir(file_path):
104 | shutil.rmtree(file_path)
105 | except Exception as e:
106 | print(f'Failed to delete {file_path}. Reason: {e}')
107 |
108 | def _get_all_gold_ids(self, task_name):
109 | return set([sample['id'] for sample in self.tasks_golds[task_name]])
110 |
111 | def _get_all_ids(self, input):
112 | return set([sample['id'] for sample in input])
113 |
114 | def evaluate_all(self, predicts_zipfile_addr):
115 | assert self.evaluate_all_is_possible, "You did not provide golds for all tasks."
116 | with zipfile.ZipFile(predicts_zipfile_addr, 'r') as zobj:
117 | zobj.extractall(path = self.extract_addr)
118 | results_raw = dict()
119 | all_task_names = set()
120 | for file_addr in glob.glob(os.path.join(self.extract_addr, "**/*.json"), recursive=True):
121 | with open(file_addr) as file:
122 | preds = json.load(file)
123 | all_task_names.add(preds['task'])
124 | results_raw[preds['task']] = self._evaluate_task(preds['golds'], preds['task'])
125 | self._empty_dir(self.extract_addr)
126 | assert len(all_task_names) == 7, "The provided results do not cover all the tasks in the benchmark."
127 | return results_raw
128 |
129 | def evaluate_task(self, predicts_json_addr, task_name):
130 | with open(predicts_json_addr) as file:
131 | preds = json.load(file)
132 | assert preds['task'] == task_name, "The provided task_name and the results do not match."
133 | assert preds['task'] in self.tasks_golds.keys(), "The provided golds cannot be used to evaluate this task."
134 | return self._evaluate_task(preds['golds'], task_name)
135 |
136 | def _evaluate_task(self, predictions, task_name):
137 | golds_dict = {y['id']:y['output'] for y in self.tasks_golds[task_name]}
138 | preds_dict = {x['id']:x['output'] for x in predictions}
139 |
140 | gold_ids = self._get_all_gold_ids(task_name)
141 | pred_ids = self._get_all_ids(predictions)
142 |
143 | assert gold_ids == pred_ids, "Predictions ids and gold ids do not match. {}".format(gold_ids-pred_ids)
144 |
145 | if task_name in ["LaMP_1", "LaMP_2N", "LaMP_2M"]:
146 | metric = create_metric_f1_accuracy(self._get_labels(task_name))
147 | elif task_name == "LaMP_3":
148 | metric = create_metric_mae_rmse()
149 | else:
150 | metric = create_metric_rouge()
151 |
152 | gold_ids = list(gold_ids)
153 | golds = [golds_dict[id] for id in gold_ids]
154 | preds = [preds_dict[id] for id in gold_ids]
155 | return metric(preds, golds)
156 |
157 | def _get_labels(self, task_name):
158 | if task_name == "LaMP_1":
159 | return ["[1]", "[2]"]
160 | elif task_name == "LaMP_2N":
161 | return ["food & drink", "sports", "education", "parents", "religion", "travel", "business", "crime", "science & technology", "culture & arts", "entertainment", "politics", "women", "style & beauty", "healthy living"]
162 | elif task_name == "LaMP_2M":
163 | return ["sci-fi", "based on a book", "comedy", "action", "twist ending", "dystopia", "dark comedy", "classic", "psychology", "fantasy", "romance", "thought-provoking", "social commentary", "violence", "true story"]
164 | elif task_name == "LaMP_3":
165 | return ["1", "2", "3", "4", "5"]
166 | else:
167 | raise ValueError("Invalid task_name")
--------------------------------------------------------------------------------
/gen_profile.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import bitsandbytes as bnb
5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
6 | from transformers import pipeline, BitsAndBytesConfig
7 | import argparse
8 | from rank_bm25 import BM25Okapi
9 | # from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
10 | import transformers
11 | import json
12 | from utils import split_batch, get_first_k_tokens, print_trainable_parameters, name2taskid
13 | from utils import extract_citation_title, extract_option, extract_movie, extract_news_cat, extract_news_headline, extract_product_review, extract_scholarly_title, extract_tweet_paraphrasing
14 |
15 | parser = argparse.ArgumentParser(description="Parser for LoRA")
16 | parser.add_argument('--model_name', type=str, default='lmsys/vicuna-7b-v1.5')
17 | parser.add_argument('--batch_size', type=int, default=3)
18 | parser.add_argument('--k', type=int, default=10)
19 | parser.add_argument('--cut_off', type=int, default=2048)
20 | parser.add_argument('--task_name', type=str, default='movie_tagging')
21 |
22 | args = parser.parse_args()
23 | model_name = args.model_name
24 | batch_size = args.batch_size
25 | k = args.k
26 | cutoff_len = args.cut_off
27 | add_eos_token = False
28 |
29 |
30 | # # 4 bit quantization inference
31 | # bnb_config = BitsAndBytesConfig(
32 | # load_in_4bit=True,
33 | # bnb_4bit_quant_type="nf4",
34 | # bnb_4bit_compute_dtype=torch.float16,
35 | # bnb_4bit_use_double_quant=True,
36 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
37 | # )
38 |
39 | # 8-bit quantization inference
40 | # bnb_config = BitsAndBytesConfig(
41 | # load_in_8bit=True,
42 | # bnb_8bit_quant_type="nf8",
43 | # bnb_8bit_compute_dtype=torch.float16,
44 | # bnb_8bit_use_double_quant=True,
45 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
46 | # )
47 |
48 | # 16-bit quantization inference
49 | # bnb_config = BitsAndBytesConfig(
50 | # load_in_16bit=True,
51 | # bnb_16bit_quant_type="bf16",
52 | # bnb_16bit_compute_dtype=torch.bfloat16,
53 | # bnb_16bit_use_double_quant=True,
54 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
55 | # )
56 |
57 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
58 | tokenizer.eos_token = ""
59 | tokenizer.pad_token = '[PAD]'
60 |
61 | base_model = AutoModelForCausalLM.from_pretrained(
62 | model_name,
63 | # quantization_config=bnb_config,
64 | local_files_only=False,
65 | device_map='auto',
66 | trust_remote_code=True,
67 | torch_dtype=torch.float16
68 | )
69 |
70 | base_model.config.use_cache = True
71 | base_model.config.pad_token_id = tokenizer.pad_token_id
72 | base_model.config.eos_token_id = tokenizer.eos_token_id
73 | base_model.config.bos_token_id = tokenizer.bos_token_id
74 |
75 |
76 | with open(f"./data/{args.task_name}/user_others.json", 'r') as f:
77 | train_data = json.load(f)
78 |
79 | with open(f"./data/{args.task_name}/user_top_100_history.json", 'r') as f:
80 | test_data = json.load(f)
81 |
82 | with open('./prompt/prompt_profile.json', 'r') as f:
83 | prompt_template = json.load(f)
84 |
85 | from tqdm import tqdm
86 | import random
87 |
88 | K = args.k
89 |
90 | prompt_list_others = []
91 | userid_list_others = []
92 |
93 | for user in tqdm(train_data):
94 |
95 | history_list = []
96 |
97 | if len(user['profile'])> K:
98 | profiles = random.sample(user['profile'], K)
99 | else:
100 | profiles = user['profile']
101 |
102 |
103 | for p in profiles:
104 | for key, value in p.items():
105 | p[key] = get_first_k_tokens(p[key], 200)
106 |
107 | for p in profiles:
108 | history_list.append(prompt_template[args.task_name]['retrieval_history'].format(**p))
109 |
110 | history_string = ' | '.join(history_list)
111 |
112 | test_prompt = prompt_template[args.task_name]["profile_prompt"].format(history_list)
113 |
114 | prompt_list_others.append(test_prompt)
115 | userid_list_others.append(user['user_id'])
116 |
117 |
118 | prompt_list_100 = []
119 | userid_list_100 = []
120 |
121 | for user in tqdm(test_data):
122 |
123 | history_list = []
124 |
125 | if len(user['profile'])> K:
126 | profiles = random.sample(user['profile'], K)
127 | else:
128 | profiles = user['profile']
129 |
130 | for p in profiles:
131 | for key, value in p.items():
132 | p[key] = get_first_k_tokens(p[key], 200)
133 |
134 | for p in profiles:
135 | history_list.append(prompt_template[args.task_name]['retrieval_history'].format(**p))
136 |
137 | history_string = ' | '.join(history_list)
138 |
139 | test_prompt = prompt_template[args.task_name]['profile_prompt'].format(history_list)
140 | prompt_list_100.append(test_prompt)
141 | userid_list_100.append(user['user_id'])
142 |
143 | batched_prompt_others = split_batch(prompt_list_others, batch_size)
144 | out_list_others = []
145 |
146 | print(len(prompt_list_others))
147 | print(len(batched_prompt_others))
148 |
149 | with torch.no_grad():
150 | for batch_idx, batch in tqdm(enumerate(batched_prompt_others), total=len(batched_prompt_others)):
151 | # try:
152 | sentences = batch
153 | inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False)
154 | inputs = inputs.to(base_model.device)
155 |
156 | with torch.autocast(device_type="cuda"):
157 | outputs = base_model.generate(
158 | **inputs,
159 | do_sample=True,
160 | top_k=10,
161 | temperature=0.6,
162 | top_p=0.9,
163 | eos_token_id=tokenizer.eos_token_id,
164 | max_new_tokens=300,
165 | )
166 |
167 | out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
168 | out_list_others += out_sentence
169 | # except:
170 | # out_list_others += ['']
171 |
172 | pred_all_others = []
173 |
174 | for i in range(len(out_list_others)):
175 | output = out_list_others[i].replace(prompt_list_others[i], '')
176 | pred_all_others.append({
177 | "id": userid_list_others[i],
178 | "output": output
179 | })
180 |
181 | print(output)
182 |
183 |
184 | with open(f'./data/{args.task_name}/profile_user_others.json', 'w') as f:
185 | json.dump(pred_all_others, f)
186 |
187 |
188 | batched_prompt_100 = split_batch(prompt_list_100, batch_size)
189 | out_list_100 = []
190 |
191 | with torch.no_grad():
192 | for batch_idx, batch in tqdm(enumerate(batched_prompt_100), total=len(batched_prompt_100)):
193 | sentences = batch
194 | inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False)
195 | inputs = inputs.to(base_model.device)
196 |
197 | with torch.autocast(device_type="cuda"):
198 | outputs = base_model.generate(
199 | **inputs,
200 | do_sample=True,
201 | top_k=10,
202 | temperature=0.6,
203 | top_p=0.9,
204 | eos_token_id=tokenizer.eos_token_id,
205 | max_new_tokens=200,
206 | )
207 |
208 | out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
209 | out_list_100 += out_sentence
210 | # except:
211 | # out_list_100 += ['']
212 |
213 | pred_all_100 = []
214 |
215 | for i in range(len(out_list_100)):
216 | output = out_list_100[i].replace(prompt_list_100[i], '')
217 | pred_all_100.append({
218 | "id": userid_list_100[i],
219 | "output": output
220 | })
221 |
222 | print(output)
223 |
224 |
225 | with open(f'./data/{args.task_name}/profile_user_100.json', 'w') as f:
226 | json.dump(pred_all_100, f)
227 |
--------------------------------------------------------------------------------
/prompt/prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "movie_tagging":{
3 | "prompt": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {} tag:",
4 | "full_prompt": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {} tag: {}",
5 | "retrieval_history": "description: {description} tag: {tag}\n",
6 | "retrieval_query": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation.\n description: {description}",
7 | "retrieval_query_wokey": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation.\n description: {}",
8 | "OPPU_input": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {description} tag:",
9 | "OPPU_full": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {description} tag: {tag}"
10 | },
11 | "citation":{
12 | "prompt": "Identify the most relevant reference for the listed publication by the researcher. Select the reference paper that is most closely related to the researcher's work.\npaper title: {} reference:",
13 | "full_prompt": "Identify the most relevant reference for the listed publication by the researcher. Select the reference paper that is most closely related to the researcher's work.\npaper title: {} reference: {}",
14 | "retrieval_history": "paper title: {title}\n",
15 | "retrieval_query": "Identify the most relevant reference for the listed publication by the researcher. Select the reference paper that is most closely related to the researcher's work. Please respond with only the number that corresponds to the reference.\npaper title: {title}",
16 | "retrieval_query_wokey": "Identify the most relevant reference for the listed publication by the researcher. Select the reference paper that is most closely related to the researcher's work. Please respond with only the number that corresponds to the reference.\npaper title: {}",
17 | "OPPU_input": "paper title: {title} reference:",
18 | "OPPU_full": "paper title: {title} reference: {citation}"
19 | },
20 | "news_categorize":{
21 | "prompt": "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [travel, education, parents, style & beauty, entertainment, food & drink, science & technology, business, sports, healthy living, women, politics, crime, culture & arts, religion]\n article: {} category:",
22 | "full_prompt": "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [travel, education, parents, style & beauty, entertainment, food & drink, science & technology, business, sports, healthy living, women, politics, crime, culture & arts, religion]\n article: {} category: {}",
23 | "retrieval_history": "article: {text} category: {category}\n",
24 | "retrieval_query": "Which category does this article relate to among the following categories? arcicle: {text}",
25 | "retrieval_query_wokey": "Which category does this article relate to among the following categories? arcicle: {}",
26 | "OPPU_input": "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [travel, education, parents, style & beauty, entertainment, food & drink, science & technology, business, sports, healthy living, women, politics, crime, culture & arts, religion]\n article: {text} category:",
27 | "OPPU_full": "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [travel, education, parents, style & beauty, entertainment, food & drink, science & technology, business, sports, healthy living, women, politics, crime, culture & arts, religion]\n article: {text} category: {category}"
28 | },
29 | "news_headline":{
30 | "prompt": "Generate a headline for the following article.\narticle: {} headline:",
31 | "full_prompt": "Generate a headline for the following article.\narticle: {} headline: {}",
32 | "retrieval_history": "article: {text} headline: {title}\n",
33 | "retrieval_query": "Generate a headline for the following article: {text}",
34 | "retrieval_query_wokey": "Generate a headline for the following article: {}",
35 | "OPPU_input": "Generate a headline for the following article.\n article: {text} headline:",
36 | "OPPU_full": "Generate a headline for the following article.\n article: {text} headline: {title}"
37 | },
38 | "product_rating":{
39 | "prompt": "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {}\n score:",
40 | "full_prompt": "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {}\n score: {}",
41 | "retrieval_history": "review: {text} score: {score}\n",
42 | "retrieval_query": "What is the score of the following review on a scale of 1 to 5? review: {text}",
43 | "retrieval_query_wokey": "What is the score of the following review on a scale of 1 to 5? review: {}",
44 | "OPPU_input": "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {text}\n score:",
45 | "OPPU_full": "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {text}\n score: {score}"
46 | },
47 | "scholarly_title": {
48 | "prompt": "Generate a title for the following abstract of a paper.\n abstract: {} title:",
49 | "full_prompt": "Generate a title for the following abstract of a paper.\n abstract: {} title: {}",
50 | "retrieval_history": "abstract: {abstract} title: {title}\n",
51 | "retrieval_query": "Generate a title for the following abstract of a paper: {abstract}",
52 | "retrieval_query_wokey": "Generate a title for the following abstract of a paper: {}",
53 | "OPPU_input": "Generate a title for the following abstract of a paper.\n abstract: {abstract} title:",
54 | "OPPU_full": "Generate a title for the following abstract of a paper.\n abstract: {abstract} title: {title}"
55 | },
56 | "tweet_paraphrase":{
57 | "prompt": "Paraphrase the following text into tweet without any explanation before or after it.\n text: {} tweet:",
58 | "full_prompt": "Paraphrase the following text into tweet without any explanation before or after it.\n text: {} tweet: {}",
59 | "retrieval_history": "tweet: {text}\n",
60 | "retrieval_query": "Paraphrase the following text into tweet without any explanation before or after it: {text}",
61 | "retrieval_query_wokey": "Paraphrase the following text into tweet without any explanation before or after it: {}",
62 | "OPPU_input": "tweet:",
63 | "OPPU_full": "tweet: {text}"
64 | }
65 | }
--------------------------------------------------------------------------------
/prompt/prompt_profile.json:
--------------------------------------------------------------------------------
1 | {
2 | "movie_tagging":{
3 | "profile_prompt": "Look at the following past movies this user has watched and determine the most popular tag they labeled. Answer in the following form: most popular tag: . User History: {} Answer:",
4 | "retrieval_history": "description: {description} tag: {tag}"
5 | },
6 | "citation":{
7 | "profile_prompt": "Write a summary, in English, of the research interests and topics of a researcher who has published the following papers. Only generate the summary, no other text. User History: {} Answer:",
8 | "retrieval_history": "paper title: {title} reference: {citation}"
9 | },
10 | "news_categorize":{
11 | "profile_prompt": "Look at the following past articles this journalist has written and determine the most popular category they write in. Answer in the following form: most popular category: . User History: {} Answer:",
12 | "retrieval_history": "article: {text} category: {category}"
13 | },
14 | "news_headline":{
15 | "profile_prompt": "Given this author’s previous articles, try to describe a template for their headlines. I want to be able to accurately predict the headline gives one of their articles. Be specific about their style and wording, don’t tell me anything generic. User History: {} Answer:",
16 | "retrieval_history": "article: {text} headline: {title}"
17 | },
18 | "product_rating":{
19 | "profile_prompt": "Based on this user’s past reviews, what are the most common scores they give for positive and negative reviews? Answer in the following form: most common positive score: , most common negative score: . User History: {} Answer:",
20 | "retrieval_history": "review: {text} score: {score}"
21 | },
22 | "scholarly_title": {
23 | "profile_prompt": "Given this author’s previous publications, try to describe a template for their titles. I want to be able to accurately predict the title of one of the papers from the abstract. Only generate the template description, nothing else. User History: {} Answer:",
24 | "retrieval_history": "abstract: {abstract} title: {title}"
25 | },
26 | "tweet_paraphrase":{
27 | "profile_prompt": "Given this person’s previous tweets, try to describe a template for their tweets. I want to take a generic sentence and rephrase it to sound like one of their tweets, with the same style/punctuation/capitalization/wording/tone/etc. as them. Only give me the template description, nothing else. User History: {} Answer:",
28 | "retrieval_history": "tweet: {text}"
29 | }
30 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.0.0
2 | accelerate==0.24.1
3 | aiohttp==3.8.6
4 | aiosignal==1.3.1
5 | anyio==4.2.0
6 | appdirs==1.4.4
7 | argon2-cffi==23.1.0
8 | argon2-cffi-bindings==21.2.0
9 | arrow==1.3.0
10 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
11 | async-lru==2.0.4
12 | async-timeout==4.0.3
13 | attrs==23.1.0
14 | Babel==2.14.0
15 | backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work
16 | bayesian-optimization==1.4.3
17 | beautifulsoup4==4.12.2
18 | bitsandbytes==0.41.2.post2
19 | bleach==6.1.0
20 | blessed==1.20.0
21 | Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work
22 | certifi==2023.7.22
23 | cffi @ file:///croot/cffi_1670423208954/work
24 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
25 | click==8.1.7
26 | cma==3.3.0
27 | colorama==0.4.6
28 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1691044910542/work
29 | cryptography @ file:///croot/cryptography_1694444244250/work
30 | datasets==2.12.0
31 | debugpy @ file:///croot/debugpy_1690905042057/work
32 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
33 | defusedxml==0.7.1
34 | dill==0.3.6
35 | docker-pycreds==0.4.0
36 | docstring-parser==0.15
37 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
38 | evaluate==0.4.1
39 | exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1692026125334/work
40 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
41 | fastjsonschema==2.19.1
42 | filelock @ file:///croot/filelock_1672387128942/work
43 | fire==0.5.0
44 | fqdn==1.5.1
45 | frozenlist==1.4.0
46 | fsspec==2023.10.0
47 | gitdb==4.0.11
48 | GitPython==3.1.40
49 | gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
50 | gpustat==1.1.1
51 | huggingface-hub==0.20.3
52 | idna @ file:///croot/idna_1666125576474/work
53 | ijson==3.2.3
54 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1698244021190/work
55 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1698846603011/work
56 | ipywidgets==8.1.1
57 | isoduration==20.11.0
58 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
59 | Jinja2 @ file:///croot/jinja2_1666908132255/work
60 | joblib==1.3.2
61 | json5==0.9.14
62 | jsonargparse==4.27.5
63 | jsonpointer==2.4
64 | jsonschema==4.20.0
65 | jsonschema-specifications==2023.12.1
66 | jupyter==1.0.0
67 | jupyter-console==6.6.3
68 | jupyter-events==0.9.0
69 | jupyter-lsp==2.2.1
70 | jupyter_client==8.6.0
71 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1669775088561/work
72 | jupyter_server==2.12.3
73 | jupyter_server_terminals==0.5.1
74 | jupyterlab==4.0.10
75 | jupyterlab-widgets==3.0.9
76 | jupyterlab_pygments==0.3.0
77 | jupyterlab_server==2.25.2
78 | lightning==2.2.0.post0
79 | lightning-utilities==0.10.1
80 | markdown-it-py==3.0.0
81 | MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
82 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
83 | mdurl==0.1.2
84 | mistune==3.0.2
85 | mkl-fft @ file:///croot/mkl_fft_1695058164594/work
86 | mkl-random @ file:///croot/mkl_random_1695059800811/work
87 | mkl-service==2.4.0
88 | mpmath @ file:///croot/mpmath_1690848262763/work
89 | multidict==6.0.4
90 | multiprocess==0.70.14
91 | nbclient==0.9.0
92 | nbconvert==7.14.0
93 | nbformat==5.9.2
94 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1697083700168/work
95 | networkx @ file:///croot/networkx_1690561992265/work
96 | nevergrad==1.0.1
97 | nltk==3.8.1
98 | notebook==7.0.6
99 | notebook_shim==0.2.3
100 | numpy @ file:///croot/numpy_and_numpy_base_1695830428084/work/dist/numpy-1.26.0-cp310-cp310-linux_x86_64.whl#sha256=fc2732718bc9e06a7b702492cb4f5afffe9671083930452d894377bf563464a3
101 | nvidia-ml-py==12.535.133
102 | overrides==7.4.0
103 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1696202382185/work
104 | pandas==2.1.2
105 | pandocfilters==1.5.0
106 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
107 | pathtools==0.1.2
108 | peft==0.5.0
109 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
110 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
111 | Pillow @ file:///croot/pillow_1696580024257/work
112 | prometheus-client==0.19.0
113 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1688565951714/work
114 | protobuf==4.25.0
115 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
116 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
117 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
118 | pyarrow==14.0.0
119 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
120 | pyg-lib==0.3.0+pt20cu118
121 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1691408637400/work
122 | pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work
123 | pyparsing==3.1.1
124 | PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work
125 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
126 | python-json-logger==2.0.7
127 | pytorch-lightning==2.2.0.post0
128 | pytz==2023.3.post1
129 | PyYAML==6.0.1
130 | pyzmq @ file:///croot/pyzmq_1686601365461/work
131 | qtconsole==5.5.1
132 | QtPy==2.4.1
133 | rank-bm25==0.2.2
134 | referencing==0.32.1
135 | regex==2023.10.3
136 | requests @ file:///croot/requests_1690400202158/work
137 | responses==0.18.0
138 | rfc3339-validator==0.1.4
139 | rfc3986-validator==0.1.1
140 | rich==13.6.0
141 | rouge-score==0.1.2
142 | rpds-py==0.16.2
143 | safetensors==0.4.2
144 | scikit-learn==1.3.2
145 | scipy==1.11.3
146 | Send2Trash==1.8.2
147 | sentencepiece==0.1.99
148 | sentry-sdk==1.34.0
149 | setproctitle==1.3.3
150 | shtab==1.6.4
151 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
152 | smmap==5.0.1
153 | sniffio==1.3.0
154 | soupsieve==2.5
155 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
156 | sympy @ file:///croot/sympy_1668202399572/work
157 | termcolor==2.4.0
158 | terminado==0.18.0
159 | threadpoolctl==3.2.0
160 | timm==0.9.16
161 | tinycss2==1.2.1
162 | tokenizers==0.15.1
163 | tomli==2.0.1
164 | torch==2.0.1
165 | torch-cluster==1.6.3+pt20cu118
166 | torch-scatter==2.1.2+pt20cu118
167 | torch-sparse==0.6.18+pt20cu118
168 | torch-spline-conv==1.2.2+pt20cu118
169 | torch_geometric==2.4.0
170 | torchaudio==2.0.2
171 | torchmetrics==1.3.1
172 | torchvision==0.15.2
173 | tornado==6.4
174 | tqdm==4.66.1
175 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1698671135544/work
176 | transformers @ git+https://github.com/huggingface/transformers@5f9685576149fb45a61d0dcec9a260930df0a49a
177 | trim==0.3
178 | triton==2.0.0
179 | trl==0.7.4
180 | types-python-dateutil==2.8.19.20240106
181 | typing_extensions @ file:///croot/typing_extensions_1690297465030/work
182 | tyro==0.5.12
183 | tzdata==2023.3
184 | uri-template==1.3.0
185 | urllib3 @ file:///croot/urllib3_1698257533958/work
186 | wandb==0.15.12
187 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1698744702785/work
188 | webcolors==1.13
189 | webencodings==0.5.1
190 | websocket-client==1.7.0
191 | widgetsnbextension==4.0.9
192 | xxhash==3.4.1
193 | yarl==1.9.2
194 |
--------------------------------------------------------------------------------
/task_LoRA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import bitsandbytes as bnb
4 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
5 | # from transformers import pipeline, BitsAndBytesConfig
6 | import argparse
7 | from rank_bm25 import BM25Okapi
8 | # from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
9 | import transformers
10 | from utils import split_batch, get_first_k_tokens, print_trainable_parameters, name2taskid
11 | from utils import extract_citation_title, extract_option, extract_movie, extract_news_cat, extract_news_headline, extract_product_review, extract_scholarly_title, extract_tweet_paraphrasing
12 | import json
13 | from tqdm import tqdm
14 |
15 |
16 | parser = argparse.ArgumentParser(description="Parser for LoRA")
17 | parser.add_argument('--model_name', type=str, default='meta-llama/Llama-2-7b-hf')
18 | parser.add_argument('--batch_size', type=int, default=16)
19 | parser.add_argument('--k', type=int, default=0)
20 | parser.add_argument('--max_step', type=int, default=5000)
21 | parser.add_argument('--cut_off', type=int, default=2048)
22 | parser.add_argument('--max_epoch', type=int, default=3)
23 | parser.add_argument('--temperature', type=float, default=0.1)
24 | parser.add_argument('--task_name', type=str, default='movie_tagging')
25 | parser.add_argument('--add_profile', action='store_true')
26 | parser.add_argument('--access_token', type=str, default=None)
27 |
28 | args = parser.parse_args()
29 | model_name = args.model_name
30 | task_name = args.task_name
31 | batch_size = args.batch_size
32 | k = args.k
33 | # max_step = args.max_step
34 | cutoff_len = args.cut_off
35 | add_eos_token = False
36 | max_epoch = args.max_epoch
37 |
38 | # # 4 bit quantization inference
39 | # bnb_config = BitsAndBytesConfig(
40 | # load_in_4bit=True,
41 | # bnb_4bit_quant_type="nf4",
42 | # bnb_4bit_compute_dtype=torch.float16,
43 | # bnb_4bit_use_double_quant=True,
44 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
45 | # )
46 |
47 | # 8-bit quantization inference
48 | # bnb_config = BitsAndBytesConfig(
49 | # load_in_8bit=True,
50 | # bnb_8bit_quant_type="nf8",
51 | # bnb_8bit_compute_dtype=torch.float16,
52 | # bnb_8bit_use_double_quant=True,
53 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
54 | # )
55 |
56 | # 16-bit quantization inference
57 | # bnb_config = BitsAndBytesConfig(
58 | # load_in_16bit=True,
59 | # bnb_16bit_quant_type="bf16",
60 | # bnb_16bit_compute_dtype=torch.bfloat16,
61 | # bnb_16bit_use_double_quant=True,
62 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
63 | # )
64 |
65 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=args.access_token)
66 | tokenizer.eos_token = ""
67 | tokenizer.pad_token = '[PAD]'
68 | # tokenizer.pad_token = tokenizer.eos_token
69 | tokenizer.pad_token_id = tokenizer.eos_token_id
70 |
71 |
72 | base_model = AutoModelForCausalLM.from_pretrained(
73 | model_name,
74 | # quantization_config=bnb_config,
75 | local_files_only=False,
76 | device_map='auto',
77 | trust_remote_code=True,
78 | torch_dtype=torch.bfloat16
79 | )
80 |
81 | base_model.config.use_cache = False
82 | base_model.config.pad_token_id = tokenizer.pad_token_id
83 | base_model.config.eos_token_id = tokenizer.eos_token_id
84 | base_model.config.bos_token_id = tokenizer.bos_token_id
85 |
86 |
87 | from peft import prepare_model_for_kbit_training
88 |
89 | base_model.gradient_checkpointing_enable()
90 | base_model = prepare_model_for_kbit_training(base_model)
91 |
92 |
93 |
94 | from peft import LoraConfig, get_peft_model
95 |
96 | peft_config = LoraConfig(
97 | r=8,
98 | lora_alpha=8,
99 | target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
100 | lora_dropout=0.05,
101 | bias="none",
102 | task_type="CAUSAL_LM"
103 | )
104 |
105 | training_arguments = transformers.TrainingArguments(
106 | output_dir='outputs/',
107 | per_device_train_batch_size=batch_size,
108 | gradient_accumulation_steps=1,
109 | optim='adamw_torch',
110 | num_train_epochs=max_epoch,
111 | save_steps=1e9,
112 | logging_steps=50,
113 | learning_rate=1e-4,
114 | weight_decay=1e-2,
115 | bf16=True,
116 | max_grad_norm=0.3,
117 | # max_steps=max_step,
118 | warmup_ratio=0.1,
119 | group_by_length=True,
120 | lr_scheduler_type='linear',
121 | report_to='none',
122 | )
123 |
124 |
125 | with open(f"./data/{task_name}/user_others.json", 'r') as f:
126 | train = json.load(f)
127 |
128 |
129 | with open(f"./data/{task_name}/user_top_100_history.json", 'r') as f:
130 | test_data = json.load(f)
131 |
132 | if args.task_name == "movie_tagging":
133 | extract_article = extract_movie
134 | elif args.task_name == "news_categorize":
135 | extract_article = extract_news_cat
136 | elif args.task_name == "news_headline":
137 | extract_article = extract_news_headline
138 | elif args.task_name == "product_rating":
139 | extract_article = extrat_product_review
140 | elif args.task_name == "scholarly_title":
141 | extract_article = extract_scholarly_title
142 | elif args.task_name == "tweet_paraphrase":
143 | extract_article = extrat_tweet_paraphrasing
144 |
145 |
146 | with open('./prompt/prompt.json', 'r') as f:
147 | prompt_template = json.load(f)
148 |
149 |
150 | if args.add_profile:
151 | with open(f'./data/{task_name}/profile_user_100.json', 'r') as f:
152 | test_profile = json.load(f)
153 |
154 | with open(f'./data/{task_name}/profile_user_others.json', 'r') as f:
155 | train_profile = json.load(f)
156 |
157 |
158 | def tokenize(prompt, add_eos_token=True):
159 | # there's probably a way to do this with the tokenizer settings
160 | # but again, gotta move fast
161 | result = tokenizer(
162 | prompt,
163 | truncation=True,
164 | max_length=cutoff_len,
165 | padding=False,
166 | return_tensors=None,
167 | )
168 | if (
169 | result["input_ids"][-1] != tokenizer.eos_token_id
170 | and len(result["input_ids"]) < cutoff_len
171 | and add_eos_token
172 | ):
173 | result["input_ids"].append(tokenizer.eos_token_id)
174 | result["attention_mask"].append(1)
175 |
176 | result["labels"] = result["input_ids"].copy()
177 |
178 | return result
179 |
180 |
181 | def generate_and_tokenize_prompt(data_point):
182 | full_prompt = data_point['full_prompt']
183 | tokenized_full_prompt = tokenize(full_prompt)
184 | # if not train_on_inputs:
185 | user_prompt = data_point['prompt']
186 |
187 | tokenized_user_prompt = tokenize(
188 | user_prompt, add_eos_token=add_eos_token
189 | )
190 | user_prompt_len = len(tokenized_user_prompt["input_ids"])
191 |
192 | if add_eos_token:
193 | user_prompt_len -= 1
194 |
195 | tokenized_full_prompt["labels"] = [
196 | -100
197 | ] * user_prompt_len + tokenized_full_prompt["labels"][
198 | user_prompt_len:
199 | ] # could be sped up, probably
200 | return tokenized_full_prompt
201 |
202 |
203 |
204 | # training
205 | from datasets import load_dataset, Dataset
206 | model = get_peft_model(base_model, peft_config)
207 | print_trainable_parameters(model)
208 |
209 | pred_all = []
210 | actual = []
211 | train_data = []
212 |
213 | for i in tqdm(range(len(train))):
214 | if args.add_profile:
215 | profile = train_profile[i]['output']
216 |
217 | for idx, q in enumerate(train[i]['query']):
218 |
219 | if args.task_name != "citation":
220 | article = get_first_k_tokens(extract_article(q['input']), 768)
221 | prompt = prompt_template[args.task_name]['prompt'].format(article)
222 | full_prompt = prompt_template[args.task_name]['full_prompt'].format(get_first_k_tokens(extract_article(q['input']), 768), q['gold'])
223 |
224 | else:
225 | question = q['input']
226 | article = extract_citation_title(question)
227 | option1, option2 = extract_option(question, 1), extract_option(question, 2)
228 |
229 | prompt = prompt_template[args.task_name]['prompt'].format(article, option1, option2)
230 | full_prompt = prompt_template[args.task_name]['full_prompt'].format(article, option1, option2, q['gold'])
231 |
232 | if k > 0:
233 | visible_history_list = train[i]['profile']
234 |
235 | for p in visible_history_list:
236 | for key, value in p.items():
237 | p[key] = get_first_k_tokens(p[key], 368)
238 |
239 | history_list = [prompt_template[args.task_name]['retrieval_history'].format(**p) for p in visible_history_list]
240 | tokenized_corpus = [doc.split(" ") for doc in history_list]
241 | bm25 = BM25Okapi(tokenized_corpus)
242 |
243 | tokenized_query = prompt_template[args.task_name]["retrieval_query_wokey"].format(article).split(' ')
244 | retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=args.k)
245 |
246 | history_string = "".join(retrieved_history)
247 | prompt = history_string + "\n" + prompt
248 | full_prompt = history_string + "\n" + full_prompt
249 |
250 | if args.add_profile:
251 | prompt = profile + "\n" + prompt
252 | full_prompt = profile + "\n" + full_prompt
253 |
254 | train_data.append(
255 | {
256 | "prompt": prompt,
257 | "full_prompt": full_prompt
258 | }
259 | )
260 |
261 | print(train_data)
262 |
263 | train_dataset = Dataset.from_list(train_data)
264 | train_dataset = train_dataset.map(generate_and_tokenize_prompt).shuffle()
265 |
266 | trainer = transformers.Trainer(
267 | model=model,
268 | train_dataset=train_dataset,
269 | args=training_arguments,
270 | data_collator=transformers.DataCollatorForSeq2Seq(
271 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
272 | ),
273 | )
274 |
275 | for name, module in trainer.model.named_modules():
276 | if "norm" in name:
277 | module = module.to(torch.float32)
278 |
279 |
280 | model.config.use_cache = False # silence the warnings. Please re-enable for inference!
281 | trainer.train()
282 |
283 | if args.add_profile:
284 | output_name = "./ckpt/{}/k{}-{}-{}-profile-task_LoRA_ckpt".format(args.task_name, args.k, args.task_name, model_name.split('/')[-1])
285 | else:
286 | output_name = "./ckpt/{}/k{}-{}-{}-task_LoRA_ckpt".format(args.task_name, args.k, args.task_name, model_name.split('/')[-1])
287 |
288 | model.save_pretrained(output_name)
289 |
290 | model.eval()
291 | model.config.use_cache = True # silence the warnings. Please re-enable for inference!
292 |
293 | for i in tqdm(range(len(test_data))):
294 | if args.add_profile:
295 | profile = test_profile[i]['output']
296 |
297 | if k > 0:
298 | visible_history_list = test_data[i]['profile']
299 | for p in visible_history_list:
300 | for key, value in p.items():
301 | p[key] = get_first_k_tokens(p[key], 368)
302 |
303 | history_list = [prompt_template[args.task_name]['retrieval_history'].format(**p) for p in visible_history_list]
304 |
305 | tokenized_corpus = [doc.split(" ") for doc in history_list]
306 | bm25 = BM25Okapi(tokenized_corpus)
307 |
308 | test_question_list = []
309 | question_id_list = []
310 |
311 | for q in test_data[i]['query']:
312 |
313 | if args.task_name == 'citation':
314 | test_question = q['input']
315 | test_article = extract_citation_title(test_question)
316 | option1, option2 = extract_option(test_question, 1), extract_option(test_question, 2)
317 | test_prompt = prompt_template[args.task_name]['prompt'].format(test_article, option1, option2)
318 |
319 | else:
320 | test_question = q['input']
321 | test_article = extract_article(test_question)
322 | test_prompt = prompt_template[args.task_name]['prompt'].format(test_article)
323 |
324 | if k > 0:
325 | tokenized_query = prompt_template[args.task_name]['retrieval_query_wokey'].format(test_article).split(" ")
326 | retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=args.k)
327 |
328 | history_string = "".join(retrieved_history)
329 | test_prompt = history_string + "\n" + test_prompt
330 |
331 | if args.add_profile:
332 | test_prompt = profile + "\n" + test_prompt
333 |
334 | test_question_list.append(test_prompt)
335 | question_id_list.append(q['id'])
336 |
337 | test_batch_list = split_batch(test_question_list, 1)
338 | out_list = []
339 |
340 | with torch.no_grad():
341 | for batch_idx, batch in tqdm(enumerate(test_batch_list), total=len(test_batch_list)):
342 | # try:
343 | sentences = batch
344 | inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False)
345 | inputs = inputs.to(model.device)
346 |
347 | with torch.autocast(device_type="cuda"):
348 | outputs = model.generate(
349 | **inputs,
350 | do_sample=True,
351 | top_k=10,
352 | temperature=args.temperature,
353 | top_p=0.9,
354 | eos_token_id=tokenizer.eos_token_id,
355 | max_new_tokens=200
356 | )
357 |
358 | out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
359 | out_list += out_sentence
360 | # except:
361 | # out_list += ['']
362 |
363 | for i in range(len(out_list)):
364 | output = out_list[i].replace(test_question_list[i], '')
365 | pred_all.append({
366 | "id": question_id_list[i],
367 | "output": output
368 | })
369 |
370 | print(output)
371 |
372 | output_file = {
373 | 'task': name2taskid[args.task_name],
374 | 'golds': pred_all,
375 | 'model': model_name,
376 | }
377 |
378 | if args.add_profile:
379 | with open('./output/{}/output-task-k{}-{}-{}-profile.json'.format(args.k, args.task_name, args.task_name, model_name.split('/')[-1]), 'w') as f:
380 | json.dump(output_file, f, indent=4)
381 | else:
382 | with open('./output/{}/output-task-k{}-{}-{}.json'.format(args.k, args.task_name, args.task_name, model_name.split('/')[-1]), 'w') as f:
383 | json.dump(output_file, f, indent=4)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def extract_option(s, num):
4 | # Look for string after [1]: and between "
5 | match = re.search(r'\[' + str(num) + '\]: "([^"]*)"', s)
6 | return match.group(1) if match else None
7 |
8 | def extract_citation_title(text):
9 | pattern = r'written the paper with the title "([^"]*)"'
10 | match = re.search(pattern, text)
11 | if match:
12 | return match.group(1)
13 | else:
14 | return None
15 |
16 |
17 | def extract_movie(text):
18 | marker = "] description: "
19 | # Find the position of the marker in the text
20 | marker_pos = text.find(marker)
21 |
22 | # Check if the marker is found
23 | if marker_pos == -1:
24 | raise ValueError()
25 |
26 | # Extract the string after the marker
27 | extracted_string = text[marker_pos + len(marker):]
28 |
29 | return extracted_string
30 |
31 | def extract_news_cat(text):
32 | marker = "] article: "
33 | # Find the position of the marker in the text
34 | marker_pos = text.find(marker)
35 |
36 | # Check if the marker is found
37 | if marker_pos == -1:
38 | raise ValueError()
39 |
40 | # Extract the string after the marker
41 | extracted_string = text[marker_pos + len(marker):]
42 |
43 | return extracted_string
44 |
45 | def extract_news_headline(text):
46 | marker = "Generate a headline for the following article: "
47 | # Find the position of the marker in the text
48 | marker_pos = text.find(marker)
49 |
50 | # Check if the marker is found
51 | if marker_pos == -1:
52 | raise ValueError()
53 |
54 | # Extract the string after the marker
55 | extracted_string = text[marker_pos + len(marker):]
56 |
57 | return extracted_string
58 |
59 | def extract_product_review(text):
60 | marker = "without further explanation. review: "
61 | # Find the position of the marker in the text
62 | marker_pos = text.find(marker)
63 |
64 | # Check if the marker is found
65 | if marker_pos == -1:
66 | raise ValueError()
67 |
68 | # Extract the string after the marker
69 | extracted_string = text[marker_pos + len(marker):]
70 |
71 | return extracted_string
72 |
73 |
74 | def extract_scholarly_title(text):
75 | marker = "Generate a title for the following abstract of a paper: "
76 | # Find the position of the marker in the text
77 | marker_pos = text.find(marker)
78 |
79 | # Check if the marker is found
80 | if marker_pos == -1:
81 | raise ValueError()
82 |
83 | # Extract the string after the marker
84 | extracted_string = text[marker_pos + len(marker):]
85 |
86 | return extracted_string
87 |
88 |
89 | def extract_tweet_paraphrasing(text):
90 | marker = "Paraphrase the following tweet without any explanation before or after it: "
91 | # Find the position of the marker in the text
92 | marker_pos = text.find(marker)
93 |
94 | # Check if the marker is found
95 | if marker_pos == -1:
96 | raise ValueError()
97 |
98 | # Extract the string after the marker
99 | extracted_string = text[marker_pos + len(marker):]
100 |
101 | return extracted_string
102 |
103 | def get_first_k_tokens(text, k):
104 | """
105 | Extracts the first k tokens from a text string.
106 |
107 | :param text: The input text string.
108 | :param k: The number of tokens to extract.
109 | :return: The first k tokens of the text string.
110 | """
111 | # Split the text into tokens based on whitespace
112 | tokens = text.split()
113 | output = " ".join(tokens[:k])
114 |
115 | # Return the first k tokens
116 | return output
117 |
118 | def split_batch(init_list, batch_size):
119 | groups = zip(*(iter(init_list),) * batch_size)
120 | end_list = [list(i) for i in groups]
121 | count = len(init_list) % batch_size
122 | end_list.append(init_list[-count:]) if count != 0 else end_list
123 | return end_list
124 |
125 | def tokenize(prompt, add_eos_token=True):
126 | # there's probably a way to do this with the tokenizer settings
127 | # but again, gotta move fast
128 | result = tokenizer(
129 | prompt,
130 | truncation=True,
131 | max_length=cutoff_len,
132 | padding=False,
133 | return_tensors=None,
134 | )
135 | if (
136 | result["input_ids"][-1] != tokenizer.eos_token_id
137 | and len(result["input_ids"]) < cutoff_len
138 | and add_eos_token
139 | ):
140 | result["input_ids"].append(tokenizer.eos_token_id)
141 | result["attention_mask"].append(1)
142 |
143 | result["labels"] = result["input_ids"].copy()
144 |
145 | return result
146 |
147 |
148 | def print_trainable_parameters(model):
149 | """
150 | Prints the number of trainable parameters in the model.
151 | """
152 | trainable_params = 0
153 | all_param = 0
154 | for _, param in model.named_parameters():
155 | all_param += param.numel()
156 | if param.requires_grad:
157 | trainable_params += param.numel()
158 | print(
159 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
160 | )
161 |
162 |
163 |
164 |
165 | name2taskid = {
166 | "citation": "LaMP_1",
167 | "movie_tagging": "LaMP_2M",
168 | "news_categorize": "LaMP_2N",
169 | "news_headline": "LaMP_4",
170 | "product_rating": "LaMP_3",
171 | "scholarly_title": "LaMP_5",
172 | "tweet_paraphrase": "LaMP_7"
173 | }
--------------------------------------------------------------------------------