├── imgs
├── hitab.png
├── fetaqa.png
├── tabfact.png
├── hybridqa.png
├── tablellama_figure1.png
└── tablellama_figure2.png
├── requirements.txt
├── ds_configs
├── stage2.json
└── stage3.json
├── eval_scripts
├── eval_tabfact.py
├── eval_hitab.py
├── eval_ent_link.py
├── eval_fetaqa.py
├── eval_col_pop.py
├── eval_row_pop.py
├── qa_datadump_utils.py
├── table_utils.py
├── eval_rel_extraction.py
├── eval_col_type.py
└── metric.py
├── LICENSE
├── inference_rel_extraction_col_type.py
├── inference_hitab_tabfact_fetaqa.py
├── inference_ent_link.py
├── inference_row_pop.py
├── README.md
├── supervised_fine_tune_stream.py
├── supervised_fine_tune.py
├── inference_schema_aug.py
└── llama_attn_replace.py
/imgs/hitab.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/hitab.png
--------------------------------------------------------------------------------
/imgs/fetaqa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/fetaqa.png
--------------------------------------------------------------------------------
/imgs/tabfact.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/tabfact.png
--------------------------------------------------------------------------------
/imgs/hybridqa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/hybridqa.png
--------------------------------------------------------------------------------
/imgs/tablellama_figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/tablellama_figure1.png
--------------------------------------------------------------------------------
/imgs/tablellama_figure2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSU-NLP-Group/TableLlama/HEAD/imgs/tablellama_figure2.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | rouge_score
3 | fire
4 | openai
5 | transformers>=4.28.1
6 | torch>=2.0
7 | sentencepiece
8 | tokenizers>=0.13.3
9 | wandb
10 | accelerate
11 | datasets
12 | deepspeed
13 | peft
14 | partial
15 | gradio
16 | einops
--------------------------------------------------------------------------------
/ds_configs/stage2.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "gradient_accumulation_steps": "auto",
4 | "gradient_clipping": "auto",
5 | "zero_allow_untested_optimizer": true,
6 | "bf16": {
7 | "enabled": "auto",
8 | "loss_scale": 0,
9 | "initial_scale_power": 16,
10 | "loss_scale_window": 1000,
11 | "hysteresis": 2,
12 | "min_loss_scale": 1
13 | },
14 | "zero_optimization": {
15 | "stage": 2,
16 | "allgather_partitions": true,
17 | "allgather_bucket_size": 1e9,
18 | "reduce_scatter": true,
19 | "reduce_bucket_size": 1e9,
20 | "overlap_comm": true,
21 | "contiguous_gradients": true
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/eval_scripts/eval_tabfact.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 |
5 | def main(args):
6 | with open(args.pred_file, "r") as f:
7 | data = json.load(f)
8 |
9 | correct = 0
10 | remove_count = 0
11 | for i in range(len(data)):
12 | ground_truth = data[i]["output"]
13 | prediction = data[i]["predict"].strip("")
14 | # if prediction.find(ground_truth) == 0:
15 | if prediction == ground_truth:
16 | correct += 1
17 | if prediction.find("") == 0:
18 | remove_count += 1
19 |
20 | print("correct:", correct)
21 | # print("remove_count:", remove_count)
22 | print("accuracy:", correct/(len(data)-remove_count))
23 |
24 | if __name__ == "__main__":
25 | parser = argparse.ArgumentParser(description='arg parser')
26 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/tabfact_pred.json', help='')
27 | args = parser.parse_args()
28 | main(args)
29 |
--------------------------------------------------------------------------------
/eval_scripts/eval_hitab.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import argparse
4 | from table_utils import evaluate
5 |
6 |
7 | def main(args):
8 | with open(args.pred_file, "r") as f:
9 | data = json.load(f)
10 |
11 | pred_list = []
12 | gold_list = []
13 | for i in range(len(data)):
14 | if len(data[i]["predict"].strip("").split(">, <")) > 1:
15 | instance_pred_list = data[i]["predict"].strip("").split(">, <")
16 | pred_list.append(instance_pred_list)
17 | gold_list.append(data[i]["output"].strip("").split(">, <"))
18 | else:
19 | pred_list.append(data[i]["predict"].strip(""))
20 | gold_list.append(data[i]["output"].strip(""))
21 |
22 | print(evaluate(gold_list, pred_list))
23 |
24 |
25 | if __name__ == "__main__":
26 | parser = argparse.ArgumentParser(description='arg parser')
27 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/hitab_pred.json', help='')
28 | args = parser.parse_args()
29 | main(args)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 OSU Natural Language Processing
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/eval_scripts/eval_ent_link.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 |
5 | def main(args):
6 | with open(args.pred_file, "r") as f:
7 | data = json.load(f)
8 |
9 | assert len(data) == 2000
10 | correct_count = 0
11 | multi_candidates_example_count = 0
12 | for i in range(len(data)):
13 | candidate_list = data[i]["candidates_entity_desc_list"]
14 | ground_truth = data[i]["output"].strip("<>").lower()
15 | predict = data[i]["predict"][:-4].strip("<>").lower()
16 | # import pdb
17 | # pdb.set_trace()
18 |
19 | if ground_truth == predict:
20 | correct_count += 1
21 | if len(candidate_list) > 1:
22 | multi_candidates_example_count += 1
23 |
24 |
25 | print("correct_count:", correct_count)
26 | print("acc:", correct_count/len(data))
27 |
28 | # print("multi_candidates_example_count:", multi_candidates_example_count)
29 | # print("multi_candidates_example_ratio:", multi_candidates_example_count/len(data))
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser(description='arg parser')
33 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/ent_link_pred.json', help='')
34 | args = parser.parse_args()
35 | main(args)
36 |
--------------------------------------------------------------------------------
/ds_configs/stage3.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "offload_optimizer": {
26 | "device": "cpu",
27 | "pin_memory": true
28 | },
29 | "offload_param": {
30 | "device": "cpu",
31 | "pin_memory": true
32 | },
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "sub_group_size": 1e9,
36 | "reduce_bucket_size": "auto",
37 | "stage3_prefetch_bucket_size": "auto",
38 | "stage3_param_persistence_threshold": "auto",
39 | "stage3_max_live_parameters": 1e9,
40 | "stage3_max_reuse_distance": 1e9,
41 | "stage3_gather_16bit_weights_on_model_save": false
42 | },
43 | "gradient_accumulation_steps": "auto",
44 | "gradient_clipping": "auto",
45 | "steps_per_print": 5,
46 | "train_batch_size": "auto",
47 | "train_micro_batch_size_per_gpu": "auto",
48 | "wall_clock_breakdown": false
49 | }
50 |
--------------------------------------------------------------------------------
/eval_scripts/eval_fetaqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import evaluate
3 | import json
4 |
5 | def main(args):
6 | with open(args.pred_file, "r") as f:
7 | data = json.load(f)
8 |
9 | test_examples_answer = [x["output"] for x in data]
10 | test_predictions_pred = [x["predict"].strip("") for x in data]
11 | predictions = test_predictions_pred
12 | references = test_examples_answer
13 |
14 | rouge = evaluate.load('rouge')
15 | results = rouge.compute(predictions=predictions, references=references)
16 | print(results)
17 |
18 | # bleu = evaluate.load('bleu')
19 | # results = bleu.compute(predictions=predictions, references=references)
20 | # print(results)
21 |
22 | sacrebleu = evaluate.load('sacrebleu')
23 | results = sacrebleu.compute(predictions=predictions, references=references)
24 | print(results)
25 |
26 | meteor = evaluate.load('meteor')
27 | results = meteor.compute(predictions=predictions, references=references)
28 | print(results)
29 |
30 | # bleurt = evaluate.load('bleurt')
31 | # results = bleurt.compute(predictions=predictions, references=references)
32 | # print(results)
33 |
34 | # bertscore = evaluate.load('bertscore')
35 | # results = bertscore.compute(predictions=predictions, references=references)
36 | # print(results)
37 |
38 | if __name__ == "__main__":
39 | parser = argparse.ArgumentParser(description='arg parser')
40 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/fetaqa_pred.json', help='')
41 | args = parser.parse_args()
42 | main(args)
--------------------------------------------------------------------------------
/eval_scripts/eval_col_pop.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from metric import *
4 |
5 |
6 | def get_map_recall(data, data_name):
7 | rs = []
8 | recall = []
9 | for i in range(len(data)):
10 | ground_truth = data[i]["target"].strip(".")
11 | # ground_truth = data[i]["target"].strip(".")
12 | pred = data[i]["predict"].strip(".")
13 | if "" in pred:
14 | end_tok_ix = pred.rfind("")
15 | pred = pred[:end_tok_ix]
16 | ground_truth_list = ground_truth.split(", ")
17 | # ground_truth_list = test_col_pop_rank[i]["target"].strip(".").split(", ")
18 | pred_list = pred.split(", ")
19 | for k in range(len(pred_list)):
20 | pred_list[k] = pred_list[k].strip("<>")
21 |
22 | # print(len(ground_truth_list), len(pred_list))
23 |
24 | # import pdb
25 | # pdb.set_trace()
26 | # add to remove repeated generated item
27 | new_pred_list = list(set(pred_list))
28 | new_pred_list.sort(key = pred_list.index)
29 | r = [1 if z in ground_truth_list else 0 for z in new_pred_list]
30 | ap = average_precision(r)
31 | # print("ap:", ap)
32 | rs.append(r)
33 |
34 | # if sum(r) != 0:
35 | # recall.append(sum(r)/len(ground_truth_list))
36 | # else:
37 | # recall.append(0)
38 | recall.append(sum(r)/len(ground_truth_list))
39 | map = mean_average_precision(rs)
40 | m_recall = sum(recall)/len(data)
41 | f1 = 2 * map * m_recall / (map + m_recall)
42 | print(data_name, len(data))
43 | print("mean_average_precision:", map)
44 |
45 |
46 | def main(args):
47 | file = args.pred_file
48 |
49 | with open(file, "r") as f:
50 | col_pop = json.load(f)
51 |
52 | get_map_recall(col_pop, 'col_pop')
53 |
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser(description='arg parser')
56 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/col_pop_pred.json', help='')
57 | args = parser.parse_args()
58 | main(args)
59 |
--------------------------------------------------------------------------------
/eval_scripts/eval_row_pop.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from metric import *
4 |
5 |
6 | def get_map_recall(data, data_name):
7 | rs = []
8 | recall = []
9 | ap_list = []
10 | for i in range(len(data)):
11 | pred = data[i]["predict"].strip(".")
12 | if "" in pred:
13 | end_tok_ix = pred.rfind("")
14 | pred = pred[:end_tok_ix]
15 | ground_truth_list = data[i]["target"]
16 | pred_list = pred.split(", ")
17 | for k in range(len(pred_list)):
18 | pred_list[k] = pred_list[k].strip("<>")
19 |
20 | # add to remove repeated generated item
21 | new_pred_list = list(set(pred_list))
22 | new_pred_list.sort(key = pred_list.index)
23 | # r = [1 if z in ground_truth_list else 0 for z in pred_list]
24 | r = [1 if z in ground_truth_list else 0 for z in new_pred_list]
25 | # ap = average_precision(r)
26 | ap = row_pop_average_precision(r, ground_truth_list)
27 | # print("ap:", ap)
28 | ap_list.append(ap)
29 |
30 | map = sum(ap_list)/len(data)
31 | m_recall = sum(recall)/len(data)
32 | f1 = 2 * map * m_recall / (map + m_recall)
33 | print(data_name, len(data))
34 | print("mean_average_precision:", map)
35 |
36 |
37 | # def merge_pred_from_multi_files(data_path, store_path):
38 | # merged_data = []
39 | # for i in range(6):
40 | # with open(data_path + "row_pop_pred_" + str(i) + ".json", "r") as f:
41 | # temp_data = json.load(f)
42 | # merged_data += temp_data
43 | # with open(store_path, "w") as f:
44 | # json.dump(merged_data, f, indent = 2)
45 |
46 |
47 | def main(args):
48 | with open(args.pred_file, "r") as f:
49 | row_pop = json.load(f)
50 | get_map_recall(row_pop, 'row_pop')
51 |
52 |
53 | if __name__ == "__main__":
54 | parser = argparse.ArgumentParser(description='arg parser')
55 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/row_pop_pred.json', help='')
56 | args = parser.parse_args()
57 | main(args)
58 |
59 |
--------------------------------------------------------------------------------
/eval_scripts/qa_datadump_utils.py:
--------------------------------------------------------------------------------
1 | """ Utility functions for datadumping."""
2 | import unicodedata
3 | import re
4 | from openpyxl.utils import get_column_letter, column_index_from_string
5 | import functools
6 |
7 |
8 | # Compare and sort cells
9 | def find_column(coord):
10 | """ Parse column letter from 'E3'. """
11 | return re.findall('[a-zA-Z]+', coord)
12 |
13 |
14 | def find_row(coord):
15 | """ Parse row number from 'E3'. """
16 | return re.findall('[0-9]+', coord)
17 |
18 |
19 | def cell_compare(cell1, cell2):
20 | """ Compare cell coord by row, then by column."""
21 | col1, col2 = find_column(cell1)[0], find_column(cell2)[0]
22 | row1, row2 = find_row(cell1)[0], find_row(cell2)[0]
23 | if int(row1) < int(row2):
24 | return -1
25 | elif int(row1) > int(row2):
26 | return 1
27 | else:
28 | if column_index_from_string(col1) < column_index_from_string(col2):
29 | return -1
30 | else:
31 | return 1
32 |
33 |
34 | def linked_cell_compare(linked_cell_a, linked_cell_b):
35 | """ Compare answer cell coord by row, then by column."""
36 | if isinstance(linked_cell_a[0], str) and isinstance(linked_cell_b[0], str):
37 | coord_a, coord_b = eval(linked_cell_a[0]), eval(linked_cell_b[0])
38 | else:
39 | coord_a, coord_b = linked_cell_a[0], linked_cell_b[0]
40 | if coord_a[0] < coord_b[0]:
41 | return -1
42 | elif coord_a[0] > coord_b[0]:
43 | return 1
44 | else:
45 | if coord_a[1] < coord_b[1]:
46 | return -1
47 | else:
48 | return 1
49 |
50 |
51 | def sort_region_by_coord(cells):
52 | """ Sort cells by coords, according to cell_compare(). """
53 | cell_list = sorted(cells, key=functools.cmp_to_key(cell_compare))
54 | cell_matrix = []
55 | last_row = None
56 | for cell in cell_list:
57 | col, row = find_column(cell), find_row(cell)
58 | if row == last_row:
59 | cell_matrix[-1].append(cell)
60 | else:
61 | last_row = row
62 | cell_matrix.append([cell])
63 | return cell_list, cell_matrix
64 |
65 |
66 | # --------------------------------------------
67 | # Normalize and Inferring Types.
68 | def normalize(x):
69 | """ Normalize header string. """
70 | # Copied from WikiTableQuestions dataset official evaluator.
71 | if x is None:
72 | return None
73 | # Remove diacritics
74 | x = ''.join(c for c in unicodedata.normalize('NFKD', x)
75 | if unicodedata.category(c) != 'Mn')
76 | # Normalize quotes and dashes
77 | x = re.sub("[‘’´`]", "'", x)
78 | x = re.sub("[“”]", "\"", x)
79 | x = re.sub("[‐‑‒–—−]", "-", x)
80 | while True:
81 | old_x = x
82 | # Remove citations
83 | x = re.sub("((?").split(",")
90 | pred = item["predict"].split("")[0].split(", ")
91 | ground_truth_list.append(ground_truth)
92 | pred_list.append(pred)
93 |
94 | get_r_p_f1_for_each_type(ground_truth_list, pred_list)
95 |
96 |
97 | if __name__ == "__main__":
98 | parser = argparse.ArgumentParser(description='arg parser')
99 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/rel_extraction_pred.json', help='')
100 | args = parser.parse_args()
101 | main(args)
--------------------------------------------------------------------------------
/eval_scripts/eval_col_type.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from collections import Counter
4 | from metric import *
5 |
6 |
7 | def r_p_f1(true_positive, pred, targets):
8 | if pred != 0:
9 | precision = true_positive/pred
10 | else:
11 | precision = 0
12 | recall = true_positive/targets
13 | if (precision + recall) != 0:
14 | f1 = 2 * precision * recall / (precision + recall)
15 | else:
16 | f1 = 0
17 | return recall, precision, f1
18 |
19 | def get_r_p_f1_for_each_type(ground_truth_list, pred_list):
20 |
21 | # import pdb
22 | # pdb.set_trace()
23 |
24 | total_ground_truth_col_types = 0
25 | total_pred_col_types = 0
26 | joint_items_list = []
27 | for i in range(len(ground_truth_list)):
28 | total_ground_truth_col_types += len(ground_truth_list[i])
29 | total_pred_col_types += len(pred_list[i])
30 | joint_items = [item for item in pred_list[i] if item in ground_truth_list[i]]
31 | # total_ground_truth_col_types += len(list(set(ground_truth_list[i])))
32 | # total_pred_col_types += len(list(set(pred_list[i])))
33 | # joint_items = [item for item in list(set(pred_list[i])) if item in list(set(ground_truth_list[i]))]
34 | joint_items_list += joint_items
35 |
36 | # import pdb
37 | # pdb.set_trace()
38 |
39 | gt_entire_col_type = {}
40 | for i in range(len(ground_truth_list)):
41 | gt = list(set(ground_truth_list[i]))
42 | for k in range(len(gt)):
43 | if gt[k] not in gt_entire_col_type.keys():
44 | gt_entire_col_type[gt[k]] = 1
45 | else:
46 | gt_entire_col_type[gt[k]] += 1
47 | # print(len(gt_entire_col_type.keys()))
48 |
49 | pd_entire_col_type = {}
50 | for i in range(len(pred_list)):
51 | pd = list(set(pred_list[i]))
52 | for k in range(len(pd)):
53 | if pd[k] not in pd_entire_col_type.keys():
54 | pd_entire_col_type[pd[k]] = 1
55 | else:
56 | pd_entire_col_type[pd[k]] += 1
57 | # print(len(pd_entire_col_type.keys()))
58 |
59 | joint_entire_col_type = {}
60 | for i in range(len(joint_items_list)):
61 | if joint_items_list[i] not in joint_entire_col_type.keys():
62 | joint_entire_col_type[joint_items_list[i]] = 1
63 | else:
64 | joint_entire_col_type[joint_items_list[i]] += 1
65 | # print(len(joint_entire_col_type.keys()))
66 |
67 | precision = len(joint_items_list)/total_pred_col_types
68 | recall = len(joint_items_list)/total_ground_truth_col_types
69 | f1 = 2 * precision * recall / (precision + recall)
70 |
71 | sorted_gt = sorted(gt_entire_col_type.items(), key=lambda x: x[1], reverse = True)
72 | # print(sorted_gt)
73 | # print("len(joint_items_list):", len(joint_items_list))
74 | # print("total_ground_truth_col_types:", total_ground_truth_col_types)
75 | # print("total_pred_col_types:", total_pred_col_types)
76 | print("precision::", precision)
77 | print("recall:", recall)
78 | print("f1:", f1)
79 |
80 | # print('r_p_f1(joint_entire_col_type["people.person"])', r_p_f1(joint_entire_col_type["people.person"], pd_entire_col_type["people.person"], gt_entire_col_type["people.person"]))
81 | # print('r_p_f1(joint_entire_col_type["sports.pro_athlete"])', r_p_f1(joint_entire_col_type["sports.pro_athlete"], pd_entire_col_type["sports.pro_athlete"], gt_entire_col_type["sports.pro_athlete"]))
82 | # print('r_p_f1(joint_entire_col_type["film.actor"])', r_p_f1(joint_entire_col_type["film.actor"], pd_entire_col_type["film.actor"], gt_entire_col_type["film.actor"]))
83 | # print('r_p_f1(joint_entire_col_type["location.location"])', r_p_f1(joint_entire_col_type["location.location"], pd_entire_col_type["location.location"], gt_entire_col_type["location.location"]))
84 | # print('r_p_f1(joint_entire_col_type["location.citytown"])', r_p_f1(joint_entire_col_type["location.citytown"], pd_entire_col_type["location.citytown"], gt_entire_col_type["location.citytown"]))
85 |
86 |
87 | def remove_ele(list, ele):
88 | while True:
89 | if ele in list:
90 | list.remove(ele)
91 | continue
92 | else:
93 | break
94 | return list
95 |
96 | def get_index(lst=None, item=''):
97 | return [i for i in range(len(lst)) if lst[i] == item]
98 |
99 | def main(args):
100 | with open(args.pred_file, "r") as f:
101 | col_type = json.load(f)
102 |
103 | ground_truth_list = []
104 | pred_list = []
105 | for i in range(len(col_type)):
106 | item = col_type[i]
107 | ground_truth = item["ground_truth"]
108 | # pred = item["predict"].strip("").split(",")
109 | pred = item["predict"].split("")[0].split(", ")
110 | ground_truth_list.append(ground_truth)
111 | pred_list.append(pred)
112 |
113 | get_r_p_f1_for_each_type(ground_truth_list, pred_list)
114 |
115 |
116 | if __name__ == "__main__":
117 | parser = argparse.ArgumentParser(description='arg parser')
118 | parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/col_type_pred.json', help='')
119 | args = parser.parse_args()
120 | main(args)
121 |
--------------------------------------------------------------------------------
/inference_rel_extraction_col_type.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import sys
4 | import math
5 | import torch
6 | import argparse
7 | # import textwrap
8 | import transformers
9 | from peft import PeftModel
10 | from transformers import GenerationConfig
11 | from llama_attn_replace import replace_llama_attn
12 | from supervised_fine_tune import PROMPT_DICT
13 | from tqdm import tqdm
14 | # from queue import Queue
15 | # from threading import Thread
16 | # import gradio as gr
17 |
18 | def parse_config():
19 | parser = argparse.ArgumentParser(description='arg parser')
20 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
21 | parser.add_argument('--cache_dir', type=str, default="./cache")
22 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
23 | parser.add_argument('--flash_attn', type=bool, default=False, help='')
24 | parser.add_argument('--temperature', type=float, default=0.6, help='')
25 | parser.add_argument('--top_p', type=float, default=0.9, help='')
26 | parser.add_argument('--max_gen_len', type=int, default=512, help='')
27 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='')
28 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='')
29 | args = parser.parse_args()
30 | return args
31 |
32 | def generate_prompt(instruction, question, input_seg=None):
33 | if input:
34 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question)
35 | else:
36 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
37 |
38 |
39 | def build_generator(
40 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
41 | ):
42 | def response(item):
43 | # def response(material, question, material_type="", material_title=None):
44 | # material = read_txt_file(material)
45 | # prompt = format_prompt(material, question, material_type, material_title)
46 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"])
47 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
48 |
49 | output = model.generate(
50 | **inputs,
51 | max_new_tokens=max_gen_len,
52 | temperature=temperature,
53 | top_p=top_p,
54 | use_cache=use_cache
55 | )
56 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
57 |
58 | out = out.split(prompt)[1].strip()
59 | return out
60 |
61 | return response
62 |
63 | def main(args):
64 | if args.flash_attn:
65 | replace_llama_attn()
66 |
67 | # Set RoPE scaling factor
68 | config = transformers.AutoConfig.from_pretrained(
69 | args.base_model,
70 | cache_dir=args.cache_dir,
71 | )
72 |
73 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
74 | if orig_ctx_len and args.context_size > orig_ctx_len:
75 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
76 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
77 |
78 | # Load model and tokenizer
79 | model = transformers.AutoModelForCausalLM.from_pretrained(
80 | args.base_model,
81 | config=config,
82 | cache_dir=args.cache_dir,
83 | torch_dtype=torch.float16,
84 | device_map="auto",
85 | )
86 | model.resize_token_embeddings(32001)
87 |
88 | tokenizer = transformers.AutoTokenizer.from_pretrained(
89 | args.base_model,
90 | cache_dir=args.cache_dir,
91 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
92 | # padding_side="right",
93 | padding_side="left",
94 | use_fast=False,
95 | )
96 |
97 | model.eval()
98 | if torch.__version__ >= "2" and sys.platform != "win32":
99 | model = torch.compile(model)
100 |
101 | with open(args.input_data_file, "r") as f:
102 | test_data = json.load(f)
103 |
104 | # import random
105 | # test_data = random.sample(test_data, k=2)
106 |
107 | test_data_pred = []
108 | for i in tqdm(range(len(test_data))):
109 | item = test_data[i]
110 | new_item = {}
111 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p,
112 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later
113 | output = respond(item)
114 |
115 | new_item["idx"] = i
116 | new_item["table_id"] = test_data[i]["table_id"]
117 | new_item["instruction"] = test_data[i]["instruction"]
118 | new_item["input_seg"] = test_data[i]["input_seg"]
119 | new_item["question"] = test_data[i]["question"]
120 | new_item["ground_truth"] = test_data[i]["ground_truth"]
121 | new_item["output"] = test_data[i]["output"]
122 | new_item["predict"] = output
123 |
124 | test_data_pred.append(new_item)
125 | # import pdb
126 | # pdb.set_trace()
127 | with open(args.output_data_file, "w") as f:
128 | json.dump(test_data_pred, f, indent = 2)
129 |
130 |
131 | if __name__ == "__main__":
132 | args = parse_config()
133 | main(args)
134 |
135 |
--------------------------------------------------------------------------------
/inference_hitab_tabfact_fetaqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import sys
4 | import math
5 | import torch
6 | import argparse
7 | # import textwrap
8 | import transformers
9 | from peft import PeftModel
10 | from transformers import GenerationConfig
11 | from llama_attn_replace import replace_llama_attn
12 | from supervised_fine_tune import PROMPT_DICT
13 | from tqdm import tqdm
14 | # from queue import Queue
15 | # from threading import Thread
16 | # import gradio as gr
17 |
18 |
19 | def parse_config():
20 | parser = argparse.ArgumentParser(description='arg parser')
21 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
22 | parser.add_argument('--cache_dir', type=str, default="./cache")
23 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
24 | parser.add_argument('--flash_attn', type=bool, default=False, help='')
25 | parser.add_argument('--temperature', type=float, default=0.6, help='')
26 | parser.add_argument('--top_p', type=float, default=0.9, help='')
27 | parser.add_argument('--max_gen_len', type=int, default=512, help='')
28 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='')
29 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='')
30 | args = parser.parse_args()
31 | return args
32 |
33 | def generate_prompt(instruction, question, input_seg=None):
34 | if input:
35 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question)
36 | else:
37 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
38 |
39 |
40 | def build_generator(
41 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
42 | ):
43 | def response(item):
44 | # def response(material, question, material_type="", material_title=None):
45 | # material = read_txt_file(material)
46 | # prompt = format_prompt(material, question, material_type, material_title)
47 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"])
48 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
49 |
50 | output = model.generate(
51 | **inputs,
52 | max_new_tokens=max_gen_len,
53 | temperature=temperature,
54 | top_p=top_p,
55 | use_cache=use_cache
56 | )
57 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
58 |
59 | out = out.split(prompt)[1].strip()
60 | return out
61 |
62 | return response
63 |
64 | def main(args):
65 | if args.flash_attn:
66 | replace_llama_attn()
67 |
68 | # Set RoPE scaling factor
69 | config = transformers.AutoConfig.from_pretrained(
70 | args.base_model,
71 | cache_dir=args.cache_dir,
72 | )
73 |
74 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
75 | if orig_ctx_len and args.context_size > orig_ctx_len:
76 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
77 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
78 |
79 | # Load model and tokenizer
80 | model = transformers.AutoModelForCausalLM.from_pretrained(
81 | args.base_model,
82 | config=config,
83 | cache_dir=args.cache_dir,
84 | torch_dtype=torch.float16,
85 | device_map="auto",
86 | )
87 | model.resize_token_embeddings(32001)
88 |
89 | tokenizer = transformers.AutoTokenizer.from_pretrained(
90 | args.base_model,
91 | cache_dir=args.cache_dir,
92 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
93 | # padding_side="right",
94 | padding_side="left",
95 | use_fast=False,
96 | )
97 |
98 | model.eval()
99 | if torch.__version__ >= "2" and sys.platform != "win32":
100 | model = torch.compile(model)
101 |
102 | with open(args.input_data_file, "r") as f:
103 | test_data = json.load(f)
104 |
105 | # import random
106 | # test_data = random.sample(test_data, k=3)
107 |
108 | test_data_pred = []
109 | for i in tqdm(range(len(test_data))):
110 | item = test_data[i]
111 | new_item = {}
112 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p,
113 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later
114 | output = respond(item)
115 |
116 | new_item["idx"] = i
117 | # new_item["table_id"] = test_data[i]["table_id"]
118 | new_item["instruction"] = test_data[i]["instruction"]
119 | new_item["input_seg"] = test_data[i]["input_seg"]
120 | new_item["question"] = test_data[i]["question"]
121 | # new_item["ground_truth"] = test_data[i]["ground_truth"]
122 | new_item["output"] = test_data[i]["output"]
123 | new_item["predict"] = output
124 |
125 | test_data_pred.append(new_item)
126 | # import pdb
127 | # pdb.set_trace()
128 | with open(args.output_data_file, "w") as f:
129 | json.dump(test_data_pred, f, indent = 2)
130 |
131 |
132 | if __name__ == "__main__":
133 | args = parse_config()
134 | main(args)
135 |
136 |
--------------------------------------------------------------------------------
/inference_ent_link.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import sys
4 | import math
5 | import torch
6 | import argparse
7 | # import textwrap
8 | import transformers
9 | from peft import PeftModel
10 | from transformers import GenerationConfig
11 | from llama_attn_replace import replace_llama_attn
12 | from supervised_fine_tune import PROMPT_DICT
13 | from tqdm import tqdm
14 | # from queue import Queue
15 | # from threading import Thread
16 | # import gradio as gr
17 |
18 | def parse_config():
19 | parser = argparse.ArgumentParser(description='arg parser')
20 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
21 | parser.add_argument('--cache_dir', type=str, default="./cache")
22 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
23 | parser.add_argument('--flash_attn', type=bool, default=False, help='')
24 | parser.add_argument('--temperature', type=float, default=0.6, help='')
25 | parser.add_argument('--top_p', type=float, default=0.9, help='')
26 | parser.add_argument('--max_gen_len', type=int, default=512, help='')
27 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='')
28 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='')
29 | args = parser.parse_args()
30 | return args
31 |
32 | def generate_prompt(instruction, question, input_seg=None):
33 | if input:
34 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question)
35 | else:
36 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
37 |
38 |
39 | def build_generator(
40 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
41 | ):
42 | def response(item):
43 | # def response(material, question, material_type="", material_title=None):
44 | # material = read_txt_file(material)
45 | # prompt = format_prompt(material, question, material_type, material_title)
46 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"])
47 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
48 |
49 | output = model.generate(
50 | **inputs,
51 | max_new_tokens=max_gen_len,
52 | temperature=temperature,
53 | top_p=top_p,
54 | use_cache=use_cache
55 | )
56 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
57 |
58 | out = out.split(prompt)[1].strip()
59 | return out
60 |
61 | return response
62 |
63 | def main(args):
64 | if args.flash_attn:
65 | replace_llama_attn()
66 |
67 | # Set RoPE scaling factor
68 | config = transformers.AutoConfig.from_pretrained(
69 | args.base_model,
70 | cache_dir=args.cache_dir,
71 | )
72 |
73 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
74 | if orig_ctx_len and args.context_size > orig_ctx_len:
75 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
76 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
77 |
78 | # Load model and tokenizer
79 | model = transformers.AutoModelForCausalLM.from_pretrained(
80 | args.base_model,
81 | config=config,
82 | cache_dir=args.cache_dir,
83 | torch_dtype=torch.float16,
84 | device_map="auto",
85 | )
86 | model.resize_token_embeddings(32001)
87 |
88 | tokenizer = transformers.AutoTokenizer.from_pretrained(
89 | args.base_model,
90 | cache_dir=args.cache_dir,
91 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
92 | # padding_side="right",
93 | padding_side="left",
94 | use_fast=False,
95 | )
96 |
97 | model.eval()
98 | if torch.__version__ >= "2" and sys.platform != "win32":
99 | model = torch.compile(model)
100 |
101 | with open(args.input_data_file, "r") as f:
102 | test_data = json.load(f)
103 |
104 | # import random
105 | # test_data = random.sample(test_data, k=5)
106 |
107 | test_data_pred = []
108 | for i in tqdm(range(len(test_data))):
109 | item = test_data[i]
110 | new_item = {}
111 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p,
112 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later
113 | output = respond(item)
114 |
115 | new_item["idx"] = i
116 | new_item["table_id"] = test_data[i]["id"]
117 | new_item["instruction"] = test_data[i]["instruction"]
118 | new_item["input_seg"] = test_data[i]["input_seg"]
119 | new_item["question"] = test_data[i]["question"]
120 | new_item["candidates_list"] = test_data[i]["candidates_list"]
121 | new_item["candidates_entity_desc_list"] = test_data[i]["candidates_entity_desc_list"]
122 | new_item["output"] = test_data[i]["output"]
123 | new_item["predict"] = output
124 |
125 | test_data_pred.append(new_item)
126 | # import pdb
127 | # pdb.set_trace()
128 | with open(args.output_data_file, "w") as f:
129 | json.dump(test_data_pred, f, indent = 2)
130 |
131 | if __name__ == "__main__":
132 | args = parse_config()
133 | main(args)
134 |
135 |
136 |
--------------------------------------------------------------------------------
/inference_row_pop.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import sys
4 | import math
5 | import torch
6 | import argparse
7 | # import textwrap
8 | import transformers
9 | from peft import PeftModel
10 | from transformers import GenerationConfig
11 | from llama_attn_replace import replace_llama_attn
12 | from supervised_fine_tune import PROMPT_DICT
13 | from tqdm import tqdm
14 | # from queue import Queue
15 | # from threading import Thread
16 | # import gradio as gr
17 |
18 | def parse_config():
19 | parser = argparse.ArgumentParser(description='arg parser')
20 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
21 | parser.add_argument('--cache_dir', type=str, default="./cache")
22 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
23 | parser.add_argument('--flash_attn', type=bool, default=False, help='')
24 | parser.add_argument('--temperature', type=float, default=0.6, help='')
25 | parser.add_argument('--top_p', type=float, default=0.9, help='')
26 | parser.add_argument('--max_gen_len', type=int, default=512, help='')
27 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='')
28 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='')
29 | args = parser.parse_args()
30 | return args
31 |
32 | def generate_prompt(instruction, question, input_seg=None):
33 | if input:
34 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question)
35 | else:
36 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
37 |
38 |
39 | def build_generator(
40 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
41 | ):
42 | def response(item):
43 | # def response(material, question, material_type="", material_title=None):
44 | # material = read_txt_file(material)
45 | # prompt = format_prompt(material, question, material_type, material_title)
46 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"])
47 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
48 |
49 | output = model.generate(
50 | **inputs,
51 | max_new_tokens=max_gen_len,
52 | temperature=temperature,
53 | top_p=top_p,
54 | use_cache=use_cache
55 | )
56 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
57 |
58 | out = out.split(prompt)[1].strip()
59 | return out
60 |
61 | return response
62 |
63 | def main(args):
64 | if args.flash_attn:
65 | replace_llama_attn()
66 |
67 | # Set RoPE scaling factor
68 | config = transformers.AutoConfig.from_pretrained(
69 | args.base_model,
70 | cache_dir=args.cache_dir,
71 | )
72 |
73 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
74 | if orig_ctx_len and args.context_size > orig_ctx_len:
75 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
76 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
77 |
78 | # Load model and tokenizer
79 | model = transformers.AutoModelForCausalLM.from_pretrained(
80 | args.base_model,
81 | config=config,
82 | cache_dir=args.cache_dir,
83 | torch_dtype=torch.float16,
84 | device_map="auto",
85 | )
86 | model.resize_token_embeddings(32001)
87 |
88 | tokenizer = transformers.AutoTokenizer.from_pretrained(
89 | args.base_model,
90 | cache_dir=args.cache_dir,
91 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
92 | # padding_side="right",
93 | padding_side="left",
94 | use_fast=False,
95 | )
96 |
97 | model.eval()
98 | if torch.__version__ >= "2" and sys.platform != "win32":
99 | model = torch.compile(model)
100 |
101 | with open(args.input_data_file, "r") as f:
102 | test_data = json.load(f)
103 |
104 | # import random
105 | # test_data = random.sample(test_data, k=1)
106 |
107 | test_data_pred = []
108 | for i in tqdm(range(len(test_data))):
109 | item = test_data[i]
110 | new_item = {}
111 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p,
112 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later
113 | output = respond(item)
114 |
115 | new_item["idx"] = i
116 | new_item["table_id"] = test_data[i]["table_id"]
117 | new_item["seed_ent"] = test_data[i]["seed_ent"]
118 | new_item["instruction"] = test_data[i]["instruction"]
119 | new_item["input_seg"] = test_data[i]["input_seg"]
120 | new_item["question"] = test_data[i]["question"]
121 | new_item["cand_list"] = test_data[i]["cand_list"]
122 | new_item["target"] = test_data[i]["target"]
123 | # new_item["output"] = test_data[i]["output"]
124 | new_item["predict"] = output
125 |
126 | test_data_pred.append(new_item)
127 | # import pdb
128 | # pdb.set_trace()
129 | with open(args.output_data_file, "w") as f:
130 | json.dump(test_data_pred, f, indent = 2)
131 |
132 | if __name__ == "__main__":
133 | args = parse_config()
134 | main(args)
135 |
136 |
--------------------------------------------------------------------------------
/eval_scripts/metric.py:
--------------------------------------------------------------------------------
1 | """Information Retrieval metrics
2 | Useful Resources:
3 | http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt
4 | http://www.nii.ac.jp/TechReports/05-014E.pdf
5 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf
6 | http://hal.archives-ouvertes.fr/docs/00/72/67/60/PDF/07-busa-fekete.pdf
7 | Learning to Rank for Information Retrieval (Tie-Yan Liu)
8 | """
9 | import numpy as np
10 | import pdb
11 | import os
12 | import pickle
13 |
14 |
15 | def mean_reciprocal_rank(rs):
16 | """Score is reciprocal of the rank of the first relevant item
17 | First element is 'rank 1'. Relevance is binary (nonzero is relevant).
18 | Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank
19 | >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
20 | >>> mean_reciprocal_rank(rs)
21 | 0.61111111111111105
22 | >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]])
23 | >>> mean_reciprocal_rank(rs)
24 | 0.5
25 | >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]]
26 | >>> mean_reciprocal_rank(rs)
27 | 0.75
28 | Args:
29 | rs: Iterator of relevance scores (list or numpy) in rank order
30 | (first element is the first item)
31 | Returns:
32 | Mean reciprocal rank
33 | """
34 | rs = (np.asarray(r).nonzero()[0] for r in rs)
35 | return np.mean([1. / (r[0] + 1) if r.size else 0. for r in rs])
36 |
37 |
38 | def r_precision(r):
39 | """Score is precision after all relevant documents have been retrieved
40 | Relevance is binary (nonzero is relevant).
41 | >>> r = [0, 0, 1]
42 | >>> r_precision(r)
43 | 0.33333333333333331
44 | >>> r = [0, 1, 0]
45 | >>> r_precision(r)
46 | 0.5
47 | >>> r = [1, 0, 0]
48 | >>> r_precision(r)
49 | 1.0
50 | Args:
51 | r: Relevance scores (list or numpy) in rank order
52 | (first element is the first item)
53 | Returns:
54 | R Precision
55 | """
56 | r = np.asarray(r) != 0
57 | z = r.nonzero()[0]
58 | if not z.size:
59 | return 0.
60 | return np.mean(r[:z[-1] + 1])
61 |
62 |
63 | def precision_at_k(r, k):
64 | """Score is precision @ k
65 | Relevance is binary (nonzero is relevant).
66 | >>> r = [0, 0, 1]
67 | >>> precision_at_k(r, 1)
68 | 0.0
69 | >>> precision_at_k(r, 2)
70 | 0.0
71 | >>> precision_at_k(r, 3)
72 | 0.33333333333333331
73 | >>> precision_at_k(r, 4)
74 | Traceback (most recent call last):
75 | File "", line 1, in ?
76 | ValueError: Relevance score length < k
77 | Args:
78 | r: Relevance scores (list or numpy) in rank order
79 | (first element is the first item)
80 | Returns:
81 | Precision @ k
82 | Raises:
83 | ValueError: len(r) must be >= k
84 | """
85 | assert k >= 1
86 | r = np.asarray(r)[:k] != 0
87 | if r.size != k:
88 | raise ValueError('Relevance score length < k')
89 | return np.mean(r)
90 |
91 |
92 | def average_precision(r):
93 | """Score is average precision (area under PR curve)
94 | Relevance is binary (nonzero is relevant).
95 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]
96 | >>> delta_r = 1. / sum(r)
97 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y])
98 | 0.7833333333333333
99 | >>> average_precision(r)
100 | 0.78333333333333333
101 | Args:
102 | r: Relevance scores (list or numpy) in rank order
103 | (first element is the first item)
104 | Returns:
105 | Average precision
106 | """
107 | r = np.asarray(r) != 0
108 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]
109 | if not out:
110 | return 0.
111 | return np.mean(out)
112 |
113 |
114 | def mean_average_precision(rs):
115 | """Score is mean average precision
116 | Relevance is binary (nonzero is relevant).
117 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]]
118 | >>> mean_average_precision(rs)
119 | 0.78333333333333333
120 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]]
121 | >>> mean_average_precision(rs)
122 | 0.39166666666666666
123 | Args:
124 | rs: Iterator of relevance scores (list or numpy) in rank order
125 | (first element is the first item)
126 | Returns:
127 | Mean average precision
128 | """
129 | return np.mean([average_precision(r) for r in rs])
130 |
131 | def row_pop_average_precision(r, target):
132 | """Score is average precision (area under PR curve)
133 | Relevance is binary (nonzero is relevant).
134 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]
135 | >>> delta_r = 1. / sum(r)
136 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y])
137 | 0.7833333333333333
138 | >>> average_precision(r)
139 | 0.78333333333333333
140 | Args:
141 | r: Relevance scores (list or numpy) in rank order
142 | (first element is the first item)
143 | Returns:
144 | Average precision
145 | """
146 | r = np.asarray(r) != 0
147 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]
148 | if len(out) < len(target):
149 | out += [0] * (len(target) - len(out))
150 | if not out:
151 | return 0.
152 | return np.mean(out)
153 |
154 |
155 | def dcg_at_k(r, k, method=0):
156 | """Score is discounted cumulative gain (dcg)
157 | Relevance is positive real values. Can use binary
158 | as the previous methods.
159 | Example from
160 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf
161 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]
162 | >>> dcg_at_k(r, 1)
163 | 3.0
164 | >>> dcg_at_k(r, 1, method=1)
165 | 3.0
166 | >>> dcg_at_k(r, 2)
167 | 5.0
168 | >>> dcg_at_k(r, 2, method=1)
169 | 4.2618595071429155
170 | >>> dcg_at_k(r, 10)
171 | 9.6051177391888114
172 | >>> dcg_at_k(r, 11)
173 | 9.6051177391888114
174 | Args:
175 | r: Relevance scores (list or numpy) in rank order
176 | (first element is the first item)
177 | k: Number of results to consider
178 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...]
179 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...]
180 | Returns:
181 | Discounted cumulative gain
182 | """
183 | r = np.asfarray(r)[:k]
184 | if r.size:
185 | if method == 0:
186 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
187 | elif method == 1:
188 | return np.sum(r / np.log2(np.arange(2, r.size + 2)))
189 | else:
190 | raise ValueError('method must be 0 or 1.')
191 | return 0.
192 |
193 |
194 | def ndcg_at_k(r, k, method=0):
195 | """Score is normalized discounted cumulative gain (ndcg)
196 | Relevance is positive real values. Can use binary
197 | as the previous methods.
198 | Example from
199 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf
200 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]
201 | >>> ndcg_at_k(r, 1)
202 | 1.0
203 | >>> r = [2, 1, 2, 0]
204 | >>> ndcg_at_k(r, 4)
205 | 0.9203032077642922
206 | >>> ndcg_at_k(r, 4, method=1)
207 | 0.96519546960144276
208 | >>> ndcg_at_k([0], 1)
209 | 0.0
210 | >>> ndcg_at_k([1], 2)
211 | 1.0
212 | Args:
213 | r: Relevance scores (list or numpy) in rank order
214 | (first element is the first item)
215 | k: Number of results to consider
216 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...]
217 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...]
218 | Returns:
219 | Normalized discounted cumulative gain
220 | """
221 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
222 | if not dcg_max:
223 | return 0.
224 | return dcg_at_k(r, k, method) / dcg_max
225 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
TableLlama Towards Open Large Generalist Models for Tables
2 |
3 |
4 | 🔥 🔥 🔥 This repo contains the code, data, and models for TableLlama.
5 | Check out our [Project Page] for more results and analysis!
6 |
7 |
8 |
9 |
10 |
11 |
12 | Figure 1: An overview of TableInstruct and TableLlama. TableInstruct includes a wide variety of realistic tables and tasks with instructions. We make the first step towards developing open-source generalist models for tables with TableInstruct and TableLlama.
13 |
14 |
15 |
16 |
17 |
18 | Figure 2: Illustration of three exemplary tasks: (a) Column type annotation. This task is to annotate the selected column with the correct semantic types. (b) Row population. This task is to populate rows given table metadata and partial row entities. (c) Hierarchical table QA. For subfigures (a) and (b), we mark candidates with red color in the "task instruction" part. The candidate set size can be hundreds to thousands in TableInstruct.
19 |
20 |
21 |
Release progress
22 |
23 | - :ballot_box_with_check: Training Dataset for TableLlama (check `/data_v3` of 🤗 [TableInstruct Dataset](https://huggingface.co/datasets/osunlp/TableInstruct/))
24 | - :ballot_box_with_check: TableLlama-7B model
25 | - :ballot_box_with_check: Code for Fine-tuning and Inference
26 | - :ballot_box_with_check: Evaluate Dataset of TableInstruct (check `/eval_data` of 🤗 [TableInstruct Dataset](https://huggingface.co/datasets/osunlp/TableInstruct/))
27 |
28 |
29 |
Updates
30 |
31 | - 2024/3/13: Our paper has been accepted by NAACL 2024!
32 | - 2024/3/21: We refine the prompts of 4 out-of-domain evaluation datasets: FEVEROUS, HybridQA, WikiSQL and WikiTQ of [TableInstruct](https://huggingface.co/datasets/osunlp/TableInstruct/) and update the results. Check the new results!
33 | - 2024/3/21: We add the results of closed-source LLMs: GPT-3.5 and GPT-4.
34 |
35 | ### Datasets and Models
36 | Our dataset and models are all available at Huggingface.
37 |
38 | 🤗 [TableInstruct Dataset](https://huggingface.co/datasets/osunlp/TableInstruct/)
39 |
40 | 🤗 [TableLlama-7B](https://osu-nlp-group.github.io/TableLlama/)
41 |
42 | The model is fine-tuned with the TableInstruct dataset using LongLoRA (7B), fully fine-tuning version as the base model, which replaces the vanilla attention mechanism of the original Llama-2 (7B) with shift short attention. The training takes 9 days on a 48 80*A100 cluster. Check out our paper for more details.
43 |
44 | TableInstruct includes a comprehensive table-based instruction tuning dataset that covers a variety of real-world tables and realistic tasks. We include 14 datasets of 11 tasks in total.
45 |
46 | The model is evaluated on 8 in-domain datasets of 8 tasks and 6 out-of-domain datasets of 4 tasks.
47 |
48 |
49 | ## **Introduction**
50 | We introduce TableLlama and TableInstruct: the FIRST open-source generalist LLM and instruction tuning dataset for tables. The TableLlama model is trained on TableInstruct Dataset, a meticulously curated instruction tuning dataset for tables. TableLlama is tuned on **2.6 million** table-based task data, and can handle up to **8K** context!
51 |
52 |
53 | ## **Installation**
54 |
55 | Clone this repository and install the required packages:
56 |
57 | ```bash
58 | git clone https://github.com/OSU-NLP-Group/TableLlama.git
59 | cd TableLlama
60 | pip install -r requirements.txt
61 | pip install flash-attn --no-build-isolation
62 | ```
63 |
64 | ## **Training and Inference**
65 |
66 | ### **Fine-tuning**
67 |
68 | To train the 7B model, run:
69 |
70 | ```bash
71 | torchrun --nproc_per_node=8 supervised_fine_tune.py \
72 | --model_name_or_path $MODEL_DIR \
73 | --bf16 True \
74 | --output_dir $OUTPUT_DIR \
75 | --model_max_length 8192 \
76 | --use_flash_attn True \
77 | --data_path $DATA_DIR \
78 | --cache_dir /ML-A800/hf_cache \
79 | --low_rank_training False \
80 | --num_train_epochs 2 \
81 | --per_device_train_batch_size 3 \
82 | --per_device_eval_batch_size 2 \
83 | --gradient_accumulation_steps 1 \
84 | --evaluation_strategy "no" \
85 | --save_strategy "steps" \
86 | --save_steps 2000 \
87 | --save_total_limit 4 \
88 | --learning_rate 2e-5 \
89 | --weight_decay 0.0 \
90 | --warmup_ratio 0.03 \
91 | --lr_scheduler_type "cosine" \
92 | --logging_steps 1 \
93 | --deepspeed "/ds_configs/stage2.json" \
94 | --tf32 True \
95 | --run_name $RUN_NAME
96 | ```
97 |
98 | **Addressing OOM**
99 | To train the 7B model with super large data size, if you encounter OOM issue, we provide code for streaming. You can run:
100 | ```bash
101 | torchrun --nproc_per_node=8 supervised_fine_tune_stream.py \
102 | --model_name_or_path $MODEL_DIR \
103 | --bf16 True \
104 | --output_dir $OUTPUT_DIR \
105 | --model_max_length 8192 \
106 | --use_flash_attn True \
107 | --data_path $DATA_DIR \
108 | --gpu_size $GPU_SIZE \
109 | --data_size $DATA_SIZE \
110 | --cache_dir /ML-A800/hf_cache \
111 | --low_rank_training False \
112 | --num_train_epochs 2 \
113 | --per_device_train_batch_size 3 \
114 | --per_device_eval_batch_size 2 \
115 | --gradient_accumulation_steps 1 \
116 | --evaluation_strategy "no" \
117 | --save_strategy "steps" \
118 | --save_steps 2000 \
119 | --save_total_limit 4 \
120 | --learning_rate 2e-5 \
121 | --weight_decay 0.0 \
122 | --warmup_ratio 0.03 \
123 | --lr_scheduler_type "cosine" \
124 | --logging_steps 1 \
125 | --deepspeed "/ds_configs/stage2.json" \
126 | --tf32 True \
127 | --run_name $RUN_NAME
128 |
129 | ```
130 |
131 | ### **Inference**
132 | ```bash
133 | python3 inference_rel_extraction_col_type.py \
134 | --base_model $MODEL_DIR \
135 | --context_size 8192 \
136 | --max_gen_len 128 \
137 | --flash_attn True \
138 | --input_data_file /test_data/test_col_type.json \
139 | --output_data_file $OUTPUT_DIR/col_type_pred.json
140 | ```
141 |
142 | ## **Evaluation**
143 |
144 | The folder `eval_scripts` includes evaluation scripts for all the in-domain test sets. To run the script, take HiTab (hierarchical table QA task) as an example:
145 |
146 | ```bash
147 | cd eval_scripts
148 | python evaluate_hitab.py --file_pred $OUTPUT_DIR/hitab_pred.json
149 | ```
150 |
151 | ## Prompt Format
152 |
153 | ```
154 | Below is an instruction that describes a task, paired with an input that provides further context. Write a response that
155 | appropriately completes the request.
156 |
157 | ### Instruction:
158 | {instruction}
159 |
160 | ### Input:
161 | {input}
162 |
163 | ### Question:
164 | {question}
165 |
166 | ### Response:
167 | ```
168 |
169 |
170 | - The instruction is designed to point out the task and give a detailed task description.
171 | - The input is designed to provide the information about the table. We concatenate table metadata (if any) such as the Wikipedia page title, section
172 | title and table caption with the serialized table as table input. We use '[TLE]' to represent the beginning of the table metadata, and use '[TAB]' to represent the beginning of the serialized table.
173 | - The question is to accommodate all the information the model needs to complete the task and prompt the model to generate an answer.
174 | - Task prompts examples (For more example prompts for other tasks, please refer to Appendix E in our paper.)
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 | **Note:**
185 |
186 | - If you directly use our model for inference on your data, please make sure you organize the data in the same way as the examples shown above and in our paper Appendix. The performance will vary significantly along with the prompts.
187 |
188 |
189 | ## **Citation**
190 |
191 | Please cite our paper if you use our data, model or code. Please also kindly cite the original dataset papers.
192 |
193 | ```
194 | @misc{zhang2023tablellama,
195 | title={TableLlama: Towards Open Large Generalist Models for Tables},
196 | author={Tianshu Zhang and Xiang Yue and Yifei Li and Huan Sun},
197 | year={2023},
198 | eprint={2311.09206},
199 | archivePrefix={arXiv},
200 | primaryClass={cs.CL}
201 | }
202 | ```
203 |
204 |
205 |
--------------------------------------------------------------------------------
/supervised_fine_tune_stream.py:
--------------------------------------------------------------------------------
1 |
2 | # Some code based on https://github.com/epfml/landmark-attention
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import io
17 | import os
18 | import copy
19 | import json
20 | import math
21 | import logging
22 | from dataclasses import dataclass, field
23 | from typing import Dict, Optional, Sequence
24 | from multiprocessing import cpu_count
25 | from datasets import load_dataset
26 | from tqdm import tqdm
27 | import psutil
28 |
29 | import torch
30 | import transformers
31 | # from torch.utils.data import Dataset, IterableDataset
32 | from datasets.iterable_dataset import IterableDataset
33 | from transformers import Trainer, DataCollatorForLanguageModeling
34 | from llama_attn_replace import replace_llama_attn
35 | from peft import LoraConfig, get_peft_model
36 | from torch.distributed import barrier
37 |
38 |
39 | IGNORE_INDEX = -100
40 | DEFAULT_PAD_TOKEN = "[PAD]"
41 | DEFAULT_EOS_TOKEN = ""
42 | DEFAULT_BOS_TOKEN = ""
43 | DEFAULT_UNK_TOKEN = ""
44 |
45 | def _make_r_io_base(f, mode: str):
46 | if not isinstance(f, io.IOBase):
47 | f = open(f, mode=mode)
48 | return f
49 |
50 | def jload(f, mode="r"):
51 | """Load a .json file into a dictionary."""
52 | f = _make_r_io_base(f, mode)
53 | jdict = json.load(f)
54 | f.close()
55 | return jdict
56 |
57 | def findAllFile(base):
58 | for root, ds, fs in os.walk(base):
59 | for f in fs:
60 | if f.endswith('.json'):
61 | fullname = os.path.join(root,f)
62 | yield fullname
63 |
64 |
65 |
66 | PROMPT_DICT = {
67 | "prompt_input": (
68 | "Below is an instruction that describes a task, paired with an input that provides further context. "
69 | "Write a response that appropriately completes the request.\n\n"
70 | "### Instruction:\n{instruction}\n\n### Input:\n{input_seg}\n\n### Question:\n{question}\n\n### Response:"
71 | ),
72 | "prompt_no_input": (
73 | "Below is an instruction that describes a task. "
74 | "Write a response that appropriately completes the request.\n\n"
75 | "### Instruction:\n{instruction}\n\n### Response:"
76 | ),
77 | }
78 |
79 | @dataclass
80 | class ModelArguments:
81 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
82 |
83 |
84 | @dataclass
85 | class DataArguments:
86 | data_path: str = field(default=None, metadata={"help": "Path to the training data."})
87 | data_size: int = field(default=None, metadata={"help": "for calculate max steps."})
88 | gpu_size: int = field(default=None, metadata={"help": "for calculate max steps and for logging for calcuated intervel."})
89 |
90 |
91 | @dataclass
92 | class TrainingArguments(transformers.TrainingArguments):
93 | cache_dir: Optional[str] = field(default=None)
94 | optim: str = field(default="adamw_torch")
95 | model_max_length: int = field(
96 | default=8192 * 4,
97 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
98 | )
99 | use_flash_attn: bool = field(
100 | default=True,
101 | metadata={"help": "Whether use flash attention for training."},
102 | )
103 | low_rank_training: bool = field(
104 | default=True,
105 | metadata={"help": "Whether use low rank adaptation for training."},
106 | )
107 | trainable_params: str = field(
108 | default="embed,norm",
109 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
110 | )
111 |
112 | def smart_tokenizer_and_embedding_resize(
113 | special_tokens_dict: Dict,
114 | tokenizer: transformers.PreTrainedTokenizer,
115 | model: transformers.PreTrainedModel,
116 | ):
117 | """Resize tokenizer and embedding.
118 |
119 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
120 | """
121 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
122 | model.resize_token_embeddings(len(tokenizer))
123 |
124 | if num_new_tokens > 0:
125 | input_embeddings = model.get_input_embeddings().weight.data
126 | output_embeddings = model.get_output_embeddings().weight.data
127 |
128 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
129 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
130 |
131 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
132 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
133 |
134 |
135 |
136 | tok_example_count = 0
137 |
138 | def preprocess_data(data_path, tokenizer):
139 |
140 | def _tokenize_fn(text: str) -> Dict:
141 | """Tokenize a list of strings."""
142 | tokenized = tokenizer(
143 | text,
144 | return_tensors="pt",
145 | # padding="longest",
146 | padding="max_length",
147 | max_length=tokenizer.model_max_length,
148 | truncation=True,
149 | )
150 |
151 | input_ids = labels = tokenized.input_ids[0]
152 | input_ids_lens = labels_lens = tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
153 |
154 | return dict(
155 | input_ids=input_ids,
156 | labels=labels,
157 | input_ids_lens=input_ids_lens,
158 | labels_lens=labels_lens,
159 | )
160 |
161 | def _process_function(example):
162 |
163 | global tok_example_count
164 | tok_example_count += 1
165 | if tok_example_count % 128 == 0:
166 | logging.warning(f"tok_example_count: {tok_example_count}")
167 |
168 | # logging.warning("Formatting inputs...")
169 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
170 | source = prompt_input.format_map(example) if example.get("input_seg", "") != "" else prompt_no_input.format_map(example)
171 |
172 | # targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
173 | target = f"{example['output']}{DEFAULT_EOS_TOKEN}"
174 |
175 | source_target = source + target
176 |
177 | example_tokenized = _tokenize_fn(source_target)
178 | source_tokenized = _tokenize_fn(source)
179 |
180 | input_ids = example_tokenized["input_ids"]
181 | label = copy.deepcopy(input_ids)
182 | label[:source_tokenized["input_ids_lens"]] = IGNORE_INDEX
183 |
184 | new_example = {"input_ids": input_ids, "labels": label}
185 |
186 | return new_example
187 |
188 | base = data_path
189 | sub_data_path_list = []
190 | for sub_data_path in findAllFile(base):
191 | sub_data_path_list.append(sub_data_path)
192 |
193 | logging.warning(f"before loading data, RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
194 | list_data_dict = load_dataset('json', data_files=sub_data_path_list, split = f'train', streaming=True)
195 | logging.warning(f"list_data_dict: {list_data_dict}")
196 | logging.warning(f"after loading data, RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
197 |
198 | logging.warning("Tokenizing inputs... This may take some time...")
199 | tokenized_dataset = list_data_dict.map(_process_function, remove_columns=["instruction", "input_seg", "question", "output"])
200 |
201 | return tokenized_dataset
202 |
203 |
204 | @dataclass
205 | class DataCollatorForSupervisedDataset(object):
206 | """Collate examples for supervised fine-tuning."""
207 |
208 | tokenizer: transformers.PreTrainedTokenizer
209 |
210 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
211 | # logging.warning(f"instances: {instances}")
212 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
213 | # logging.warning(f"input_ids: {input_ids}")
214 | # logging.warning(f"labels: {labels}")
215 | input_ids = torch.nn.utils.rnn.pad_sequence(
216 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
217 | )
218 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
219 | return dict(
220 | input_ids=input_ids,
221 | labels=labels,
222 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
223 | )
224 |
225 |
226 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
227 | """Make dataset and collator for supervised fine-tuning."""
228 | train_dataset = preprocess_data(tokenizer=tokenizer, data_path=data_args.data_path)
229 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
230 | # logging.warning(f"train_dataset: {train_dataset}")
231 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
232 |
233 |
234 | def train():
235 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
236 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
237 |
238 | replace_llama_attn(training_args.use_flash_attn, True)
239 |
240 | # Set RoPE scaling factor
241 | config = transformers.AutoConfig.from_pretrained(
242 | model_args.model_name_or_path,
243 | cache_dir=training_args.cache_dir,
244 | )
245 |
246 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
247 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
248 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
249 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
250 |
251 | # Load model and tokenizer
252 | model = transformers.AutoModelForCausalLM.from_pretrained(
253 | model_args.model_name_or_path,
254 | config=config,
255 | cache_dir=training_args.cache_dir,
256 | )
257 |
258 | tokenizer = transformers.AutoTokenizer.from_pretrained(
259 | model_args.model_name_or_path,
260 | cache_dir=training_args.cache_dir,
261 | model_max_length=training_args.model_max_length,
262 | padding_side="right",
263 | use_fast=False,
264 | )
265 |
266 | special_tokens_dict = dict()
267 | if tokenizer.pad_token is None:
268 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
269 | if tokenizer.eos_token is None:
270 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
271 | if tokenizer.bos_token is None:
272 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
273 | if tokenizer.unk_token is None:
274 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
275 |
276 | smart_tokenizer_and_embedding_resize(
277 | special_tokens_dict=special_tokens_dict,
278 | tokenizer=tokenizer,
279 | model=model,
280 | )
281 |
282 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
283 |
284 | if training_args.low_rank_training:
285 | config = LoraConfig(
286 | r=8,
287 | lora_alpha=16,
288 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
289 | lora_dropout=0,
290 | bias="none",
291 | task_type="CAUSAL_LM",
292 | )
293 | model = get_peft_model(model, config)
294 | # enable trainable params
295 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
296 |
297 | model.enable_input_require_grads() # required for gradient checkpointing
298 | model.gradient_checkpointing_enable() # enable gradient checkpointing
299 |
300 | logging.warning(f"data_module: {data_module}")
301 | # data_size = 2636762
302 | # data_size = 7325
303 | # GPU_size = 8
304 | training_args.max_steps = math.ceil(training_args.num_train_epochs * data_args.data_size/(training_args.per_device_train_batch_size * data_args.gpu_size))
305 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
306 | trainer.train()
307 | trainer.save_state()
308 | trainer.save_model(output_dir=training_args.output_dir)
309 |
310 |
311 | if __name__ == "__main__":
312 | train()
313 |
--------------------------------------------------------------------------------
/supervised_fine_tune.py:
--------------------------------------------------------------------------------
1 | # Some code based on https://github.com/epfml/landmark-attention
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import io
16 | import os
17 | import copy
18 | import json
19 | import math
20 | import logging
21 | from dataclasses import dataclass, field
22 | from typing import Dict, Optional, Sequence
23 |
24 | import torch
25 | import transformers
26 | from torch.utils.data import Dataset
27 | from transformers import Trainer, DataCollatorForLanguageModeling
28 | from llama_attn_replace import replace_llama_attn
29 | from peft import LoraConfig, get_peft_model
30 | from torch.distributed import barrier
31 |
32 |
33 |
34 | IGNORE_INDEX = -100
35 | DEFAULT_PAD_TOKEN = "[PAD]"
36 | DEFAULT_EOS_TOKEN = ""
37 | DEFAULT_BOS_TOKEN = ""
38 | DEFAULT_UNK_TOKEN = ""
39 |
40 | def _make_r_io_base(f, mode: str):
41 | if not isinstance(f, io.IOBase):
42 | f = open(f, mode=mode)
43 | return f
44 |
45 | def jload(f, mode="r"):
46 | """Load a .json file into a dictionary."""
47 | f = _make_r_io_base(f, mode)
48 | jdict = json.load(f)
49 | f.close()
50 | return jdict
51 |
52 | # PROMPT_DICT = {
53 | # "prompt_input": (
54 | # "Below is an instruction that describes a task, paired with an input that provides further context. "
55 | # "Write a response that appropriately completes the request.\n\n"
56 | # "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
57 | # ),
58 | # "prompt_no_input": (
59 | # "Below is an instruction that describes a task. "
60 | # "Write a response that appropriately completes the request.\n\n"
61 | # "### Instruction:\n{instruction}\n\n### Response:"
62 | # ),
63 | # }
64 |
65 | PROMPT_DICT = {
66 | "prompt_input": (
67 | "Below is an instruction that describes a task, paired with an input that provides further context. "
68 | "Write a response that appropriately completes the request.\n\n"
69 | "### Instruction:\n{instruction}\n\n### Input:\n{input_seg}\n\n### Question:\n{question}\n\n### Response:"
70 | ),
71 | "prompt_no_input": (
72 | "Below is an instruction that describes a task. "
73 | "Write a response that appropriately completes the request.\n\n"
74 | "### Instruction:\n{instruction}\n\n### Response:"
75 | ),
76 | }
77 |
78 | @dataclass
79 | class ModelArguments:
80 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
81 |
82 |
83 | @dataclass
84 | class DataArguments:
85 | data_path: str = field(default=None, metadata={"help": "Path to the training data."})
86 |
87 |
88 | @dataclass
89 | class TrainingArguments(transformers.TrainingArguments):
90 | cache_dir: Optional[str] = field(default=None)
91 | optim: str = field(default="adamw_torch")
92 | model_max_length: int = field(
93 | default=8192 * 4,
94 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
95 | )
96 | use_flash_attn: bool = field(
97 | default=True,
98 | metadata={"help": "Whether use flash attention for training."},
99 | )
100 | low_rank_training: bool = field(
101 | default=True,
102 | metadata={"help": "Whether use low rank adaptation for training."},
103 | )
104 | trainable_params: str = field(
105 | default="embed,norm",
106 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
107 | )
108 |
109 | def smart_tokenizer_and_embedding_resize(
110 | special_tokens_dict: Dict,
111 | tokenizer: transformers.PreTrainedTokenizer,
112 | model: transformers.PreTrainedModel,
113 | ):
114 | """Resize tokenizer and embedding.
115 |
116 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
117 | """
118 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
119 | model.resize_token_embeddings(len(tokenizer))
120 |
121 | if num_new_tokens > 0:
122 | input_embeddings = model.get_input_embeddings().weight.data
123 | output_embeddings = model.get_output_embeddings().weight.data
124 |
125 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
126 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
127 |
128 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
129 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
130 |
131 |
132 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
133 | """Tokenize a list of strings."""
134 | tokenized_list = [
135 | tokenizer(
136 | text,
137 | return_tensors="pt",
138 | padding="longest",
139 | max_length=tokenizer.model_max_length,
140 | truncation=True,
141 | )
142 | for text in strings
143 | ]
144 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
145 | input_ids_lens = labels_lens = [
146 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
147 | ]
148 | return dict(
149 | input_ids=input_ids,
150 | labels=labels,
151 | input_ids_lens=input_ids_lens,
152 | labels_lens=labels_lens,
153 | )
154 |
155 |
156 | def preprocess(
157 | sources: Sequence[str],
158 | targets: Sequence[str],
159 | tokenizer: transformers.PreTrainedTokenizer,
160 | ) -> Dict:
161 | """Preprocess the data by tokenizing."""
162 | examples = [s + t for s, t in zip(sources, targets)]
163 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
164 | input_ids = examples_tokenized["input_ids"]
165 | labels = copy.deepcopy(input_ids)
166 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
167 | label[:source_len] = IGNORE_INDEX
168 | return dict(input_ids=input_ids, labels=labels)
169 |
170 |
171 | class SupervisedDataset(Dataset):
172 | """Dataset for supervised fine-tuning."""
173 |
174 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
175 | super(SupervisedDataset, self).__init__()
176 | logging.warning("Loading data...")
177 | list_data_dict = jload(data_path)
178 |
179 | logging.warning("Formatting inputs...")
180 | '''
181 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
182 | sources = [
183 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
184 | for example in list_data_dict
185 | ]
186 | '''
187 |
188 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
189 |
190 | # print(prompt_input.format_map(list_data_dict[1]))
191 | # print("****")
192 | # print(f"{list_data_dict[1]['output']}{DEFAULT_EOS_TOKEN}")
193 | # import pdb
194 | # pdb.set_trace()
195 |
196 | sources = [
197 | # prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
198 | prompt_input.format_map(example) if example.get("input_seg", "") != "" else prompt_no_input.format_map(example)
199 | for example in list_data_dict
200 | ]
201 | # targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
202 | targets = [f"{example['output']}{DEFAULT_EOS_TOKEN}" for example in list_data_dict]
203 |
204 | # sources = [example["instruction"] for example in list_data_dict]
205 |
206 | # targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
207 |
208 | logging.warning("Tokenizing inputs... This may take some time...")
209 | data_dict = preprocess(sources, targets, tokenizer)
210 |
211 | self.input_ids = data_dict["input_ids"]
212 | self.labels = data_dict["labels"]
213 |
214 | def __len__(self):
215 | return len(self.input_ids)
216 |
217 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
218 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
219 |
220 |
221 | @dataclass
222 | class DataCollatorForSupervisedDataset(object):
223 | """Collate examples for supervised fine-tuning."""
224 |
225 | tokenizer: transformers.PreTrainedTokenizer
226 |
227 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
228 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
229 | input_ids = torch.nn.utils.rnn.pad_sequence(
230 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
231 | )
232 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
233 | return dict(
234 | input_ids=input_ids,
235 | labels=labels,
236 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
237 | )
238 |
239 |
240 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
241 | """Make dataset and collator for supervised fine-tuning."""
242 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
243 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
244 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
245 |
246 |
247 | def train():
248 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
249 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
250 |
251 | replace_llama_attn(training_args.use_flash_attn, True)
252 |
253 | # Set RoPE scaling factor
254 | config = transformers.AutoConfig.from_pretrained(
255 | model_args.model_name_or_path,
256 | cache_dir=training_args.cache_dir,
257 | )
258 |
259 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
260 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
261 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
262 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
263 |
264 | # Load model and tokenizer
265 | model = transformers.AutoModelForCausalLM.from_pretrained(
266 | model_args.model_name_or_path,
267 | config=config,
268 | cache_dir=training_args.cache_dir,
269 | )
270 |
271 | tokenizer = transformers.AutoTokenizer.from_pretrained(
272 | model_args.model_name_or_path,
273 | cache_dir=training_args.cache_dir,
274 | model_max_length=training_args.model_max_length,
275 | padding_side="right",
276 | use_fast=False,
277 | )
278 |
279 | special_tokens_dict = dict()
280 | if tokenizer.pad_token is None:
281 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
282 | if tokenizer.eos_token is None:
283 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
284 | if tokenizer.bos_token is None:
285 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
286 | if tokenizer.unk_token is None:
287 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
288 |
289 | smart_tokenizer_and_embedding_resize(
290 | special_tokens_dict=special_tokens_dict,
291 | tokenizer=tokenizer,
292 | model=model,
293 | )
294 |
295 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
296 |
297 | if training_args.low_rank_training:
298 | config = LoraConfig(
299 | r=8,
300 | lora_alpha=16,
301 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
302 | lora_dropout=0,
303 | bias="none",
304 | task_type="CAUSAL_LM",
305 | )
306 | model = get_peft_model(model, config)
307 | # enable trainable params
308 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
309 |
310 | model.enable_input_require_grads() # required for gradient checkpointing
311 | model.gradient_checkpointing_enable() # enable gradient checkpointing
312 |
313 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
314 | trainer.train()
315 | trainer.save_state()
316 | trainer.save_model(output_dir=training_args.output_dir)
317 |
318 |
319 | if __name__ == "__main__":
320 | train()
321 |
--------------------------------------------------------------------------------
/inference_schema_aug.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import sys
4 | import math
5 | import torch
6 | import argparse
7 | # import textwrap
8 | import transformers
9 | from peft import PeftModel
10 | from transformers import GenerationConfig
11 | from llama_attn_replace import replace_llama_attn
12 | from supervised_fine_tune import PROMPT_DICT
13 | from tqdm import tqdm
14 | # from queue import Queue
15 | # from threading import Thread
16 | # import gradio as gr
17 |
18 | def parse_config():
19 | parser = argparse.ArgumentParser(description='arg parser')
20 | # parser.add_argument('--question', type=str, default="")
21 | # parser.add_argument('--material', type=str, default="")
22 | # parser.add_argument('--material_title', type=str, default="")
23 | # parser.add_argument('--material_type', type=str, default="material")
24 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
25 | parser.add_argument('--cache_dir', type=str, default="./cache")
26 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
27 | parser.add_argument('--flash_attn', type=bool, default=False, help='')
28 | parser.add_argument('--temperature', type=float, default=0.6, help='')
29 | parser.add_argument('--top_p', type=float, default=0.9, help='')
30 | parser.add_argument('--max_gen_len', type=int, default=512, help='')
31 | parser.add_argument('--input_data_file', type=str, default='input_data/', help='')
32 | parser.add_argument('--output_data_file', type=str, default='output_data/', help='')
33 | args = parser.parse_args()
34 | return args
35 |
36 | def generate_prompt(instruction, question, input_seg=None):
37 | if input:
38 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question)
39 | else:
40 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
41 |
42 | # def format_prompt(material, message, material_type="book", material_title=""):
43 | # if material_type == "paper":
44 | # prompt = f"Below is a paper. Memorize the material and answer my question after the paper.\n {material} \n "
45 | # elif material_type == "book":
46 | # material_title = ", %s"%material_title if len(material_title)>0 else ""
47 | # prompt = f"Below is some paragraphs in the book{material_title}. Memorize the content and answer my question after the book.\n {material} \n "
48 | # else:
49 | # prompt = f"Below is a material. Memorize the material and answer my question after the material. \n {material} \n "
50 | # message = str(message).strip()
51 | # prompt += f"Now the material ends. {message}"
52 |
53 | # return prompt
54 |
55 | # def read_txt_file(material_txt):
56 | # if not material_txt.split(".")[-1]=='txt':
57 | # raise ValueError("Only support txt or pdf file.")
58 | # content = ""
59 | # with open(material_txt) as f:
60 | # for line in f.readlines():
61 | # content += line
62 | # return content
63 |
64 | def build_generator(
65 | item, model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
66 | ):
67 | def response(item):
68 | # def response(material, question, material_type="", material_title=None):
69 | # material = read_txt_file(material)
70 | # prompt = format_prompt(material, question, material_type, material_title)
71 | prompt = generate_prompt(instruction = item["instruction"], input_seg = item["input_seg"], question = item["question"])
72 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
73 |
74 | output = model.generate(
75 | **inputs,
76 | max_new_tokens=max_gen_len,
77 | temperature=temperature,
78 | top_p=top_p,
79 | use_cache=use_cache
80 | )
81 | out = tokenizer.decode(output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
82 |
83 | out = out.split(prompt)[1].strip()
84 | return out
85 |
86 | return response
87 |
88 | def main(args):
89 | if args.flash_attn:
90 | replace_llama_attn()
91 |
92 | # Set RoPE scaling factor
93 | config = transformers.AutoConfig.from_pretrained(
94 | args.base_model,
95 | cache_dir=args.cache_dir,
96 | )
97 |
98 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
99 | if orig_ctx_len and args.context_size > orig_ctx_len:
100 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
101 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
102 |
103 | # Load model and tokenizer
104 | model = transformers.AutoModelForCausalLM.from_pretrained(
105 | args.base_model,
106 | config=config,
107 | cache_dir=args.cache_dir,
108 | torch_dtype=torch.float16,
109 | device_map="auto",
110 | )
111 | model.resize_token_embeddings(32001)
112 |
113 | tokenizer = transformers.AutoTokenizer.from_pretrained(
114 | args.base_model,
115 | cache_dir=args.cache_dir,
116 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
117 | # padding_side="right",
118 | padding_side="left",
119 | use_fast=False,
120 | )
121 |
122 | model.eval()
123 | if torch.__version__ >= "2" and sys.platform != "win32":
124 | model = torch.compile(model)
125 |
126 | with open(args.input_data_file, "r") as f:
127 | test_data = json.load(f)
128 |
129 | # import random
130 | # test_data = random.sample(test_data, k=5)
131 |
132 | test_data_pred = []
133 | for i in tqdm(range(len(test_data))):
134 | item = test_data[i]
135 | new_item = {}
136 | respond = build_generator(item, model, tokenizer, temperature=args.temperature, top_p=args.top_p,
137 | max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) # the temperature and top_p are highly different with previous alpaca exp, pay attention to this if there is sth wrong later
138 | output = respond(item)
139 |
140 | new_item["idx"] = i
141 | new_item["table_id"] = test_data[i]["table_id"]
142 | new_item["instruction"] = test_data[i]["instruction"]
143 | new_item["input_seg"] = test_data[i]["input_seg"]
144 | new_item["question"] = test_data[i]["question"]
145 | new_item["target"] = test_data[i]["target"]
146 | new_item["output_list"] = test_data[i]["output_list"]
147 | new_item["output"] = test_data[i]["output"]
148 | new_item["predict"] = output
149 |
150 | test_data_pred.append(new_item)
151 | # import pdb
152 | # pdb.set_trace()
153 | with open(args.output_data_file, "w") as f:
154 | json.dump(test_data_pred, f, indent = 2)
155 |
156 | # output = respond(args.material, args.question, args.material_type, args.material_title)
157 | # print("output", output)
158 |
159 | if __name__ == "__main__":
160 | args = parse_config()
161 | main(args)
162 |
163 |
164 | # from dataclasses import dataclass, field
165 |
166 | # import numpy as np
167 | # import torch
168 | # import transformers
169 | # from transformers import GenerationConfig
170 |
171 | # from train_llama2_long_context_reformat import ModelArguments, smart_tokenizer_and_embedding_resize, DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, \
172 | # DEFAULT_BOS_TOKEN, DEFAULT_UNK_TOKEN, PROMPT_DICT
173 |
174 | # import json
175 | # from tqdm import tqdm
176 | # import math
177 | # import argparse
178 |
179 | # @dataclass
180 | # class InferenceArguments:
181 | # model_max_length: int = field(
182 | # # default=512,
183 | # # default=1024,
184 | # default=1536,
185 | # metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
186 | # )
187 | # load_in_8bit: bool = field(
188 | # default=False,
189 | # metadata={"help": "Load the model in 8-bit mode."},
190 | # )
191 | # inference_dtype: torch.dtype = field(
192 | # default=torch.float32,
193 | # metadata={"help": "The dtype to use for inference."},
194 | # )
195 | # max_new_tokens: int = field(
196 | # default=64,
197 | # metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
198 | # )
199 |
200 | # @dataclass
201 | # class FileArguments:
202 | # input_data_file: str = field(
203 | # default="",
204 | # metadata={"help": ""},
205 | # )
206 | # output_data_file: str = field(
207 | # default="",
208 | # metadata={"help": ""},
209 | # )
210 |
211 |
212 |
213 |
214 | # def batch_process(data_list, model, tokenizer, generation_config, batch_size, max_new_tokens):
215 | # pred = []
216 | # for i in tqdm(range(math.ceil(len(data_list)/batch_size))):
217 | # if i != math.ceil(len(data_list)/batch_size) - 1:
218 | # batch_data = data_list[i * batch_size: i * batch_size + batch_size]
219 | # else:
220 | # batch_data = data_list[i * batch_size:]
221 | # batch_prompt =[generate_prompt(item["instruction"], item["input_seg"], item["question"]) for item in batch_data]
222 | # inputs = tokenizer(batch_prompt,
223 | # return_tensors="pt",
224 | # padding="longest",
225 | # max_length=tokenizer.model_max_length,
226 | # truncation=True)
227 | # outputs = model.generate(input_ids=inputs["input_ids"].cuda(), generation_config=generation_config, max_new_tokens = max_new_tokens)
228 |
229 | # # import pdb
230 | # # pdb.set_trace()
231 | # # input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
232 | # # generated_tokens = outputs.sequences[:, input_length:]
233 | # # pred += tokenizer.batch_decode(generated_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
234 | # pred += tokenizer.batch_decode(outputs, skip_special_tokens=False, clean_up_tokenization_spaces=False)
235 | # # import pdb
236 | # # pdb.set_trace()
237 | # return pred
238 |
239 |
240 | # def inference(test_data, model_args, inference_args):
241 | # # parser = transformers.HfArgumentParser((ModelArguments, InferenceArguments))
242 | # # model_args, inference_args = parser.parse_args_into_dataclasses()
243 |
244 | # model = transformers.AutoModelForCausalLM.from_pretrained(
245 | # model_args.model_name_or_path,
246 | # load_in_8bit=inference_args.load_in_8bit,
247 | # torch_dtype=inference_args.inference_dtype,
248 | # device_map="auto",
249 | # )
250 | # model.cuda()
251 | # model.eval()
252 |
253 | # generation_config = GenerationConfig(
254 | # temperature=0.1,
255 | # top_p=0.75,
256 | # # num_beams=4,
257 | # num_beams=1,
258 | # # num_beams=2,
259 | # )
260 |
261 | # tokenizer = transformers.AutoTokenizer.from_pretrained(
262 | # model_args.model_name_or_path,
263 | # use_fast=False,
264 | # model_max_length=inference_args.model_max_length,
265 | # padding_side="left" ### important to add this in inference
266 | # )
267 |
268 | # if tokenizer.pad_token is None:
269 | # smart_tokenizer_and_embedding_resize(
270 | # special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
271 | # tokenizer=tokenizer,
272 | # model=model,
273 | # )
274 | # tokenizer.add_special_tokens(
275 | # {
276 | # "eos_token": DEFAULT_EOS_TOKEN,
277 | # "bos_token": DEFAULT_BOS_TOKEN,
278 | # "unk_token": DEFAULT_UNK_TOKEN,
279 | # }
280 | # )
281 |
282 | # pred = batch_process(test_data, model, tokenizer, generation_config, 1, inference_args.max_new_tokens)
283 |
284 | # new_test_list = []
285 | # for i in tqdm(range(len(test_data))):
286 | # # for i in tqdm(range(90, 101)):
287 | # # for i in tqdm(range(3)):
288 | # instruction = test_data[i]["instruction"]
289 | # item = {}
290 | # item["idx"] = i
291 | # # item["table_id"] = test_data[i]["table_id"]
292 | # # item["entity"] = test_data[i]["entity"]
293 | # item["instruction"] = instruction
294 | # # item["input"] = input
295 | # item["input_seg"] = test_data[i]["input_seg"]
296 | # # item["tokenizer_tensor_shape"] = inputs["input_ids"].shape
297 | # item["output"] = test_data[i]["output"]
298 | # item["predict"] = pred[i]
299 |
300 | # new_test_list.append(item)
301 | # return new_test_list
302 |
303 |
304 | # if __name__ == "__main__":
305 |
306 | # parser = transformers.HfArgumentParser((ModelArguments, InferenceArguments, FileArguments))
307 | # model_args, inference_args, file_args = parser.parse_args_into_dataclasses()
308 |
309 | # # num = 0
310 | # # with open("/users/PAA0201/shubaobao/stanford_alpaca/table_all_tasks_fair/test/split_16_col_type/test_" + str(file_args.input_data_file_num) + ".json", "r") as f:
311 | # with open(file_args.input_data_file, "r") as f:
312 | # test_data = json.load(f)
313 |
314 | # # import random
315 | # # test_data = random.sample(test_data, k=5)
316 | # test_list = inference(test_data, model_args, inference_args)
317 |
318 | # # with open("/users/PAA0201/shubaobao/stanford_alpaca/table_all_tasks_fair/pred_ser_20000_seg/test_beam_search/test_" + str(file_args.output_data_file_num) + ".json", "w") as f:
319 | # with open(file_args.output_data_file, "w") as f:
320 | # json.dump(test_list, f, indent = 2)
321 |
322 | # print("input file:", str(file_args.input_data_file))
323 | # print("output file:", str(file_args.output_data_file))
324 |
--------------------------------------------------------------------------------
/llama_attn_replace.py:
--------------------------------------------------------------------------------
1 | # Modified based on https://github.com/lm-sys/FastChat
2 |
3 | from typing import Optional, Tuple
4 | import warnings
5 | import torch
6 | from torch import nn
7 | import transformers
8 | from einops import rearrange
9 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
10 |
11 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
12 | from flash_attn.bert_padding import unpad_input, pad_input
13 |
14 | group_size_ratio = 1/4
15 | def forward_flashattn(
16 | self,
17 | hidden_states: torch.Tensor,
18 | attention_mask: Optional[torch.Tensor] = None,
19 | position_ids: Optional[torch.Tensor] = None,
20 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
21 | output_attentions: bool = False,
22 | use_cache: bool = False,
23 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
24 | """Input shape: Batch x Time x Channel
25 |
26 | attention_mask: [bsz, q_len]
27 | """
28 | if output_attentions:
29 | warnings.warn(
30 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
31 | )
32 |
33 | bsz, q_len, _ = hidden_states.size()
34 |
35 | query_states = (
36 | self.q_proj(hidden_states)
37 | .view(bsz, q_len, self.num_heads, self.head_dim)
38 | .transpose(1, 2)
39 | )
40 | key_states = (
41 | self.k_proj(hidden_states)
42 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
43 | .transpose(1, 2)
44 | )
45 | value_states = (
46 | self.v_proj(hidden_states)
47 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
48 | .transpose(1, 2)
49 | )
50 | # [bsz, q_len, nh, hd]
51 | # [bsz, nh, q_len, hd]
52 |
53 | kv_seq_len = key_states.shape[-2]
54 | if past_key_value is not None:
55 | kv_seq_len += past_key_value[0].shape[-2]
56 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
57 | query_states, key_states = apply_rotary_pos_emb(
58 | query_states, key_states, cos, sin, position_ids
59 | )
60 |
61 | # Past Key value support
62 | if past_key_value is not None:
63 | # reuse k, v, self_attention
64 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
65 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
66 |
67 | past_key_value = (key_states, value_states) if use_cache else None
68 |
69 | # repeat k/v heads if n_kv_heads < n_heads
70 | key_states = repeat_kv(key_states, self.num_key_value_groups)
71 | value_states = repeat_kv(value_states, self.num_key_value_groups)
72 |
73 | # Flash attention codes from
74 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
75 |
76 | # transform the data into the format required by flash attention
77 | qkv = torch.stack(
78 | [query_states, key_states, value_states], dim=2
79 | ) # [bsz, nh, 3, q_len, hd]
80 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
81 |
82 | # shift
83 | if self.training:
84 | group_size = int(q_len * group_size_ratio)
85 | if q_len % group_size > 0:
86 | raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
87 | num_group = q_len // group_size
88 | qkv[:, :, :, self.num_heads//2:] = qkv[:, :, :, self.num_heads//2:].roll(-group_size//2, dims=1)
89 | qkv = qkv.reshape(bsz*num_group, group_size, 3, self.num_heads, self.head_dim)
90 |
91 | # We have disabled _prepare_decoder_attention_mask in LlamaModel
92 | # the attention_mask should be the same as the key_padding_mask
93 |
94 | key_padding_mask = attention_mask[:, :group_size].repeat(num_group, 1) if self.training else attention_mask
95 | nheads = qkv.shape[-2]
96 | x = rearrange(qkv, "b s three h d -> b s (three h d)")
97 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
98 | x_unpad = rearrange(
99 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
100 | )
101 | output_unpad = flash_attn_varlen_qkvpacked_func(
102 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
103 | )
104 | output = rearrange(
105 | pad_input(
106 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * num_group if self.training else bsz, group_size if self.training else q_len
107 | ),
108 | "b s (h d) -> b s h d",
109 | h=nheads,
110 | )
111 | output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
112 | if self.training:
113 | # shift back
114 | output[:, :, self.num_heads//2:] = output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
115 |
116 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
117 |
118 | def forward_flashattn_full(
119 | self,
120 | hidden_states: torch.Tensor,
121 | attention_mask: Optional[torch.Tensor] = None,
122 | position_ids: Optional[torch.Tensor] = None,
123 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
124 | output_attentions: bool = False,
125 | use_cache: bool = False,
126 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
127 | """Input shape: Batch x Time x Channel
128 |
129 | attention_mask: [bsz, q_len]
130 | """
131 | if output_attentions:
132 | warnings.warn(
133 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
134 | )
135 |
136 | bsz, q_len, _ = hidden_states.size()
137 |
138 | query_states = (
139 | self.q_proj(hidden_states)
140 | .view(bsz, q_len, self.num_heads, self.head_dim)
141 | .transpose(1, 2)
142 | )
143 | key_states = (
144 | self.k_proj(hidden_states)
145 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
146 | .transpose(1, 2)
147 | )
148 | value_states = (
149 | self.v_proj(hidden_states)
150 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
151 | .transpose(1, 2)
152 | )
153 | # [bsz, q_len, nh, hd]
154 | # [bsz, nh, q_len, hd]
155 |
156 | kv_seq_len = key_states.shape[-2]
157 | if past_key_value is not None:
158 | kv_seq_len += past_key_value[0].shape[-2]
159 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
160 | query_states, key_states = apply_rotary_pos_emb(
161 | query_states, key_states, cos, sin, position_ids
162 | )
163 |
164 | # Past Key value support
165 | if past_key_value is not None:
166 | # reuse k, v, self_attention
167 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
168 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
169 |
170 | past_key_value = (key_states, value_states) if use_cache else None
171 |
172 | # repeat k/v heads if n_kv_heads < n_heads
173 | key_states = repeat_kv(key_states, self.num_key_value_groups)
174 | value_states = repeat_kv(value_states, self.num_key_value_groups)
175 |
176 | # Flash attention codes from
177 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
178 |
179 | # transform the data into the format required by flash attention
180 | qkv = torch.stack(
181 | [query_states, key_states, value_states], dim=2
182 | ) # [bsz, nh, 3, q_len, hd]
183 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
184 |
185 | # We have disabled _prepare_decoder_attention_mask in LlamaModel
186 | # the attention_mask should be the same as the key_padding_mask
187 |
188 | key_padding_mask = attention_mask
189 | nheads = qkv.shape[-2]
190 | x = rearrange(qkv, "b s three h d -> b s (three h d)")
191 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
192 | x_unpad = rearrange(
193 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
194 | )
195 | output_unpad = flash_attn_varlen_qkvpacked_func(
196 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
197 | )
198 | output = rearrange(
199 | pad_input(
200 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
201 | ),
202 | "b s (h d) -> b s h d",
203 | h=nheads,
204 | )
205 | output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
206 |
207 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
208 |
209 |
210 | def forward_noflashattn(
211 | self,
212 | hidden_states: torch.Tensor,
213 | attention_mask: Optional[torch.Tensor] = None,
214 | position_ids: Optional[torch.LongTensor] = None,
215 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
216 | output_attentions: bool = False,
217 | use_cache: bool = False,
218 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
219 | bsz, q_len, _ = hidden_states.size()
220 |
221 | group_size = int(q_len * group_size_ratio)
222 |
223 | if q_len % group_size > 0:
224 | raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size))
225 | num_group = q_len // group_size
226 |
227 | if self.config.pretraining_tp > 1:
228 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
229 | query_slices = self.q_proj.weight.split(
230 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
231 | )
232 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
233 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
234 |
235 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
236 | query_states = torch.cat(query_states, dim=-1)
237 |
238 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
239 | key_states = torch.cat(key_states, dim=-1)
240 |
241 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
242 | value_states = torch.cat(value_states, dim=-1)
243 |
244 | else:
245 | query_states = self.q_proj(hidden_states)
246 | key_states = self.k_proj(hidden_states)
247 | value_states = self.v_proj(hidden_states)
248 |
249 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
250 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
251 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
252 |
253 | kv_seq_len = key_states.shape[-2]
254 | if past_key_value is not None:
255 | kv_seq_len += past_key_value[0].shape[-2]
256 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
257 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
258 |
259 | if past_key_value is not None:
260 | # reuse k, v, self_attention
261 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
262 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
263 |
264 | past_key_value = (key_states, value_states) if use_cache else None
265 |
266 | # repeat k/v heads if n_kv_heads < n_heads
267 | key_states = repeat_kv(key_states, self.num_key_value_groups)
268 | value_states = repeat_kv(value_states, self.num_key_value_groups)
269 |
270 | # shift
271 | def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
272 | qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
273 | qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
274 | return qkv
275 |
276 | query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
277 | key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
278 | value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
279 |
280 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
281 |
282 | if attn_weights.size() != (bsz * num_group, self.num_heads, group_size, group_size):
283 | raise ValueError(
284 | f"Attention weights should be of size {(bsz * num_group, self.num_heads, group_size, group_size)}, but is"
285 | f" {attn_weights.size()}"
286 | )
287 |
288 | attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)
289 | if attention_mask is not None:
290 | if attention_mask.size() != (bsz * num_group, 1, group_size, group_size):
291 | raise ValueError(
292 | f"Attention mask should be of size {(bsz * num_group, 1, group_size, group_size)}, but is {attention_mask.size()}"
293 | )
294 | attn_weights = attn_weights + attention_mask
295 |
296 | # upcast attention to fp32
297 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
298 | attn_output = torch.matmul(attn_weights, value_states)
299 |
300 | if attn_output.size() != (bsz * num_group, self.num_heads, group_size, self.head_dim):
301 | raise ValueError(
302 | f"`attn_output` should be of size {(bsz * num_group, self.num_heads, group_size, self.head_dim)}, but is"
303 | f" {attn_output.size()}"
304 | )
305 | attn_output = attn_output.transpose(1, 2).contiguous()
306 |
307 | attn_output = attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
308 |
309 | # shift back
310 | output[:, :, self.num_heads//2:] = output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
311 |
312 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
313 |
314 | if self.config.pretraining_tp > 1:
315 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
316 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
317 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
318 | else:
319 | attn_output = self.o_proj(attn_output)
320 |
321 | if not output_attentions:
322 | attn_weights = None
323 |
324 | return attn_output, attn_weights, past_key_value
325 |
326 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
327 | # requires the attention mask to be the same as the key_padding_mask
328 | def _prepare_decoder_attention_mask(
329 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
330 | ):
331 | # [bsz, seq_len]
332 | return attention_mask
333 |
334 |
335 | def replace_llama_attn(use_flash_attn=True, use_full=False):
336 | if use_flash_attn:
337 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
338 | if cuda_major < 8:
339 | warnings.warn(
340 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
341 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
342 | )
343 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
344 | _prepare_decoder_attention_mask
345 | )
346 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattn
347 | else:
348 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn
349 |
--------------------------------------------------------------------------------