├── .DS_Store ├── OPPU.py ├── README.md ├── __pycache__ └── utils.cpython-310.pyc ├── asset ├── overview.png └── teaser.png ├── data └── .DS_Store ├── eval ├── __init__.py ├── __pycache__ │ ├── evaluation.cpython-310.pyc │ └── evaluation.cpython-39.pyc ├── eval_all.py ├── eval_task.py └── evaluation.py ├── gen_profile.py ├── prompt ├── prompt.json └── prompt_profile.json ├── requirements.txt ├── task_LoRA.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/.DS_Store -------------------------------------------------------------------------------- /OPPU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import bitsandbytes as bnb 4 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 5 | # from transformers import pipeline, BitsAndBytesConfig 6 | import argparse 7 | from rank_bm25 import BM25Okapi 8 | # from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 9 | import transformers 10 | from utils import split_batch, get_first_k_tokens, print_trainable_parameters, name2taskid 11 | from utils import extract_citation_title, extract_option, extract_movie, extract_news_cat, extract_news_headline, extract_product_review, extract_scholarly_title, extract_tweet_paraphrasing 12 | import json 13 | from tqdm import tqdm 14 | from peft import LoraConfig, get_peft_model, PeftModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Parser for LoRA") 18 | parser.add_argument('--model_name', type=str, default='meta-llama/Llama-2-7b-hf') 19 | parser.add_argument('--batch_size', type=int, default=16) 20 | parser.add_argument('--k', type=int, default=0) 21 | parser.add_argument('--max_step', type=int, default=5000) 22 | parser.add_argument('--cut_off', type=int, default=2048) 23 | parser.add_argument('--max_epoch', type=int, default=2) 24 | parser.add_argument('--temperature', type=float, default=0.1) 25 | parser.add_argument('--task_name', type=str, default='movie_tagging') 26 | parser.add_argument('--add_profile', action='store_true') 27 | parser.add_argument('--task_lora', type=str, default='./ckpt/movie_tagging/k1-movie_tagging-Llama-2-7b-hf-task_LoRA_ckpt') 28 | parser.add_argument('--access_token', type=str, default=None) 29 | 30 | args = parser.parse_args() 31 | model_name = args.model_name 32 | task_name = args.task_name 33 | batch_size = args.batch_size 34 | k = args.k 35 | # max_step = args.max_step 36 | cutoff_len = args.cut_off 37 | add_eos_token = False 38 | max_epoch = args.max_epoch 39 | 40 | # # 4 bit quantization inference 41 | # bnb_config = BitsAndBytesConfig( 42 | # load_in_4bit=True, 43 | # bnb_4bit_quant_type="nf4", 44 | # bnb_4bit_compute_dtype=torch.float16, 45 | # bnb_4bit_use_double_quant=True, 46 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 47 | # ) 48 | 49 | # 8-bit quantization inference 50 | # bnb_config = BitsAndBytesConfig( 51 | # load_in_8bit=True, 52 | # bnb_8bit_quant_type="nf8", 53 | # bnb_8bit_compute_dtype=torch.float16, 54 | # bnb_8bit_use_double_quant=True, 55 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 56 | # ) 57 | 58 | # 16-bit quantization inference 59 | # bnb_config = BitsAndBytesConfig( 60 | # load_in_16bit=True, 61 | # bnb_16bit_quant_type="bf16", 62 | # bnb_16bit_compute_dtype=torch.bfloat16, 63 | # bnb_16bit_use_double_quant=True, 64 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 65 | # ) 66 | 67 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=args.access_token) 68 | tokenizer.eos_token = "" 69 | tokenizer.pad_token = '[PAD]' 70 | # tokenizer.pad_token = tokenizer.eos_token 71 | tokenizer.pad_token_id = tokenizer.eos_token_id 72 | 73 | 74 | base_model = AutoModelForCausalLM.from_pretrained( 75 | model_name, 76 | # quantization_config=bnb_config, 77 | local_files_only=False, 78 | device_map='auto', 79 | trust_remote_code=True, 80 | torch_dtype=torch.bfloat16 81 | ) 82 | 83 | base_model.config.use_cache = False 84 | base_model.config.pad_token_id = tokenizer.pad_token_id 85 | base_model.config.eos_token_id = tokenizer.eos_token_id 86 | base_model.config.bos_token_id = tokenizer.bos_token_id 87 | 88 | 89 | from peft import prepare_model_for_kbit_training 90 | 91 | base_model.gradient_checkpointing_enable() 92 | base_model = prepare_model_for_kbit_training(base_model) 93 | 94 | 95 | 96 | from peft import LoraConfig, get_peft_model 97 | 98 | peft_config = LoraConfig( 99 | r=8, 100 | lora_alpha=8, 101 | target_modules=["q_proj", "v_proj"], # , "k_proj", "out_proj" 102 | lora_dropout=0.05, 103 | bias="none", 104 | task_type="CAUSAL_LM", 105 | ) 106 | 107 | training_arguments = transformers.TrainingArguments( 108 | output_dir='outputs/', 109 | per_device_train_batch_size=batch_size, 110 | gradient_accumulation_steps=1, 111 | optim='adamw_torch', 112 | num_train_epochs=max_epoch, 113 | save_steps=1e9, 114 | logging_steps=50, 115 | learning_rate=1e-4, 116 | weight_decay=1e-2, 117 | bf16=True, 118 | max_grad_norm=0.3, 119 | # max_steps=max_step, 120 | warmup_ratio=0.1, 121 | group_by_length=True, 122 | lr_scheduler_type='linear', 123 | report_to='none', 124 | ) 125 | 126 | 127 | with open(f"./data/{task_name}/user_top_100_history.json", 'r') as f: 128 | test_data = json.load(f) 129 | 130 | format_flag = False 131 | if args.task_name == "movie_tagging": 132 | extract_article = extract_movie 133 | format_flag = True 134 | elif args.task_name == "news_categorize": 135 | extract_article = extract_news_cat 136 | format_flag = True 137 | elif args.task_name == "news_headline": 138 | extract_article = extract_news_headline 139 | format_flag = True 140 | elif args.task_name == "product_rating": 141 | extract_article = extrat_product_review 142 | format_flag = True 143 | elif args.task_name == "scholarly_title": 144 | extract_article = extract_scholarly_title 145 | format_flag = True 146 | elif args.task_name == "tweet_paraphrase": 147 | extract_article = extrat_tweet_paraphrasing 148 | 149 | 150 | with open('./prompt/prompt.json', 'r') as f: 151 | prompt_template = json.load(f) 152 | 153 | 154 | if args.add_profile: 155 | with open(f'./data/{task_name}/profile_user_100.json', 'r') as f: 156 | test_profile = json.load(f) 157 | 158 | 159 | def tokenize(prompt, add_eos_token=True): 160 | # there's probably a way to do this with the tokenizer settings 161 | # but again, gotta move fast 162 | result = tokenizer( 163 | prompt, 164 | truncation=True, 165 | max_length=cutoff_len, 166 | padding=False, 167 | return_tensors=None, 168 | ) 169 | if ( 170 | result["input_ids"][-1] != tokenizer.eos_token_id 171 | and len(result["input_ids"]) < cutoff_len 172 | and add_eos_token 173 | ): 174 | result["input_ids"].append(tokenizer.eos_token_id) 175 | result["attention_mask"].append(1) 176 | 177 | result["labels"] = result["input_ids"].copy() 178 | 179 | return result 180 | 181 | 182 | def generate_and_tokenize_prompt(data_point): 183 | full_prompt = data_point['full_prompt'] 184 | tokenized_full_prompt = tokenize(full_prompt) 185 | # if not train_on_inputs: 186 | user_prompt = data_point['prompt'] 187 | 188 | tokenized_user_prompt = tokenize( 189 | user_prompt, add_eos_token=add_eos_token 190 | ) 191 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 192 | 193 | if add_eos_token: 194 | user_prompt_len -= 1 195 | 196 | tokenized_full_prompt["labels"] = [ 197 | -100 198 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 199 | user_prompt_len: 200 | ] # could be sped up, probably 201 | return tokenized_full_prompt 202 | 203 | 204 | 205 | # training 206 | from datasets import load_dataset, Dataset 207 | model = PeftModel.from_pretrained(model=base_model, model_id=args.task_lora, is_trainable=False) 208 | base_model = model.merge_and_unload() 209 | print_trainable_parameters(model) 210 | 211 | 212 | pred_all = [] 213 | actual = [] 214 | train_data = [] 215 | 216 | for i in tqdm(range(len(test_data))): 217 | model = get_peft_model(base_model, peft_config) 218 | print_trainable_parameters(model) 219 | 220 | if args.add_profile: 221 | profile = test_profile[i]['output'] 222 | 223 | for idx, q in enumerate(test_data[i]['profile']): 224 | for key, value in q.items(): 225 | q[key] = get_first_k_tokens(q[key], 768) 226 | 227 | prompt = prompt_template[args.task_name]['OPPU_input'].format(**q) 228 | full_prompt = prompt_template[args.task_name]['OPPU_full'].format(**q) 229 | 230 | if k > 0 and idx != 0 and format_flag==True: 231 | visible_history_list = test_data[i]['profile'][:idx] 232 | 233 | for p in visible_history_list: 234 | for key, value in p.items(): 235 | p[key] = get_first_k_tokens(p[key], 768) 236 | 237 | history_list = [prompt_template[args.task_name]['retrieval_history'].format(**p) for p in visible_history_list] 238 | tokenized_corpus = [doc.split(" ") for doc in history_list] 239 | bm25 = BM25Okapi(tokenized_corpus) 240 | 241 | tokenized_query = prompt_template[args.task_name]["retrieval_query"].format(**q).split(' ') 242 | retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=args.k) 243 | 244 | history_string = "".join(retrieved_history) 245 | prompt = history_string + "\n" + prompt 246 | full_prompt = history_string + "\n" + full_prompt 247 | 248 | if args.add_profile and format_flag == True: 249 | prompt = profile + "\n" + prompt 250 | full_prompt = profile + "\n" + full_prompt 251 | 252 | train_data.append( 253 | { 254 | "prompt": prompt, 255 | "full_prompt": full_prompt 256 | } 257 | ) 258 | 259 | # print(train_data) 260 | 261 | train_dataset = Dataset.from_list(train_data) 262 | train_dataset = train_dataset.map(generate_and_tokenize_prompt).shuffle() 263 | 264 | trainer = transformers.Trainer( 265 | model=model, 266 | train_dataset=train_dataset, 267 | args=training_arguments, 268 | data_collator=transformers.DataCollatorForSeq2Seq( 269 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 270 | ), 271 | ) 272 | 273 | for name, module in trainer.model.named_modules(): 274 | if "norm" in name: 275 | module = module.to(torch.float32) 276 | 277 | 278 | model.config.use_cache = False # silence the warnings. Please re-enable for inference! 279 | trainer.train() 280 | 281 | output_name = "./ckpt/{}/k{}-{}-{}-OPPU_LoRA-{}".format(args.task_name, args.k, args.task_name, model_name.split('/')[-1], i) 282 | model.save_pretrained(output_name) 283 | 284 | model.eval() 285 | model.config.use_cache = True # silence the warnings. Please re-enable for inference! 286 | 287 | # test inference 288 | if args.add_profile: 289 | profile = test_profile[i]['output'] 290 | 291 | if k > 0: 292 | visible_history_list = test_data[i]['profile'] 293 | for p in visible_history_list: 294 | for key, value in p.items(): 295 | p[key] = get_first_k_tokens(p[key], 368) 296 | 297 | history_list = [prompt_template[args.task_name]['retrieval_history'].format(**p) for p in visible_history_list] 298 | 299 | tokenized_corpus = [doc.split(" ") for doc in history_list] 300 | bm25 = BM25Okapi(tokenized_corpus) 301 | 302 | test_question_list = [] 303 | question_id_list = [] 304 | 305 | for q in test_data[i]['query']: 306 | 307 | if args.task_name == 'citation': 308 | test_question = q['input'] 309 | test_article = extract_citation_title(test_question) 310 | option1, option2 = extract_option(test_question, 1), extract_option(test_question, 2) 311 | test_prompt = prompt_template[args.task_name]['prompt'].format(test_article, option1, option2) 312 | 313 | else: 314 | test_question = q['input'] 315 | test_article = extract_article(test_question) 316 | test_prompt = prompt_template[args.task_name]['prompt'].format(test_article) 317 | 318 | if k > 0: 319 | tokenized_query = prompt_template[args.task_name]['retrieval_query_wokey'].format(test_article).split(" ") 320 | retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=args.k) 321 | 322 | history_string = "".join(retrieved_history) 323 | test_prompt = history_string + "\n" + test_prompt 324 | 325 | if args.add_profile: 326 | test_prompt = profile + "\n" + test_prompt 327 | 328 | test_question_list.append(test_prompt) 329 | question_id_list.append(q['id']) 330 | 331 | test_batch_list = split_batch(test_question_list, 1) 332 | out_list = [] 333 | 334 | with torch.no_grad(): 335 | for batch_idx, batch in tqdm(enumerate(test_batch_list), total=len(test_batch_list)): 336 | # try: 337 | sentences = batch 338 | inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False) 339 | inputs = inputs.to(model.device) 340 | 341 | with torch.autocast(device_type="cuda"): 342 | outputs = model.generate( 343 | **inputs, 344 | do_sample=True, 345 | top_k=10, 346 | temperature=args.temperature, 347 | top_p=0.9, 348 | eos_token_id=tokenizer.eos_token_id, 349 | max_new_tokens=200 350 | ) 351 | 352 | out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) 353 | out_list += out_sentence 354 | # except: 355 | # out_list += [''] 356 | 357 | for i in range(len(out_list)): 358 | output = out_list[i].replace(test_question_list[i], '') 359 | pred_all.append({ 360 | "id": question_id_list[i], 361 | "output": output 362 | }) 363 | 364 | print(output) 365 | 366 | output_file = { 367 | 'task': name2taskid[args.task_name], 368 | 'golds': pred_all, 369 | 'model': model_name, 370 | } 371 | 372 | if args.add_profile: 373 | with open('./output/{}/output-OPPU-k{}-{}-{}-profile.json'.format(args.k, args.task_name, args.task_name, model_name.split('/')[-1]), 'w') as f: 374 | json.dump(output_file, f, indent=4) 375 | else: 376 | with open('./output/{}/output-OPPU-k{}-{}-{}.json'.format(args.k, args.task_name, args.task_name, model_name.split('/')[-1]), 'w') as f: 377 | json.dump(output_file, f, indent=4) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Democratizing Large Language Models via Personalized Parameter-Efficient Fine-tuning 2 | 3 | 4 | This is source code of our EMNLP 2024 paper 5 | 6 | [**Democratizing Large Language Models via Personalized Parameter-Efficient Fine-tuning**](https://arxiv.org/abs/2402.04401). 7 | 8 | by 9 | [Zhaoxuan Tan](https://zhaoxuan.info/), 10 | [Qingkai Zeng](https://qingkaizeng.github.io/), 11 | [Yijun Tian](http://tianyijun.com/), 12 | [Zheyuan Liu](https://franciscoliu.github.io/), 13 | [Bing Yin](https://scholar.google.com/citations?user=qSOxydEAAAAJ&hl=en), 14 | [Meng Jiang](http://www.meng-jiang.com/). 15 | 16 | 17 | 18 | ## Overview ## 19 | 20 | * **Ownership**: Existing methods are processed centralized, where user history is encoded in a personalized prompt and processed by centralized LLMs. This paradigm limits the model's customization and ability to provide deep, personalized experiences tailored to individual users. Moreover, when using a centralized model, users often have to share personal data with the service provider, which raises concerns about how user data are stored, used, and protected. 21 | 22 | * **Behavior Pattern Generalization**: As is revealed by existing research, LLMs can be easily distracted by irrelevant context information that retrieval can hardly avoid. In LLM personalization, where the retrieval corpus is confined to a specific user's behaviors, retrieval augmentation might underperform, especially when the user's past behaviors do not closely mirror the patterns needed for the query at hand. 23 | 24 |
25 | 26 |
27 | 28 | Personalization in large language models (LLMs) is increasingly important, aiming to align the LLMs' interactions, content, and recommendations with individual user preferences. Recent advances have highlighted effective prompt design by enriching user queries with non-parametric knowledge through behavior history retrieval and textual profiles. However, these methods faced limitations due to a lack of model ownership, resulting in constrained customization and privacy issues, and often failed to capture complex, dynamic user behavior patterns. To address these shortcomings, we introduce One PEFT Per User (OPPU), employing personalized parameter-efficient fine-tuning (PEFT) modules to store user-specific behavior patterns and preferences. By plugging in personal PEFT parameters, users can own and use their LLMs individually. OPPU integrates parametric user knowledge in the personal PEFT parameters with non-parametric knowledge from retrieval and profiles, adapting LLMs to user behavior shifts. Experimental results demonstrate that OPPU significantly outperforms existing prompt-based methods across seven diverse tasks in the LaMP benchmark. Further studies reveal OPPU's enhanced capabilities in handling user behavior shifts, modeling users at different activity levels, maintaining robustness across various user history formats, and displaying versatility with different PEFT methods. 29 | 30 |
31 | 32 |
33 | 34 | ## Dataset ## 35 | 36 | We use publicly available data from the [LaMP](https://arxiv.org/abs/2304.11406) benchmark. You can download the our processed data [here](https://drive.google.com/file/d/1bJ3Rh_sqrw3suwwweFbra5CTV7GVjgxF/view?usp=sharing), unzip it, and place it under the ```./data``` folder 37 | 38 | 39 | ## Installation ## 40 | Please install the dependencies via conda, using the following command: 41 | 42 | ```bash 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ## Experiment ## 47 | ```task_name``` can be selected from ```[citation, movie_tagging, news_categorize, news_headline, product_rating, scholarly_title, tweet_paraphrase]```. Here, we take ```movie_tagging``` as an example. 48 | 49 | ### OPPU 50 | #### 1. Base LLM Task Adaption 51 | 52 | ```bash 53 | CUDA_VISIBLE_DEVICES=0 python task_LoRA.py --k 0 --task_name movie_tagging 54 | ``` 55 | 56 | #### 2. Train One PEFT Per User 57 | ```bash 58 | CUDA_VISIBLE_DEVICES=0 python OPPU.py --k 0 --task_name movie_tagging --task_lora ./ckpt/movie_tagging/k0-movie_tagging-Llama-2-7b-hf-task_LoRA_ckpt 59 | ``` 60 | 61 | ### OPPU + RAG 62 | 63 | #### 1. Base LLM Task Adaption 64 | 65 | ```bash 66 | CUDA_VISIBLE_DEVICES=0 python task_LoRA.py --k 1 --task_name movie_tagging 67 | ``` 68 | 69 | #### 2. Train One PEFT Per User 70 | ```bash 71 | CUDA_VISIBLE_DEVICES=0 python OPPU.py --k 1 --task_name movie_tagging --task_lora ./ckpt/movie_tagging/k1-movie_tagging-Llama-2-7b-hf-task_LoRA_ckpt 72 | ``` 73 | ---- 74 | 75 | ### OPPU + PAG 76 | #### 1. Base LLM Task Adaption 77 | 78 | ```bash 79 | CUDA_VISIBLE_DEVICES=0 python task_LoRA.py --k 1 --task_name movie_tagging --add_profile 80 | ``` 81 | 82 | #### 2. Train One PEFT Per User 83 | ```bash 84 | CUDA_VISIBLE_DEVICES=0 python OPPU.py --k 1 --task_name movie_tagging --task_lora ./ckpt/movie_tagging/k1-movie_tagging-Llama-2-7b-hf-profile-task_LoRA_ckpt --add_profile 85 | ``` 86 | 87 | ## Evaluation ## 88 | ```TASK_ID``` is the corresponding ID selected from ```["LaMP_1", "LaMP_2N", "LaMP_2M", "LaMP_3", "LaMP_4", "LaMP_5", "LaMP_7"]``` 89 | 90 | ```bash 91 | python ./eval/eval_task.py \ 92 | --golds_json {PATH_TO_LABEL_JSON_FILE} \ 93 | --preds_json {PATH_TO_PREDICTION_JSON_FILE} \ 94 | --task_name {TASK_ID} \ 95 | --output_file {RESULT_JSON_PATH} 96 | ``` 97 | 98 | ## Citation ## 99 | If you find this paper or codebase useful in your research, please kindly cite the following paper. 100 | 101 | ```bibtex 102 | @article{tan2024democratizing, 103 | title={Democratizing Large Language Models via Personalized Parameter-Efficient Fine-tuning}, 104 | author={Tan, Zhaoxuan and Zeng, Qingkai and Tian, Yijun and Liu, Zheyuan and Yin, Bing and Jiang, Meng}, 105 | journal={arXiv preprint arXiv:2402.04401}, 106 | year={2024} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /asset/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/asset/overview.png -------------------------------------------------------------------------------- /asset/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/asset/teaser.png -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/data/.DS_Store -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluation import LaMPEvaluation -------------------------------------------------------------------------------- /eval/__pycache__/evaluation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/eval/__pycache__/evaluation.cpython-310.pyc -------------------------------------------------------------------------------- /eval/__pycache__/evaluation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TamSiuhin/OPPU/9fb3f465fab408da34915e60db12cc9cd43c2223/eval/__pycache__/evaluation.cpython-39.pyc -------------------------------------------------------------------------------- /eval/eval_all.py: -------------------------------------------------------------------------------- 1 | from evaluation import LaMPEvaluation 2 | import argparse 3 | import json 4 | 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument("--golds_zip", required=True, help="Address to all gold labels for all tasks zipped in a file") 8 | parser.add_argument("--preds_zip", required=True, help="Address to all predictions for all tasks zipped in a file") 9 | parser.add_argument("--temp_dir", required=False, help="Address to a temp dir for extracting files", default="./tmp") 10 | parser.add_argument("--output_file", required=True, help="Address to the results file") 11 | 12 | if __name__ == "__main__": 13 | 14 | opts = parser.parse_args() 15 | 16 | evaluator = LaMPEvaluation(all_golds_zip_file_addr=opts.golds_zip, extract_addr=opts.temp_dir) 17 | results = evaluator.evaluate_all(opts.preds_zip) 18 | with open(opts.output_file, "w") as file: 19 | json.dump(results, file) 20 | -------------------------------------------------------------------------------- /eval/eval_task.py: -------------------------------------------------------------------------------- 1 | from evaluation import LaMPEvaluation 2 | import argparse 3 | import json 4 | import warnings 5 | 6 | warnings.filterwarnings('ignore') 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--golds_json", required=True, help="Address to all gold labels for the task as a json file") 11 | parser.add_argument("--preds_json", required=True, help="Address to all predictions for the task as a json file") 12 | parser.add_argument("--task_name", required=True, help="[LaMP_1, LaMP_2, LaMP_3, LaMP_4, LaMP_5, LaMP_6, LaMP_7]") 13 | parser.add_argument("--output_file", required=True, help="Address to the results file") 14 | 15 | if __name__ == "__main__": 16 | 17 | opts = parser.parse_args() 18 | 19 | evaluator = LaMPEvaluation(single_gold_json_file_addr=opts.golds_json) 20 | results = evaluator.evaluate_task(opts.preds_json, opts.task_name) 21 | with open(opts.output_file, "w") as file: 22 | json.dump(results, file) 23 | -------------------------------------------------------------------------------- /eval/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import zipfile 3 | import glob 4 | import os 5 | import shutil 6 | 7 | import os 8 | 9 | os.environ['HF_EVALUATE_OFFLINE'] = '1' 10 | 11 | import evaluate 12 | 13 | def postprocess_text_classification(preds, labels): 14 | preds = [str(pred).strip() for pred in preds] 15 | labels = [str(label).strip() for label in labels] 16 | return preds, labels 17 | 18 | def postprocess_text_generation(preds, labels): 19 | preds = [pred.strip() for pred in preds] 20 | labels = [[label.strip()] for label in labels] 21 | 22 | return preds, labels 23 | 24 | def create_metric_f1_accuracy(all_labels): 25 | f1_metric = evaluate.load("f1") 26 | accuracy_metric = evaluate.load("accuracy") 27 | def create_mapping(x): 28 | try: 29 | return all_labels.index(x) 30 | except: 31 | return -1 32 | def compute_metrics(decoded_preds, decoded_labels): 33 | decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels) 34 | decoded_preds = [create_mapping(x) for x in decoded_preds] 35 | decoded_labels = [create_mapping(x) for x in decoded_labels] 36 | result_acc = accuracy_metric.compute(predictions=decoded_preds, references=decoded_labels) 37 | result_f1 = f1_metric.compute(predictions=decoded_preds, references=decoded_labels, labels=list(range(len(all_labels))), average = "macro") 38 | result = {"accuracy" : result_acc["accuracy"], "f1" : result_f1["f1"]} 39 | return result 40 | return compute_metrics 41 | 42 | def create_metric_mae_rmse(): 43 | mse_metric = evaluate.load("mse") 44 | mae_metric = evaluate.load("mae") 45 | def create_mapping(x, y): 46 | try: 47 | return float(x) 48 | except: 49 | print(x) 50 | y = float(y) 51 | if abs(1 - y) > abs(5 - y): 52 | return 1.0 53 | else: 54 | return 5.0 55 | def compute_metrics(decoded_preds, decoded_labels): 56 | decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels) 57 | decoded_preds = [create_mapping(x,y) for x,y in zip(decoded_preds, decoded_labels)] 58 | decoded_labels = [create_mapping(x,x) for x in decoded_labels] 59 | result_mae = mae_metric.compute(predictions=decoded_preds, references=decoded_labels) 60 | result_rmse = mse_metric.compute(predictions=decoded_preds, references=decoded_labels, squared = False) 61 | result = {"MAE" : result_mae["mae"], "RMSE" : result_rmse["mse"]} 62 | return result 63 | return compute_metrics 64 | 65 | def create_metric_rouge(): 66 | rouge_metric = evaluate.load('rouge') 67 | def compute_metrics(decoded_preds, decoded_labels): 68 | decoded_preds, decoded_labels = postprocess_text_generation(decoded_preds, decoded_labels) 69 | result_rouge = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels) 70 | result = {"rouge-1" : result_rouge["rouge1"], "rouge-L" : result_rouge["rougeL"]} 71 | return result 72 | return compute_metrics 73 | 74 | class LaMPEvaluation(object): 75 | 76 | def __init__(self, all_golds_zip_file_addr = None, single_gold_json_file_addr = None, extract_addr = "./tmp") -> None: 77 | assert all_golds_zip_file_addr or single_gold_json_file_addr, "The golds should be provided for all datasets or at least one." 78 | assert not (all_golds_zip_file_addr and single_gold_json_file_addr), "The golds should be provided using zip file or json file not both." 79 | self.tasks_golds = dict() 80 | self.extract_addr = extract_addr 81 | self.evaluate_all_is_possible = False 82 | if all_golds_zip_file_addr: 83 | os.makedirs(self.extract_addr, exist_ok=True) 84 | with zipfile.ZipFile(all_golds_zip_file_addr, 'r') as zobj: 85 | zobj.extractall(path = extract_addr) 86 | for file_addr in glob.glob(os.path.join(self.extract_addr, "**/*.json"), recursive=True): 87 | with open(file_addr) as file: 88 | task = json.load(file) 89 | self.tasks_golds[task['task']] = task['golds'] 90 | self._empty_dir(self.extract_addr) 91 | self.evaluate_all_is_possible = True 92 | if single_gold_json_file_addr: 93 | with open(single_gold_json_file_addr) as file: 94 | task = json.load(file) 95 | self.tasks_golds[task['task']] = task['golds'] 96 | 97 | def _empty_dir(self, directory_path): 98 | for filename in os.listdir(directory_path): 99 | file_path = os.path.join(directory_path, filename) 100 | try: 101 | if os.path.isfile(file_path): 102 | os.unlink(file_path) 103 | elif os.path.isdir(file_path): 104 | shutil.rmtree(file_path) 105 | except Exception as e: 106 | print(f'Failed to delete {file_path}. Reason: {e}') 107 | 108 | def _get_all_gold_ids(self, task_name): 109 | return set([sample['id'] for sample in self.tasks_golds[task_name]]) 110 | 111 | def _get_all_ids(self, input): 112 | return set([sample['id'] for sample in input]) 113 | 114 | def evaluate_all(self, predicts_zipfile_addr): 115 | assert self.evaluate_all_is_possible, "You did not provide golds for all tasks." 116 | with zipfile.ZipFile(predicts_zipfile_addr, 'r') as zobj: 117 | zobj.extractall(path = self.extract_addr) 118 | results_raw = dict() 119 | all_task_names = set() 120 | for file_addr in glob.glob(os.path.join(self.extract_addr, "**/*.json"), recursive=True): 121 | with open(file_addr) as file: 122 | preds = json.load(file) 123 | all_task_names.add(preds['task']) 124 | results_raw[preds['task']] = self._evaluate_task(preds['golds'], preds['task']) 125 | self._empty_dir(self.extract_addr) 126 | assert len(all_task_names) == 7, "The provided results do not cover all the tasks in the benchmark." 127 | return results_raw 128 | 129 | def evaluate_task(self, predicts_json_addr, task_name): 130 | with open(predicts_json_addr) as file: 131 | preds = json.load(file) 132 | assert preds['task'] == task_name, "The provided task_name and the results do not match." 133 | assert preds['task'] in self.tasks_golds.keys(), "The provided golds cannot be used to evaluate this task." 134 | return self._evaluate_task(preds['golds'], task_name) 135 | 136 | def _evaluate_task(self, predictions, task_name): 137 | golds_dict = {y['id']:y['output'] for y in self.tasks_golds[task_name]} 138 | preds_dict = {x['id']:x['output'] for x in predictions} 139 | 140 | gold_ids = self._get_all_gold_ids(task_name) 141 | pred_ids = self._get_all_ids(predictions) 142 | 143 | assert gold_ids == pred_ids, "Predictions ids and gold ids do not match. {}".format(gold_ids-pred_ids) 144 | 145 | if task_name in ["LaMP_1", "LaMP_2N", "LaMP_2M"]: 146 | metric = create_metric_f1_accuracy(self._get_labels(task_name)) 147 | elif task_name == "LaMP_3": 148 | metric = create_metric_mae_rmse() 149 | else: 150 | metric = create_metric_rouge() 151 | 152 | gold_ids = list(gold_ids) 153 | golds = [golds_dict[id] for id in gold_ids] 154 | preds = [preds_dict[id] for id in gold_ids] 155 | return metric(preds, golds) 156 | 157 | def _get_labels(self, task_name): 158 | if task_name == "LaMP_1": 159 | return ["[1]", "[2]"] 160 | elif task_name == "LaMP_2N": 161 | return ["food & drink", "sports", "education", "parents", "religion", "travel", "business", "crime", "science & technology", "culture & arts", "entertainment", "politics", "women", "style & beauty", "healthy living"] 162 | elif task_name == "LaMP_2M": 163 | return ["sci-fi", "based on a book", "comedy", "action", "twist ending", "dystopia", "dark comedy", "classic", "psychology", "fantasy", "romance", "thought-provoking", "social commentary", "violence", "true story"] 164 | elif task_name == "LaMP_3": 165 | return ["1", "2", "3", "4", "5"] 166 | else: 167 | raise ValueError("Invalid task_name") -------------------------------------------------------------------------------- /gen_profile.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import bitsandbytes as bnb 5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 6 | from transformers import pipeline, BitsAndBytesConfig 7 | import argparse 8 | from rank_bm25 import BM25Okapi 9 | # from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 10 | import transformers 11 | import json 12 | from utils import split_batch, get_first_k_tokens, print_trainable_parameters, name2taskid 13 | from utils import extract_citation_title, extract_option, extract_movie, extract_news_cat, extract_news_headline, extract_product_review, extract_scholarly_title, extract_tweet_paraphrasing 14 | 15 | parser = argparse.ArgumentParser(description="Parser for LoRA") 16 | parser.add_argument('--model_name', type=str, default='lmsys/vicuna-7b-v1.5') 17 | parser.add_argument('--batch_size', type=int, default=3) 18 | parser.add_argument('--k', type=int, default=10) 19 | parser.add_argument('--cut_off', type=int, default=2048) 20 | parser.add_argument('--task_name', type=str, default='movie_tagging') 21 | 22 | args = parser.parse_args() 23 | model_name = args.model_name 24 | batch_size = args.batch_size 25 | k = args.k 26 | cutoff_len = args.cut_off 27 | add_eos_token = False 28 | 29 | 30 | # # 4 bit quantization inference 31 | # bnb_config = BitsAndBytesConfig( 32 | # load_in_4bit=True, 33 | # bnb_4bit_quant_type="nf4", 34 | # bnb_4bit_compute_dtype=torch.float16, 35 | # bnb_4bit_use_double_quant=True, 36 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 37 | # ) 38 | 39 | # 8-bit quantization inference 40 | # bnb_config = BitsAndBytesConfig( 41 | # load_in_8bit=True, 42 | # bnb_8bit_quant_type="nf8", 43 | # bnb_8bit_compute_dtype=torch.float16, 44 | # bnb_8bit_use_double_quant=True, 45 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 46 | # ) 47 | 48 | # 16-bit quantization inference 49 | # bnb_config = BitsAndBytesConfig( 50 | # load_in_16bit=True, 51 | # bnb_16bit_quant_type="bf16", 52 | # bnb_16bit_compute_dtype=torch.bfloat16, 53 | # bnb_16bit_use_double_quant=True, 54 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 55 | # ) 56 | 57 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") 58 | tokenizer.eos_token = "" 59 | tokenizer.pad_token = '[PAD]' 60 | 61 | base_model = AutoModelForCausalLM.from_pretrained( 62 | model_name, 63 | # quantization_config=bnb_config, 64 | local_files_only=False, 65 | device_map='auto', 66 | trust_remote_code=True, 67 | torch_dtype=torch.float16 68 | ) 69 | 70 | base_model.config.use_cache = True 71 | base_model.config.pad_token_id = tokenizer.pad_token_id 72 | base_model.config.eos_token_id = tokenizer.eos_token_id 73 | base_model.config.bos_token_id = tokenizer.bos_token_id 74 | 75 | 76 | with open(f"./data/{args.task_name}/user_others.json", 'r') as f: 77 | train_data = json.load(f) 78 | 79 | with open(f"./data/{args.task_name}/user_top_100_history.json", 'r') as f: 80 | test_data = json.load(f) 81 | 82 | with open('./prompt/prompt_profile.json', 'r') as f: 83 | prompt_template = json.load(f) 84 | 85 | from tqdm import tqdm 86 | import random 87 | 88 | K = args.k 89 | 90 | prompt_list_others = [] 91 | userid_list_others = [] 92 | 93 | for user in tqdm(train_data): 94 | 95 | history_list = [] 96 | 97 | if len(user['profile'])> K: 98 | profiles = random.sample(user['profile'], K) 99 | else: 100 | profiles = user['profile'] 101 | 102 | 103 | for p in profiles: 104 | for key, value in p.items(): 105 | p[key] = get_first_k_tokens(p[key], 200) 106 | 107 | for p in profiles: 108 | history_list.append(prompt_template[args.task_name]['retrieval_history'].format(**p)) 109 | 110 | history_string = ' | '.join(history_list) 111 | 112 | test_prompt = prompt_template[args.task_name]["profile_prompt"].format(history_list) 113 | 114 | prompt_list_others.append(test_prompt) 115 | userid_list_others.append(user['user_id']) 116 | 117 | 118 | prompt_list_100 = [] 119 | userid_list_100 = [] 120 | 121 | for user in tqdm(test_data): 122 | 123 | history_list = [] 124 | 125 | if len(user['profile'])> K: 126 | profiles = random.sample(user['profile'], K) 127 | else: 128 | profiles = user['profile'] 129 | 130 | for p in profiles: 131 | for key, value in p.items(): 132 | p[key] = get_first_k_tokens(p[key], 200) 133 | 134 | for p in profiles: 135 | history_list.append(prompt_template[args.task_name]['retrieval_history'].format(**p)) 136 | 137 | history_string = ' | '.join(history_list) 138 | 139 | test_prompt = prompt_template[args.task_name]['profile_prompt'].format(history_list) 140 | prompt_list_100.append(test_prompt) 141 | userid_list_100.append(user['user_id']) 142 | 143 | batched_prompt_others = split_batch(prompt_list_others, batch_size) 144 | out_list_others = [] 145 | 146 | print(len(prompt_list_others)) 147 | print(len(batched_prompt_others)) 148 | 149 | with torch.no_grad(): 150 | for batch_idx, batch in tqdm(enumerate(batched_prompt_others), total=len(batched_prompt_others)): 151 | # try: 152 | sentences = batch 153 | inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False) 154 | inputs = inputs.to(base_model.device) 155 | 156 | with torch.autocast(device_type="cuda"): 157 | outputs = base_model.generate( 158 | **inputs, 159 | do_sample=True, 160 | top_k=10, 161 | temperature=0.6, 162 | top_p=0.9, 163 | eos_token_id=tokenizer.eos_token_id, 164 | max_new_tokens=300, 165 | ) 166 | 167 | out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) 168 | out_list_others += out_sentence 169 | # except: 170 | # out_list_others += [''] 171 | 172 | pred_all_others = [] 173 | 174 | for i in range(len(out_list_others)): 175 | output = out_list_others[i].replace(prompt_list_others[i], '') 176 | pred_all_others.append({ 177 | "id": userid_list_others[i], 178 | "output": output 179 | }) 180 | 181 | print(output) 182 | 183 | 184 | with open(f'./data/{args.task_name}/profile_user_others.json', 'w') as f: 185 | json.dump(pred_all_others, f) 186 | 187 | 188 | batched_prompt_100 = split_batch(prompt_list_100, batch_size) 189 | out_list_100 = [] 190 | 191 | with torch.no_grad(): 192 | for batch_idx, batch in tqdm(enumerate(batched_prompt_100), total=len(batched_prompt_100)): 193 | sentences = batch 194 | inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False) 195 | inputs = inputs.to(base_model.device) 196 | 197 | with torch.autocast(device_type="cuda"): 198 | outputs = base_model.generate( 199 | **inputs, 200 | do_sample=True, 201 | top_k=10, 202 | temperature=0.6, 203 | top_p=0.9, 204 | eos_token_id=tokenizer.eos_token_id, 205 | max_new_tokens=200, 206 | ) 207 | 208 | out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) 209 | out_list_100 += out_sentence 210 | # except: 211 | # out_list_100 += [''] 212 | 213 | pred_all_100 = [] 214 | 215 | for i in range(len(out_list_100)): 216 | output = out_list_100[i].replace(prompt_list_100[i], '') 217 | pred_all_100.append({ 218 | "id": userid_list_100[i], 219 | "output": output 220 | }) 221 | 222 | print(output) 223 | 224 | 225 | with open(f'./data/{args.task_name}/profile_user_100.json', 'w') as f: 226 | json.dump(pred_all_100, f) 227 | -------------------------------------------------------------------------------- /prompt/prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "movie_tagging":{ 3 | "prompt": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {} tag:", 4 | "full_prompt": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {} tag: {}", 5 | "retrieval_history": "description: {description} tag: {tag}\n", 6 | "retrieval_query": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation.\n description: {description}", 7 | "retrieval_query_wokey": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation.\n description: {}", 8 | "OPPU_input": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {description} tag:", 9 | "OPPU_full": "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {description} tag: {tag}" 10 | }, 11 | "citation":{ 12 | "prompt": "Identify the most relevant reference for the listed publication by the researcher. Select the reference paper that is most closely related to the researcher's work.\npaper title: {} reference:", 13 | "full_prompt": "Identify the most relevant reference for the listed publication by the researcher. Select the reference paper that is most closely related to the researcher's work.\npaper title: {} reference: {}", 14 | "retrieval_history": "paper title: {title}\n", 15 | "retrieval_query": "Identify the most relevant reference for the listed publication by the researcher. Select the reference paper that is most closely related to the researcher's work. Please respond with only the number that corresponds to the reference.\npaper title: {title}", 16 | "retrieval_query_wokey": "Identify the most relevant reference for the listed publication by the researcher. Select the reference paper that is most closely related to the researcher's work. Please respond with only the number that corresponds to the reference.\npaper title: {}", 17 | "OPPU_input": "paper title: {title} reference:", 18 | "OPPU_full": "paper title: {title} reference: {citation}" 19 | }, 20 | "news_categorize":{ 21 | "prompt": "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [travel, education, parents, style & beauty, entertainment, food & drink, science & technology, business, sports, healthy living, women, politics, crime, culture & arts, religion]\n article: {} category:", 22 | "full_prompt": "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [travel, education, parents, style & beauty, entertainment, food & drink, science & technology, business, sports, healthy living, women, politics, crime, culture & arts, religion]\n article: {} category: {}", 23 | "retrieval_history": "article: {text} category: {category}\n", 24 | "retrieval_query": "Which category does this article relate to among the following categories? arcicle: {text}", 25 | "retrieval_query_wokey": "Which category does this article relate to among the following categories? arcicle: {}", 26 | "OPPU_input": "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [travel, education, parents, style & beauty, entertainment, food & drink, science & technology, business, sports, healthy living, women, politics, crime, culture & arts, religion]\n article: {text} category:", 27 | "OPPU_full": "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [travel, education, parents, style & beauty, entertainment, food & drink, science & technology, business, sports, healthy living, women, politics, crime, culture & arts, religion]\n article: {text} category: {category}" 28 | }, 29 | "news_headline":{ 30 | "prompt": "Generate a headline for the following article.\narticle: {} headline:", 31 | "full_prompt": "Generate a headline for the following article.\narticle: {} headline: {}", 32 | "retrieval_history": "article: {text} headline: {title}\n", 33 | "retrieval_query": "Generate a headline for the following article: {text}", 34 | "retrieval_query_wokey": "Generate a headline for the following article: {}", 35 | "OPPU_input": "Generate a headline for the following article.\n article: {text} headline:", 36 | "OPPU_full": "Generate a headline for the following article.\n article: {text} headline: {title}" 37 | }, 38 | "product_rating":{ 39 | "prompt": "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {}\n score:", 40 | "full_prompt": "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {}\n score: {}", 41 | "retrieval_history": "review: {text} score: {score}\n", 42 | "retrieval_query": "What is the score of the following review on a scale of 1 to 5? review: {text}", 43 | "retrieval_query_wokey": "What is the score of the following review on a scale of 1 to 5? review: {}", 44 | "OPPU_input": "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {text}\n score:", 45 | "OPPU_full": "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {text}\n score: {score}" 46 | }, 47 | "scholarly_title": { 48 | "prompt": "Generate a title for the following abstract of a paper.\n abstract: {} title:", 49 | "full_prompt": "Generate a title for the following abstract of a paper.\n abstract: {} title: {}", 50 | "retrieval_history": "abstract: {abstract} title: {title}\n", 51 | "retrieval_query": "Generate a title for the following abstract of a paper: {abstract}", 52 | "retrieval_query_wokey": "Generate a title for the following abstract of a paper: {}", 53 | "OPPU_input": "Generate a title for the following abstract of a paper.\n abstract: {abstract} title:", 54 | "OPPU_full": "Generate a title for the following abstract of a paper.\n abstract: {abstract} title: {title}" 55 | }, 56 | "tweet_paraphrase":{ 57 | "prompt": "Paraphrase the following text into tweet without any explanation before or after it.\n text: {} tweet:", 58 | "full_prompt": "Paraphrase the following text into tweet without any explanation before or after it.\n text: {} tweet: {}", 59 | "retrieval_history": "tweet: {text}\n", 60 | "retrieval_query": "Paraphrase the following text into tweet without any explanation before or after it: {text}", 61 | "retrieval_query_wokey": "Paraphrase the following text into tweet without any explanation before or after it: {}", 62 | "OPPU_input": "tweet:", 63 | "OPPU_full": "tweet: {text}" 64 | } 65 | } -------------------------------------------------------------------------------- /prompt/prompt_profile.json: -------------------------------------------------------------------------------- 1 | { 2 | "movie_tagging":{ 3 | "profile_prompt": "Look at the following past movies this user has watched and determine the most popular tag they labeled. Answer in the following form: most popular tag: . User History: {} Answer:", 4 | "retrieval_history": "description: {description} tag: {tag}" 5 | }, 6 | "citation":{ 7 | "profile_prompt": "Write a summary, in English, of the research interests and topics of a researcher who has published the following papers. Only generate the summary, no other text. User History: {} Answer:", 8 | "retrieval_history": "paper title: {title} reference: {citation}" 9 | }, 10 | "news_categorize":{ 11 | "profile_prompt": "Look at the following past articles this journalist has written and determine the most popular category they write in. Answer in the following form: most popular category: . User History: {} Answer:", 12 | "retrieval_history": "article: {text} category: {category}" 13 | }, 14 | "news_headline":{ 15 | "profile_prompt": "Given this author’s previous articles, try to describe a template for their headlines. I want to be able to accurately predict the headline gives one of their articles. Be specific about their style and wording, don’t tell me anything generic. User History: {} Answer:", 16 | "retrieval_history": "article: {text} headline: {title}" 17 | }, 18 | "product_rating":{ 19 | "profile_prompt": "Based on this user’s past reviews, what are the most common scores they give for positive and negative reviews? Answer in the following form: most common positive score: , most common negative score: . User History: {} Answer:", 20 | "retrieval_history": "review: {text} score: {score}" 21 | }, 22 | "scholarly_title": { 23 | "profile_prompt": "Given this author’s previous publications, try to describe a template for their titles. I want to be able to accurately predict the title of one of the papers from the abstract. Only generate the template description, nothing else. User History: {} Answer:", 24 | "retrieval_history": "abstract: {abstract} title: {title}" 25 | }, 26 | "tweet_paraphrase":{ 27 | "profile_prompt": "Given this person’s previous tweets, try to describe a template for their tweets. I want to take a generic sentence and rephrase it to sound like one of their tweets, with the same style/punctuation/capitalization/wording/tone/etc. as them. Only give me the template description, nothing else. User History: {} Answer:", 28 | "retrieval_history": "tweet: {text}" 29 | } 30 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | accelerate==0.24.1 3 | aiohttp==3.8.6 4 | aiosignal==1.3.1 5 | anyio==4.2.0 6 | appdirs==1.4.4 7 | argon2-cffi==23.1.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.3.0 10 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work 11 | async-lru==2.0.4 12 | async-timeout==4.0.3 13 | attrs==23.1.0 14 | Babel==2.14.0 15 | backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work 16 | bayesian-optimization==1.4.3 17 | beautifulsoup4==4.12.2 18 | bitsandbytes==0.41.2.post2 19 | bleach==6.1.0 20 | blessed==1.20.0 21 | Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work 22 | certifi==2023.7.22 23 | cffi @ file:///croot/cffi_1670423208954/work 24 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 25 | click==8.1.7 26 | cma==3.3.0 27 | colorama==0.4.6 28 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1691044910542/work 29 | cryptography @ file:///croot/cryptography_1694444244250/work 30 | datasets==2.12.0 31 | debugpy @ file:///croot/debugpy_1690905042057/work 32 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work 33 | defusedxml==0.7.1 34 | dill==0.3.6 35 | docker-pycreds==0.4.0 36 | docstring-parser==0.15 37 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work 38 | evaluate==0.4.1 39 | exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1692026125334/work 40 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work 41 | fastjsonschema==2.19.1 42 | filelock @ file:///croot/filelock_1672387128942/work 43 | fire==0.5.0 44 | fqdn==1.5.1 45 | frozenlist==1.4.0 46 | fsspec==2023.10.0 47 | gitdb==4.0.11 48 | GitPython==3.1.40 49 | gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work 50 | gpustat==1.1.1 51 | huggingface-hub==0.20.3 52 | idna @ file:///croot/idna_1666125576474/work 53 | ijson==3.2.3 54 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1698244021190/work 55 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1698846603011/work 56 | ipywidgets==8.1.1 57 | isoduration==20.11.0 58 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work 59 | Jinja2 @ file:///croot/jinja2_1666908132255/work 60 | joblib==1.3.2 61 | json5==0.9.14 62 | jsonargparse==4.27.5 63 | jsonpointer==2.4 64 | jsonschema==4.20.0 65 | jsonschema-specifications==2023.12.1 66 | jupyter==1.0.0 67 | jupyter-console==6.6.3 68 | jupyter-events==0.9.0 69 | jupyter-lsp==2.2.1 70 | jupyter_client==8.6.0 71 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1669775088561/work 72 | jupyter_server==2.12.3 73 | jupyter_server_terminals==0.5.1 74 | jupyterlab==4.0.10 75 | jupyterlab-widgets==3.0.9 76 | jupyterlab_pygments==0.3.0 77 | jupyterlab_server==2.25.2 78 | lightning==2.2.0.post0 79 | lightning-utilities==0.10.1 80 | markdown-it-py==3.0.0 81 | MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work 82 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work 83 | mdurl==0.1.2 84 | mistune==3.0.2 85 | mkl-fft @ file:///croot/mkl_fft_1695058164594/work 86 | mkl-random @ file:///croot/mkl_random_1695059800811/work 87 | mkl-service==2.4.0 88 | mpmath @ file:///croot/mpmath_1690848262763/work 89 | multidict==6.0.4 90 | multiprocess==0.70.14 91 | nbclient==0.9.0 92 | nbconvert==7.14.0 93 | nbformat==5.9.2 94 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1697083700168/work 95 | networkx @ file:///croot/networkx_1690561992265/work 96 | nevergrad==1.0.1 97 | nltk==3.8.1 98 | notebook==7.0.6 99 | notebook_shim==0.2.3 100 | numpy @ file:///croot/numpy_and_numpy_base_1695830428084/work/dist/numpy-1.26.0-cp310-cp310-linux_x86_64.whl#sha256=fc2732718bc9e06a7b702492cb4f5afffe9671083930452d894377bf563464a3 101 | nvidia-ml-py==12.535.133 102 | overrides==7.4.0 103 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1696202382185/work 104 | pandas==2.1.2 105 | pandocfilters==1.5.0 106 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work 107 | pathtools==0.1.2 108 | peft==0.5.0 109 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work 110 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 111 | Pillow @ file:///croot/pillow_1696580024257/work 112 | prometheus-client==0.19.0 113 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1688565951714/work 114 | protobuf==4.25.0 115 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work 116 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 117 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 118 | pyarrow==14.0.0 119 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 120 | pyg-lib==0.3.0+pt20cu118 121 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1691408637400/work 122 | pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work 123 | pyparsing==3.1.1 124 | PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work 125 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work 126 | python-json-logger==2.0.7 127 | pytorch-lightning==2.2.0.post0 128 | pytz==2023.3.post1 129 | PyYAML==6.0.1 130 | pyzmq @ file:///croot/pyzmq_1686601365461/work 131 | qtconsole==5.5.1 132 | QtPy==2.4.1 133 | rank-bm25==0.2.2 134 | referencing==0.32.1 135 | regex==2023.10.3 136 | requests @ file:///croot/requests_1690400202158/work 137 | responses==0.18.0 138 | rfc3339-validator==0.1.4 139 | rfc3986-validator==0.1.1 140 | rich==13.6.0 141 | rouge-score==0.1.2 142 | rpds-py==0.16.2 143 | safetensors==0.4.2 144 | scikit-learn==1.3.2 145 | scipy==1.11.3 146 | Send2Trash==1.8.2 147 | sentencepiece==0.1.99 148 | sentry-sdk==1.34.0 149 | setproctitle==1.3.3 150 | shtab==1.6.4 151 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 152 | smmap==5.0.1 153 | sniffio==1.3.0 154 | soupsieve==2.5 155 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work 156 | sympy @ file:///croot/sympy_1668202399572/work 157 | termcolor==2.4.0 158 | terminado==0.18.0 159 | threadpoolctl==3.2.0 160 | timm==0.9.16 161 | tinycss2==1.2.1 162 | tokenizers==0.15.1 163 | tomli==2.0.1 164 | torch==2.0.1 165 | torch-cluster==1.6.3+pt20cu118 166 | torch-scatter==2.1.2+pt20cu118 167 | torch-sparse==0.6.18+pt20cu118 168 | torch-spline-conv==1.2.2+pt20cu118 169 | torch_geometric==2.4.0 170 | torchaudio==2.0.2 171 | torchmetrics==1.3.1 172 | torchvision==0.15.2 173 | tornado==6.4 174 | tqdm==4.66.1 175 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1698671135544/work 176 | transformers @ git+https://github.com/huggingface/transformers@5f9685576149fb45a61d0dcec9a260930df0a49a 177 | trim==0.3 178 | triton==2.0.0 179 | trl==0.7.4 180 | types-python-dateutil==2.8.19.20240106 181 | typing_extensions @ file:///croot/typing_extensions_1690297465030/work 182 | tyro==0.5.12 183 | tzdata==2023.3 184 | uri-template==1.3.0 185 | urllib3 @ file:///croot/urllib3_1698257533958/work 186 | wandb==0.15.12 187 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1698744702785/work 188 | webcolors==1.13 189 | webencodings==0.5.1 190 | websocket-client==1.7.0 191 | widgetsnbextension==4.0.9 192 | xxhash==3.4.1 193 | yarl==1.9.2 194 | -------------------------------------------------------------------------------- /task_LoRA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import bitsandbytes as bnb 4 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 5 | # from transformers import pipeline, BitsAndBytesConfig 6 | import argparse 7 | from rank_bm25 import BM25Okapi 8 | # from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 9 | import transformers 10 | from utils import split_batch, get_first_k_tokens, print_trainable_parameters, name2taskid 11 | from utils import extract_citation_title, extract_option, extract_movie, extract_news_cat, extract_news_headline, extract_product_review, extract_scholarly_title, extract_tweet_paraphrasing 12 | import json 13 | from tqdm import tqdm 14 | 15 | 16 | parser = argparse.ArgumentParser(description="Parser for LoRA") 17 | parser.add_argument('--model_name', type=str, default='meta-llama/Llama-2-7b-hf') 18 | parser.add_argument('--batch_size', type=int, default=16) 19 | parser.add_argument('--k', type=int, default=0) 20 | parser.add_argument('--max_step', type=int, default=5000) 21 | parser.add_argument('--cut_off', type=int, default=2048) 22 | parser.add_argument('--max_epoch', type=int, default=3) 23 | parser.add_argument('--temperature', type=float, default=0.1) 24 | parser.add_argument('--task_name', type=str, default='movie_tagging') 25 | parser.add_argument('--add_profile', action='store_true') 26 | parser.add_argument('--access_token', type=str, default=None) 27 | 28 | args = parser.parse_args() 29 | model_name = args.model_name 30 | task_name = args.task_name 31 | batch_size = args.batch_size 32 | k = args.k 33 | # max_step = args.max_step 34 | cutoff_len = args.cut_off 35 | add_eos_token = False 36 | max_epoch = args.max_epoch 37 | 38 | # # 4 bit quantization inference 39 | # bnb_config = BitsAndBytesConfig( 40 | # load_in_4bit=True, 41 | # bnb_4bit_quant_type="nf4", 42 | # bnb_4bit_compute_dtype=torch.float16, 43 | # bnb_4bit_use_double_quant=True, 44 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 45 | # ) 46 | 47 | # 8-bit quantization inference 48 | # bnb_config = BitsAndBytesConfig( 49 | # load_in_8bit=True, 50 | # bnb_8bit_quant_type="nf8", 51 | # bnb_8bit_compute_dtype=torch.float16, 52 | # bnb_8bit_use_double_quant=True, 53 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 54 | # ) 55 | 56 | # 16-bit quantization inference 57 | # bnb_config = BitsAndBytesConfig( 58 | # load_in_16bit=True, 59 | # bnb_16bit_quant_type="bf16", 60 | # bnb_16bit_compute_dtype=torch.bfloat16, 61 | # bnb_16bit_use_double_quant=True, 62 | # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 63 | # ) 64 | 65 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=args.access_token) 66 | tokenizer.eos_token = "" 67 | tokenizer.pad_token = '[PAD]' 68 | # tokenizer.pad_token = tokenizer.eos_token 69 | tokenizer.pad_token_id = tokenizer.eos_token_id 70 | 71 | 72 | base_model = AutoModelForCausalLM.from_pretrained( 73 | model_name, 74 | # quantization_config=bnb_config, 75 | local_files_only=False, 76 | device_map='auto', 77 | trust_remote_code=True, 78 | torch_dtype=torch.bfloat16 79 | ) 80 | 81 | base_model.config.use_cache = False 82 | base_model.config.pad_token_id = tokenizer.pad_token_id 83 | base_model.config.eos_token_id = tokenizer.eos_token_id 84 | base_model.config.bos_token_id = tokenizer.bos_token_id 85 | 86 | 87 | from peft import prepare_model_for_kbit_training 88 | 89 | base_model.gradient_checkpointing_enable() 90 | base_model = prepare_model_for_kbit_training(base_model) 91 | 92 | 93 | 94 | from peft import LoraConfig, get_peft_model 95 | 96 | peft_config = LoraConfig( 97 | r=8, 98 | lora_alpha=8, 99 | target_modules=["q_proj", "v_proj", "k_proj", "out_proj"], 100 | lora_dropout=0.05, 101 | bias="none", 102 | task_type="CAUSAL_LM" 103 | ) 104 | 105 | training_arguments = transformers.TrainingArguments( 106 | output_dir='outputs/', 107 | per_device_train_batch_size=batch_size, 108 | gradient_accumulation_steps=1, 109 | optim='adamw_torch', 110 | num_train_epochs=max_epoch, 111 | save_steps=1e9, 112 | logging_steps=50, 113 | learning_rate=1e-4, 114 | weight_decay=1e-2, 115 | bf16=True, 116 | max_grad_norm=0.3, 117 | # max_steps=max_step, 118 | warmup_ratio=0.1, 119 | group_by_length=True, 120 | lr_scheduler_type='linear', 121 | report_to='none', 122 | ) 123 | 124 | 125 | with open(f"./data/{task_name}/user_others.json", 'r') as f: 126 | train = json.load(f) 127 | 128 | 129 | with open(f"./data/{task_name}/user_top_100_history.json", 'r') as f: 130 | test_data = json.load(f) 131 | 132 | if args.task_name == "movie_tagging": 133 | extract_article = extract_movie 134 | elif args.task_name == "news_categorize": 135 | extract_article = extract_news_cat 136 | elif args.task_name == "news_headline": 137 | extract_article = extract_news_headline 138 | elif args.task_name == "product_rating": 139 | extract_article = extrat_product_review 140 | elif args.task_name == "scholarly_title": 141 | extract_article = extract_scholarly_title 142 | elif args.task_name == "tweet_paraphrase": 143 | extract_article = extrat_tweet_paraphrasing 144 | 145 | 146 | with open('./prompt/prompt.json', 'r') as f: 147 | prompt_template = json.load(f) 148 | 149 | 150 | if args.add_profile: 151 | with open(f'./data/{task_name}/profile_user_100.json', 'r') as f: 152 | test_profile = json.load(f) 153 | 154 | with open(f'./data/{task_name}/profile_user_others.json', 'r') as f: 155 | train_profile = json.load(f) 156 | 157 | 158 | def tokenize(prompt, add_eos_token=True): 159 | # there's probably a way to do this with the tokenizer settings 160 | # but again, gotta move fast 161 | result = tokenizer( 162 | prompt, 163 | truncation=True, 164 | max_length=cutoff_len, 165 | padding=False, 166 | return_tensors=None, 167 | ) 168 | if ( 169 | result["input_ids"][-1] != tokenizer.eos_token_id 170 | and len(result["input_ids"]) < cutoff_len 171 | and add_eos_token 172 | ): 173 | result["input_ids"].append(tokenizer.eos_token_id) 174 | result["attention_mask"].append(1) 175 | 176 | result["labels"] = result["input_ids"].copy() 177 | 178 | return result 179 | 180 | 181 | def generate_and_tokenize_prompt(data_point): 182 | full_prompt = data_point['full_prompt'] 183 | tokenized_full_prompt = tokenize(full_prompt) 184 | # if not train_on_inputs: 185 | user_prompt = data_point['prompt'] 186 | 187 | tokenized_user_prompt = tokenize( 188 | user_prompt, add_eos_token=add_eos_token 189 | ) 190 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 191 | 192 | if add_eos_token: 193 | user_prompt_len -= 1 194 | 195 | tokenized_full_prompt["labels"] = [ 196 | -100 197 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 198 | user_prompt_len: 199 | ] # could be sped up, probably 200 | return tokenized_full_prompt 201 | 202 | 203 | 204 | # training 205 | from datasets import load_dataset, Dataset 206 | model = get_peft_model(base_model, peft_config) 207 | print_trainable_parameters(model) 208 | 209 | pred_all = [] 210 | actual = [] 211 | train_data = [] 212 | 213 | for i in tqdm(range(len(train))): 214 | if args.add_profile: 215 | profile = train_profile[i]['output'] 216 | 217 | for idx, q in enumerate(train[i]['query']): 218 | 219 | if args.task_name != "citation": 220 | article = get_first_k_tokens(extract_article(q['input']), 768) 221 | prompt = prompt_template[args.task_name]['prompt'].format(article) 222 | full_prompt = prompt_template[args.task_name]['full_prompt'].format(get_first_k_tokens(extract_article(q['input']), 768), q['gold']) 223 | 224 | else: 225 | question = q['input'] 226 | article = extract_citation_title(question) 227 | option1, option2 = extract_option(question, 1), extract_option(question, 2) 228 | 229 | prompt = prompt_template[args.task_name]['prompt'].format(article, option1, option2) 230 | full_prompt = prompt_template[args.task_name]['full_prompt'].format(article, option1, option2, q['gold']) 231 | 232 | if k > 0: 233 | visible_history_list = train[i]['profile'] 234 | 235 | for p in visible_history_list: 236 | for key, value in p.items(): 237 | p[key] = get_first_k_tokens(p[key], 368) 238 | 239 | history_list = [prompt_template[args.task_name]['retrieval_history'].format(**p) for p in visible_history_list] 240 | tokenized_corpus = [doc.split(" ") for doc in history_list] 241 | bm25 = BM25Okapi(tokenized_corpus) 242 | 243 | tokenized_query = prompt_template[args.task_name]["retrieval_query_wokey"].format(article).split(' ') 244 | retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=args.k) 245 | 246 | history_string = "".join(retrieved_history) 247 | prompt = history_string + "\n" + prompt 248 | full_prompt = history_string + "\n" + full_prompt 249 | 250 | if args.add_profile: 251 | prompt = profile + "\n" + prompt 252 | full_prompt = profile + "\n" + full_prompt 253 | 254 | train_data.append( 255 | { 256 | "prompt": prompt, 257 | "full_prompt": full_prompt 258 | } 259 | ) 260 | 261 | print(train_data) 262 | 263 | train_dataset = Dataset.from_list(train_data) 264 | train_dataset = train_dataset.map(generate_and_tokenize_prompt).shuffle() 265 | 266 | trainer = transformers.Trainer( 267 | model=model, 268 | train_dataset=train_dataset, 269 | args=training_arguments, 270 | data_collator=transformers.DataCollatorForSeq2Seq( 271 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 272 | ), 273 | ) 274 | 275 | for name, module in trainer.model.named_modules(): 276 | if "norm" in name: 277 | module = module.to(torch.float32) 278 | 279 | 280 | model.config.use_cache = False # silence the warnings. Please re-enable for inference! 281 | trainer.train() 282 | 283 | if args.add_profile: 284 | output_name = "./ckpt/{}/k{}-{}-{}-profile-task_LoRA_ckpt".format(args.task_name, args.k, args.task_name, model_name.split('/')[-1]) 285 | else: 286 | output_name = "./ckpt/{}/k{}-{}-{}-task_LoRA_ckpt".format(args.task_name, args.k, args.task_name, model_name.split('/')[-1]) 287 | 288 | model.save_pretrained(output_name) 289 | 290 | model.eval() 291 | model.config.use_cache = True # silence the warnings. Please re-enable for inference! 292 | 293 | for i in tqdm(range(len(test_data))): 294 | if args.add_profile: 295 | profile = test_profile[i]['output'] 296 | 297 | if k > 0: 298 | visible_history_list = test_data[i]['profile'] 299 | for p in visible_history_list: 300 | for key, value in p.items(): 301 | p[key] = get_first_k_tokens(p[key], 368) 302 | 303 | history_list = [prompt_template[args.task_name]['retrieval_history'].format(**p) for p in visible_history_list] 304 | 305 | tokenized_corpus = [doc.split(" ") for doc in history_list] 306 | bm25 = BM25Okapi(tokenized_corpus) 307 | 308 | test_question_list = [] 309 | question_id_list = [] 310 | 311 | for q in test_data[i]['query']: 312 | 313 | if args.task_name == 'citation': 314 | test_question = q['input'] 315 | test_article = extract_citation_title(test_question) 316 | option1, option2 = extract_option(test_question, 1), extract_option(test_question, 2) 317 | test_prompt = prompt_template[args.task_name]['prompt'].format(test_article, option1, option2) 318 | 319 | else: 320 | test_question = q['input'] 321 | test_article = extract_article(test_question) 322 | test_prompt = prompt_template[args.task_name]['prompt'].format(test_article) 323 | 324 | if k > 0: 325 | tokenized_query = prompt_template[args.task_name]['retrieval_query_wokey'].format(test_article).split(" ") 326 | retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=args.k) 327 | 328 | history_string = "".join(retrieved_history) 329 | test_prompt = history_string + "\n" + test_prompt 330 | 331 | if args.add_profile: 332 | test_prompt = profile + "\n" + test_prompt 333 | 334 | test_question_list.append(test_prompt) 335 | question_id_list.append(q['id']) 336 | 337 | test_batch_list = split_batch(test_question_list, 1) 338 | out_list = [] 339 | 340 | with torch.no_grad(): 341 | for batch_idx, batch in tqdm(enumerate(test_batch_list), total=len(test_batch_list)): 342 | # try: 343 | sentences = batch 344 | inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False) 345 | inputs = inputs.to(model.device) 346 | 347 | with torch.autocast(device_type="cuda"): 348 | outputs = model.generate( 349 | **inputs, 350 | do_sample=True, 351 | top_k=10, 352 | temperature=args.temperature, 353 | top_p=0.9, 354 | eos_token_id=tokenizer.eos_token_id, 355 | max_new_tokens=200 356 | ) 357 | 358 | out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) 359 | out_list += out_sentence 360 | # except: 361 | # out_list += [''] 362 | 363 | for i in range(len(out_list)): 364 | output = out_list[i].replace(test_question_list[i], '') 365 | pred_all.append({ 366 | "id": question_id_list[i], 367 | "output": output 368 | }) 369 | 370 | print(output) 371 | 372 | output_file = { 373 | 'task': name2taskid[args.task_name], 374 | 'golds': pred_all, 375 | 'model': model_name, 376 | } 377 | 378 | if args.add_profile: 379 | with open('./output/{}/output-task-k{}-{}-{}-profile.json'.format(args.k, args.task_name, args.task_name, model_name.split('/')[-1]), 'w') as f: 380 | json.dump(output_file, f, indent=4) 381 | else: 382 | with open('./output/{}/output-task-k{}-{}-{}.json'.format(args.k, args.task_name, args.task_name, model_name.split('/')[-1]), 'w') as f: 383 | json.dump(output_file, f, indent=4) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def extract_option(s, num): 4 | # Look for string after [1]: and between " 5 | match = re.search(r'\[' + str(num) + '\]: "([^"]*)"', s) 6 | return match.group(1) if match else None 7 | 8 | def extract_citation_title(text): 9 | pattern = r'written the paper with the title "([^"]*)"' 10 | match = re.search(pattern, text) 11 | if match: 12 | return match.group(1) 13 | else: 14 | return None 15 | 16 | 17 | def extract_movie(text): 18 | marker = "] description: " 19 | # Find the position of the marker in the text 20 | marker_pos = text.find(marker) 21 | 22 | # Check if the marker is found 23 | if marker_pos == -1: 24 | raise ValueError() 25 | 26 | # Extract the string after the marker 27 | extracted_string = text[marker_pos + len(marker):] 28 | 29 | return extracted_string 30 | 31 | def extract_news_cat(text): 32 | marker = "] article: " 33 | # Find the position of the marker in the text 34 | marker_pos = text.find(marker) 35 | 36 | # Check if the marker is found 37 | if marker_pos == -1: 38 | raise ValueError() 39 | 40 | # Extract the string after the marker 41 | extracted_string = text[marker_pos + len(marker):] 42 | 43 | return extracted_string 44 | 45 | def extract_news_headline(text): 46 | marker = "Generate a headline for the following article: " 47 | # Find the position of the marker in the text 48 | marker_pos = text.find(marker) 49 | 50 | # Check if the marker is found 51 | if marker_pos == -1: 52 | raise ValueError() 53 | 54 | # Extract the string after the marker 55 | extracted_string = text[marker_pos + len(marker):] 56 | 57 | return extracted_string 58 | 59 | def extract_product_review(text): 60 | marker = "without further explanation. review: " 61 | # Find the position of the marker in the text 62 | marker_pos = text.find(marker) 63 | 64 | # Check if the marker is found 65 | if marker_pos == -1: 66 | raise ValueError() 67 | 68 | # Extract the string after the marker 69 | extracted_string = text[marker_pos + len(marker):] 70 | 71 | return extracted_string 72 | 73 | 74 | def extract_scholarly_title(text): 75 | marker = "Generate a title for the following abstract of a paper: " 76 | # Find the position of the marker in the text 77 | marker_pos = text.find(marker) 78 | 79 | # Check if the marker is found 80 | if marker_pos == -1: 81 | raise ValueError() 82 | 83 | # Extract the string after the marker 84 | extracted_string = text[marker_pos + len(marker):] 85 | 86 | return extracted_string 87 | 88 | 89 | def extract_tweet_paraphrasing(text): 90 | marker = "Paraphrase the following tweet without any explanation before or after it: " 91 | # Find the position of the marker in the text 92 | marker_pos = text.find(marker) 93 | 94 | # Check if the marker is found 95 | if marker_pos == -1: 96 | raise ValueError() 97 | 98 | # Extract the string after the marker 99 | extracted_string = text[marker_pos + len(marker):] 100 | 101 | return extracted_string 102 | 103 | def get_first_k_tokens(text, k): 104 | """ 105 | Extracts the first k tokens from a text string. 106 | 107 | :param text: The input text string. 108 | :param k: The number of tokens to extract. 109 | :return: The first k tokens of the text string. 110 | """ 111 | # Split the text into tokens based on whitespace 112 | tokens = text.split() 113 | output = " ".join(tokens[:k]) 114 | 115 | # Return the first k tokens 116 | return output 117 | 118 | def split_batch(init_list, batch_size): 119 | groups = zip(*(iter(init_list),) * batch_size) 120 | end_list = [list(i) for i in groups] 121 | count = len(init_list) % batch_size 122 | end_list.append(init_list[-count:]) if count != 0 else end_list 123 | return end_list 124 | 125 | def tokenize(prompt, add_eos_token=True): 126 | # there's probably a way to do this with the tokenizer settings 127 | # but again, gotta move fast 128 | result = tokenizer( 129 | prompt, 130 | truncation=True, 131 | max_length=cutoff_len, 132 | padding=False, 133 | return_tensors=None, 134 | ) 135 | if ( 136 | result["input_ids"][-1] != tokenizer.eos_token_id 137 | and len(result["input_ids"]) < cutoff_len 138 | and add_eos_token 139 | ): 140 | result["input_ids"].append(tokenizer.eos_token_id) 141 | result["attention_mask"].append(1) 142 | 143 | result["labels"] = result["input_ids"].copy() 144 | 145 | return result 146 | 147 | 148 | def print_trainable_parameters(model): 149 | """ 150 | Prints the number of trainable parameters in the model. 151 | """ 152 | trainable_params = 0 153 | all_param = 0 154 | for _, param in model.named_parameters(): 155 | all_param += param.numel() 156 | if param.requires_grad: 157 | trainable_params += param.numel() 158 | print( 159 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" 160 | ) 161 | 162 | 163 | 164 | 165 | name2taskid = { 166 | "citation": "LaMP_1", 167 | "movie_tagging": "LaMP_2M", 168 | "news_categorize": "LaMP_2N", 169 | "news_headline": "LaMP_4", 170 | "product_rating": "LaMP_3", 171 | "scholarly_title": "LaMP_5", 172 | "tweet_paraphrase": "LaMP_7" 173 | } --------------------------------------------------------------------------------