├── imgs ├── hitab.png ├── fetaqa.png ├── tabfact.png ├── hybridqa.png ├── tablellama_figure1.png └── tablellama_figure2.png ├── requirements.txt ├── ds_configs ├── stage2.json └── stage3.json ├── eval_scripts ├── eval_tabfact.py ├── eval_hitab.py ├── eval_ent_link.py ├── eval_fetaqa.py ├── eval_col_pop.py ├── eval_row_pop.py ├── qa_datadump_utils.py ├── table_utils.py ├── eval_rel_extraction.py ├── eval_col_type.py └── metric.py ├── LICENSE ├── inference_rel_extraction_col_type.py ├── inference_hitab_tabfact_fetaqa.py ├── inference_ent_link.py ├── inference_row_pop.py ├── README.md ├── supervised_fine_tune_stream.py ├── supervised_fine_tune.py ├── inference_schema_aug.py └── llama_attn_replace.py /imgs/hitab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/hitab.png -------------------------------------------------------------------------------- /imgs/fetaqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/fetaqa.png -------------------------------------------------------------------------------- /imgs/tabfact.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/tabfact.png -------------------------------------------------------------------------------- /imgs/hybridqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/hybridqa.png -------------------------------------------------------------------------------- /imgs/tablellama_figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/tablellama_figure1.png -------------------------------------------------------------------------------- /imgs/tablellama_figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/tablellama_figure2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | rouge_score 3 | fire 4 | openai 5 | transformers>=4.28.1 6 | torch>=2.0 7 | sentencepiece 8 | tokenizers>=0.13.3 9 | wandb 10 | accelerate 11 | datasets 12 | deepspeed 13 | peft 14 | partial 15 | gradio 16 | einops -------------------------------------------------------------------------------- /ds_configs/stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "gradient_accumulation_steps": "auto", 4 | "gradient_clipping": "auto", 5 | "zero_allow_untested_optimizer": true, 6 | "bf16": { 7 | "enabled": "auto", 8 | "loss_scale": 0, 9 | "initial_scale_power": 16, 10 | "loss_scale_window": 1000, 11 | "hysteresis": 2, 12 | "min_loss_scale": 1 13 | }, 14 | "zero_optimization": { 15 | "stage": 2, 16 | "allgather_partitions": true, 17 | "allgather_bucket_size": 1e9, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 1e9, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /eval_scripts/eval_tabfact.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | 5 | def main(args): 6 | with open(args.pred_file, "r") as f: 7 | data = json.load(f) 8 | 9 | correct = 0 10 | remove_count = 0 11 | for i in range(len(data)): 12 | ground_truth = data[i]["output"] 13 | prediction = data[i]["predict"].strip("") 14 | # if prediction.find(ground_truth) == 0: 15 | if prediction == ground_truth: 16 | correct += 1 17 | if prediction.find("") == 0: 18 | remove_count += 1 19 | 20 | print("correct:", correct) 21 | # print("remove_count:", remove_count) 22 | print("accuracy:", correct/(len(data)-remove_count)) 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser(description='arg parser') 26 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/tabfact_pred.json', help='') 27 | args = parser.parse_args() 28 | main(args) 29 | -------------------------------------------------------------------------------- /eval_scripts/eval_hitab.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import argparse 4 | from table_utils import evaluate 5 | 6 | 7 | def main(args): 8 | with open(args.pred_file, "r") as f: 9 | data = json.load(f) 10 | 11 | pred_list = [] 12 | gold_list = [] 13 | for i in range(len(data)): 14 | if len(data[i]["predict"].strip("").split(">, <")) > 1: 15 | instance_pred_list = data[i]["predict"].strip("").split(">, <") 16 | pred_list.append(instance_pred_list) 17 | gold_list.append(data[i]["output"].strip("").split(">, <")) 18 | else: 19 | pred_list.append(data[i]["predict"].strip("")) 20 | gold_list.append(data[i]["output"].strip("")) 21 | 22 | print(evaluate(gold_list, pred_list)) 23 | 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser(description='arg parser') 27 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/hitab_pred.json', help='') 28 | args = parser.parse_args() 29 | main(args) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 OSU Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval_scripts/eval_ent_link.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | 5 | def main(args): 6 | with open(args.pred_file, "r") as f: 7 | data = json.load(f) 8 | 9 | assert len(data) == 2000 10 | correct_count = 0 11 | multi_candidates_example_count = 0 12 | for i in range(len(data)): 13 | candidate_list = data[i]["candidates_entity_desc_list"] 14 | ground_truth = data[i]["output"].strip("<>").lower() 15 | predict = data[i]["predict"][:-4].strip("<>").lower() 16 | # import pdb 17 | # pdb.set_trace() 18 | 19 | if ground_truth == predict: 20 | correct_count += 1 21 | if len(candidate_list) > 1: 22 | multi_candidates_example_count += 1 23 | 24 | 25 | print("correct_count:", correct_count) 26 | print("acc:", correct_count/len(data)) 27 | 28 | # print("multi_candidates_example_count:", multi_candidates_example_count) 29 | # print("multi_candidates_example_ratio:", multi_candidates_example_count/len(data)) 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description='arg parser') 33 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/ent_link_pred.json', help='') 34 | args = parser.parse_args() 35 | main(args) 36 | -------------------------------------------------------------------------------- /ds_configs/stage3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "cpu", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "cpu", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": false 42 | }, 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "steps_per_print": 5, 46 | "train_batch_size": "auto", 47 | "train_micro_batch_size_per_gpu": "auto", 48 | "wall_clock_breakdown": false 49 | } 50 | -------------------------------------------------------------------------------- /eval_scripts/eval_fetaqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import evaluate 3 | import json 4 | 5 | def main(args): 6 | with open(args.pred_file, "r") as f: 7 | data = json.load(f) 8 | 9 | test_examples_answer = [x["output"] for x in data] 10 | test_predictions_pred = [x["predict"].strip("") for x in data] 11 | predictions = test_predictions_pred 12 | references = test_examples_answer 13 | 14 | rouge = evaluate.load('rouge') 15 | results = rouge.compute(predictions=predictions, references=references) 16 | print(results) 17 | 18 | # bleu = evaluate.load('bleu') 19 | # results = bleu.compute(predictions=predictions, references=references) 20 | # print(results) 21 | 22 | sacrebleu = evaluate.load('sacrebleu') 23 | results = sacrebleu.compute(predictions=predictions, references=references) 24 | print(results) 25 | 26 | meteor = evaluate.load('meteor') 27 | results = meteor.compute(predictions=predictions, references=references) 28 | print(results) 29 | 30 | # bleurt = evaluate.load('bleurt') 31 | # results = bleurt.compute(predictions=predictions, references=references) 32 | # print(results) 33 | 34 | # bertscore = evaluate.load('bertscore') 35 | # results = bertscore.compute(predictions=predictions, references=references) 36 | # print(results) 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser(description='arg parser') 40 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/fetaqa_pred.json', help='') 41 | args = parser.parse_args() 42 | main(args) -------------------------------------------------------------------------------- /eval_scripts/eval_col_pop.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from metric import * 4 | 5 | 6 | def get_map_recall(data, data_name): 7 | rs = [] 8 | recall = [] 9 | for i in range(len(data)): 10 | ground_truth = data[i]["target"].strip(".") 11 | # ground_truth = data[i]["target"].strip(".") 12 | pred = data[i]["predict"].strip(".") 13 | if "" in pred: 14 | end_tok_ix = pred.rfind("") 15 | pred = pred[:end_tok_ix] 16 | ground_truth_list = ground_truth.split(", ") 17 | # ground_truth_list = test_col_pop_rank[i]["target"].strip(".").split(", ") 18 | pred_list = pred.split(", ") 19 | for k in range(len(pred_list)): 20 | pred_list[k] = pred_list[k].strip("<>") 21 | 22 | # print(len(ground_truth_list), len(pred_list)) 23 | 24 | # import pdb 25 | # pdb.set_trace() 26 | # add to remove repeated generated item 27 | new_pred_list = list(set(pred_list)) 28 | new_pred_list.sort(key = pred_list.index) 29 | r = [1 if z in ground_truth_list else 0 for z in new_pred_list] 30 | ap = average_precision(r) 31 | # print("ap:", ap) 32 | rs.append(r) 33 | 34 | # if sum(r) != 0: 35 | # recall.append(sum(r)/len(ground_truth_list)) 36 | # else: 37 | # recall.append(0) 38 | recall.append(sum(r)/len(ground_truth_list)) 39 | map = mean_average_precision(rs) 40 | m_recall = sum(recall)/len(data) 41 | f1 = 2 * map * m_recall / (map + m_recall) 42 | print(data_name, len(data)) 43 | print("mean_average_precision:", map) 44 | 45 | 46 | def main(args): 47 | file = args.pred_file 48 | 49 | with open(file, "r") as f: 50 | col_pop = json.load(f) 51 | 52 | get_map_recall(col_pop, 'col_pop') 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser(description='arg parser') 56 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/col_pop_pred.json', help='') 57 | args = parser.parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /eval_scripts/eval_row_pop.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from metric import * 4 | 5 | 6 | def get_map_recall(data, data_name): 7 | rs = [] 8 | recall = [] 9 | ap_list = [] 10 | for i in range(len(data)): 11 | pred = data[i]["predict"].strip(".") 12 | if "" in pred: 13 | end_tok_ix = pred.rfind("") 14 | pred = pred[:end_tok_ix] 15 | ground_truth_list = data[i]["target"] 16 | pred_list = pred.split(", ") 17 | for k in range(len(pred_list)): 18 | pred_list[k] = pred_list[k].strip("<>") 19 | 20 | # add to remove repeated generated item 21 | new_pred_list = list(set(pred_list)) 22 | new_pred_list.sort(key = pred_list.index) 23 | # r = [1 if z in ground_truth_list else 0 for z in pred_list] 24 | r = [1 if z in ground_truth_list else 0 for z in new_pred_list] 25 | # ap = average_precision(r) 26 | ap = row_pop_average_precision(r, ground_truth_list) 27 | # print("ap:", ap) 28 | ap_list.append(ap) 29 | 30 | map = sum(ap_list)/len(data) 31 | m_recall = sum(recall)/len(data) 32 | f1 = 2 * map * m_recall / (map + m_recall) 33 | print(data_name, len(data)) 34 | print("mean_average_precision:", map) 35 | 36 | 37 | # def merge_pred_from_multi_files(data_path, store_path): 38 | # merged_data = [] 39 | # for i in range(6): 40 | # with open(data_path + "row_pop_pred_" + str(i) + ".json", "r") as f: 41 | # temp_data = json.load(f) 42 | # merged_data += temp_data 43 | # with open(store_path, "w") as f: 44 | # json.dump(merged_data, f, indent = 2) 45 | 46 | 47 | def main(args): 48 | with open(args.pred_file, "r") as f: 49 | row_pop = json.load(f) 50 | get_map_recall(row_pop, 'row_pop') 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser(description='arg parser') 55 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/row_pop_pred.json', help='') 56 | args = parser.parse_args() 57 | main(args) 58 | 59 | -------------------------------------------------------------------------------- /eval_scripts/qa_datadump_utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for datadumping.""" 2 | import unicodedata 3 | import re 4 | from openpyxl.utils import get_column_letter, column_index_from_string 5 | import functools 6 | 7 | 8 | # Compare and sort cells 9 | def find_column(coord): 10 | """ Parse column letter from 'E3'. """ 11 | return re.findall('[a-zA-Z]+', coord) 12 | 13 | 14 | def find_row(coord): 15 | """ Parse row number from 'E3'. """ 16 | return re.findall('[0-9]+', coord) 17 | 18 | 19 | def cell_compare(cell1, cell2): 20 | """ Compare cell coord by row, then by column.""" 21 | col1, col2 = find_column(cell1)[0], find_column(cell2)[0] 22 | row1, row2 = find_row(cell1)[0], find_row(cell2)[0] 23 | if int(row1) < int(row2): 24 | return -1 25 | elif int(row1) > int(row2): 26 | return 1 27 | else: 28 | if column_index_from_string(col1) < column_index_from_string(col2): 29 | return -1 30 | else: 31 | return 1 32 | 33 | 34 | def linked_cell_compare(linked_cell_a, linked_cell_b): 35 | """ Compare answer cell coord by row, then by column.""" 36 | if isinstance(linked_cell_a[0], str) and isinstance(linked_cell_b[0], str): 37 | coord_a, coord_b = eval(linked_cell_a[0]), eval(linked_cell_b[0]) 38 | else: 39 | coord_a, coord_b = linked_cell_a[0], linked_cell_b[0] 40 | if coord_a[0] < coord_b[0]: 41 | return -1 42 | elif coord_a[0] > coord_b[0]: 43 | return 1 44 | else: 45 | if coord_a[1] < coord_b[1]: 46 | return -1 47 | else: 48 | return 1 49 | 50 | 51 | def sort_region_by_coord(cells): 52 | """ Sort cells by coords, according to cell_compare(). """ 53 | cell_list = sorted(cells, key=functools.cmp_to_key(cell_compare)) 54 | cell_matrix = [] 55 | last_row = None 56 | for cell in cell_list: 57 | col, row = find_column(cell), find_row(cell) 58 | if row == last_row: 59 | cell_matrix[-1].append(cell) 60 | else: 61 | last_row = row 62 | cell_matrix.append([cell]) 63 | return cell_list, cell_matrix 64 | 65 | 66 | # -------------------------------------------- 67 | # Normalize and Inferring Types. 68 | def normalize(x): 69 | """ Normalize header string. """ 70 | # Copied from WikiTableQuestions dataset official evaluator. 71 | if x is None: 72 | return None 73 | # Remove diacritics 74 | x = ''.join(c for c in unicodedata.normalize('NFKD', x) 75 | if unicodedata.category(c) != 'Mn') 76 | # Normalize quotes and dashes 77 | x = re.sub("[‘’´`]", "'", x) 78 | x = re.sub("[“”]", "\"", x) 79 | x = re.sub("[‐‑‒–—−]", "-", x) 80 | while True: 81 | old_x = x 82 | # Remove citations 83 | x = re.sub("((?").split(",") 90 | pred = item["predict"].split("")[0].split(", ") 91 | ground_truth_list.append(ground_truth) 92 | pred_list.append(pred) 93 | 94 | get_r_p_f1_for_each_type(ground_truth_list, pred_list) 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser(description='arg parser') 99 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/rel_extraction_pred.json', help='') 100 | args = parser.parse_args() 101 | main(args) -------------------------------------------------------------------------------- /eval_scripts/eval_col_type.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from collections import Counter 4 | from metric import * 5 | 6 | 7 | def r_p_f1(true_positive, pred, targets): 8 | if pred != 0: 9 | precision = true_positive/pred 10 | else: 11 | precision = 0 12 | recall = true_positive/targets 13 | if (precision + recall) != 0: 14 | f1 = 2 * precision * recall / (precision + recall) 15 | else: 16 | f1 = 0 17 | return recall, precision, f1 18 | 19 | def get_r_p_f1_for_each_type(ground_truth_list, pred_list): 20 | 21 | # import pdb 22 | # pdb.set_trace() 23 | 24 | total_ground_truth_col_types = 0 25 | total_pred_col_types = 0 26 | joint_items_list = [] 27 | for i in range(len(ground_truth_list)): 28 | total_ground_truth_col_types += len(ground_truth_list[i]) 29 | total_pred_col_types += len(pred_list[i]) 30 | joint_items = [item for item in pred_list[i] if item in ground_truth_list[i]] 31 | # total_ground_truth_col_types += len(list(set(ground_truth_list[i]))) 32 | # total_pred_col_types += len(list(set(pred_list[i]))) 33 | # joint_items = [item for item in list(set(pred_list[i])) if item in list(set(ground_truth_list[i]))] 34 | joint_items_list += joint_items 35 | 36 | # import pdb 37 | # pdb.set_trace() 38 | 39 | gt_entire_col_type = {} 40 | for i in range(len(ground_truth_list)): 41 | gt = list(set(ground_truth_list[i])) 42 | for k in range(len(gt)): 43 | if gt[k] not in gt_entire_col_type.keys(): 44 | gt_entire_col_type[gt[k]] = 1 45 | else: 46 | gt_entire_col_type[gt[k]] += 1 47 | # print(len(gt_entire_col_type.keys())) 48 | 49 | pd_entire_col_type = {} 50 | for i in range(len(pred_list)): 51 | pd = list(set(pred_list[i])) 52 | for k in range(len(pd)): 53 | if pd[k] not in pd_entire_col_type.keys(): 54 | pd_entire_col_type[pd[k]] = 1 55 | else: 56 | pd_entire_col_type[pd[k]] += 1 57 | # print(len(pd_entire_col_type.keys())) 58 | 59 | joint_entire_col_type = {} 60 | for i in range(len(joint_items_list)): 61 | if joint_items_list[i] not in joint_entire_col_type.keys(): 62 | joint_entire_col_type[joint_items_list[i]] = 1 63 | else: 64 | joint_entire_col_type[joint_items_list[i]] += 1 65 | # print(len(joint_entire_col_type.keys())) 66 | 67 | precision = len(joint_items_list)/total_pred_col_types 68 | recall = len(joint_items_list)/total_ground_truth_col_types 69 | f1 = 2 * precision * recall / (precision + recall) 70 | 71 | sorted_gt = sorted(gt_entire_col_type.items(), key=lambda x: x[1], reverse = True) 72 | # print(sorted_gt) 73 | # print("len(joint_items_list):", len(joint_items_list)) 74 | # print("total_ground_truth_col_types:", total_ground_truth_col_types) 75 | # print("total_pred_col_types:", total_pred_col_types) 76 | print("precision::", precision) 77 | print("recall:", recall) 78 | print("f1:", f1) 79 | 80 | # print('r_p_f1(joint_entire_col_type["people.person"])', r_p_f1(joint_entire_col_type["people.person"], pd_entire_col_type["people.person"], gt_entire_col_type["people.person"])) 81 | # print('r_p_f1(joint_entire_col_type["sports.pro_athlete"])', r_p_f1(joint_entire_col_type["sports.pro_athlete"], pd_entire_col_type["sports.pro_athlete"], gt_entire_col_type["sports.pro_athlete"])) 82 | # print('r_p_f1(joint_entire_col_type["film.actor"])', r_p_f1(joint_entire_col_type["film.actor"], pd_entire_col_type["film.actor"], gt_entire_col_type["film.actor"])) 83 | # print('r_p_f1(joint_entire_col_type["location.location"])', r_p_f1(joint_entire_col_type["location.location"], pd_entire_col_type["location.location"], gt_entire_col_type["location.location"])) 84 | # print('r_p_f1(joint_entire_col_type["location.citytown"])', r_p_f1(joint_entire_col_type["location.citytown"], pd_entire_col_type["location.citytown"], gt_entire_col_type["location.citytown"])) 85 | 86 | 87 | def remove_ele(list, ele): 88 | while True: 89 | if ele in list: 90 | list.remove(ele) 91 | continue 92 | else: 93 | break 94 | return list 95 | 96 | def get_index(lst=None, item=''): 97 | return [i for i in range(len(lst)) if lst[i] == item] 98 | 99 | def main(args): 100 | with open(args.pred_file, "r") as f: 101 | col_type = json.load(f) 102 | 103 | ground_truth_list = [] 104 | pred_list = [] 105 | for i in range(len(col_type)): 106 | item = col_type[i] 107 | ground_truth = item["ground_truth"] 108 | # pred = item["predict"].strip("").split(",") 109 | pred = item["predict"].split("")[0].split(", ") 110 | ground_truth_list.append(ground_truth) 111 | pred_list.append(pred) 112 | 113 | get_r_p_f1_for_each_type(ground_truth_list, pred_list) 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser(description='arg parser') 118 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/col_type_pred.json', help='') 119 | args = parser.parse_args() 120 | main(args) 121 | -------------------------------------------------------------------------------- /inference_rel_extraction_col_type.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import math 5 | import torch 6 | import argparse 7 | # import textwrap 8 | import transformers 9 | from peft import PeftModel 10 | from transformers import GenerationConfig 11 | from llama_attn_replace import replace_llama_attn 12 | from supervised_fine_tune import PROMPT_DICT 13 | from tqdm import tqdm 14 | # from queue import Queue 15 | # from threading import Thread 16 | # import gradio as gr 17 | 18 | def parse_config(): 19 | parser = argparse.ArgumentParser(description='arg parser') 20 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 21 | parser.add_argument('--cache_dir', type=str, default="./cache") 22 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 23 | parser.add_argument('--flash_attn', type=bool, default=False, help='') 24 | parser.add_argument('--temperature', type=float, default=0.6, help='') 25 | parser.add_argument('--top_p', type=float, default=0.9, help='') 26 | parser.add_argument('--max_gen_len', type=int, default=512, help='') 27 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='') 28 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='') 29 | args = parser.parse_args() 30 | return args 31 | 32 | def generate_prompt(instruction, question, input_seg=None): 33 | if input: 34 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question) 35 | else: 36 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) 37 | 38 | 39 | def build_generator( 40 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True 41 | ): 42 | def response(item): 43 | # def response(material, question, material_type="", material_title=None): 44 | # material = read_txt_file(material) 45 | # prompt = format_prompt(material, question, material_type, material_title) 46 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"]) 47 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 48 | 49 | output = model.generate( 50 | **inputs, 51 | max_new_tokens=max_gen_len, 52 | temperature=temperature, 53 | top_p=top_p, 54 | use_cache=use_cache 55 | ) 56 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False) 57 | 58 | out = out.split(prompt)[1].strip() 59 | return out 60 | 61 | return response 62 | 63 | def main(args): 64 | if args.flash_attn: 65 | replace_llama_attn() 66 | 67 | # Set RoPE scaling factor 68 | config = transformers.AutoConfig.from_pretrained( 69 | args.base_model, 70 | cache_dir=args.cache_dir, 71 | ) 72 | 73 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 74 | if orig_ctx_len and args.context_size > orig_ctx_len: 75 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len)) 76 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 77 | 78 | # Load model and tokenizer 79 | model = transformers.AutoModelForCausalLM.from_pretrained( 80 | args.base_model, 81 | config=config, 82 | cache_dir=args.cache_dir, 83 | torch_dtype=torch.float16, 84 | device_map="auto", 85 | ) 86 | model.resize_token_embeddings(32001) 87 | 88 | tokenizer = transformers.AutoTokenizer.from_pretrained( 89 | args.base_model, 90 | cache_dir=args.cache_dir, 91 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 92 | # padding_side="right", 93 | padding_side="left", 94 | use_fast=False, 95 | ) 96 | 97 | model.eval() 98 | if torch.__version__ >= "2" and sys.platform != "win32": 99 | model = torch.compile(model) 100 | 101 | with open(args.input_data_file, "r") as f: 102 | test_data = json.load(f) 103 | 104 | # import random 105 | # test_data = random.sample(test_data, k=2) 106 | 107 | test_data_pred = [] 108 | for i in tqdm(range(len(test_data))): 109 | item = test_data[i] 110 | new_item = {} 111 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p, 112 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later 113 | output = respond(item) 114 | 115 | new_item["idx"] = i 116 | new_item["table_id"] = test_data[i]["table_id"] 117 | new_item["instruction"] = test_data[i]["instruction"] 118 | new_item["input_seg"] = test_data[i]["input_seg"] 119 | new_item["question"] = test_data[i]["question"] 120 | new_item["ground_truth"] = test_data[i]["ground_truth"] 121 | new_item["output"] = test_data[i]["output"] 122 | new_item["predict"] = output 123 | 124 | test_data_pred.append(new_item) 125 | # import pdb 126 | # pdb.set_trace() 127 | with open(args.output_data_file, "w") as f: 128 | json.dump(test_data_pred, f, indent = 2) 129 | 130 | 131 | if __name__ == "__main__": 132 | args = parse_config() 133 | main(args) 134 | 135 | -------------------------------------------------------------------------------- /inference_hitab_tabfact_fetaqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import math 5 | import torch 6 | import argparse 7 | # import textwrap 8 | import transformers 9 | from peft import PeftModel 10 | from transformers import GenerationConfig 11 | from llama_attn_replace import replace_llama_attn 12 | from supervised_fine_tune import PROMPT_DICT 13 | from tqdm import tqdm 14 | # from queue import Queue 15 | # from threading import Thread 16 | # import gradio as gr 17 | 18 | 19 | def parse_config(): 20 | parser = argparse.ArgumentParser(description='arg parser') 21 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 22 | parser.add_argument('--cache_dir', type=str, default="./cache") 23 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 24 | parser.add_argument('--flash_attn', type=bool, default=False, help='') 25 | parser.add_argument('--temperature', type=float, default=0.6, help='') 26 | parser.add_argument('--top_p', type=float, default=0.9, help='') 27 | parser.add_argument('--max_gen_len', type=int, default=512, help='') 28 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='') 29 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='') 30 | args = parser.parse_args() 31 | return args 32 | 33 | def generate_prompt(instruction, question, input_seg=None): 34 | if input: 35 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question) 36 | else: 37 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) 38 | 39 | 40 | def build_generator( 41 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True 42 | ): 43 | def response(item): 44 | # def response(material, question, material_type="", material_title=None): 45 | # material = read_txt_file(material) 46 | # prompt = format_prompt(material, question, material_type, material_title) 47 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"]) 48 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 49 | 50 | output = model.generate( 51 | **inputs, 52 | max_new_tokens=max_gen_len, 53 | temperature=temperature, 54 | top_p=top_p, 55 | use_cache=use_cache 56 | ) 57 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False) 58 | 59 | out = out.split(prompt)[1].strip() 60 | return out 61 | 62 | return response 63 | 64 | def main(args): 65 | if args.flash_attn: 66 | replace_llama_attn() 67 | 68 | # Set RoPE scaling factor 69 | config = transformers.AutoConfig.from_pretrained( 70 | args.base_model, 71 | cache_dir=args.cache_dir, 72 | ) 73 | 74 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 75 | if orig_ctx_len and args.context_size > orig_ctx_len: 76 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len)) 77 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 78 | 79 | # Load model and tokenizer 80 | model = transformers.AutoModelForCausalLM.from_pretrained( 81 | args.base_model, 82 | config=config, 83 | cache_dir=args.cache_dir, 84 | torch_dtype=torch.float16, 85 | device_map="auto", 86 | ) 87 | model.resize_token_embeddings(32001) 88 | 89 | tokenizer = transformers.AutoTokenizer.from_pretrained( 90 | args.base_model, 91 | cache_dir=args.cache_dir, 92 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 93 | # padding_side="right", 94 | padding_side="left", 95 | use_fast=False, 96 | ) 97 | 98 | model.eval() 99 | if torch.__version__ >= "2" and sys.platform != "win32": 100 | model = torch.compile(model) 101 | 102 | with open(args.input_data_file, "r") as f: 103 | test_data = json.load(f) 104 | 105 | # import random 106 | # test_data = random.sample(test_data, k=3) 107 | 108 | test_data_pred = [] 109 | for i in tqdm(range(len(test_data))): 110 | item = test_data[i] 111 | new_item = {} 112 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p, 113 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later 114 | output = respond(item) 115 | 116 | new_item["idx"] = i 117 | # new_item["table_id"] = test_data[i]["table_id"] 118 | new_item["instruction"] = test_data[i]["instruction"] 119 | new_item["input_seg"] = test_data[i]["input_seg"] 120 | new_item["question"] = test_data[i]["question"] 121 | # new_item["ground_truth"] = test_data[i]["ground_truth"] 122 | new_item["output"] = test_data[i]["output"] 123 | new_item["predict"] = output 124 | 125 | test_data_pred.append(new_item) 126 | # import pdb 127 | # pdb.set_trace() 128 | with open(args.output_data_file, "w") as f: 129 | json.dump(test_data_pred, f, indent = 2) 130 | 131 | 132 | if __name__ == "__main__": 133 | args = parse_config() 134 | main(args) 135 | 136 | -------------------------------------------------------------------------------- /inference_ent_link.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import math 5 | import torch 6 | import argparse 7 | # import textwrap 8 | import transformers 9 | from peft import PeftModel 10 | from transformers import GenerationConfig 11 | from llama_attn_replace import replace_llama_attn 12 | from supervised_fine_tune import PROMPT_DICT 13 | from tqdm import tqdm 14 | # from queue import Queue 15 | # from threading import Thread 16 | # import gradio as gr 17 | 18 | def parse_config(): 19 | parser = argparse.ArgumentParser(description='arg parser') 20 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 21 | parser.add_argument('--cache_dir', type=str, default="./cache") 22 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 23 | parser.add_argument('--flash_attn', type=bool, default=False, help='') 24 | parser.add_argument('--temperature', type=float, default=0.6, help='') 25 | parser.add_argument('--top_p', type=float, default=0.9, help='') 26 | parser.add_argument('--max_gen_len', type=int, default=512, help='') 27 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='') 28 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='') 29 | args = parser.parse_args() 30 | return args 31 | 32 | def generate_prompt(instruction, question, input_seg=None): 33 | if input: 34 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question) 35 | else: 36 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) 37 | 38 | 39 | def build_generator( 40 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True 41 | ): 42 | def response(item): 43 | # def response(material, question, material_type="", material_title=None): 44 | # material = read_txt_file(material) 45 | # prompt = format_prompt(material, question, material_type, material_title) 46 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"]) 47 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 48 | 49 | output = model.generate( 50 | **inputs, 51 | max_new_tokens=max_gen_len, 52 | temperature=temperature, 53 | top_p=top_p, 54 | use_cache=use_cache 55 | ) 56 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False) 57 | 58 | out = out.split(prompt)[1].strip() 59 | return out 60 | 61 | return response 62 | 63 | def main(args): 64 | if args.flash_attn: 65 | replace_llama_attn() 66 | 67 | # Set RoPE scaling factor 68 | config = transformers.AutoConfig.from_pretrained( 69 | args.base_model, 70 | cache_dir=args.cache_dir, 71 | ) 72 | 73 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 74 | if orig_ctx_len and args.context_size > orig_ctx_len: 75 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len)) 76 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 77 | 78 | # Load model and tokenizer 79 | model = transformers.AutoModelForCausalLM.from_pretrained( 80 | args.base_model, 81 | config=config, 82 | cache_dir=args.cache_dir, 83 | torch_dtype=torch.float16, 84 | device_map="auto", 85 | ) 86 | model.resize_token_embeddings(32001) 87 | 88 | tokenizer = transformers.AutoTokenizer.from_pretrained( 89 | args.base_model, 90 | cache_dir=args.cache_dir, 91 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 92 | # padding_side="right", 93 | padding_side="left", 94 | use_fast=False, 95 | ) 96 | 97 | model.eval() 98 | if torch.__version__ >= "2" and sys.platform != "win32": 99 | model = torch.compile(model) 100 | 101 | with open(args.input_data_file, "r") as f: 102 | test_data = json.load(f) 103 | 104 | # import random 105 | # test_data = random.sample(test_data, k=5) 106 | 107 | test_data_pred = [] 108 | for i in tqdm(range(len(test_data))): 109 | item = test_data[i] 110 | new_item = {} 111 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p, 112 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later 113 | output = respond(item) 114 | 115 | new_item["idx"] = i 116 | new_item["table_id"] = test_data[i]["id"] 117 | new_item["instruction"] = test_data[i]["instruction"] 118 | new_item["input_seg"] = test_data[i]["input_seg"] 119 | new_item["question"] = test_data[i]["question"] 120 | new_item["candidates_list"] = test_data[i]["candidates_list"] 121 | new_item["candidates_entity_desc_list"] = test_data[i]["candidates_entity_desc_list"] 122 | new_item["output"] = test_data[i]["output"] 123 | new_item["predict"] = output 124 | 125 | test_data_pred.append(new_item) 126 | # import pdb 127 | # pdb.set_trace() 128 | with open(args.output_data_file, "w") as f: 129 | json.dump(test_data_pred, f, indent = 2) 130 | 131 | if __name__ == "__main__": 132 | args = parse_config() 133 | main(args) 134 | 135 | 136 | -------------------------------------------------------------------------------- /inference_row_pop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import math 5 | import torch 6 | import argparse 7 | # import textwrap 8 | import transformers 9 | from peft import PeftModel 10 | from transformers import GenerationConfig 11 | from llama_attn_replace import replace_llama_attn 12 | from supervised_fine_tune import PROMPT_DICT 13 | from tqdm import tqdm 14 | # from queue import Queue 15 | # from threading import Thread 16 | # import gradio as gr 17 | 18 | def parse_config(): 19 | parser = argparse.ArgumentParser(description='arg parser') 20 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 21 | parser.add_argument('--cache_dir', type=str, default="./cache") 22 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 23 | parser.add_argument('--flash_attn', type=bool, default=False, help='') 24 | parser.add_argument('--temperature', type=float, default=0.6, help='') 25 | parser.add_argument('--top_p', type=float, default=0.9, help='') 26 | parser.add_argument('--max_gen_len', type=int, default=512, help='') 27 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='') 28 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='') 29 | args = parser.parse_args() 30 | return args 31 | 32 | def generate_prompt(instruction, question, input_seg=None): 33 | if input: 34 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question) 35 | else: 36 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) 37 | 38 | 39 | def build_generator( 40 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True 41 | ): 42 | def response(item): 43 | # def response(material, question, material_type="", material_title=None): 44 | # material = read_txt_file(material) 45 | # prompt = format_prompt(material, question, material_type, material_title) 46 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"]) 47 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 48 | 49 | output = model.generate( 50 | **inputs, 51 | max_new_tokens=max_gen_len, 52 | temperature=temperature, 53 | top_p=top_p, 54 | use_cache=use_cache 55 | ) 56 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False) 57 | 58 | out = out.split(prompt)[1].strip() 59 | return out 60 | 61 | return response 62 | 63 | def main(args): 64 | if args.flash_attn: 65 | replace_llama_attn() 66 | 67 | # Set RoPE scaling factor 68 | config = transformers.AutoConfig.from_pretrained( 69 | args.base_model, 70 | cache_dir=args.cache_dir, 71 | ) 72 | 73 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 74 | if orig_ctx_len and args.context_size > orig_ctx_len: 75 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len)) 76 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 77 | 78 | # Load model and tokenizer 79 | model = transformers.AutoModelForCausalLM.from_pretrained( 80 | args.base_model, 81 | config=config, 82 | cache_dir=args.cache_dir, 83 | torch_dtype=torch.float16, 84 | device_map="auto", 85 | ) 86 | model.resize_token_embeddings(32001) 87 | 88 | tokenizer = transformers.AutoTokenizer.from_pretrained( 89 | args.base_model, 90 | cache_dir=args.cache_dir, 91 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 92 | # padding_side="right", 93 | padding_side="left", 94 | use_fast=False, 95 | ) 96 | 97 | model.eval() 98 | if torch.__version__ >= "2" and sys.platform != "win32": 99 | model = torch.compile(model) 100 | 101 | with open(args.input_data_file, "r") as f: 102 | test_data = json.load(f) 103 | 104 | # import random 105 | # test_data = random.sample(test_data, k=1) 106 | 107 | test_data_pred = [] 108 | for i in tqdm(range(len(test_data))): 109 | item = test_data[i] 110 | new_item = {} 111 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p, 112 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later 113 | output = respond(item) 114 | 115 | new_item["idx"] = i 116 | new_item["table_id"] = test_data[i]["table_id"] 117 | new_item["seed_ent"] = test_data[i]["seed_ent"] 118 | new_item["instruction"] = test_data[i]["instruction"] 119 | new_item["input_seg"] = test_data[i]["input_seg"] 120 | new_item["question"] = test_data[i]["question"] 121 | new_item["cand_list"] = test_data[i]["cand_list"] 122 | new_item["target"] = test_data[i]["target"] 123 | # new_item["output"] = test_data[i]["output"] 124 | new_item["predict"] = output 125 | 126 | test_data_pred.append(new_item) 127 | # import pdb 128 | # pdb.set_trace() 129 | with open(args.output_data_file, "w") as f: 130 | json.dump(test_data_pred, f, indent = 2) 131 | 132 | if __name__ == "__main__": 133 | args = parse_config() 134 | main(args) 135 | 136 | -------------------------------------------------------------------------------- /eval_scripts/metric.py: -------------------------------------------------------------------------------- 1 | """Information Retrieval metrics 2 | Useful Resources: 3 | http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt 4 | http://www.nii.ac.jp/TechReports/05-014E.pdf 5 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 6 | http://hal.archives-ouvertes.fr/docs/00/72/67/60/PDF/07-busa-fekete.pdf 7 | Learning to Rank for Information Retrieval (Tie-Yan Liu) 8 | """ 9 | import numpy as np 10 | import pdb 11 | import os 12 | import pickle 13 | 14 | 15 | def mean_reciprocal_rank(rs): 16 | """Score is reciprocal of the rank of the first relevant item 17 | First element is 'rank 1'. Relevance is binary (nonzero is relevant). 18 | Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank 19 | >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] 20 | >>> mean_reciprocal_rank(rs) 21 | 0.61111111111111105 22 | >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) 23 | >>> mean_reciprocal_rank(rs) 24 | 0.5 25 | >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]] 26 | >>> mean_reciprocal_rank(rs) 27 | 0.75 28 | Args: 29 | rs: Iterator of relevance scores (list or numpy) in rank order 30 | (first element is the first item) 31 | Returns: 32 | Mean reciprocal rank 33 | """ 34 | rs = (np.asarray(r).nonzero()[0] for r in rs) 35 | return np.mean([1. / (r[0] + 1) if r.size else 0. for r in rs]) 36 | 37 | 38 | def r_precision(r): 39 | """Score is precision after all relevant documents have been retrieved 40 | Relevance is binary (nonzero is relevant). 41 | >>> r = [0, 0, 1] 42 | >>> r_precision(r) 43 | 0.33333333333333331 44 | >>> r = [0, 1, 0] 45 | >>> r_precision(r) 46 | 0.5 47 | >>> r = [1, 0, 0] 48 | >>> r_precision(r) 49 | 1.0 50 | Args: 51 | r: Relevance scores (list or numpy) in rank order 52 | (first element is the first item) 53 | Returns: 54 | R Precision 55 | """ 56 | r = np.asarray(r) != 0 57 | z = r.nonzero()[0] 58 | if not z.size: 59 | return 0. 60 | return np.mean(r[:z[-1] + 1]) 61 | 62 | 63 | def precision_at_k(r, k): 64 | """Score is precision @ k 65 | Relevance is binary (nonzero is relevant). 66 | >>> r = [0, 0, 1] 67 | >>> precision_at_k(r, 1) 68 | 0.0 69 | >>> precision_at_k(r, 2) 70 | 0.0 71 | >>> precision_at_k(r, 3) 72 | 0.33333333333333331 73 | >>> precision_at_k(r, 4) 74 | Traceback (most recent call last): 75 | File "", line 1, in ? 76 | ValueError: Relevance score length < k 77 | Args: 78 | r: Relevance scores (list or numpy) in rank order 79 | (first element is the first item) 80 | Returns: 81 | Precision @ k 82 | Raises: 83 | ValueError: len(r) must be >= k 84 | """ 85 | assert k >= 1 86 | r = np.asarray(r)[:k] != 0 87 | if r.size != k: 88 | raise ValueError('Relevance score length < k') 89 | return np.mean(r) 90 | 91 | 92 | def average_precision(r): 93 | """Score is average precision (area under PR curve) 94 | Relevance is binary (nonzero is relevant). 95 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1] 96 | >>> delta_r = 1. / sum(r) 97 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y]) 98 | 0.7833333333333333 99 | >>> average_precision(r) 100 | 0.78333333333333333 101 | Args: 102 | r: Relevance scores (list or numpy) in rank order 103 | (first element is the first item) 104 | Returns: 105 | Average precision 106 | """ 107 | r = np.asarray(r) != 0 108 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]] 109 | if not out: 110 | return 0. 111 | return np.mean(out) 112 | 113 | 114 | def mean_average_precision(rs): 115 | """Score is mean average precision 116 | Relevance is binary (nonzero is relevant). 117 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]] 118 | >>> mean_average_precision(rs) 119 | 0.78333333333333333 120 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]] 121 | >>> mean_average_precision(rs) 122 | 0.39166666666666666 123 | Args: 124 | rs: Iterator of relevance scores (list or numpy) in rank order 125 | (first element is the first item) 126 | Returns: 127 | Mean average precision 128 | """ 129 | return np.mean([average_precision(r) for r in rs]) 130 | 131 | def row_pop_average_precision(r, target): 132 | """Score is average precision (area under PR curve) 133 | Relevance is binary (nonzero is relevant). 134 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1] 135 | >>> delta_r = 1. / sum(r) 136 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y]) 137 | 0.7833333333333333 138 | >>> average_precision(r) 139 | 0.78333333333333333 140 | Args: 141 | r: Relevance scores (list or numpy) in rank order 142 | (first element is the first item) 143 | Returns: 144 | Average precision 145 | """ 146 | r = np.asarray(r) != 0 147 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]] 148 | if len(out) < len(target): 149 | out += [0] * (len(target) - len(out)) 150 | if not out: 151 | return 0. 152 | return np.mean(out) 153 | 154 | 155 | def dcg_at_k(r, k, method=0): 156 | """Score is discounted cumulative gain (dcg) 157 | Relevance is positive real values. Can use binary 158 | as the previous methods. 159 | Example from 160 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 161 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 162 | >>> dcg_at_k(r, 1) 163 | 3.0 164 | >>> dcg_at_k(r, 1, method=1) 165 | 3.0 166 | >>> dcg_at_k(r, 2) 167 | 5.0 168 | >>> dcg_at_k(r, 2, method=1) 169 | 4.2618595071429155 170 | >>> dcg_at_k(r, 10) 171 | 9.6051177391888114 172 | >>> dcg_at_k(r, 11) 173 | 9.6051177391888114 174 | Args: 175 | r: Relevance scores (list or numpy) in rank order 176 | (first element is the first item) 177 | k: Number of results to consider 178 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 179 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 180 | Returns: 181 | Discounted cumulative gain 182 | """ 183 | r = np.asfarray(r)[:k] 184 | if r.size: 185 | if method == 0: 186 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 187 | elif method == 1: 188 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 189 | else: 190 | raise ValueError('method must be 0 or 1.') 191 | return 0. 192 | 193 | 194 | def ndcg_at_k(r, k, method=0): 195 | """Score is normalized discounted cumulative gain (ndcg) 196 | Relevance is positive real values. Can use binary 197 | as the previous methods. 198 | Example from 199 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 200 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 201 | >>> ndcg_at_k(r, 1) 202 | 1.0 203 | >>> r = [2, 1, 2, 0] 204 | >>> ndcg_at_k(r, 4) 205 | 0.9203032077642922 206 | >>> ndcg_at_k(r, 4, method=1) 207 | 0.96519546960144276 208 | >>> ndcg_at_k([0], 1) 209 | 0.0 210 | >>> ndcg_at_k([1], 2) 211 | 1.0 212 | Args: 213 | r: Relevance scores (list or numpy) in rank order 214 | (first element is the first item) 215 | k: Number of results to consider 216 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 217 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 218 | Returns: 219 | Normalized discounted cumulative gain 220 | """ 221 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 222 | if not dcg_max: 223 | return 0. 224 | return dcg_at_k(r, k, method) / dcg_max 225 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

TableLlama
Towards Open Large Generalist Models for Tables

2 | 3 |
4 | 🔥 🔥 🔥 This repo contains the code, data, and models for TableLlama. 5 | Check out our [Project Page] for more results and analysis! 6 |
7 | 8 |
9 |
10 | 11 |
12 | Figure 1: An overview of TableInstruct and TableLlama. TableInstruct includes a wide variety of realistic tables and tasks with instructions. We make the first step towards developing open-source generalist models for tables with TableInstruct and TableLlama. 13 | 14 |
15 |
16 | 17 |
18 | Figure 2: Illustration of three exemplary tasks: (a) Column type annotation. This task is to annotate the selected column with the correct semantic types. (b) Row population. This task is to populate rows given table metadata and partial row entities. (c) Hierarchical table QA. For subfigures (a) and (b), we mark candidates with red color in the "task instruction" part. The candidate set size can be hundreds to thousands in TableInstruct. 19 | 20 | 21 |

Release progress

22 | 23 | - :ballot_box_with_check: Training Dataset for TableLlama (check `/data_v3` of 🤗 [TableInstruct Dataset](https://huggingface.co/datasets/osunlp/TableInstruct/)) 24 | - :ballot_box_with_check: TableLlama-7B model 25 | - :ballot_box_with_check: Code for Fine-tuning and Inference 26 | - :ballot_box_with_check: Evaluate Dataset of TableInstruct (check `/eval_data` of 🤗 [TableInstruct Dataset](https://huggingface.co/datasets/osunlp/TableInstruct/)) 27 | 28 | 29 |

Updates

30 | 31 | - 2024/3/13: Our paper has been accepted by NAACL 2024! 32 | - 2024/3/21: We refine the prompts of 4 out-of-domain evaluation datasets: FEVEROUS, HybridQA, WikiSQL and WikiTQ of [TableInstruct](https://huggingface.co/datasets/osunlp/TableInstruct/) and update the results. Check the new results! 33 | - 2024/3/21: We add the results of closed-source LLMs: GPT-3.5 and GPT-4. 34 | 35 | ### Datasets and Models 36 | Our dataset and models are all available at Huggingface. 37 | 38 | 🤗 [TableInstruct Dataset](https://huggingface.co/datasets/osunlp/TableInstruct/) 39 | 40 | 🤗 [TableLlama-7B](https://osu-nlp-group.github.io/TableLlama/) 41 | 42 | The model is fine-tuned with the TableInstruct dataset using LongLoRA (7B), fully fine-tuning version as the base model, which replaces the vanilla attention mechanism of the original Llama-2 (7B) with shift short attention. The training takes 9 days on a 48 80*A100 cluster. Check out our paper for more details. 43 | 44 | TableInstruct includes a comprehensive table-based instruction tuning dataset that covers a variety of real-world tables and realistic tasks. We include 14 datasets of 11 tasks in total. 45 | 46 | The model is evaluated on 8 in-domain datasets of 8 tasks and 6 out-of-domain datasets of 4 tasks. 47 | 48 | 49 | ## **Introduction** 50 | We introduce TableLlama and TableInstruct: the FIRST open-source generalist LLM and instruction tuning dataset for tables. The TableLlama model is trained on TableInstruct Dataset, a meticulously curated instruction tuning dataset for tables. TableLlama is tuned on **2.6 million** table-based task data, and can handle up to **8K** context! 51 | 52 | 53 | ## **Installation** 54 | 55 | Clone this repository and install the required packages: 56 | 57 | ```bash 58 | git clone https://github.com/OSU-NLP-Group/TableLlama.git 59 | cd TableLlama 60 | pip install -r requirements.txt 61 | pip install flash-attn --no-build-isolation 62 | ``` 63 | 64 | ## **Training and Inference** 65 | 66 | ### **Fine-tuning** 67 | 68 | To train the 7B model, run: 69 | 70 | ```bash 71 | torchrun --nproc_per_node=8 supervised_fine_tune.py \ 72 | --model_name_or_path $MODEL_DIR \ 73 | --bf16 True \ 74 | --output_dir $OUTPUT_DIR \ 75 | --model_max_length 8192 \ 76 | --use_flash_attn True \ 77 | --data_path $DATA_DIR \ 78 | --cache_dir /ML-A800/hf_cache \ 79 | --low_rank_training False \ 80 | --num_train_epochs 2 \ 81 | --per_device_train_batch_size 3 \ 82 | --per_device_eval_batch_size 2 \ 83 | --gradient_accumulation_steps 1 \ 84 | --evaluation_strategy "no" \ 85 | --save_strategy "steps" \ 86 | --save_steps 2000 \ 87 | --save_total_limit 4 \ 88 | --learning_rate 2e-5 \ 89 | --weight_decay 0.0 \ 90 | --warmup_ratio 0.03 \ 91 | --lr_scheduler_type "cosine" \ 92 | --logging_steps 1 \ 93 | --deepspeed "/ds_configs/stage2.json" \ 94 | --tf32 True \ 95 | --run_name $RUN_NAME 96 | ``` 97 | 98 | **Addressing OOM** 99 | To train the 7B model with super large data size, if you encounter OOM issue, we provide code for streaming. You can run: 100 | ```bash 101 | torchrun --nproc_per_node=8 supervised_fine_tune_stream.py \ 102 | --model_name_or_path $MODEL_DIR \ 103 | --bf16 True \ 104 | --output_dir $OUTPUT_DIR \ 105 | --model_max_length 8192 \ 106 | --use_flash_attn True \ 107 | --data_path $DATA_DIR \ 108 | --gpu_size $GPU_SIZE \ 109 | --data_size $DATA_SIZE \ 110 | --cache_dir /ML-A800/hf_cache \ 111 | --low_rank_training False \ 112 | --num_train_epochs 2 \ 113 | --per_device_train_batch_size 3 \ 114 | --per_device_eval_batch_size 2 \ 115 | --gradient_accumulation_steps 1 \ 116 | --evaluation_strategy "no" \ 117 | --save_strategy "steps" \ 118 | --save_steps 2000 \ 119 | --save_total_limit 4 \ 120 | --learning_rate 2e-5 \ 121 | --weight_decay 0.0 \ 122 | --warmup_ratio 0.03 \ 123 | --lr_scheduler_type "cosine" \ 124 | --logging_steps 1 \ 125 | --deepspeed "/ds_configs/stage2.json" \ 126 | --tf32 True \ 127 | --run_name $RUN_NAME 128 | 129 | ``` 130 | 131 | ### **Inference** 132 | ```bash 133 | python3 inference_rel_extraction_col_type.py \ 134 | --base_model $MODEL_DIR \ 135 | --context_size 8192 \ 136 | --max_gen_len 128 \ 137 | --flash_attn True \ 138 | --input_data_file /test_data/test_col_type.json \ 139 | --output_data_file $OUTPUT_DIR/col_type_pred.json 140 | ``` 141 | 142 | ## **Evaluation** 143 | 144 | The folder `eval_scripts` includes evaluation scripts for all the in-domain test sets. To run the script, take HiTab (hierarchical table QA task) as an example: 145 | 146 | ```bash 147 | cd eval_scripts 148 | python evaluate_hitab.py --file_pred $OUTPUT_DIR/hitab_pred.json 149 | ``` 150 | 151 | ## Prompt Format 152 | 153 | ``` 154 | Below is an instruction that describes a task, paired with an input that provides further context. Write a response that 155 | appropriately completes the request. 156 | 157 | ### Instruction: 158 | {instruction} 159 | 160 | ### Input: 161 | {input} 162 | 163 | ### Question: 164 | {question} 165 | 166 | ### Response: 167 | ``` 168 | 169 | 170 | - The instruction is designed to point out the task and give a detailed task description. 171 | - The input is designed to provide the information about the table. We concatenate table metadata (if any) such as the Wikipedia page title, section 172 | title and table caption with the serialized table as table input. We use '[TLE]' to represent the beginning of the table metadata, and use '[TAB]' to represent the beginning of the serialized table. 173 | - The question is to accommodate all the information the model needs to complete the task and prompt the model to generate an answer. 174 | - Task prompts examples (For more example prompts for other tasks, please refer to Appendix E in our paper.) 175 | 176 |
177 |
178 | 179 | 180 | 181 | 182 |
183 | 184 | **Note:** 185 | 186 | - If you directly use our model for inference on your data, please make sure you organize the data in the same way as the examples shown above and in our paper Appendix. The performance will vary significantly along with the prompts. 187 | 188 | 189 | ## **Citation** 190 | 191 | Please cite our paper if you use our data, model or code. Please also kindly cite the original dataset papers. 192 | 193 | ``` 194 | @misc{zhang2023tablellama, 195 | title={TableLlama: Towards Open Large Generalist Models for Tables}, 196 | author={Tianshu Zhang and Xiang Yue and Yifei Li and Huan Sun}, 197 | year={2023}, 198 | eprint={2311.09206}, 199 | archivePrefix={arXiv}, 200 | primaryClass={cs.CL} 201 | } 202 | ``` 203 | 204 | 205 | -------------------------------------------------------------------------------- /supervised_fine_tune_stream.py: -------------------------------------------------------------------------------- 1 | 2 | # Some code based on https://github.com/epfml/landmark-attention 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import io 17 | import os 18 | import copy 19 | import json 20 | import math 21 | import logging 22 | from dataclasses import dataclass, field 23 | from typing import Dict, Optional, Sequence 24 | from multiprocessing import cpu_count 25 | from datasets import load_dataset 26 | from tqdm import tqdm 27 | import psutil 28 | 29 | import torch 30 | import transformers 31 | # from torch.utils.data import Dataset, IterableDataset 32 | from datasets.iterable_dataset import IterableDataset 33 | from transformers import Trainer, DataCollatorForLanguageModeling 34 | from llama_attn_replace import replace_llama_attn 35 | from peft import LoraConfig, get_peft_model 36 | from torch.distributed import barrier 37 | 38 | 39 | IGNORE_INDEX = -100 40 | DEFAULT_PAD_TOKEN = "[PAD]" 41 | DEFAULT_EOS_TOKEN = "" 42 | DEFAULT_BOS_TOKEN = "" 43 | DEFAULT_UNK_TOKEN = "" 44 | 45 | def _make_r_io_base(f, mode: str): 46 | if not isinstance(f, io.IOBase): 47 | f = open(f, mode=mode) 48 | return f 49 | 50 | def jload(f, mode="r"): 51 | """Load a .json file into a dictionary.""" 52 | f = _make_r_io_base(f, mode) 53 | jdict = json.load(f) 54 | f.close() 55 | return jdict 56 | 57 | def findAllFile(base): 58 | for root, ds, fs in os.walk(base): 59 | for f in fs: 60 | if f.endswith('.json'): 61 | fullname = os.path.join(root,f) 62 | yield fullname 63 | 64 | 65 | 66 | PROMPT_DICT = { 67 | "prompt_input": ( 68 | "Below is an instruction that describes a task, paired with an input that provides further context. " 69 | "Write a response that appropriately completes the request.\n\n" 70 | "### Instruction:\n{instruction}\n\n### Input:\n{input_seg}\n\n### Question:\n{question}\n\n### Response:" 71 | ), 72 | "prompt_no_input": ( 73 | "Below is an instruction that describes a task. " 74 | "Write a response that appropriately completes the request.\n\n" 75 | "### Instruction:\n{instruction}\n\n### Response:" 76 | ), 77 | } 78 | 79 | @dataclass 80 | class ModelArguments: 81 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 82 | 83 | 84 | @dataclass 85 | class DataArguments: 86 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 87 | data_size: int = field(default=None, metadata={"help": "for calculate max steps."}) 88 | gpu_size: int = field(default=None, metadata={"help": "for calculate max steps and for logging for calcuated intervel."}) 89 | 90 | 91 | @dataclass 92 | class TrainingArguments(transformers.TrainingArguments): 93 | cache_dir: Optional[str] = field(default=None) 94 | optim: str = field(default="adamw_torch") 95 | model_max_length: int = field( 96 | default=8192 * 4, 97 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 98 | ) 99 | use_flash_attn: bool = field( 100 | default=True, 101 | metadata={"help": "Whether use flash attention for training."}, 102 | ) 103 | low_rank_training: bool = field( 104 | default=True, 105 | metadata={"help": "Whether use low rank adaptation for training."}, 106 | ) 107 | trainable_params: str = field( 108 | default="embed,norm", 109 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, 110 | ) 111 | 112 | def smart_tokenizer_and_embedding_resize( 113 | special_tokens_dict: Dict, 114 | tokenizer: transformers.PreTrainedTokenizer, 115 | model: transformers.PreTrainedModel, 116 | ): 117 | """Resize tokenizer and embedding. 118 | 119 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 120 | """ 121 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 122 | model.resize_token_embeddings(len(tokenizer)) 123 | 124 | if num_new_tokens > 0: 125 | input_embeddings = model.get_input_embeddings().weight.data 126 | output_embeddings = model.get_output_embeddings().weight.data 127 | 128 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 129 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 130 | 131 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 132 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 133 | 134 | 135 | 136 | tok_example_count = 0 137 | 138 | def preprocess_data(data_path, tokenizer): 139 | 140 | def _tokenize_fn(text: str) -> Dict: 141 | """Tokenize a list of strings.""" 142 | tokenized = tokenizer( 143 | text, 144 | return_tensors="pt", 145 | # padding="longest", 146 | padding="max_length", 147 | max_length=tokenizer.model_max_length, 148 | truncation=True, 149 | ) 150 | 151 | input_ids = labels = tokenized.input_ids[0] 152 | input_ids_lens = labels_lens = tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() 153 | 154 | return dict( 155 | input_ids=input_ids, 156 | labels=labels, 157 | input_ids_lens=input_ids_lens, 158 | labels_lens=labels_lens, 159 | ) 160 | 161 | def _process_function(example): 162 | 163 | global tok_example_count 164 | tok_example_count += 1 165 | if tok_example_count % 128 == 0: 166 | logging.warning(f"tok_example_count: {tok_example_count}") 167 | 168 | # logging.warning("Formatting inputs...") 169 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 170 | source = prompt_input.format_map(example) if example.get("input_seg", "") != "" else prompt_no_input.format_map(example) 171 | 172 | # targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 173 | target = f"{example['output']}{DEFAULT_EOS_TOKEN}" 174 | 175 | source_target = source + target 176 | 177 | example_tokenized = _tokenize_fn(source_target) 178 | source_tokenized = _tokenize_fn(source) 179 | 180 | input_ids = example_tokenized["input_ids"] 181 | label = copy.deepcopy(input_ids) 182 | label[:source_tokenized["input_ids_lens"]] = IGNORE_INDEX 183 | 184 | new_example = {"input_ids": input_ids, "labels": label} 185 | 186 | return new_example 187 | 188 | base = data_path 189 | sub_data_path_list = [] 190 | for sub_data_path in findAllFile(base): 191 | sub_data_path_list.append(sub_data_path) 192 | 193 | logging.warning(f"before loading data, RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") 194 | list_data_dict = load_dataset('json', data_files=sub_data_path_list, split = f'train', streaming=True) 195 | logging.warning(f"list_data_dict: {list_data_dict}") 196 | logging.warning(f"after loading data, RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") 197 | 198 | logging.warning("Tokenizing inputs... This may take some time...") 199 | tokenized_dataset = list_data_dict.map(_process_function, remove_columns=["instruction", "input_seg", "question", "output"]) 200 | 201 | return tokenized_dataset 202 | 203 | 204 | @dataclass 205 | class DataCollatorForSupervisedDataset(object): 206 | """Collate examples for supervised fine-tuning.""" 207 | 208 | tokenizer: transformers.PreTrainedTokenizer 209 | 210 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 211 | # logging.warning(f"instances: {instances}") 212 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 213 | # logging.warning(f"input_ids: {input_ids}") 214 | # logging.warning(f"labels: {labels}") 215 | input_ids = torch.nn.utils.rnn.pad_sequence( 216 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 217 | ) 218 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 219 | return dict( 220 | input_ids=input_ids, 221 | labels=labels, 222 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 223 | ) 224 | 225 | 226 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 227 | """Make dataset and collator for supervised fine-tuning.""" 228 | train_dataset = preprocess_data(tokenizer=tokenizer, data_path=data_args.data_path) 229 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 230 | # logging.warning(f"train_dataset: {train_dataset}") 231 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 232 | 233 | 234 | def train(): 235 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 236 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 237 | 238 | replace_llama_attn(training_args.use_flash_attn, True) 239 | 240 | # Set RoPE scaling factor 241 | config = transformers.AutoConfig.from_pretrained( 242 | model_args.model_name_or_path, 243 | cache_dir=training_args.cache_dir, 244 | ) 245 | 246 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 247 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: 248 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 249 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 250 | 251 | # Load model and tokenizer 252 | model = transformers.AutoModelForCausalLM.from_pretrained( 253 | model_args.model_name_or_path, 254 | config=config, 255 | cache_dir=training_args.cache_dir, 256 | ) 257 | 258 | tokenizer = transformers.AutoTokenizer.from_pretrained( 259 | model_args.model_name_or_path, 260 | cache_dir=training_args.cache_dir, 261 | model_max_length=training_args.model_max_length, 262 | padding_side="right", 263 | use_fast=False, 264 | ) 265 | 266 | special_tokens_dict = dict() 267 | if tokenizer.pad_token is None: 268 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 269 | if tokenizer.eos_token is None: 270 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 271 | if tokenizer.bos_token is None: 272 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 273 | if tokenizer.unk_token is None: 274 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 275 | 276 | smart_tokenizer_and_embedding_resize( 277 | special_tokens_dict=special_tokens_dict, 278 | tokenizer=tokenizer, 279 | model=model, 280 | ) 281 | 282 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 283 | 284 | if training_args.low_rank_training: 285 | config = LoraConfig( 286 | r=8, 287 | lora_alpha=16, 288 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], 289 | lora_dropout=0, 290 | bias="none", 291 | task_type="CAUSAL_LM", 292 | ) 293 | model = get_peft_model(model, config) 294 | # enable trainable params 295 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] 296 | 297 | model.enable_input_require_grads() # required for gradient checkpointing 298 | model.gradient_checkpointing_enable() # enable gradient checkpointing 299 | 300 | logging.warning(f"data_module: {data_module}") 301 | # data_size = 2636762 302 | # data_size = 7325 303 | # GPU_size = 8 304 | training_args.max_steps = math.ceil(training_args.num_train_epochs * data_args.data_size/(training_args.per_device_train_batch_size * data_args.gpu_size)) 305 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 306 | trainer.train() 307 | trainer.save_state() 308 | trainer.save_model(output_dir=training_args.output_dir) 309 | 310 | 311 | if __name__ == "__main__": 312 | train() 313 | -------------------------------------------------------------------------------- /supervised_fine_tune.py: -------------------------------------------------------------------------------- 1 | # Some code based on https://github.com/epfml/landmark-attention 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import io 16 | import os 17 | import copy 18 | import json 19 | import math 20 | import logging 21 | from dataclasses import dataclass, field 22 | from typing import Dict, Optional, Sequence 23 | 24 | import torch 25 | import transformers 26 | from torch.utils.data import Dataset 27 | from transformers import Trainer, DataCollatorForLanguageModeling 28 | from llama_attn_replace import replace_llama_attn 29 | from peft import LoraConfig, get_peft_model 30 | from torch.distributed import barrier 31 | 32 | 33 | 34 | IGNORE_INDEX = -100 35 | DEFAULT_PAD_TOKEN = "[PAD]" 36 | DEFAULT_EOS_TOKEN = "" 37 | DEFAULT_BOS_TOKEN = "" 38 | DEFAULT_UNK_TOKEN = "" 39 | 40 | def _make_r_io_base(f, mode: str): 41 | if not isinstance(f, io.IOBase): 42 | f = open(f, mode=mode) 43 | return f 44 | 45 | def jload(f, mode="r"): 46 | """Load a .json file into a dictionary.""" 47 | f = _make_r_io_base(f, mode) 48 | jdict = json.load(f) 49 | f.close() 50 | return jdict 51 | 52 | # PROMPT_DICT = { 53 | # "prompt_input": ( 54 | # "Below is an instruction that describes a task, paired with an input that provides further context. " 55 | # "Write a response that appropriately completes the request.\n\n" 56 | # "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 57 | # ), 58 | # "prompt_no_input": ( 59 | # "Below is an instruction that describes a task. " 60 | # "Write a response that appropriately completes the request.\n\n" 61 | # "### Instruction:\n{instruction}\n\n### Response:" 62 | # ), 63 | # } 64 | 65 | PROMPT_DICT = { 66 | "prompt_input": ( 67 | "Below is an instruction that describes a task, paired with an input that provides further context. " 68 | "Write a response that appropriately completes the request.\n\n" 69 | "### Instruction:\n{instruction}\n\n### Input:\n{input_seg}\n\n### Question:\n{question}\n\n### Response:" 70 | ), 71 | "prompt_no_input": ( 72 | "Below is an instruction that describes a task. " 73 | "Write a response that appropriately completes the request.\n\n" 74 | "### Instruction:\n{instruction}\n\n### Response:" 75 | ), 76 | } 77 | 78 | @dataclass 79 | class ModelArguments: 80 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 81 | 82 | 83 | @dataclass 84 | class DataArguments: 85 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 86 | 87 | 88 | @dataclass 89 | class TrainingArguments(transformers.TrainingArguments): 90 | cache_dir: Optional[str] = field(default=None) 91 | optim: str = field(default="adamw_torch") 92 | model_max_length: int = field( 93 | default=8192 * 4, 94 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 95 | ) 96 | use_flash_attn: bool = field( 97 | default=True, 98 | metadata={"help": "Whether use flash attention for training."}, 99 | ) 100 | low_rank_training: bool = field( 101 | default=True, 102 | metadata={"help": "Whether use low rank adaptation for training."}, 103 | ) 104 | trainable_params: str = field( 105 | default="embed,norm", 106 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, 107 | ) 108 | 109 | def smart_tokenizer_and_embedding_resize( 110 | special_tokens_dict: Dict, 111 | tokenizer: transformers.PreTrainedTokenizer, 112 | model: transformers.PreTrainedModel, 113 | ): 114 | """Resize tokenizer and embedding. 115 | 116 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 117 | """ 118 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 119 | model.resize_token_embeddings(len(tokenizer)) 120 | 121 | if num_new_tokens > 0: 122 | input_embeddings = model.get_input_embeddings().weight.data 123 | output_embeddings = model.get_output_embeddings().weight.data 124 | 125 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 126 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 127 | 128 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 129 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 130 | 131 | 132 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 133 | """Tokenize a list of strings.""" 134 | tokenized_list = [ 135 | tokenizer( 136 | text, 137 | return_tensors="pt", 138 | padding="longest", 139 | max_length=tokenizer.model_max_length, 140 | truncation=True, 141 | ) 142 | for text in strings 143 | ] 144 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 145 | input_ids_lens = labels_lens = [ 146 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 147 | ] 148 | return dict( 149 | input_ids=input_ids, 150 | labels=labels, 151 | input_ids_lens=input_ids_lens, 152 | labels_lens=labels_lens, 153 | ) 154 | 155 | 156 | def preprocess( 157 | sources: Sequence[str], 158 | targets: Sequence[str], 159 | tokenizer: transformers.PreTrainedTokenizer, 160 | ) -> Dict: 161 | """Preprocess the data by tokenizing.""" 162 | examples = [s + t for s, t in zip(sources, targets)] 163 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 164 | input_ids = examples_tokenized["input_ids"] 165 | labels = copy.deepcopy(input_ids) 166 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 167 | label[:source_len] = IGNORE_INDEX 168 | return dict(input_ids=input_ids, labels=labels) 169 | 170 | 171 | class SupervisedDataset(Dataset): 172 | """Dataset for supervised fine-tuning.""" 173 | 174 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): 175 | super(SupervisedDataset, self).__init__() 176 | logging.warning("Loading data...") 177 | list_data_dict = jload(data_path) 178 | 179 | logging.warning("Formatting inputs...") 180 | ''' 181 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 182 | sources = [ 183 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 184 | for example in list_data_dict 185 | ] 186 | ''' 187 | 188 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 189 | 190 | # print(prompt_input.format_map(list_data_dict[1])) 191 | # print("****") 192 | # print(f"{list_data_dict[1]['output']}{DEFAULT_EOS_TOKEN}") 193 | # import pdb 194 | # pdb.set_trace() 195 | 196 | sources = [ 197 | # prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 198 | prompt_input.format_map(example) if example.get("input_seg", "") != "" else prompt_no_input.format_map(example) 199 | for example in list_data_dict 200 | ] 201 | # targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 202 | targets = [f"{example['output']}{DEFAULT_EOS_TOKEN}" for example in list_data_dict] 203 | 204 | # sources = [example["instruction"] for example in list_data_dict] 205 | 206 | # targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 207 | 208 | logging.warning("Tokenizing inputs... This may take some time...") 209 | data_dict = preprocess(sources, targets, tokenizer) 210 | 211 | self.input_ids = data_dict["input_ids"] 212 | self.labels = data_dict["labels"] 213 | 214 | def __len__(self): 215 | return len(self.input_ids) 216 | 217 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 218 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 219 | 220 | 221 | @dataclass 222 | class DataCollatorForSupervisedDataset(object): 223 | """Collate examples for supervised fine-tuning.""" 224 | 225 | tokenizer: transformers.PreTrainedTokenizer 226 | 227 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 228 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 229 | input_ids = torch.nn.utils.rnn.pad_sequence( 230 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 231 | ) 232 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 233 | return dict( 234 | input_ids=input_ids, 235 | labels=labels, 236 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 237 | ) 238 | 239 | 240 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 241 | """Make dataset and collator for supervised fine-tuning.""" 242 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) 243 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 244 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 245 | 246 | 247 | def train(): 248 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 249 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 250 | 251 | replace_llama_attn(training_args.use_flash_attn, True) 252 | 253 | # Set RoPE scaling factor 254 | config = transformers.AutoConfig.from_pretrained( 255 | model_args.model_name_or_path, 256 | cache_dir=training_args.cache_dir, 257 | ) 258 | 259 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 260 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: 261 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 262 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 263 | 264 | # Load model and tokenizer 265 | model = transformers.AutoModelForCausalLM.from_pretrained( 266 | model_args.model_name_or_path, 267 | config=config, 268 | cache_dir=training_args.cache_dir, 269 | ) 270 | 271 | tokenizer = transformers.AutoTokenizer.from_pretrained( 272 | model_args.model_name_or_path, 273 | cache_dir=training_args.cache_dir, 274 | model_max_length=training_args.model_max_length, 275 | padding_side="right", 276 | use_fast=False, 277 | ) 278 | 279 | special_tokens_dict = dict() 280 | if tokenizer.pad_token is None: 281 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 282 | if tokenizer.eos_token is None: 283 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 284 | if tokenizer.bos_token is None: 285 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 286 | if tokenizer.unk_token is None: 287 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 288 | 289 | smart_tokenizer_and_embedding_resize( 290 | special_tokens_dict=special_tokens_dict, 291 | tokenizer=tokenizer, 292 | model=model, 293 | ) 294 | 295 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 296 | 297 | if training_args.low_rank_training: 298 | config = LoraConfig( 299 | r=8, 300 | lora_alpha=16, 301 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], 302 | lora_dropout=0, 303 | bias="none", 304 | task_type="CAUSAL_LM", 305 | ) 306 | model = get_peft_model(model, config) 307 | # enable trainable params 308 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] 309 | 310 | model.enable_input_require_grads() # required for gradient checkpointing 311 | model.gradient_checkpointing_enable() # enable gradient checkpointing 312 | 313 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 314 | trainer.train() 315 | trainer.save_state() 316 | trainer.save_model(output_dir=training_args.output_dir) 317 | 318 | 319 | if __name__ == "__main__": 320 | train() 321 | -------------------------------------------------------------------------------- /inference_schema_aug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import math 5 | import torch 6 | import argparse 7 | # import textwrap 8 | import transformers 9 | from peft import PeftModel 10 | from transformers import GenerationConfig 11 | from llama_attn_replace import replace_llama_attn 12 | from supervised_fine_tune import PROMPT_DICT 13 | from tqdm import tqdm 14 | # from queue import Queue 15 | # from threading import Thread 16 | # import gradio as gr 17 | 18 | def parse_config(): 19 | parser = argparse.ArgumentParser(description='arg parser') 20 | # parser.add_argument('--question', type=str, default="") 21 | # parser.add_argument('--material', type=str, default="") 22 | # parser.add_argument('--material_title', type=str, default="") 23 | # parser.add_argument('--material_type', type=str, default="material") 24 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 25 | parser.add_argument('--cache_dir', type=str, default="./cache") 26 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 27 | parser.add_argument('--flash_attn', type=bool, default=False, help='') 28 | parser.add_argument('--temperature', type=float, default=0.6, help='') 29 | parser.add_argument('--top_p', type=float, default=0.9, help='') 30 | parser.add_argument('--max_gen_len', type=int, default=512, help='') 31 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='') 32 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='') 33 | args = parser.parse_args() 34 | return args 35 | 36 | def generate_prompt(instruction, question, input_seg=None): 37 | if input: 38 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question) 39 | else: 40 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) 41 | 42 | # def format_prompt(material, message, material_type="book", material_title=""): 43 | # if material_type == "paper": 44 | # prompt = f"Below is a paper. Memorize the material and answer my question after the paper.\n {material} \n " 45 | # elif material_type == "book": 46 | # material_title = ", %s"%material_title if len(material_title)>0 else "" 47 | # prompt = f"Below is some paragraphs in the book{material_title}. Memorize the content and answer my question after the book.\n {material} \n " 48 | # else: 49 | # prompt = f"Below is a material. Memorize the material and answer my question after the material. \n {material} \n " 50 | # message = str(message).strip() 51 | # prompt += f"Now the material ends. {message}" 52 | 53 | # return prompt 54 | 55 | # def read_txt_file(material_txt): 56 | # if not material_txt.split(".")[-1]=='txt': 57 | # raise ValueError("Only support txt or pdf file.") 58 | # content = "" 59 | # with open(material_txt) as f: 60 | # for line in f.readlines(): 61 | # content += line 62 | # return content 63 | 64 | def build_generator( 65 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True 66 | ): 67 | def response(item): 68 | # def response(material, question, material_type="", material_title=None): 69 | # material = read_txt_file(material) 70 | # prompt = format_prompt(material, question, material_type, material_title) 71 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"]) 72 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 73 | 74 | output = model.generate( 75 | **inputs, 76 | max_new_tokens=max_gen_len, 77 | temperature=temperature, 78 | top_p=top_p, 79 | use_cache=use_cache 80 | ) 81 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False) 82 | 83 | out = out.split(prompt)[1].strip() 84 | return out 85 | 86 | return response 87 | 88 | def main(args): 89 | if args.flash_attn: 90 | replace_llama_attn() 91 | 92 | # Set RoPE scaling factor 93 | config = transformers.AutoConfig.from_pretrained( 94 | args.base_model, 95 | cache_dir=args.cache_dir, 96 | ) 97 | 98 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 99 | if orig_ctx_len and args.context_size > orig_ctx_len: 100 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len)) 101 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 102 | 103 | # Load model and tokenizer 104 | model = transformers.AutoModelForCausalLM.from_pretrained( 105 | args.base_model, 106 | config=config, 107 | cache_dir=args.cache_dir, 108 | torch_dtype=torch.float16, 109 | device_map="auto", 110 | ) 111 | model.resize_token_embeddings(32001) 112 | 113 | tokenizer = transformers.AutoTokenizer.from_pretrained( 114 | args.base_model, 115 | cache_dir=args.cache_dir, 116 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 117 | # padding_side="right", 118 | padding_side="left", 119 | use_fast=False, 120 | ) 121 | 122 | model.eval() 123 | if torch.__version__ >= "2" and sys.platform != "win32": 124 | model = torch.compile(model) 125 | 126 | with open(args.input_data_file, "r") as f: 127 | test_data = json.load(f) 128 | 129 | # import random 130 | # test_data = random.sample(test_data, k=5) 131 | 132 | test_data_pred = [] 133 | for i in tqdm(range(len(test_data))): 134 | item = test_data[i] 135 | new_item = {} 136 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p, 137 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later 138 | output = respond(item) 139 | 140 | new_item["idx"] = i 141 | new_item["table_id"] = test_data[i]["table_id"] 142 | new_item["instruction"] = test_data[i]["instruction"] 143 | new_item["input_seg"] = test_data[i]["input_seg"] 144 | new_item["question"] = test_data[i]["question"] 145 | new_item["target"] = test_data[i]["target"] 146 | new_item["output_list"] = test_data[i]["output_list"] 147 | new_item["output"] = test_data[i]["output"] 148 | new_item["predict"] = output 149 | 150 | test_data_pred.append(new_item) 151 | # import pdb 152 | # pdb.set_trace() 153 | with open(args.output_data_file, "w") as f: 154 | json.dump(test_data_pred, f, indent = 2) 155 | 156 | # output = respond(args.material, args.question, args.material_type, args.material_title) 157 | # print("output", output) 158 | 159 | if __name__ == "__main__": 160 | args = parse_config() 161 | main(args) 162 | 163 | 164 | # from dataclasses import dataclass, field 165 | 166 | # import numpy as np 167 | # import torch 168 | # import transformers 169 | # from transformers import GenerationConfig 170 | 171 | # from train_llama2_long_context_reformat import ModelArguments, smart_tokenizer_and_embedding_resize, DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, \ 172 | # DEFAULT_BOS_TOKEN, DEFAULT_UNK_TOKEN, PROMPT_DICT 173 | 174 | # import json 175 | # from tqdm import tqdm 176 | # import math 177 | # import argparse 178 | 179 | # @dataclass 180 | # class InferenceArguments: 181 | # model_max_length: int = field( 182 | # # default=512, 183 | # # default=1024, 184 | # default=1536, 185 | # metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 186 | # ) 187 | # load_in_8bit: bool = field( 188 | # default=False, 189 | # metadata={"help": "Load the model in 8-bit mode."}, 190 | # ) 191 | # inference_dtype: torch.dtype = field( 192 | # default=torch.float32, 193 | # metadata={"help": "The dtype to use for inference."}, 194 | # ) 195 | # max_new_tokens: int = field( 196 | # default=64, 197 | # metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 198 | # ) 199 | 200 | # @dataclass 201 | # class FileArguments: 202 | # input_data_file: str = field( 203 | # default="", 204 | # metadata={"help": ""}, 205 | # ) 206 | # output_data_file: str = field( 207 | # default="", 208 | # metadata={"help": ""}, 209 | # ) 210 | 211 | 212 | 213 | 214 | # def batch_process(data_list, model, tokenizer, generation_config, batch_size, max_new_tokens): 215 | # pred = [] 216 | # for i in tqdm(range(math.ceil(len(data_list)/batch_size))): 217 | # if i != math.ceil(len(data_list)/batch_size) - 1: 218 | # batch_data = data_list[i * batch_size: i * batch_size + batch_size] 219 | # else: 220 | # batch_data = data_list[i * batch_size:] 221 | # batch_prompt =[generate_prompt(item["instruction"], item["input_seg"], item["question"]) for item in batch_data] 222 | # inputs = tokenizer(batch_prompt, 223 | # return_tensors="pt", 224 | # padding="longest", 225 | # max_length=tokenizer.model_max_length, 226 | # truncation=True) 227 | # outputs = model.generate(input_ids=inputs["input_ids"].cuda(), generation_config=generation_config, max_new_tokens = max_new_tokens) 228 | 229 | # # import pdb 230 | # # pdb.set_trace() 231 | # # input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] 232 | # # generated_tokens = outputs.sequences[:, input_length:] 233 | # # pred += tokenizer.batch_decode(generated_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0] 234 | # pred += tokenizer.batch_decode(outputs, skip_special_tokens=False, clean_up_tokenization_spaces=False) 235 | # # import pdb 236 | # # pdb.set_trace() 237 | # return pred 238 | 239 | 240 | # def inference(test_data, model_args, inference_args): 241 | # # parser = transformers.HfArgumentParser((ModelArguments, InferenceArguments)) 242 | # # model_args, inference_args = parser.parse_args_into_dataclasses() 243 | 244 | # model = transformers.AutoModelForCausalLM.from_pretrained( 245 | # model_args.model_name_or_path, 246 | # load_in_8bit=inference_args.load_in_8bit, 247 | # torch_dtype=inference_args.inference_dtype, 248 | # device_map="auto", 249 | # ) 250 | # model.cuda() 251 | # model.eval() 252 | 253 | # generation_config = GenerationConfig( 254 | # temperature=0.1, 255 | # top_p=0.75, 256 | # # num_beams=4, 257 | # num_beams=1, 258 | # # num_beams=2, 259 | # ) 260 | 261 | # tokenizer = transformers.AutoTokenizer.from_pretrained( 262 | # model_args.model_name_or_path, 263 | # use_fast=False, 264 | # model_max_length=inference_args.model_max_length, 265 | # padding_side="left" ### important to add this in inference 266 | # ) 267 | 268 | # if tokenizer.pad_token is None: 269 | # smart_tokenizer_and_embedding_resize( 270 | # special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 271 | # tokenizer=tokenizer, 272 | # model=model, 273 | # ) 274 | # tokenizer.add_special_tokens( 275 | # { 276 | # "eos_token": DEFAULT_EOS_TOKEN, 277 | # "bos_token": DEFAULT_BOS_TOKEN, 278 | # "unk_token": DEFAULT_UNK_TOKEN, 279 | # } 280 | # ) 281 | 282 | # pred = batch_process(test_data, model, tokenizer, generation_config, 1, inference_args.max_new_tokens) 283 | 284 | # new_test_list = [] 285 | # for i in tqdm(range(len(test_data))): 286 | # # for i in tqdm(range(90, 101)): 287 | # # for i in tqdm(range(3)): 288 | # instruction = test_data[i]["instruction"] 289 | # item = {} 290 | # item["idx"] = i 291 | # # item["table_id"] = test_data[i]["table_id"] 292 | # # item["entity"] = test_data[i]["entity"] 293 | # item["instruction"] = instruction 294 | # # item["input"] = input 295 | # item["input_seg"] = test_data[i]["input_seg"] 296 | # # item["tokenizer_tensor_shape"] = inputs["input_ids"].shape 297 | # item["output"] = test_data[i]["output"] 298 | # item["predict"] = pred[i] 299 | 300 | # new_test_list.append(item) 301 | # return new_test_list 302 | 303 | 304 | # if __name__ == "__main__": 305 | 306 | # parser = transformers.HfArgumentParser((ModelArguments, InferenceArguments, FileArguments)) 307 | # model_args, inference_args, file_args = parser.parse_args_into_dataclasses() 308 | 309 | # # num = 0 310 | # # with open("/users/PAA0201/shubaobao/stanford_alpaca/table_all_tasks_fair/test/split_16_col_type/test_" + str(file_args.input_data_file_num) + ".json", "r") as f: 311 | # with open(file_args.input_data_file, "r") as f: 312 | # test_data = json.load(f) 313 | 314 | # # import random 315 | # # test_data = random.sample(test_data, k=5) 316 | # test_list = inference(test_data, model_args, inference_args) 317 | 318 | # # with open("/users/PAA0201/shubaobao/stanford_alpaca/table_all_tasks_fair/pred_ser_20000_seg/test_beam_search/test_" + str(file_args.output_data_file_num) + ".json", "w") as f: 319 | # with open(file_args.output_data_file, "w") as f: 320 | # json.dump(test_list, f, indent = 2) 321 | 322 | # print("input file:", str(file_args.input_data_file)) 323 | # print("output file:", str(file_args.output_data_file)) 324 | -------------------------------------------------------------------------------- /llama_attn_replace.py: -------------------------------------------------------------------------------- 1 | # Modified based on https://github.com/lm-sys/FastChat 2 | 3 | from typing import Optional, Tuple 4 | import warnings 5 | import torch 6 | from torch import nn 7 | import transformers 8 | from einops import rearrange 9 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 10 | 11 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 12 | from flash_attn.bert_padding import unpad_input, pad_input 13 | 14 | group_size_ratio = 1/4 15 | def forward_flashattn( 16 | self, 17 | hidden_states: torch.Tensor, 18 | attention_mask: Optional[torch.Tensor] = None, 19 | position_ids: Optional[torch.Tensor] = None, 20 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 21 | output_attentions: bool = False, 22 | use_cache: bool = False, 23 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 24 | """Input shape: Batch x Time x Channel 25 | 26 | attention_mask: [bsz, q_len] 27 | """ 28 | if output_attentions: 29 | warnings.warn( 30 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 31 | ) 32 | 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | # [bsz, q_len, nh, hd] 51 | # [bsz, nh, q_len, hd] 52 | 53 | kv_seq_len = key_states.shape[-2] 54 | if past_key_value is not None: 55 | kv_seq_len += past_key_value[0].shape[-2] 56 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 57 | query_states, key_states = apply_rotary_pos_emb( 58 | query_states, key_states, cos, sin, position_ids 59 | ) 60 | 61 | # Past Key value support 62 | if past_key_value is not None: 63 | # reuse k, v, self_attention 64 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 65 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 66 | 67 | past_key_value = (key_states, value_states) if use_cache else None 68 | 69 | # repeat k/v heads if n_kv_heads < n_heads 70 | key_states = repeat_kv(key_states, self.num_key_value_groups) 71 | value_states = repeat_kv(value_states, self.num_key_value_groups) 72 | 73 | # Flash attention codes from 74 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 75 | 76 | # transform the data into the format required by flash attention 77 | qkv = torch.stack( 78 | [query_states, key_states, value_states], dim=2 79 | ) # [bsz, nh, 3, q_len, hd] 80 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 81 | 82 | # shift 83 | if self.training: 84 | group_size = int(q_len * group_size_ratio) 85 | if q_len % group_size > 0: 86 | raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) 87 | num_group = q_len // group_size 88 | qkv[:, :, :, self.num_heads//2:] = qkv[:, :, :, self.num_heads//2:].roll(-group_size//2, dims=1) 89 | qkv = qkv.reshape(bsz*num_group, group_size, 3, self.num_heads, self.head_dim) 90 | 91 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 92 | # the attention_mask should be the same as the key_padding_mask 93 | 94 | key_padding_mask = attention_mask[:, :group_size].repeat(num_group, 1) if self.training else attention_mask 95 | nheads = qkv.shape[-2] 96 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 97 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 98 | x_unpad = rearrange( 99 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 100 | ) 101 | output_unpad = flash_attn_varlen_qkvpacked_func( 102 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 103 | ) 104 | output = rearrange( 105 | pad_input( 106 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * num_group if self.training else bsz, group_size if self.training else q_len 107 | ), 108 | "b s (h d) -> b s h d", 109 | h=nheads, 110 | ) 111 | output = output.reshape(bsz, q_len, self.num_heads, self.head_dim) 112 | if self.training: 113 | # shift back 114 | output[:, :, self.num_heads//2:] = output[:, :, self.num_heads//2:].roll(group_size//2, dims=1) 115 | 116 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value 117 | 118 | def forward_flashattn_full( 119 | self, 120 | hidden_states: torch.Tensor, 121 | attention_mask: Optional[torch.Tensor] = None, 122 | position_ids: Optional[torch.Tensor] = None, 123 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 124 | output_attentions: bool = False, 125 | use_cache: bool = False, 126 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 127 | """Input shape: Batch x Time x Channel 128 | 129 | attention_mask: [bsz, q_len] 130 | """ 131 | if output_attentions: 132 | warnings.warn( 133 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 134 | ) 135 | 136 | bsz, q_len, _ = hidden_states.size() 137 | 138 | query_states = ( 139 | self.q_proj(hidden_states) 140 | .view(bsz, q_len, self.num_heads, self.head_dim) 141 | .transpose(1, 2) 142 | ) 143 | key_states = ( 144 | self.k_proj(hidden_states) 145 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 146 | .transpose(1, 2) 147 | ) 148 | value_states = ( 149 | self.v_proj(hidden_states) 150 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 151 | .transpose(1, 2) 152 | ) 153 | # [bsz, q_len, nh, hd] 154 | # [bsz, nh, q_len, hd] 155 | 156 | kv_seq_len = key_states.shape[-2] 157 | if past_key_value is not None: 158 | kv_seq_len += past_key_value[0].shape[-2] 159 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 160 | query_states, key_states = apply_rotary_pos_emb( 161 | query_states, key_states, cos, sin, position_ids 162 | ) 163 | 164 | # Past Key value support 165 | if past_key_value is not None: 166 | # reuse k, v, self_attention 167 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 168 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 169 | 170 | past_key_value = (key_states, value_states) if use_cache else None 171 | 172 | # repeat k/v heads if n_kv_heads < n_heads 173 | key_states = repeat_kv(key_states, self.num_key_value_groups) 174 | value_states = repeat_kv(value_states, self.num_key_value_groups) 175 | 176 | # Flash attention codes from 177 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 178 | 179 | # transform the data into the format required by flash attention 180 | qkv = torch.stack( 181 | [query_states, key_states, value_states], dim=2 182 | ) # [bsz, nh, 3, q_len, hd] 183 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 184 | 185 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 186 | # the attention_mask should be the same as the key_padding_mask 187 | 188 | key_padding_mask = attention_mask 189 | nheads = qkv.shape[-2] 190 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 191 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 192 | x_unpad = rearrange( 193 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 194 | ) 195 | output_unpad = flash_attn_varlen_qkvpacked_func( 196 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 197 | ) 198 | output = rearrange( 199 | pad_input( 200 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 201 | ), 202 | "b s (h d) -> b s h d", 203 | h=nheads, 204 | ) 205 | output = output.reshape(bsz, q_len, self.num_heads, self.head_dim) 206 | 207 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value 208 | 209 | 210 | def forward_noflashattn( 211 | self, 212 | hidden_states: torch.Tensor, 213 | attention_mask: Optional[torch.Tensor] = None, 214 | position_ids: Optional[torch.LongTensor] = None, 215 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 216 | output_attentions: bool = False, 217 | use_cache: bool = False, 218 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 219 | bsz, q_len, _ = hidden_states.size() 220 | 221 | group_size = int(q_len * group_size_ratio) 222 | 223 | if q_len % group_size > 0: 224 | raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size)) 225 | num_group = q_len // group_size 226 | 227 | if self.config.pretraining_tp > 1: 228 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 229 | query_slices = self.q_proj.weight.split( 230 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 231 | ) 232 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 233 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 234 | 235 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 236 | query_states = torch.cat(query_states, dim=-1) 237 | 238 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 239 | key_states = torch.cat(key_states, dim=-1) 240 | 241 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 242 | value_states = torch.cat(value_states, dim=-1) 243 | 244 | else: 245 | query_states = self.q_proj(hidden_states) 246 | key_states = self.k_proj(hidden_states) 247 | value_states = self.v_proj(hidden_states) 248 | 249 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 250 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 251 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 252 | 253 | kv_seq_len = key_states.shape[-2] 254 | if past_key_value is not None: 255 | kv_seq_len += past_key_value[0].shape[-2] 256 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 257 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 258 | 259 | if past_key_value is not None: 260 | # reuse k, v, self_attention 261 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 262 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 263 | 264 | past_key_value = (key_states, value_states) if use_cache else None 265 | 266 | # repeat k/v heads if n_kv_heads < n_heads 267 | key_states = repeat_kv(key_states, self.num_key_value_groups) 268 | value_states = repeat_kv(value_states, self.num_key_value_groups) 269 | 270 | # shift 271 | def shift(qkv, bsz, q_len, group_size, num_heads, head_dim): 272 | qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2) 273 | qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2) 274 | return qkv 275 | 276 | query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim) 277 | key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim) 278 | value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim) 279 | 280 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 281 | 282 | if attn_weights.size() != (bsz * num_group, self.num_heads, group_size, group_size): 283 | raise ValueError( 284 | f"Attention weights should be of size {(bsz * num_group, self.num_heads, group_size, group_size)}, but is" 285 | f" {attn_weights.size()}" 286 | ) 287 | 288 | attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) 289 | if attention_mask is not None: 290 | if attention_mask.size() != (bsz * num_group, 1, group_size, group_size): 291 | raise ValueError( 292 | f"Attention mask should be of size {(bsz * num_group, 1, group_size, group_size)}, but is {attention_mask.size()}" 293 | ) 294 | attn_weights = attn_weights + attention_mask 295 | 296 | # upcast attention to fp32 297 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 298 | attn_output = torch.matmul(attn_weights, value_states) 299 | 300 | if attn_output.size() != (bsz * num_group, self.num_heads, group_size, self.head_dim): 301 | raise ValueError( 302 | f"`attn_output` should be of size {(bsz * num_group, self.num_heads, group_size, self.head_dim)}, but is" 303 | f" {attn_output.size()}" 304 | ) 305 | attn_output = attn_output.transpose(1, 2).contiguous() 306 | 307 | attn_output = attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) 308 | 309 | # shift back 310 | output[:, :, self.num_heads//2:] = output[:, :, self.num_heads//2:].roll(group_size//2, dims=1) 311 | 312 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 313 | 314 | if self.config.pretraining_tp > 1: 315 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 316 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 317 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) 318 | else: 319 | attn_output = self.o_proj(attn_output) 320 | 321 | if not output_attentions: 322 | attn_weights = None 323 | 324 | return attn_output, attn_weights, past_key_value 325 | 326 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 327 | # requires the attention mask to be the same as the key_padding_mask 328 | def _prepare_decoder_attention_mask( 329 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 330 | ): 331 | # [bsz, seq_len] 332 | return attention_mask 333 | 334 | 335 | def replace_llama_attn(use_flash_attn=True, use_full=False): 336 | if use_flash_attn: 337 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 338 | if cuda_major < 8: 339 | warnings.warn( 340 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 341 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 342 | ) 343 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 344 | _prepare_decoder_attention_mask 345 | ) 346 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattn 347 | else: 348 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn 349 | --------------------------------------------------------------------------------