├── ds_configs ├── stage2.json └── stage3.json ├── src ├── run.sh ├── inference.py └── sft_minicpm.py ├── eval_scripts ├── eval_hitab.py ├── eval_ent_link.py ├── eval_fetaqa.py ├── qa_datadump_utils.py ├── table_utils.py ├── eval_rel_extraction.py ├── eval_col_type.py └── metric.py ├── README.md ├── TPE-Llama ├── __init__.py ├── configuration_llama.py ├── tokenization_llama_fast.py ├── convert_llama_weights_to_hf.py ├── tokenization_llama.py └── modeling_llama.py └── requirements.txt /ds_configs/stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "gradient_accumulation_steps": "auto", 4 | "gradient_clipping": "auto", 5 | "zero_allow_untested_optimizer": true, 6 | "bf16": { 7 | "enabled": "auto", 8 | "loss_scale": 0, 9 | "initial_scale_power": 16, 10 | "loss_scale_window": 1000, 11 | "hysteresis": 2, 12 | "min_loss_scale": 1 13 | }, 14 | "zero_optimization": { 15 | "stage": 2, 16 | "allgather_partitions": true, 17 | "allgather_bucket_size": 1e9, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 1e9, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true 22 | } 23 | } -------------------------------------------------------------------------------- /src/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7,6,5,4,3,2,1,0 torchrun --nproc_per_node=8 --master_port=20213 sft_minicpm.py \ 2 | --model_name_or_path /model/MiniCPM-2B-sft-bf16-llama-format \ 3 | --bf16 True \ 4 | --output_dir /output/rel_extraction_2d \ 5 | --model_max_length 4096 \ 6 | --use_flash_attn True \ 7 | --data_path /data/rel_extraction_train_62954.json \ 8 | --low_rank_training False \ 9 | --num_train_epochs 2 \ 10 | --per_device_train_batch_size 2 \ 11 | --gradient_accumulation_steps 4 \ 12 | --evaluation_strategy "no" \ 13 | --save_strategy "epoch" \ 14 | --save_total_limit 1 \ 15 | --learning_rate 2e-5 \ 16 | --weight_decay 0.0 \ 17 | --warmup_ratio 0.03 \ 18 | --lr_scheduler_type "cosine" \ 19 | --logging_steps 10 \ 20 | --deepspeed ds_configs/stage2.json \ 21 | --tf32 True -------------------------------------------------------------------------------- /eval_scripts/eval_hitab.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import argparse 4 | from table_utils import evaluate 5 | 6 | 7 | def main(args): 8 | data = [] 9 | with open(args.pred_file, "r") as f: 10 | for line in f: 11 | data.append(json.loads(line)) 12 | 13 | pred_list = [] 14 | gold_list = [] 15 | for i in range(len(data)): 16 | if len(data[i]["predict"].strip("").split(">, <")) > 1: 17 | instance_pred_list = data[i]["predict"].strip("").split(">, <") 18 | pred_list.append(instance_pred_list) 19 | gold_list.append(data[i]["output"].strip("").split(">, <")) 20 | else: 21 | pred_list.append(data[i]["predict"].strip("")) 22 | gold_list.append(data[i]["output"].strip("")) 23 | 24 | print(evaluate(gold_list, pred_list)) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser(description='arg parser') 29 | parser.add_argument('--pred_file', type=str, default='../res/hitab_res.json', help='') 30 | args = parser.parse_args() 31 | main(args) -------------------------------------------------------------------------------- /eval_scripts/eval_ent_link.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | 5 | def main(args): 6 | data = [] 7 | with open(args.pred_file, "r") as f: 8 | for line in f: 9 | data.append(json.loads(line)) 10 | 11 | correct_count = 0 12 | multi_candidates_example_count = 0 13 | for i in range(len(data)): 14 | # candidate_list = data[i]["candidates_entity_desc_list"] 15 | ground_truth = data[i]["output"].strip("<>").lower() 16 | predict = data[i]["predict"].strip("<>").lower() 17 | # import pdb 18 | # pdb.set_trace() 19 | 20 | if ground_truth == predict: 21 | correct_count += 1 22 | # if len(candidate_list) > 1: 23 | # multi_candidates_example_count += 1 24 | 25 | 26 | print("correct_count:", correct_count) 27 | print("acc:", correct_count/len(data)) 28 | 29 | # print("multi_candidates_example_count:", multi_candidates_example_count) 30 | # print("multi_candidates_example_ratio:", multi_candidates_example_count/len(data)) 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser(description='arg parser') 34 | parser.add_argument('--pred_file', type=str, default='../res/ent_link_res.json', help='') 35 | args = parser.parse_args() 36 | main(args) -------------------------------------------------------------------------------- /ds_configs/stage3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "cpu", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "cpu", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": false 42 | }, 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "steps_per_print": 5, 46 | "train_batch_size": "auto", 47 | "train_micro_batch_size_per_gpu": "auto", 48 | "wall_clock_breakdown": false 49 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2D-TPE: Two-Dimensional Positional Encoding Enhances Table Understanding for Large Language Models 2 | 3 | This repository is the official implementation of 2D-TPE: Two-Dimensional Positional Encoding Enhances Table Understanding for Large Language Models. 4 | 5 | ## Requirements 6 | 7 | To install requirements: 8 | 9 | ```setup 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Dataset 14 | 15 | You can find the training and testing data at [osunlp/TableInstruct](https://huggingface.co/datasets/osunlp/TableInstruct). 16 | 17 | ## Training 18 | 19 | To train the model(s) in the paper, run this command: 20 | 21 | ```train 22 | cd src 23 | ./run.sh 24 | ``` 25 | 26 | Replace `model_name_or_path`, `output_dir`, and `data_path` with the paths to your local model, the trained model's output directory, and the location of the training data, respectively. 27 | 28 | ## Evaluation 29 | 30 | ### Inference 31 | 32 | ```eval 33 | cd src 34 | python inference.py 35 | ``` 36 | 37 | ### Evaluate 38 | 39 | >evaluate various metrics 40 | ```eval 41 | cd eval_scripts 42 | python eval_hitab.py 43 | ``` 44 | 45 | ## Pre-trained Models 46 | 47 | You can download pretrained models here: 48 | 49 | - [openbmb/MiniCPM-2B-sft-bf16-llama-format](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16-llama-format) 50 | - [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) 51 | -------------------------------------------------------------------------------- /eval_scripts/eval_fetaqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | # import evaluate 3 | import json 4 | from rouge import Rouge 5 | import numpy as np 6 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 7 | 8 | 9 | def compute_rouge(answers, targets): 10 | rouger = Rouge() 11 | rouge_1_f_scores = [] 12 | rouge_2_f_scores = [] 13 | rouge_l_f_scores = [] 14 | 15 | for idx, (answer, target) in enumerate(zip(answers, targets)): 16 | try: 17 | scores = rouger.get_scores(answer, target)[0] 18 | rouge_1_f_scores.append(scores['rouge-1']['f']) 19 | rouge_2_f_scores.append(scores['rouge-2']['f']) 20 | rouge_l_f_scores.append(scores['rouge-l']['f']) 21 | except ValueError as e: 22 | print(f"Error at index {idx}: {e}") 23 | print(f"Answer: {answer}") 24 | print(f"Target: {target}") 25 | continue 26 | 27 | avg_rouge_1_f = sum(rouge_1_f_scores) / len(rouge_1_f_scores) 28 | avg_rouge_2_f = sum(rouge_2_f_scores) / len(rouge_2_f_scores) 29 | avg_rouge_l_f = sum(rouge_l_f_scores) / len(rouge_l_f_scores) 30 | 31 | return {'rouge_1': avg_rouge_1_f, 32 | 'rouge_2': avg_rouge_2_f, 33 | 'rouge_l': avg_rouge_l_f} 34 | 35 | def compute_bleu(labels, preds, weights=None): 36 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 37 | weights = weights or (0.25, 0.25, 0.25, 0.25) 38 | return np.mean([sentence_bleu(references=[label], 39 | hypothesis=pred, 40 | smoothing_function=SmoothingFunction().method1, 41 | weights=weights) for label, pred in zip(labels, preds)]) 42 | 43 | def main(args): 44 | data = [] 45 | with open(args.pred_file, "r") as f: 46 | for line in f: 47 | data.append(json.loads(line)) 48 | 49 | test_examples_answer = [x["output"] for x in data] 50 | test_predictions_pred = [x["predict"].strip("") for x in data] 51 | predictions = test_predictions_pred 52 | references = test_examples_answer 53 | 54 | results = compute_rouge(answers=predictions, targets=references) 55 | print(results) 56 | 57 | results = compute_bleu(labels=predictions, preds=references) 58 | print(results) 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser(description='arg parser') 62 | parser.add_argument('--pred_file', type=str, default='../res/fetaqa_res.json', help='') 63 | args = parser.parse_args() 64 | main(args) -------------------------------------------------------------------------------- /TPE-Llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import ( 17 | OptionalDependencyNotAvailable, 18 | _LazyModule, 19 | is_sentencepiece_available, 20 | is_tokenizers_available, 21 | is_torch_available, 22 | ) 23 | 24 | 25 | _import_structure = { 26 | "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"], 27 | } 28 | 29 | try: 30 | if not is_sentencepiece_available(): 31 | raise OptionalDependencyNotAvailable() 32 | except OptionalDependencyNotAvailable: 33 | pass 34 | else: 35 | _import_structure["tokenization_llama"] = ["LlamaTokenizer"] 36 | 37 | try: 38 | if not is_tokenizers_available(): 39 | raise OptionalDependencyNotAvailable() 40 | except OptionalDependencyNotAvailable: 41 | pass 42 | else: 43 | _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"] 44 | 45 | try: 46 | if not is_torch_available(): 47 | raise OptionalDependencyNotAvailable() 48 | except OptionalDependencyNotAvailable: 49 | pass 50 | else: 51 | _import_structure["modeling_llama"] = [ 52 | "LlamaForCausalLM", 53 | "LlamaModel", 54 | "LlamaPreTrainedModel", 55 | "LlamaForSequenceClassification", 56 | ] 57 | 58 | 59 | if TYPE_CHECKING: 60 | from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig 61 | 62 | try: 63 | if not is_sentencepiece_available(): 64 | raise OptionalDependencyNotAvailable() 65 | except OptionalDependencyNotAvailable: 66 | pass 67 | else: 68 | from .tokenization_llama import LlamaTokenizer 69 | 70 | try: 71 | if not is_tokenizers_available(): 72 | raise OptionalDependencyNotAvailable() 73 | except OptionalDependencyNotAvailable: 74 | pass 75 | else: 76 | from .tokenization_llama_fast import LlamaTokenizerFast 77 | 78 | try: 79 | if not is_torch_available(): 80 | raise OptionalDependencyNotAvailable() 81 | except OptionalDependencyNotAvailable: 82 | pass 83 | else: 84 | from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel 85 | 86 | 87 | else: 88 | import sys 89 | 90 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 91 | -------------------------------------------------------------------------------- /eval_scripts/qa_datadump_utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for datadumping.""" 2 | import unicodedata 3 | import re 4 | from openpyxl.utils import get_column_letter, column_index_from_string 5 | import functools 6 | 7 | 8 | # Compare and sort cells 9 | def find_column(coord): 10 | """ Parse column letter from 'E3'. """ 11 | return re.findall('[a-zA-Z]+', coord) 12 | 13 | 14 | def find_row(coord): 15 | """ Parse row number from 'E3'. """ 16 | return re.findall('[0-9]+', coord) 17 | 18 | 19 | def cell_compare(cell1, cell2): 20 | """ Compare cell coord by row, then by column.""" 21 | col1, col2 = find_column(cell1)[0], find_column(cell2)[0] 22 | row1, row2 = find_row(cell1)[0], find_row(cell2)[0] 23 | if int(row1) < int(row2): 24 | return -1 25 | elif int(row1) > int(row2): 26 | return 1 27 | else: 28 | if column_index_from_string(col1) < column_index_from_string(col2): 29 | return -1 30 | else: 31 | return 1 32 | 33 | 34 | def linked_cell_compare(linked_cell_a, linked_cell_b): 35 | """ Compare answer cell coord by row, then by column.""" 36 | if isinstance(linked_cell_a[0], str) and isinstance(linked_cell_b[0], str): 37 | coord_a, coord_b = eval(linked_cell_a[0]), eval(linked_cell_b[0]) 38 | else: 39 | coord_a, coord_b = linked_cell_a[0], linked_cell_b[0] 40 | if coord_a[0] < coord_b[0]: 41 | return -1 42 | elif coord_a[0] > coord_b[0]: 43 | return 1 44 | else: 45 | if coord_a[1] < coord_b[1]: 46 | return -1 47 | else: 48 | return 1 49 | 50 | 51 | def sort_region_by_coord(cells): 52 | """ Sort cells by coords, according to cell_compare(). """ 53 | cell_list = sorted(cells, key=functools.cmp_to_key(cell_compare)) 54 | cell_matrix = [] 55 | last_row = None 56 | for cell in cell_list: 57 | col, row = find_column(cell), find_row(cell) 58 | if row == last_row: 59 | cell_matrix[-1].append(cell) 60 | else: 61 | last_row = row 62 | cell_matrix.append([cell]) 63 | return cell_list, cell_matrix 64 | 65 | 66 | # -------------------------------------------- 67 | # Normalize and Inferring Types. 68 | def normalize(x): 69 | """ Normalize header string. """ 70 | # Copied from WikiTableQuestions dataset official evaluator. 71 | if x is None: 72 | return None 73 | # Remove diacritics 74 | x = ''.join(c for c in unicodedata.normalize('NFKD', x) 75 | if unicodedata.category(c) != 'Mn') 76 | # Normalize quotes and dashes 77 | x = re.sub("[‘’´`]", "'", x) 78 | x = re.sub("[“”]", "\"", x) 79 | x = re.sub("[‐‑‒–—−]", "-", x) 80 | while True: 81 | old_x = x 82 | # Remove citations 83 | x = re.sub("((?")[0].split(", ") 92 | ground_truth_list.append(ground_truth) 93 | pred_list.append(pred) 94 | 95 | get_r_p_f1_for_each_type(ground_truth_list, pred_list) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description='arg parser') 100 | parser.add_argument('--pred_file', type=str, default='../res/rel_extraction_res.json', help='') 101 | args = parser.parse_args() 102 | main(args) -------------------------------------------------------------------------------- /eval_scripts/eval_col_type.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from collections import Counter 4 | from metric import * 5 | 6 | 7 | def r_p_f1(true_positive, pred, targets): 8 | if pred != 0: 9 | precision = true_positive/pred 10 | else: 11 | precision = 0 12 | recall = true_positive/targets 13 | if (precision + recall) != 0: 14 | f1 = 2 * precision * recall / (precision + recall) 15 | else: 16 | f1 = 0 17 | return recall, precision, f1 18 | 19 | def get_r_p_f1_for_each_type(ground_truth_list, pred_list): 20 | 21 | # import pdb 22 | # pdb.set_trace() 23 | 24 | total_ground_truth_col_types = 0 25 | total_pred_col_types = 0 26 | joint_items_list = [] 27 | for i in range(len(ground_truth_list)): 28 | total_ground_truth_col_types += len(ground_truth_list[i]) 29 | total_pred_col_types += len(pred_list[i]) 30 | joint_items = [item for item in pred_list[i] if item in ground_truth_list[i]] 31 | # total_ground_truth_col_types += len(list(set(ground_truth_list[i]))) 32 | # total_pred_col_types += len(list(set(pred_list[i]))) 33 | # joint_items = [item for item in list(set(pred_list[i])) if item in list(set(ground_truth_list[i]))] 34 | joint_items_list += joint_items 35 | 36 | # import pdb 37 | # pdb.set_trace() 38 | 39 | gt_entire_col_type = {} 40 | for i in range(len(ground_truth_list)): 41 | gt = list(set(ground_truth_list[i])) 42 | for k in range(len(gt)): 43 | if gt[k] not in gt_entire_col_type.keys(): 44 | gt_entire_col_type[gt[k]] = 1 45 | else: 46 | gt_entire_col_type[gt[k]] += 1 47 | # print(len(gt_entire_col_type.keys())) 48 | 49 | pd_entire_col_type = {} 50 | for i in range(len(pred_list)): 51 | pd = list(set(pred_list[i])) 52 | for k in range(len(pd)): 53 | if pd[k] not in pd_entire_col_type.keys(): 54 | pd_entire_col_type[pd[k]] = 1 55 | else: 56 | pd_entire_col_type[pd[k]] += 1 57 | # print(len(pd_entire_col_type.keys())) 58 | 59 | joint_entire_col_type = {} 60 | for i in range(len(joint_items_list)): 61 | if joint_items_list[i] not in joint_entire_col_type.keys(): 62 | joint_entire_col_type[joint_items_list[i]] = 1 63 | else: 64 | joint_entire_col_type[joint_items_list[i]] += 1 65 | # print(len(joint_entire_col_type.keys())) 66 | 67 | precision = len(joint_items_list)/total_pred_col_types 68 | recall = len(joint_items_list)/total_ground_truth_col_types 69 | f1 = 2 * precision * recall / (precision + recall) 70 | 71 | sorted_gt = sorted(gt_entire_col_type.items(), key=lambda x: x[1], reverse = True) 72 | # print(sorted_gt) 73 | # print("len(joint_items_list):", len(joint_items_list)) 74 | # print("total_ground_truth_col_types:", total_ground_truth_col_types) 75 | # print("total_pred_col_types:", total_pred_col_types) 76 | print("precision::", precision) 77 | print("recall:", recall) 78 | print("f1:", f1) 79 | 80 | # print('r_p_f1(joint_entire_col_type["people.person"])', r_p_f1(joint_entire_col_type["people.person"], pd_entire_col_type["people.person"], gt_entire_col_type["people.person"])) 81 | # print('r_p_f1(joint_entire_col_type["sports.pro_athlete"])', r_p_f1(joint_entire_col_type["sports.pro_athlete"], pd_entire_col_type["sports.pro_athlete"], gt_entire_col_type["sports.pro_athlete"])) 82 | # print('r_p_f1(joint_entire_col_type["film.actor"])', r_p_f1(joint_entire_col_type["film.actor"], pd_entire_col_type["film.actor"], gt_entire_col_type["film.actor"])) 83 | # print('r_p_f1(joint_entire_col_type["location.location"])', r_p_f1(joint_entire_col_type["location.location"], pd_entire_col_type["location.location"], gt_entire_col_type["location.location"])) 84 | # print('r_p_f1(joint_entire_col_type["location.citytown"])', r_p_f1(joint_entire_col_type["location.citytown"], pd_entire_col_type["location.citytown"], gt_entire_col_type["location.citytown"])) 85 | 86 | 87 | def remove_ele(list, ele): 88 | while True: 89 | if ele in list: 90 | list.remove(ele) 91 | continue 92 | else: 93 | break 94 | return list 95 | 96 | def get_index(lst=None, item=''): 97 | return [i for i in range(len(lst)) if lst[i] == item] 98 | 99 | def main(args): 100 | col_type = [] 101 | with open(args.pred_file, "r") as f: 102 | for line in f: 103 | col_type.append(json.loads(line)) 104 | 105 | ground_truth_list = [] 106 | pred_list = [] 107 | for i in range(len(col_type)): 108 | item = col_type[i] 109 | # ground_truth = item["ground_truth"] 110 | ground_truth = item["output"].split(", ") 111 | # pred = item["predict"].strip("").split(",") 112 | pred = item["predict"].split("")[0].split(", ") 113 | ground_truth_list.append(ground_truth) 114 | pred_list.append(pred) 115 | 116 | get_r_p_f1_for_each_type(ground_truth_list, pred_list) 117 | 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser(description='arg parser') 121 | parser.add_argument('--pred_file', type=str, default='../res/col_type_res.json', help='') 122 | args = parser.parse_args() 123 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.21.0 3 | adabench==1.2.64 4 | aii-pypai==0.1.40.45 5 | aiohttp==3.9.5 6 | aiosignal==1.3.1 7 | aistudio-analyzer==0.0.4.102 8 | aistudio-checkpoint==0.0.240422 9 | aistudio-common==0.0.28.51 10 | aistudio-notebook==2.0.127 11 | aistudio-serving==0.0.0.74 12 | alipay-pcache==0.1.8 13 | aliyun-python-sdk-core==2.15.1 14 | aliyun-python-sdk-kms==2.16.2 15 | ant-couler==0.0.1rc17 16 | anyio==4.3.0 17 | argo-workflows==3.5.1 18 | argon2-cffi==23.1.0 19 | argon2-cffi-bindings==21.2.0 20 | arrow==1.3.0 21 | astroid==3.1.0 22 | async-timeout==4.0.3 23 | atorch==1.2.1 24 | autopep8==2.0.4 25 | babel==2.16.0 26 | boto3==1.34.88 27 | botocore==1.34.88 28 | cheroot==10.0.0 29 | click==8.1.7 30 | click-config-file==0.6.0 31 | cloudpickle==3.0.0 32 | comm==0.2.2 33 | configobj==5.0.8 34 | configparser==7.0.0 35 | contourpy==1.1.1 36 | couler-core==0.1.1rc11 37 | crcmod==1.7 38 | cycler==0.12.1 39 | Cython==3.0.10 40 | datasets==2.15.0 41 | debugpy==1.8.1 42 | deepspeed==0.10.3 43 | defusedxml==0.7.1 44 | delta-center-client==0.0.4 45 | Deprecated==1.2.14 46 | deprecation==2.1.0 47 | diffusers==0.18.2 48 | dill==0.3.7 49 | distlib==0.3.8 50 | dlrover==0.4.2 51 | docker==4.1.0 52 | docopt==0.6.2 53 | docstring-to-markdown==0.15 54 | easydl-sdk==0.0.6 55 | einops==0.7.0 56 | entrypoints==0.4 57 | et-xmlfile==1.1.0 58 | evaluate==0.4.0 59 | exceptiongroup==1.2.1 60 | fairscale==0.4.1 61 | fastjsonschema==2.19.1 62 | fasttext==0.9.2 63 | fe==0.3.33 64 | filesplit==4.0.1 65 | fire==0.5.0 66 | flake8==7.0.0 67 | flash-attn==2.3.6 68 | Flask==3.0.3 69 | fonttools==4.51.0 70 | fqdn==1.5.1 71 | frozenlist==1.4.1 72 | fsspec==2023.10.0 73 | ftfy==6.2.0 74 | gitdb==4.0.11 75 | GitPython==3.1.43 76 | google-auth==2.29.0 77 | google-auth-oauthlib==0.4.6 78 | greenlet==3.0.3 79 | grpcio==1.34.1 80 | grpcio-tools==1.34.1 81 | hjson==3.1.0 82 | huggingface-hub==0.17.3 83 | icetk==0.0.7 84 | importlib_metadata==7.1.0 85 | iniconfig==2.0.0 86 | invisible-watermark==0.2.0 87 | ipykernel==6.29.4 88 | ipython-genutils==0.2.0 89 | isodate==0.6.1 90 | isoduration==20.11.0 91 | isort==5.13.2 92 | itsdangerous==2.2.0 93 | jaraco.functools==4.0.1 94 | jdcal==1.4.1 95 | jedi==0.19.1 96 | jedi-language-server==0.41.4 97 | Jinja2==2.11.3 98 | jinjasql==0.1.8 99 | jmespath==0.10.0 100 | joblib==1.4.0 101 | jsonpath-ng==1.6.1 102 | jsonpointer==2.1 103 | jupyter-events==0.10.0 104 | jupyter-lsp==2.2.1 105 | jupyter_client==8.6.1 106 | jupyter_core==5.7.2 107 | jupyter_server==2.10.1 108 | jupyter_server_terminals==0.5.3 109 | jupyterlab_pygments==0.3.0 110 | kiwisolver==1.4.5 111 | kmitool==0.0.9 112 | kubemaker==0.2.17 113 | kubernetes==29.0.0 114 | langdetect==1.0.9 115 | loralib==0.1.1 116 | lsh==0.1.2 117 | lsprotocol==2023.0.1 118 | lxml==5.2.1 119 | M2Crypto==0.38.0 120 | Markdown==3.6 121 | markdown-it-py==3.0.0 122 | MarkupSafe==2.0.1 123 | marshmallow==3.21.1 124 | matplotlib==3.7.5 125 | max-common==0.20240411.1 126 | mccabe==0.7.0 127 | mdatasets==0.3.5 128 | mdurl==0.1.2 129 | mpi4py==3.1.6 130 | mpmath==1.3.0 131 | msgpack==1.0.8 132 | multidict==6.0.5 133 | multiprocess==0.70.15 134 | nbclient==0.5.13 135 | nbconvert==6.4.4 136 | nbformat==5.10.4 137 | nest-asyncio==1.6.0 138 | networkx==3.0 139 | ninja==1.11.1.1 140 | nltk==3.8.1 141 | notebook==6.4.6 142 | oauthlib==3.2.2 143 | odps==3.5.1 144 | openai==0.28.1 145 | opencv-python-headless==4.9.0.80 146 | opendelta==0.3.2 147 | openpyxl==2.4.11 148 | oss2==2.17.0 149 | ossfs==2021.8.0 150 | overrides==3.1.0 151 | pandas==1.0.0 152 | pandocfilters==1.5.1 153 | parameterized==0.9.0 154 | pathos==0.3.0 155 | peft==0.3.0 156 | peppercorn==0.6 157 | pillow==10.3.0 158 | ply==3.11 159 | pox==0.3.4 160 | ppft==1.7.6.8 161 | prettytable==3.10.0 162 | prometheus_client==0.20.0 163 | py==1.11.0 164 | py-cpuinfo==9.0.0 165 | py-spy==0.3.14 166 | pyaml==21.10.1 167 | pyarrow==12.0.0 168 | pyarrow-hotfix==0.6 169 | pyasn1==0.6.0 170 | pyasn1_modules==0.4.0 171 | pycryptodome==3.20.0 172 | pydantic==1.10.15 173 | pyDes==2.0.1 174 | pydocstyle==6.3.0 175 | pyflakes==3.2.0 176 | pygls==1.3.1 177 | pyhocon==0.3.60 178 | pyinotify==0.9.6 179 | pylint==3.1.0 180 | pynvml==11.4.1 181 | pyodps==0.11.6 182 | Pyomo==6.7.1 183 | pyparsing==2.0.3 184 | pytest==7.4.3 185 | python-dateutil==2.9.0.post0 186 | python-json-logger==2.0.7 187 | python-lsp-jsonrpc==1.1.2 188 | python-lsp-server==1.11.0 189 | pytoolconfig==1.3.1 190 | requests-file==2.0.0 191 | requests-oauthlib==2.0.0 192 | requests-toolbelt==1.0.0 193 | responses==0.18.0 194 | retry==0.9.2 195 | rfc3339-validator==0.1.4 196 | rfc3986-validator==0.1.1 197 | rich==13.7.1 198 | rope==1.13.0 199 | rouge-chinese==1.0.3 200 | rouge-score==0.1.2 201 | rsa==4.9 202 | ruamel.yaml==0.16.10 203 | ruamel.yaml.clib==0.2.8 204 | ruff==0.4.1 205 | ruff-lsp==0.0.53 206 | s3transfer==0.10.1 207 | safetensors==0.4.3 208 | scikit-learn==1.3.2 209 | scipy==1.10.1 210 | Send2Trash==1.8.3 211 | sentencepiece==0.1.97 212 | smmap==5.0.1 213 | sniffio==1.3.1 214 | snowballstemmer==2.2.0 215 | stringcase==1.2.0 216 | StringGenerator==0.4.4 217 | sympy==1.12 218 | tablib==3.6.1 219 | tabulate==0.9.0 220 | tenacity==9.0.0 221 | tensorboard==2.11.0 222 | tensorboard-data-server==0.6.1 223 | tensorboard-plugin-wit==1.8.1 224 | tensorboardX==2.6 225 | termcolor==2.4.0 226 | terminado==0.18.1 227 | testpath==0.6.0 228 | threadpoolctl==3.4.0 229 | tiktoken==0.6.0 230 | tinycss2==1.2.1 231 | titans==0.0.7 232 | tldextract==5.1.2 233 | tokenizers==0.14.1 234 | tomlkit==0.12.4 235 | torch==2.1.2 236 | torchvision==0.16.0 237 | tornado==6.4 238 | transformers==4.34.1 239 | transformers-stream-generator==0.0.5 240 | tzdata==2024.1 241 | ujson==5.9.0 242 | uncertainty-calibration==0.1.4 243 | Unidecode==1.3.8 244 | unifile-sdk==0.1.17 245 | uri-template==1.3.0 246 | urllib3==1.26.18 247 | virtualenv==20.25.3 248 | watchdog==2.3.1 249 | wcwidth==0.2.13 250 | web.py==0.62 251 | webcolors==1.13 252 | webencodings==0.5.1 253 | websocket-client==1.7.0 254 | Werkzeug==3.0.2 255 | wfbuilder==1.0.56.43 256 | wget==3.2 257 | whatthepatch==1.0.5 258 | wrapt==1.16.0 259 | xattr==1.1.0 260 | xxhash==3.4.1 261 | yacs==0.1.8 262 | yapf==0.40.2 263 | yarl==1.9.4 264 | zdfs-dfs==2.3.2 265 | zeep==4.2.1 266 | -------------------------------------------------------------------------------- /eval_scripts/metric.py: -------------------------------------------------------------------------------- 1 | """Information Retrieval metrics 2 | Useful Resources: 3 | http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt 4 | http://www.nii.ac.jp/TechReports/05-014E.pdf 5 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 6 | http://hal.archives-ouvertes.fr/docs/00/72/67/60/PDF/07-busa-fekete.pdf 7 | Learning to Rank for Information Retrieval (Tie-Yan Liu) 8 | """ 9 | import numpy as np 10 | import pdb 11 | import os 12 | import pickle 13 | 14 | 15 | def mean_reciprocal_rank(rs): 16 | """Score is reciprocal of the rank of the first relevant item 17 | First element is 'rank 1'. Relevance is binary (nonzero is relevant). 18 | Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank 19 | >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] 20 | >>> mean_reciprocal_rank(rs) 21 | 0.61111111111111105 22 | >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) 23 | >>> mean_reciprocal_rank(rs) 24 | 0.5 25 | >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]] 26 | >>> mean_reciprocal_rank(rs) 27 | 0.75 28 | Args: 29 | rs: Iterator of relevance scores (list or numpy) in rank order 30 | (first element is the first item) 31 | Returns: 32 | Mean reciprocal rank 33 | """ 34 | rs = (np.asarray(r).nonzero()[0] for r in rs) 35 | return np.mean([1. / (r[0] + 1) if r.size else 0. for r in rs]) 36 | 37 | 38 | def r_precision(r): 39 | """Score is precision after all relevant documents have been retrieved 40 | Relevance is binary (nonzero is relevant). 41 | >>> r = [0, 0, 1] 42 | >>> r_precision(r) 43 | 0.33333333333333331 44 | >>> r = [0, 1, 0] 45 | >>> r_precision(r) 46 | 0.5 47 | >>> r = [1, 0, 0] 48 | >>> r_precision(r) 49 | 1.0 50 | Args: 51 | r: Relevance scores (list or numpy) in rank order 52 | (first element is the first item) 53 | Returns: 54 | R Precision 55 | """ 56 | r = np.asarray(r) != 0 57 | z = r.nonzero()[0] 58 | if not z.size: 59 | return 0. 60 | return np.mean(r[:z[-1] + 1]) 61 | 62 | 63 | def precision_at_k(r, k): 64 | """Score is precision @ k 65 | Relevance is binary (nonzero is relevant). 66 | >>> r = [0, 0, 1] 67 | >>> precision_at_k(r, 1) 68 | 0.0 69 | >>> precision_at_k(r, 2) 70 | 0.0 71 | >>> precision_at_k(r, 3) 72 | 0.33333333333333331 73 | >>> precision_at_k(r, 4) 74 | Traceback (most recent call last): 75 | File "", line 1, in ? 76 | ValueError: Relevance score length < k 77 | Args: 78 | r: Relevance scores (list or numpy) in rank order 79 | (first element is the first item) 80 | Returns: 81 | Precision @ k 82 | Raises: 83 | ValueError: len(r) must be >= k 84 | """ 85 | assert k >= 1 86 | r = np.asarray(r)[:k] != 0 87 | if r.size != k: 88 | raise ValueError('Relevance score length < k') 89 | return np.mean(r) 90 | 91 | 92 | def average_precision(r): 93 | """Score is average precision (area under PR curve) 94 | Relevance is binary (nonzero is relevant). 95 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1] 96 | >>> delta_r = 1. / sum(r) 97 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y]) 98 | 0.7833333333333333 99 | >>> average_precision(r) 100 | 0.78333333333333333 101 | Args: 102 | r: Relevance scores (list or numpy) in rank order 103 | (first element is the first item) 104 | Returns: 105 | Average precision 106 | """ 107 | r = np.asarray(r) != 0 108 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]] 109 | if not out: 110 | return 0. 111 | return np.mean(out) 112 | 113 | 114 | def mean_average_precision(rs): 115 | """Score is mean average precision 116 | Relevance is binary (nonzero is relevant). 117 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]] 118 | >>> mean_average_precision(rs) 119 | 0.78333333333333333 120 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]] 121 | >>> mean_average_precision(rs) 122 | 0.39166666666666666 123 | Args: 124 | rs: Iterator of relevance scores (list or numpy) in rank order 125 | (first element is the first item) 126 | Returns: 127 | Mean average precision 128 | """ 129 | return np.mean([average_precision(r) for r in rs]) 130 | 131 | def row_pop_average_precision(r, target): 132 | """Score is average precision (area under PR curve) 133 | Relevance is binary (nonzero is relevant). 134 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1] 135 | >>> delta_r = 1. / sum(r) 136 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y]) 137 | 0.7833333333333333 138 | >>> average_precision(r) 139 | 0.78333333333333333 140 | Args: 141 | r: Relevance scores (list or numpy) in rank order 142 | (first element is the first item) 143 | Returns: 144 | Average precision 145 | """ 146 | r = np.asarray(r) != 0 147 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]] 148 | if len(out) < len(target): 149 | out += [0] * (len(target) - len(out)) 150 | if not out: 151 | return 0. 152 | return np.mean(out) 153 | 154 | 155 | def dcg_at_k(r, k, method=0): 156 | """Score is discounted cumulative gain (dcg) 157 | Relevance is positive real values. Can use binary 158 | as the previous methods. 159 | Example from 160 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 161 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 162 | >>> dcg_at_k(r, 1) 163 | 3.0 164 | >>> dcg_at_k(r, 1, method=1) 165 | 3.0 166 | >>> dcg_at_k(r, 2) 167 | 5.0 168 | >>> dcg_at_k(r, 2, method=1) 169 | 4.2618595071429155 170 | >>> dcg_at_k(r, 10) 171 | 9.6051177391888114 172 | >>> dcg_at_k(r, 11) 173 | 9.6051177391888114 174 | Args: 175 | r: Relevance scores (list or numpy) in rank order 176 | (first element is the first item) 177 | k: Number of results to consider 178 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 179 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 180 | Returns: 181 | Discounted cumulative gain 182 | """ 183 | r = np.asfarray(r)[:k] 184 | if r.size: 185 | if method == 0: 186 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 187 | elif method == 1: 188 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 189 | else: 190 | raise ValueError('method must be 0 or 1.') 191 | return 0. 192 | 193 | 194 | def ndcg_at_k(r, k, method=0): 195 | """Score is normalized discounted cumulative gain (ndcg) 196 | Relevance is positive real values. Can use binary 197 | as the previous methods. 198 | Example from 199 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 200 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 201 | >>> ndcg_at_k(r, 1) 202 | 1.0 203 | >>> r = [2, 1, 2, 0] 204 | >>> ndcg_at_k(r, 4) 205 | 0.9203032077642922 206 | >>> ndcg_at_k(r, 4, method=1) 207 | 0.96519546960144276 208 | >>> ndcg_at_k([0], 1) 209 | 0.0 210 | >>> ndcg_at_k([1], 2) 211 | 1.0 212 | Args: 213 | r: Relevance scores (list or numpy) in rank order 214 | (first element is the first item) 215 | k: Number of results to consider 216 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 217 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 218 | Returns: 219 | Normalized discounted cumulative gain 220 | """ 221 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 222 | if not dcg_max: 223 | return 0. 224 | return dcg_at_k(r, k, method) / dcg_max -------------------------------------------------------------------------------- /TPE-Llama/configuration_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class LlamaConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the LLaMA-7B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`LlamaModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 11008): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer encoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer encoder. 53 | num_key_value_heads (`int`, *optional*): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 60 | `num_attention_heads`. 61 | pretraining_tp (`int`, *optional*, defaults to `1`): 62 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 63 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 64 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 65 | issue](https://github.com/pytorch/pytorch/issues/76232). 66 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 67 | The non-linear activation function (function or string) in the decoder. 68 | max_position_embeddings (`int`, *optional*, defaults to 2048): 69 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, 70 | Llama 2 up to 4096, CodeLlama up to 16384. 71 | initializer_range (`float`, *optional*, defaults to 0.02): 72 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 73 | rms_norm_eps (`float`, *optional*, defaults to 1e-12): 74 | The epsilon used by the rms normalization layers. 75 | use_cache (`bool`, *optional*, defaults to `True`): 76 | Whether or not the model should return the last key/values attentions (not used by all models). Only 77 | relevant if `config.is_decoder=True`. 78 | tie_word_embeddings(`bool`, *optional*, defaults to `False`): 79 | Whether to tie weight embeddings 80 | rope_theta (`float`, *optional*, defaults to 10000.0): 81 | The base period of the RoPE embeddings. 82 | rope_scaling (`Dict`, *optional*): 83 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 84 | strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format 85 | is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 86 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 87 | these scaling strategies behave: 88 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 89 | experimental feature, subject to breaking API changes in future versions. 90 | attention_bias (`bool`, defaults to `False`): 91 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 92 | 93 | Example: 94 | 95 | ```python 96 | >>> from transformers import LlamaModel, LlamaConfig 97 | 98 | >>> # Initializing a LLaMA llama-7b style configuration 99 | >>> configuration = LlamaConfig() 100 | 101 | >>> # Initializing a model from the llama-7b style configuration 102 | >>> model = LlamaModel(configuration) 103 | 104 | >>> # Accessing the model configuration 105 | >>> configuration = model.config 106 | ```""" 107 | model_type = "llama" 108 | keys_to_ignore_at_inference = ["past_key_values"] 109 | 110 | def __init__( 111 | self, 112 | vocab_size=32000, 113 | hidden_size=4096, 114 | intermediate_size=11008, 115 | num_hidden_layers=32, 116 | num_attention_heads=32, 117 | num_key_value_heads=None, 118 | hidden_act="silu", 119 | max_position_embeddings=2048, 120 | initializer_range=0.02, 121 | rms_norm_eps=1e-6, 122 | use_cache=True, 123 | pad_token_id=None, 124 | bos_token_id=1, 125 | eos_token_id=2, 126 | pretraining_tp=1, 127 | tie_word_embeddings=False, 128 | rope_theta=10000.0, 129 | rope_scaling=None, 130 | attention_bias=False, 131 | **kwargs, 132 | ): 133 | self.vocab_size = vocab_size 134 | self.max_position_embeddings = max_position_embeddings 135 | self.hidden_size = hidden_size 136 | self.intermediate_size = intermediate_size 137 | self.num_hidden_layers = num_hidden_layers 138 | self.num_attention_heads = num_attention_heads 139 | 140 | # for backward compatibility 141 | if num_key_value_heads is None: 142 | num_key_value_heads = num_attention_heads 143 | 144 | self.num_key_value_heads = num_key_value_heads 145 | self.hidden_act = hidden_act 146 | self.initializer_range = initializer_range 147 | self.rms_norm_eps = rms_norm_eps 148 | self.pretraining_tp = pretraining_tp 149 | self.use_cache = use_cache 150 | self.rope_theta = rope_theta 151 | self.rope_scaling = rope_scaling 152 | self._rope_scaling_validation() 153 | self.attention_bias = attention_bias 154 | 155 | super().__init__( 156 | pad_token_id=pad_token_id, 157 | bos_token_id=bos_token_id, 158 | eos_token_id=eos_token_id, 159 | tie_word_embeddings=tie_word_embeddings, 160 | **kwargs, 161 | ) 162 | 163 | def _rope_scaling_validation(self): 164 | """ 165 | Validate the `rope_scaling` configuration. 166 | """ 167 | if self.rope_scaling is None: 168 | return 169 | 170 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 171 | raise ValueError( 172 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 173 | f"got {self.rope_scaling}" 174 | ) 175 | rope_scaling_type = self.rope_scaling.get("type", None) 176 | rope_scaling_factor = self.rope_scaling.get("factor", None) 177 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 178 | raise ValueError( 179 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 180 | ) 181 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 182 | raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") 183 | -------------------------------------------------------------------------------- /TPE-Llama/tokenization_llama_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | from shutil import copyfile 17 | from typing import Optional, Tuple 18 | 19 | from tokenizers import processors 20 | 21 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 22 | from transformers.utils import is_sentencepiece_available, logging 23 | from transformers.utils.versions import require_version 24 | 25 | 26 | require_version("tokenizers>=0.13.3") 27 | 28 | if is_sentencepiece_available(): 29 | from .tokenization_llama import LlamaTokenizer 30 | else: 31 | LlamaTokenizer = None 32 | 33 | logger = logging.get_logger(__name__) 34 | VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"} 35 | 36 | PRETRAINED_VOCAB_FILES_MAP = { 37 | "vocab_file": { 38 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", 39 | }, 40 | "tokenizer_file": { 41 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", 42 | }, 43 | } 44 | B_INST, E_INST = "[INST]", "[/INST]" 45 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 46 | 47 | # fmt: off 48 | DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ 49 | answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ 50 | that your responses are socially unbiased and positive in nature. 51 | 52 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ 53 | correct. If you don't know the answer to a question, please don't share false information.""" 54 | # fmt: on 55 | 56 | 57 | class LlamaTokenizerFast(PreTrainedTokenizerFast): 58 | """ 59 | Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. 60 | 61 | This uses notably ByteFallback and no normalization. 62 | 63 | ``` 64 | from transformers import LlamaTokenizerFast 65 | 66 | tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") 67 | tokenizer.encode("Hello this is a test") 68 | >>> [1, 15043, 445, 338, 263, 1243] 69 | ``` 70 | 71 | If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or 72 | call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the 73 | values of the first token and final token of an encoded sequence will not be correct). For more details, checkout 74 | [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. 75 | 76 | 77 | This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should 78 | refer to this superclass for more information regarding those methods. 79 | 80 | Args: 81 | vocab_file (`str`): 82 | [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that 83 | contains the vocabulary necessary to instantiate a tokenizer. 84 | tokenizer_file (`str`): 85 | [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that 86 | contains everything needed to load the tokenizer. 87 | 88 | clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`): 89 | Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra 90 | spaces. 91 | 92 | bos_token (`str`, *optional*, defaults to `""`): 93 | The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. 94 | 95 | eos_token (`str`, *optional*, defaults to `""`): 96 | The end of sequence token. 97 | 98 | unk_token (`str`, *optional*, defaults to `""`): 99 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 100 | token instead. 101 | """ 102 | 103 | vocab_files_names = VOCAB_FILES_NAMES 104 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 105 | slow_tokenizer_class = LlamaTokenizer 106 | padding_side = "left" 107 | model_input_names = ["input_ids", "attention_mask"] 108 | 109 | def __init__( 110 | self, 111 | vocab_file=None, 112 | tokenizer_file=None, 113 | clean_up_tokenization_spaces=False, 114 | unk_token="", 115 | bos_token="", 116 | eos_token="", 117 | add_bos_token=True, 118 | add_eos_token=False, 119 | use_default_system_prompt=True, 120 | **kwargs, 121 | ): 122 | super().__init__( 123 | vocab_file=vocab_file, 124 | tokenizer_file=tokenizer_file, 125 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 126 | unk_token=unk_token, 127 | bos_token=bos_token, 128 | eos_token=eos_token, 129 | use_default_system_prompt=use_default_system_prompt, 130 | **kwargs, 131 | ) 132 | self._add_bos_token = add_bos_token 133 | self._add_eos_token = add_eos_token 134 | self.update_post_processor() 135 | self.use_default_system_prompt = use_default_system_prompt 136 | self.vocab_file = vocab_file 137 | 138 | @property 139 | def can_save_slow_tokenizer(self) -> bool: 140 | return os.path.isfile(self.vocab_file) if self.vocab_file else False 141 | 142 | def update_post_processor(self): 143 | """ 144 | Updates the underlying post processor with the current `bos_token` and `eos_token`. 145 | """ 146 | bos = self.bos_token 147 | bos_token_id = self.bos_token_id 148 | 149 | eos = self.eos_token 150 | eos_token_id = self.eos_token_id 151 | 152 | single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" 153 | pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" 154 | 155 | special_tokens = [] 156 | if self.add_bos_token: 157 | special_tokens.append((bos, bos_token_id)) 158 | if self.add_eos_token: 159 | special_tokens.append((eos, eos_token_id)) 160 | self._tokenizer.post_processor = processors.TemplateProcessing( 161 | single=single, pair=pair, special_tokens=special_tokens 162 | ) 163 | 164 | @property 165 | def add_eos_token(self): 166 | return self._add_eos_token 167 | 168 | @property 169 | def add_bos_token(self): 170 | return self._add_bos_token 171 | 172 | @add_eos_token.setter 173 | def add_eos_token(self, value): 174 | self._add_eos_token = value 175 | self.update_post_processor() 176 | 177 | @add_bos_token.setter 178 | def add_bos_token(self, value): 179 | self._add_bos_token = value 180 | self.update_post_processor() 181 | 182 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 183 | if not self.can_save_slow_tokenizer: 184 | raise ValueError( 185 | "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " 186 | "tokenizer." 187 | ) 188 | 189 | if not os.path.isdir(save_directory): 190 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 191 | return 192 | out_vocab_file = os.path.join( 193 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 194 | ) 195 | 196 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 197 | copyfile(self.vocab_file, out_vocab_file) 198 | 199 | return (out_vocab_file,) 200 | 201 | @property 202 | # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template 203 | def default_chat_template(self): 204 | """ 205 | LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. 206 | Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict 207 | user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering 208 | rather than needing special tokens. The system message is partly 'embedded' in the first user message, which 209 | results in an unusual token ordering when it is present. This template should definitely be changed if you wish 210 | to fine-tune a model with more flexible role ordering! 211 | 212 | The output should look something like: 213 | 214 | [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer 215 | [INST] Prompt [/INST] 216 | """ 217 | 218 | template = ( 219 | "{% if messages[0]['role'] == 'system' %}" 220 | "{% set loop_messages = messages[1:] %}" # Extract system message if it's present 221 | "{% set system_message = messages[0]['content'] %}" 222 | "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" 223 | "{% set loop_messages = messages %}" # Or use the default system message if the flag is set 224 | "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" 225 | "{% else %}" 226 | "{% set loop_messages = messages %}" 227 | "{% set system_message = false %}" 228 | "{% endif %}" 229 | "{% for message in loop_messages %}" # Loop over all non-system messages 230 | "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" 231 | "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" 232 | "{% endif %}" 233 | "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message 234 | "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" 235 | "{% else %}" 236 | "{% set content = message['content'] %}" 237 | "{% endif %}" 238 | "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way 239 | "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" 240 | "{% elif message['role'] == 'system' %}" 241 | "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" 242 | "{% elif message['role'] == 'assistant' %}" 243 | "{{ ' ' + content.strip() + ' ' + eos_token }}" 244 | "{% endif %}" 245 | "{% endfor %}" 246 | ) 247 | template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") 248 | default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") 249 | template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) 250 | 251 | return template 252 | 253 | # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers 254 | # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens 255 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 256 | bos_token_id = [self.bos_token_id] if self.add_bos_token else [] 257 | eos_token_id = [self.eos_token_id] if self.add_eos_token else [] 258 | 259 | output = bos_token_id + token_ids_0 + eos_token_id 260 | 261 | if token_ids_1 is not None: 262 | output = output + bos_token_id + token_ids_1 + eos_token_id 263 | 264 | return output 265 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoTokenizer 2 | import torch 3 | import numpy as np 4 | from typing import List 5 | import multiprocessing 6 | import datasets 7 | import queue 8 | import time 9 | import os 10 | import pickle 11 | import json 12 | import logging 13 | import re 14 | from tqdm import tqdm 15 | from sft_minicpm import PROMPT_DICT 16 | import sys 17 | import os 18 | import random 19 | from tqdm import tqdm 20 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) 21 | 22 | if project_root not in sys.path: 23 | sys.path.append(project_root) 24 | 25 | from TPE-Llama.modeling_llama import LlamaForCausalLM 26 | 27 | torch.set_printoptions(profile="full") 28 | torch.multiprocessing.set_start_method('spawn',force=True) 29 | num_workers = 32 30 | gpu_num = 8 31 | max_length = 4096 32 | max_new_tokens = 100 33 | 34 | def generate_prompt(instruction, question, input_seg=None): 35 | if input_seg: 36 | return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question) 37 | else: 38 | return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) 39 | 40 | 41 | def read_data(input_file, to_tokenize_queue): 42 | with open(input_file, "r") as f: 43 | ds = json.load(f) 44 | 45 | print(len(ds)) 46 | 47 | for i, data in tqdm(enumerate(ds), total=len(ds)): 48 | data['idx'] = i 49 | to_tokenize_queue.put(data) 50 | 51 | for i in range(num_workers): 52 | to_tokenize_queue.put(None) 53 | 54 | 55 | def encode_and_insert_separators(table_array, tokenizer): 56 | separator_col = [1425] # '▁|' 57 | separator_row = [48017] # '-' 58 | 59 | separator_row_end = [3] # '' 60 | separator_col_end = [4] # '' 61 | 62 | new_table = [] 63 | 64 | for k, row in enumerate(table_array): 65 | new_row, new_separator = [], [] 66 | for col in row: 67 | encoded_col = tokenizer.encode(str(col), add_special_tokens=False) 68 | new_row.append(encoded_col) 69 | new_row.append(separator_col) # Insert '|' between each coded column 70 | 71 | new_separator.append(separator_col_end if k == len(table_array) - 1 else separator_row) 72 | new_separator.append(separator_col) 73 | new_row.append(separator_row_end) 74 | new_separator.append(separator_row_end) 75 | new_table.append(new_row) 76 | new_table.append(new_separator) 77 | return new_table 78 | 79 | 80 | def tokenize_data(to_tokenize_queue, to_output_queue, rank): 81 | model_name = '/output/rel_extraction_2d' 82 | config = AutoConfig.from_pretrained(model_name) 83 | config.remove_unused_columns = False 84 | config._flash_attn_2_enabled = True 85 | config.output_loss = False 86 | config.pad_token_id = 0 87 | model = LlamaForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.bfloat16).to(f"cuda:{rank%gpu_num}") 88 | model = model.to(dtype=torch.bfloat16) 89 | model.eval() 90 | 91 | while True: 92 | data = to_tokenize_queue.get() 93 | if data is None: 94 | break 95 | 96 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") 97 | tokenizer.add_special_tokens({"pad_token":""}) 98 | tokenizer.pad_token_id = 0 99 | tokenizer.truncation_side = "left" 100 | tokenizer.padding_side = "left" 101 | 102 | source = generate_prompt(instruction = data["instruction"], input_seg = data["input_seg"], question = data["question"]) 103 | 104 | parts = re.split(r'(\[TAB\] )|(\n\n### Response)', source) 105 | parts = [part for part in parts if part is not None] 106 | 107 | part1 = parts[0] + parts[1] 108 | table_data = parts[2] 109 | part3 = parts[3] + parts[4] 110 | 111 | # Convert a table from text format to list format 112 | if 'col:' in table_data and 'row 1:' in table_data: 113 | headers_part, rows_part = table_data.split(' row 1:', 1) 114 | headers = headers_part.strip('col: ').split(' | ') 115 | headers = [header.strip(" |") if header.strip(" |") else 'None' for header in headers] 116 | rows_part = 'row 1:' + rows_part 117 | 118 | rows = rows_part.split(' [SEP]') 119 | data_rows = [] 120 | for row in rows: 121 | if row: 122 | parts = row.strip().split(' | ')[1:] 123 | cleaned_parts = [part.strip(" |") if part.strip(" |") else 'None' for part in parts] 124 | data_rows.append(cleaned_parts) 125 | 126 | table_array = [headers] + data_rows 127 | elif 'col:' in table_data and 'row 1:' not in table_data: 128 | rows = table_data.split(" [SEP] ") 129 | headers = rows[0].split(" | ") if rows[0].endswith("|") else (rows[0] + " |").split(" | ") 130 | headers = [header.strip(" |") if header.strip(" |") else 'None' for header in headers][1:] 131 | data_rows = [] 132 | for row in rows[1:]: 133 | if row: 134 | parts = row.strip("").split(' | ') 135 | cleaned_parts = [part.strip(" |") if part.strip(" |") else 'None' for part in parts] 136 | data_rows.append(cleaned_parts) 137 | 138 | table_array = [headers] + data_rows 139 | else: 140 | rows = table_data.split(" [SEP] ") 141 | headers = rows[0].split(" | ") if rows[0].endswith("|") else (rows[0] + " |").split(" | ") 142 | headers = [header.strip(" |") if header.strip(" |") else 'None' for header in headers] 143 | data_rows = [] 144 | for row in rows[1:]: 145 | if row: 146 | parts = row.strip("").split(' | ') 147 | cleaned_parts = [part.strip(" |") if part.strip(" |") else 'None' for part in parts] 148 | data_rows.append(cleaned_parts) 149 | 150 | table_array = [headers] + data_rows 151 | 152 | 153 | # Determine whether the table is a rectangle 154 | expected_columns = len(table_array[0]) 155 | flag = True 156 | for row in table_array: 157 | if len(row) != expected_columns: 158 | flag = False 159 | break 160 | 161 | if not flag: 162 | logging.error(f"table_array at index {data['idx']}") 163 | continue 164 | 165 | #add sep 166 | new_table = encode_and_insert_separators(table_array, tokenizer) 167 | 168 | 169 | # Determine whether the table is a rectangle 170 | expected_columns = len(new_table[0]) 171 | flag = True 172 | for row in new_table: 173 | if len(row) != expected_columns: 174 | flag = False 175 | break 176 | 177 | if not flag: 178 | problematic_indices.append(idx) 179 | logging.error(f"new_table at index {idx}") 180 | continue 181 | 182 | input_ids = [tokenizer.bos_token_id] + tokenizer.encode(text=part1, add_special_tokens=False) 183 | l_part1 = len(input_ids) 184 | tx = list(range(l_part1)) 185 | ty = list(range(l_part1)) 186 | 187 | px = list(range(l_part1)) 188 | py = list(range(l_part1)) 189 | 190 | height = len(new_table) 191 | for i, row in enumerate(new_table): 192 | width = len(row) 193 | row_x = l_part1 - 1 + (width + 1) * (i + 1) 194 | for j, item in enumerate(row): 195 | row_y = l_part1 - 1 + (height + 1) * (j + 1) 196 | item_en = item 197 | px.extend([row_x] * len(item_en)) 198 | py.extend([row_y] * len(item_en)) 199 | input_ids.extend(item_en) 200 | 201 | for i, row in enumerate(new_table): 202 | for j, item in enumerate(row): 203 | tx_count = len(tx) 204 | tx.extend(list(range(tx_count, tx_count + len(item)))) 205 | transpose_new_table = np.transpose(new_table).tolist() 206 | ty_list, count = [], len(ty) 207 | for i, row in enumerate(transpose_new_table): 208 | ty_list.append([]) 209 | for j, item in enumerate(row): 210 | ty_list[-1].append(list(range(count, count + len(item)))) 211 | count += len(item) 212 | transpose_ty_list = np.transpose(ty_list).tolist() 213 | for i, row in enumerate(transpose_ty_list): 214 | for j, item in enumerate(row): 215 | ty.extend(item) 216 | 217 | k_part3_start = l_part1 - 1 + (width + 1) * (height + 1) 218 | 219 | part3_en = tokenizer.encode(text=part3, add_special_tokens=False) 220 | input_ids.extend(part3_en) 221 | tx_count = len(tx) 222 | ty_count = len(ty) 223 | assert tx_count == ty_count 224 | tx.extend(list(range(tx_count, tx_count + len(part3_en)))) 225 | ty.extend(list(range(ty_count, ty_count + len(part3_en)))) 226 | 227 | k_part3_end = k_part3_start + len(part3_en) 228 | px.extend(list(range(k_part3_start, k_part3_end))) 229 | py.extend(list(range(k_part3_start, k_part3_end))) 230 | 231 | if len(input_ids) > tokenizer.model_max_length-1: 232 | continue 233 | input_ids = input_ids[-tokenizer.model_max_length+1:] 234 | px = px[-tokenizer.model_max_length+1:] 235 | py = py[-tokenizer.model_max_length+1:] 236 | tx = tx[-tokenizer.model_max_length+1:] 237 | ty = ty[-tokenizer.model_max_length+1:] 238 | 239 | pi = np.concatenate([px, py]) 240 | ti = np.concatenate([tx, ty]) 241 | 242 | input_ids = torch.tensor(input_ids).reshape(1, -1) 243 | token_ids = torch.tensor(ti).reshape(1, -1) 244 | position_ids = torch.tensor(pi).reshape(1, -1) 245 | 246 | with torch.no_grad(): 247 | outputs = model( 248 | input_ids=input_ids.to(f"cuda:{rank%gpu_num}"), 249 | token_ids=token_ids.to(f"cuda:{rank%gpu_num}"), 250 | position_ids=position_ids.to(f"cuda:{rank%gpu_num}"), 251 | use_cache=True, 252 | ) 253 | pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 254 | past_key_values = outputs.past_key_values 255 | npx = [px[-1] + 1] 256 | npy = [py[-1] + 1] 257 | 258 | ntx = [tx[-1] + 1] 259 | nty = [ty[-1] + 1] 260 | 261 | pi = np.concatenate([npx, npy]) 262 | position_ids = torch.tensor(pi).reshape(1, -1) 263 | ti = np.concatenate([ntx, nty]) 264 | token_ids = torch.tensor(ti).reshape(1, -1) 265 | generated_ids = [pred_token_idx.item()] 266 | 267 | for _ in range(max_new_tokens - 1): 268 | outputs = model( 269 | input_ids=pred_token_idx, 270 | past_key_values=past_key_values, 271 | token_ids=token_ids.to(f"cuda:{rank%gpu_num}"), 272 | position_ids=position_ids.to(f"cuda:{rank%gpu_num}"), 273 | use_cache=True, 274 | ) 275 | past_key_values = outputs.past_key_values 276 | pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 277 | npx = [npx[-1] + 1] 278 | npy = [npy[-1] + 1] 279 | ntx = [ntx[-1] + 1] 280 | nty = [nty[-1] + 1] 281 | 282 | pi = np.concatenate([npx, npy]) 283 | position_ids = torch.tensor(pi).reshape(1, -1) 284 | ti = np.concatenate([ntx, nty]) 285 | token_ids = torch.tensor(ti).reshape(1, -1) 286 | generated_ids.append(pred_token_idx.item()) 287 | 288 | if pred_token_idx == tokenizer.eos_token_id: 289 | break 290 | 291 | generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) 292 | 293 | result = { 'idx': data['idx'], 294 | 'instruction': data['instruction'], 295 | 'input_seg': data['input_seg'], 296 | 'question': data['question'], 297 | 'output': data['output'], 298 | 'predict': generated_text} 299 | 300 | to_output_queue.put(result) 301 | 302 | to_output_queue.put(None) 303 | 304 | def output_data(to_output_queue): 305 | count = 0 306 | start_time = None 307 | finish_tag = 0 308 | 309 | while True: 310 | data = to_output_queue.get() 311 | if start_time is None: 312 | start_time = time.time() 313 | if data is None: 314 | finish_tag += 1 315 | if finish_tag == num_workers: 316 | print("End") 317 | break 318 | else: 319 | continue 320 | else: 321 | with open('./res/rel_extraction_2d_res.json', 'a') as f: 322 | try: 323 | json.dump(data, f) 324 | f.write('\n') 325 | except: 326 | continue 327 | 328 | count += 1 329 | if count % 100 == 0: 330 | end_time = time.time() 331 | print(count) 332 | print(f"Spend:{(end_time-start_time)} s") 333 | 334 | 335 | if __name__ == "__main__": 336 | import sys 337 | 338 | to_tokenize_queue = multiprocessing.Queue(maxsize=100000) 339 | to_output_queue = multiprocessing.Queue(maxsize=100000) 340 | 341 | # start 342 | reader_process = multiprocessing.Process(target=read_data, args=("/eval_data/rel_extraction_test.json", to_tokenize_queue)) 343 | tokenizer_processes = [multiprocessing.Process(target=tokenize_data, args=(to_tokenize_queue, to_output_queue, rank)) for rank in range(num_workers)] 344 | output_process = multiprocessing.Process(target=output_data, args=(to_output_queue,)) 345 | 346 | reader_process.start() 347 | for p in tokenizer_processes: 348 | p.start() 349 | output_process.start() 350 | 351 | start_time = time.time() 352 | reader_process.join() 353 | for p in tokenizer_processes: 354 | p.join() 355 | output_process.join() 356 | end_time = time.time() 357 | print(end_time-start_time) -------------------------------------------------------------------------------- /TPE-Llama/convert_llama_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import gc 16 | import json 17 | import os 18 | import shutil 19 | import warnings 20 | 21 | import torch 22 | 23 | from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer 24 | 25 | 26 | try: 27 | from transformers import LlamaTokenizerFast 28 | except ImportError as e: 29 | warnings.warn(e) 30 | warnings.warn( 31 | "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" 32 | ) 33 | LlamaTokenizerFast = None 34 | 35 | """ 36 | Sample usage: 37 | 38 | ``` 39 | python src/transformers/models/llama/convert_llama_weights_to_hf.py \ 40 | --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path 41 | ``` 42 | 43 | Thereafter, models can be loaded via: 44 | 45 | ```py 46 | from transformers import LlamaForCausalLM, LlamaTokenizer 47 | 48 | model = LlamaForCausalLM.from_pretrained("/output/path") 49 | tokenizer = LlamaTokenizer.from_pretrained("/output/path") 50 | ``` 51 | 52 | Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions 53 | come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). 54 | """ 55 | 56 | NUM_SHARDS = { 57 | "7B": 1, 58 | "7Bf": 1, 59 | "13B": 2, 60 | "13Bf": 2, 61 | "34B": 4, 62 | "30B": 4, 63 | "65B": 8, 64 | "70B": 8, 65 | "70Bf": 8, 66 | } 67 | 68 | 69 | def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): 70 | return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) 71 | 72 | 73 | def read_json(path): 74 | with open(path, "r") as f: 75 | return json.load(f) 76 | 77 | 78 | def write_json(text, path): 79 | with open(path, "w") as f: 80 | json.dump(text, f) 81 | 82 | 83 | def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True): 84 | # for backward compatibility, before you needed the repo to be called `my_repo/model_size` 85 | if not os.path.isfile(os.path.join(input_base_path, "params.json")): 86 | input_base_path = os.path.join(input_base_path, model_size) 87 | 88 | os.makedirs(model_path, exist_ok=True) 89 | tmp_model_path = os.path.join(model_path, "tmp") 90 | os.makedirs(tmp_model_path, exist_ok=True) 91 | 92 | params = read_json(os.path.join(input_base_path, "params.json")) 93 | num_shards = NUM_SHARDS[model_size] 94 | n_layers = params["n_layers"] 95 | n_heads = params["n_heads"] 96 | n_heads_per_shard = n_heads // num_shards 97 | dim = params["dim"] 98 | dims_per_head = dim // n_heads 99 | base = params.get("rope_theta", 10000.0) 100 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 101 | if base > 10000.0: 102 | max_position_embeddings = 16384 103 | else: 104 | max_position_embeddings = 2048 105 | 106 | tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast 107 | if tokenizer_path is not None: 108 | tokenizer = tokenizer_class(tokenizer_path) 109 | tokenizer.save_pretrained(model_path) 110 | vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 111 | 112 | if "n_kv_heads" in params: 113 | num_key_value_heads = params["n_kv_heads"] # for GQA / MQA 114 | num_local_key_value_heads = n_heads_per_shard // num_key_value_heads 115 | key_value_dim = dim // num_key_value_heads 116 | else: # compatibility with other checkpoints 117 | num_key_value_heads = n_heads 118 | num_local_key_value_heads = n_heads_per_shard 119 | key_value_dim = dim 120 | 121 | # permute for sliced rotary 122 | def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): 123 | return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) 124 | 125 | print(f"Fetching all parameters from the checkpoint at {input_base_path}.") 126 | # Load weights 127 | if model_size == "7B": 128 | # Not sharded 129 | # (The sharded implementation would also work, but this is simpler.) 130 | loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") 131 | else: 132 | # Sharded 133 | loaded = [ 134 | torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") 135 | for i in range(num_shards) 136 | ] 137 | param_count = 0 138 | index_dict = {"weight_map": {}} 139 | for layer_i in range(n_layers): 140 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" 141 | if model_size == "7B": 142 | # Unsharded 143 | state_dict = { 144 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( 145 | loaded[f"layers.{layer_i}.attention.wq.weight"] 146 | ), 147 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( 148 | loaded[f"layers.{layer_i}.attention.wk.weight"] 149 | ), 150 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], 151 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], 152 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], 153 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], 154 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], 155 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], 156 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], 157 | } 158 | else: 159 | # Sharded 160 | # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share 161 | # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is 162 | # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. 163 | 164 | state_dict = { 165 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ 166 | f"layers.{layer_i}.attention_norm.weight" 167 | ].clone(), 168 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ 169 | f"layers.{layer_i}.ffn_norm.weight" 170 | ].clone(), 171 | } 172 | state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( 173 | torch.cat( 174 | [ 175 | loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) 176 | for i in range(num_shards) 177 | ], 178 | dim=0, 179 | ).reshape(dim, dim) 180 | ) 181 | state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( 182 | torch.cat( 183 | [ 184 | loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( 185 | num_local_key_value_heads, dims_per_head, dim 186 | ) 187 | for i in range(num_shards) 188 | ], 189 | dim=0, 190 | ).reshape(key_value_dim, dim), 191 | num_key_value_heads, 192 | key_value_dim, 193 | dim, 194 | ) 195 | state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( 196 | [ 197 | loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( 198 | num_local_key_value_heads, dims_per_head, dim 199 | ) 200 | for i in range(num_shards) 201 | ], 202 | dim=0, 203 | ).reshape(key_value_dim, dim) 204 | 205 | state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( 206 | [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 207 | ) 208 | state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( 209 | [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 210 | ) 211 | state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( 212 | [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 213 | ) 214 | state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( 215 | [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 216 | ) 217 | 218 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 219 | for k, v in state_dict.items(): 220 | index_dict["weight_map"][k] = filename 221 | param_count += v.numel() 222 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 223 | 224 | filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" 225 | if model_size == "7B": 226 | # Unsharded 227 | state_dict = { 228 | "model.embed_tokens.weight": loaded["tok_embeddings.weight"], 229 | "model.norm.weight": loaded["norm.weight"], 230 | "lm_head.weight": loaded["output.weight"], 231 | } 232 | else: 233 | state_dict = { 234 | "model.norm.weight": loaded[0]["norm.weight"], 235 | "model.embed_tokens.weight": torch.cat( 236 | [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 237 | ), 238 | "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), 239 | } 240 | 241 | for k, v in state_dict.items(): 242 | index_dict["weight_map"][k] = filename 243 | param_count += v.numel() 244 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 245 | 246 | # Write configs 247 | index_dict["metadata"] = {"total_size": param_count * 2} 248 | write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) 249 | ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 250 | multiple_of = params["multiple_of"] if "multiple_of" in params else 256 251 | config = LlamaConfig( 252 | hidden_size=dim, 253 | intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), 254 | num_attention_heads=params["n_heads"], 255 | num_hidden_layers=params["n_layers"], 256 | rms_norm_eps=params["norm_eps"], 257 | num_key_value_heads=num_key_value_heads, 258 | vocab_size=vocab_size, 259 | rope_theta=base, 260 | max_position_embeddings=max_position_embeddings, 261 | ) 262 | config.save_pretrained(tmp_model_path) 263 | 264 | # Make space so we can load the model properly now. 265 | del state_dict 266 | del loaded 267 | gc.collect() 268 | 269 | print("Loading the checkpoint in a Llama model.") 270 | model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) 271 | # Avoid saving this as part of the config. 272 | del model.config._name_or_path 273 | model.config.torch_dtype = torch.float16 274 | print("Saving in the Transformers format.") 275 | model.save_pretrained(model_path, safe_serialization=safe_serialization) 276 | shutil.rmtree(tmp_model_path) 277 | 278 | 279 | def write_tokenizer(tokenizer_path, input_tokenizer_path): 280 | # Initialize the tokenizer based on the `spm` model 281 | tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast 282 | print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") 283 | tokenizer = tokenizer_class(input_tokenizer_path) 284 | tokenizer.save_pretrained(tokenizer_path) 285 | 286 | 287 | def main(): 288 | parser = argparse.ArgumentParser() 289 | parser.add_argument( 290 | "--input_dir", 291 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 292 | ) 293 | parser.add_argument( 294 | "--model_size", 295 | choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"], 296 | help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", 297 | ) 298 | parser.add_argument( 299 | "--output_dir", 300 | help="Location to write HF model and tokenizer", 301 | ) 302 | parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") 303 | args = parser.parse_args() 304 | spm_path = os.path.join(args.input_dir, "tokenizer.model") 305 | if args.model_size != "tokenizer_only": 306 | write_model( 307 | model_path=args.output_dir, 308 | input_base_path=args.input_dir, 309 | model_size=args.model_size, 310 | safe_serialization=args.safe_serialization, 311 | tokenizer_path=spm_path, 312 | ) 313 | else: 314 | write_tokenizer(args.output_dir, spm_path) 315 | 316 | 317 | if __name__ == "__main__": 318 | main() 319 | -------------------------------------------------------------------------------- /TPE-Llama/tokenization_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | """Tokenization classes for LLaMA.""" 22 | import os 23 | from shutil import copyfile 24 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple 25 | 26 | import sentencepiece as spm 27 | 28 | from transformers.convert_slow_tokenizer import import_protobuf 29 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer 30 | from transformers.utils import logging 31 | 32 | 33 | if TYPE_CHECKING: 34 | from transformers.tokenization_utils_base import TextInput 35 | 36 | logger = logging.get_logger(__name__) 37 | 38 | VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} 39 | 40 | PRETRAINED_VOCAB_FILES_MAP = { 41 | "vocab_file": { 42 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", 43 | }, 44 | "tokenizer_file": { 45 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", 46 | }, 47 | } 48 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 49 | "hf-internal-testing/llama-tokenizer": 2048, 50 | } 51 | SPIECE_UNDERLINE = "▁" 52 | 53 | B_INST, E_INST = "[INST]", "[/INST]" 54 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 55 | 56 | # fmt: off 57 | DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ 58 | answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ 59 | that your responses are socially unbiased and positive in nature. 60 | 61 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ 62 | correct. If you don't know the answer to a question, please don't share false information.""" 63 | # fmt: on 64 | 65 | 66 | class LlamaTokenizer(PreTrainedTokenizer): 67 | """ 68 | Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is 69 | no padding token in the original model. 70 | 71 | Args: 72 | vocab_file (`str`): 73 | Path to the vocabulary file. 74 | legacy (`bool`, *optional*): 75 | Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 76 | and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple 77 | example: 78 | 79 | - `legacy=True`: 80 | ```python 81 | >>> from transformers import T5Tokenizer 82 | 83 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True) 84 | >>> tokenizer.encode("Hello .") 85 | [8774, 32099, 3, 5, 1] 86 | ``` 87 | - `legacy=False`: 88 | ```python 89 | >>> from transformers import T5Tokenizer 90 | 91 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False) 92 | >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here 93 | [8774, 32099, 5, 1] 94 | ``` 95 | Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. 96 | 97 | """ 98 | 99 | vocab_files_names = VOCAB_FILES_NAMES 100 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 101 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 102 | model_input_names = ["input_ids", "attention_mask"] 103 | 104 | def __init__( 105 | self, 106 | vocab_file, 107 | unk_token="", 108 | bos_token="", 109 | eos_token="", 110 | pad_token=None, 111 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 112 | add_bos_token=True, 113 | add_eos_token=False, 114 | clean_up_tokenization_spaces=False, 115 | use_default_system_prompt=True, 116 | spaces_between_special_tokens=False, 117 | legacy=None, 118 | **kwargs, 119 | ): 120 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 121 | bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token 122 | eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token 123 | unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token 124 | pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token 125 | 126 | if legacy is None: 127 | logger.warning_once( 128 | f"You are using the default legacy behaviour of the {self.__class__}. This is" 129 | " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." 130 | " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" 131 | " means, and thouroughly read the reason why this was added as explained in" 132 | " https://github.com/huggingface/transformers/pull/24565" 133 | ) 134 | legacy = True 135 | 136 | self.legacy = legacy 137 | self.vocab_file = vocab_file 138 | self.add_bos_token = add_bos_token 139 | self.add_eos_token = add_eos_token 140 | self.use_default_system_prompt = use_default_system_prompt 141 | self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) 142 | 143 | super().__init__( 144 | bos_token=bos_token, 145 | eos_token=eos_token, 146 | unk_token=unk_token, 147 | pad_token=pad_token, 148 | add_bos_token=add_bos_token, 149 | add_eos_token=add_eos_token, 150 | sp_model_kwargs=self.sp_model_kwargs, 151 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 152 | use_default_system_prompt=use_default_system_prompt, 153 | spaces_between_special_tokens=spaces_between_special_tokens, 154 | legacy=legacy, 155 | **kwargs, 156 | ) 157 | 158 | @property 159 | def unk_token_length(self): 160 | return len(self.sp_model.encode(str(self.unk_token))) 161 | 162 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor 163 | def get_spm_processor(self, from_slow=False): 164 | tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) 165 | if self.legacy or from_slow: # no dependency on protobuf 166 | tokenizer.Load(self.vocab_file) 167 | return tokenizer 168 | 169 | with open(self.vocab_file, "rb") as f: 170 | sp_model = f.read() 171 | model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") 172 | model = model_pb2.ModelProto.FromString(sp_model) 173 | normalizer_spec = model_pb2.NormalizerSpec() 174 | normalizer_spec.add_dummy_prefix = False 175 | model.normalizer_spec.MergeFrom(normalizer_spec) 176 | sp_model = model.SerializeToString() 177 | tokenizer.LoadFromSerializedProto(sp_model) 178 | return tokenizer 179 | 180 | def __getstate__(self): 181 | state = self.__dict__.copy() 182 | state["sp_model"] = None 183 | state["sp_model_proto"] = self.sp_model.serialized_model_proto() 184 | return state 185 | 186 | def __setstate__(self, d): 187 | self.__dict__ = d 188 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) 189 | self.sp_model.LoadFromSerializedProto(self.sp_model_proto) 190 | 191 | @property 192 | def vocab_size(self): 193 | """Returns vocab size""" 194 | return self.sp_model.get_piece_size() 195 | 196 | def get_vocab(self): 197 | """Returns vocab as a dict""" 198 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} 199 | vocab.update(self.added_tokens_encoder) 200 | return vocab 201 | 202 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize 203 | def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]: 204 | """ 205 | Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the 206 | first token is special. 207 | """ 208 | if self.legacy or len(text) == 0: 209 | return super().tokenize(text, **kwargs) 210 | 211 | tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) 212 | 213 | if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: 214 | tokens = tokens[1:] 215 | return tokens 216 | 217 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize 218 | def _tokenize(self, text, **kwargs): 219 | """ 220 | Returns a tokenized string. 221 | 222 | We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any 223 | SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give 224 | `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the 225 | `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. 226 | `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. 227 | """ 228 | tokens = self.sp_model.encode(text, out_type=str) 229 | if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): 230 | return tokens 231 | 232 | # 1. Encode string + prefix ex: " Hey" 233 | tokens = self.sp_model.encode(self.unk_token + text, out_type=str) 234 | # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] 235 | return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens 236 | 237 | def _convert_token_to_id(self, token): 238 | """Converts a token (str) in an id using the vocab.""" 239 | return self.sp_model.piece_to_id(token) 240 | 241 | def _convert_id_to_token(self, index): 242 | """Converts an index (integer) in a token (str) using the vocab.""" 243 | token = self.sp_model.IdToPiece(index) 244 | return token 245 | 246 | def convert_tokens_to_string(self, tokens): 247 | """Converts a sequence of tokens (string) in a single string.""" 248 | # since we manually add the prefix space, we have to remove it when decoding 249 | if tokens[0].startswith(SPIECE_UNDERLINE): 250 | tokens[0] = tokens[0][1:] 251 | 252 | current_sub_tokens = [] 253 | out_string = "" 254 | prev_is_special = False 255 | for i, token in enumerate(tokens): 256 | # make sure that special tokens are not decoded using sentencepiece model 257 | if token in self.all_special_tokens: 258 | if not prev_is_special and i != 0 and self.legacy: 259 | out_string += " " 260 | out_string += self.sp_model.decode(current_sub_tokens) + token 261 | prev_is_special = True 262 | current_sub_tokens = [] 263 | else: 264 | current_sub_tokens.append(token) 265 | prev_is_special = False 266 | out_string += self.sp_model.decode(current_sub_tokens) 267 | return out_string 268 | 269 | def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: 270 | """ 271 | Save the vocabulary and special tokens file to a directory. 272 | 273 | Args: 274 | save_directory (`str`): 275 | The directory in which to save the vocabulary. 276 | 277 | Returns: 278 | `Tuple(str)`: Paths to the files saved. 279 | """ 280 | if not os.path.isdir(save_directory): 281 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 282 | return 283 | out_vocab_file = os.path.join( 284 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 285 | ) 286 | 287 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): 288 | copyfile(self.vocab_file, out_vocab_file) 289 | elif not os.path.isfile(self.vocab_file): 290 | with open(out_vocab_file, "wb") as fi: 291 | content_spiece_model = self.sp_model.serialized_model_proto() 292 | fi.write(content_spiece_model) 293 | 294 | return (out_vocab_file,) 295 | 296 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 297 | bos_token_id = [self.bos_token_id] if self.add_bos_token else [] 298 | eos_token_id = [self.eos_token_id] if self.add_eos_token else [] 299 | 300 | output = bos_token_id + token_ids_0 + eos_token_id 301 | 302 | if token_ids_1 is not None: 303 | output = output + bos_token_id + token_ids_1 + eos_token_id 304 | 305 | return output 306 | 307 | def get_special_tokens_mask( 308 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 309 | ) -> List[int]: 310 | """ 311 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 312 | special tokens using the tokenizer `prepare_for_model` method. 313 | 314 | Args: 315 | token_ids_0 (`List[int]`): 316 | List of IDs. 317 | token_ids_1 (`List[int]`, *optional*): 318 | Optional second list of IDs for sequence pairs. 319 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 320 | Whether or not the token list is already formatted with special tokens for the model. 321 | 322 | Returns: 323 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 324 | """ 325 | if already_has_special_tokens: 326 | return super().get_special_tokens_mask( 327 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 328 | ) 329 | 330 | bos_token_id = [1] if self.add_bos_token else [] 331 | eos_token_id = [1] if self.add_eos_token else [] 332 | 333 | if token_ids_1 is None: 334 | return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id 335 | return ( 336 | bos_token_id 337 | + ([0] * len(token_ids_0)) 338 | + eos_token_id 339 | + bos_token_id 340 | + ([0] * len(token_ids_1)) 341 | + eos_token_id 342 | ) 343 | 344 | def create_token_type_ids_from_sequences( 345 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 346 | ) -> List[int]: 347 | """ 348 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT 349 | sequence pair mask has the following format: 350 | 351 | ``` 352 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 353 | | first sequence | second sequence | 354 | ``` 355 | 356 | if token_ids_1 is None, only returns the first portion of the mask (0s). 357 | 358 | Args: 359 | token_ids_0 (`List[int]`): 360 | List of ids. 361 | token_ids_1 (`List[int]`, *optional*): 362 | Optional second list of IDs for sequence pairs. 363 | 364 | Returns: 365 | `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). 366 | """ 367 | bos_token_id = [self.bos_token_id] if self.add_bos_token else [] 368 | eos_token_id = [self.eos_token_id] if self.add_eos_token else [] 369 | 370 | output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) 371 | 372 | if token_ids_1 is not None: 373 | output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) 374 | 375 | return output 376 | 377 | @property 378 | def default_chat_template(self): 379 | """ 380 | LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. 381 | Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict 382 | user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering 383 | rather than needing special tokens. The system message is partly 'embedded' in the first user message, which 384 | results in an unusual token ordering when it is present. This template should definitely be changed if you wish 385 | to fine-tune a model with more flexible role ordering! 386 | 387 | The output should look something like: 388 | 389 | [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer 390 | [INST] Prompt [/INST] 391 | """ 392 | 393 | template = ( 394 | "{% if messages[0]['role'] == 'system' %}" 395 | "{% set loop_messages = messages[1:] %}" # Extract system message if it's present 396 | "{% set system_message = messages[0]['content'] %}" 397 | "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" 398 | "{% set loop_messages = messages %}" # Or use the default system message if the flag is set 399 | "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" 400 | "{% else %}" 401 | "{% set loop_messages = messages %}" 402 | "{% set system_message = false %}" 403 | "{% endif %}" 404 | "{% for message in loop_messages %}" # Loop over all non-system messages 405 | "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" 406 | "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" 407 | "{% endif %}" 408 | "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message 409 | "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" 410 | "{% else %}" 411 | "{% set content = message['content'] %}" 412 | "{% endif %}" 413 | "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way 414 | "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" 415 | "{% elif message['role'] == 'system' %}" 416 | "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" 417 | "{% elif message['role'] == 'assistant' %}" 418 | "{{ ' ' + content.strip() + ' ' + eos_token }}" 419 | "{% endif %}" 420 | "{% endfor %}" 421 | ) 422 | template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") 423 | default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") 424 | template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) 425 | 426 | return template 427 | -------------------------------------------------------------------------------- /src/sft_minicpm.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | import io 3 | import os 4 | import copy 5 | import re 6 | import json 7 | import math 8 | import logging 9 | import numpy as np 10 | from dataclasses import dataclass, field 11 | from typing import Dict, Optional, Sequence 12 | from multiprocessing import cpu_count 13 | from datasets import load_dataset 14 | from tqdm import tqdm 15 | import psutil 16 | 17 | import torch 18 | import transformers 19 | from torch.utils.data import Dataset, IterableDataset 20 | from datasets.iterable_dataset import IterableDataset 21 | from transformers import Trainer, DataCollatorForLanguageModeling 22 | from peft import LoraConfig, get_peft_model 23 | from torch.distributed import barrier 24 | import sys 25 | import os 26 | import random 27 | from tqdm import tqdm 28 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) 29 | 30 | if project_root not in sys.path: 31 | sys.path.append(project_root) 32 | 33 | from TPE-Llama.modeling_llama import LlamaForCausalLM 34 | 35 | 36 | IGNORE_INDEX = -100 37 | DEFAULT_PAD_TOKEN = "" 38 | 39 | 40 | def _make_r_io_base(f, mode: str): 41 | if not isinstance(f, io.IOBase): 42 | f = open(f, mode=mode) 43 | return f 44 | 45 | def jload(f, mode="r"): 46 | """Load a .json file into a dictionary.""" 47 | f = _make_r_io_base(f, mode) 48 | jdict = json.load(f) 49 | f.close() 50 | return jdict 51 | 52 | def findAllFile(base): 53 | for root, ds, fs in os.walk(base): 54 | for f in fs: 55 | if f.endswith('.json'): 56 | fullname = os.path.join(root,f) 57 | yield fullname 58 | 59 | 60 | PROMPT_DICT = { 61 | "prompt_input": ( 62 | "Below is an instruction that describes a task, paired with an input that provides further context. " 63 | "Write a response that appropriately completes the request.\n\n" 64 | "### Instruction:\n{instruction}\n\n### Question:\n{question}\n\n### Input:\n{input_seg}\n\n### Response:" 65 | ), 66 | "prompt_no_input": ( 67 | "Below is an instruction that describes a task. " 68 | "Write a response that appropriately completes the request.\n\n" 69 | "### Instruction:\n{instruction}\n\n### Response:" 70 | ), 71 | } 72 | 73 | @dataclass 74 | class ModelArguments: 75 | model_name_or_path: Optional[str] = field(default="/model/MiniCPM-2B-sft-bf16-llama-format") 76 | 77 | 78 | @dataclass 79 | class DataArguments: 80 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 81 | data_size: int = field(default=None, metadata={"help": "for calculate max steps."}) 82 | gpu_size: int = field(default=None, metadata={"help": "for calculate max steps and for logging for calcuated intervel."}) 83 | 84 | 85 | @dataclass 86 | class TrainingArguments(transformers.TrainingArguments): 87 | cache_dir: Optional[str] = field(default=None) 88 | optim: str = field(default="adamw_torch") 89 | model_max_length: int = field( 90 | default=8192 * 4, 91 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 92 | ) 93 | use_flash_attn: bool = field( 94 | default=True, 95 | metadata={"help": "Whether use flash attention for training."}, 96 | ) 97 | low_rank_training: bool = field( 98 | default=True, 99 | metadata={"help": "Whether use low rank adaptation for training."}, 100 | ) 101 | trainable_params: str = field( 102 | default="embed,norm", 103 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, 104 | ) 105 | 106 | def smart_tokenizer_and_embedding_resize( 107 | special_tokens_dict: Dict, 108 | tokenizer: transformers.PreTrainedTokenizer, 109 | model: transformers.PreTrainedModel, 110 | ): 111 | """Resize tokenizer and embedding. 112 | 113 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 114 | """ 115 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 116 | model.resize_token_embeddings(len(tokenizer)) 117 | 118 | if num_new_tokens > 0: 119 | input_embeddings = model.get_input_embeddings().weight.data 120 | output_embeddings = model.get_output_embeddings().weight.data 121 | 122 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 123 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 124 | 125 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 126 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 127 | 128 | 129 | def encode_and_insert_separators(table_array, tokenizer): 130 | separator_col = [1425] # '▁|' 131 | separator_row = [48017] # '-' 132 | 133 | separator_row_end = [3] # '' 134 | separator_col_end = [4] # '' 135 | 136 | new_table = [] 137 | 138 | for k, row in enumerate(table_array): 139 | new_row, new_separator = [], [] 140 | for col in row: 141 | encoded_col = tokenizer.encode(str(col), add_special_tokens=False) 142 | new_row.append(encoded_col) 143 | new_row.append(separator_col) # Insert '|' between each coded column 144 | 145 | new_separator.append(separator_col_end if k == len(table_array) - 1 else separator_row) 146 | new_separator.append(separator_col) 147 | new_row.append(separator_row_end) 148 | new_separator.append(separator_row_end) 149 | new_table.append(new_row) 150 | new_table.append(new_separator) 151 | return new_table 152 | 153 | tok_example_count = 0 154 | 155 | class SupervisedDataset(Dataset): 156 | """Dataset for supervised fine-tuning.""" 157 | 158 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): 159 | super(SupervisedDataset, self).__init__() 160 | logging.warning("Loading data...") 161 | list_data_dict = jload(data_path) 162 | print(len(list_data_dict)) 163 | global tok_example_count 164 | 165 | input_ids_all = [] 166 | labels_all = [] 167 | token_ids_all = [] 168 | position_ids_all = [] 169 | problematic_indices = [] 170 | substart_all = [] 171 | subend_all = [] 172 | 173 | for idx, example in enumerate(list_data_dict): 174 | try: 175 | tok_example_count += 1 176 | if tok_example_count % 128 == 0: 177 | logging.warning(f"tok_example_count: {tok_example_count}") 178 | 179 | # logging.warning("Formatting inputs...") 180 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 181 | source = prompt_input.format_map(example) if example.get("input_seg", "") != "" else prompt_no_input.format_map(example) 182 | target = f"{example['output']}" 183 | 184 | parts = re.split(r'(\[TAB\] )|(\n\n### Response)', source) 185 | parts = [part for part in parts if part is not None] 186 | 187 | part1 = parts[0] + parts[1] 188 | table_data = parts[2] 189 | part3 = parts[3] + parts[4] 190 | 191 | # Convert a table from text format to list format 192 | if 'col:' in table_data and 'row 1:' in table_data: 193 | headers_part, rows_part = table_data.split(' row 1:', 1) 194 | headers = headers_part.strip('col: ').split(' | ') 195 | headers = [header.strip(" |") if header.strip(" |") else 'None' for header in headers] 196 | rows_part = 'row 1:' + rows_part 197 | 198 | rows = rows_part.split(' [SEP]') 199 | data_rows = [] 200 | for row in rows: 201 | if row: 202 | parts = row.strip().split(' | ')[1:] 203 | cleaned_parts = [part.strip(" |") if part.strip(" |") else 'None' for part in parts] 204 | data_rows.append(cleaned_parts) 205 | 206 | table_array = [headers] + data_rows 207 | elif 'col:' in table_data and 'row 1:' not in table_data: 208 | rows = table_data.split(" [SEP] ") 209 | headers = rows[0].split(" | ") if rows[0].endswith("|") else (rows[0] + " |").split(" | ") 210 | headers = [header.strip(" |") if header.strip(" |") else 'None' for header in headers][1:] 211 | data_rows = [] 212 | for row in rows[1:]: 213 | if row: 214 | parts = row.strip("").split(' | ') 215 | cleaned_parts = [part.strip(" |") if part.strip(" |") else 'None' for part in parts] 216 | data_rows.append(cleaned_parts) 217 | 218 | table_array = [headers] + data_rows 219 | else: 220 | rows = table_data.split(" [SEP] ") 221 | headers = rows[0].split(" | ") if rows[0].endswith("|") else (rows[0] + " |").split(" | ") 222 | headers = [header.strip(" |") if header.strip(" |") else 'None' for header in headers] 223 | data_rows = [] 224 | for row in rows[1:]: 225 | if row: 226 | parts = row.strip("").split(' | ') 227 | cleaned_parts = [part.strip(" |") if part.strip(" |") else 'None' for part in parts] 228 | data_rows.append(cleaned_parts) 229 | 230 | table_array = [headers] + data_rows 231 | 232 | # Determine whether the table is a rectangle 233 | expected_columns = len(table_array[0]) 234 | flag = True 235 | for row in table_array: 236 | if len(row) != expected_columns: 237 | flag = False 238 | break 239 | 240 | if not flag: 241 | problematic_indices.append(idx) 242 | logging.error(f"table_array at index {idx}") 243 | continue 244 | 245 | # add sep 246 | new_table = encode_and_insert_separators(table_array, tokenizer) 247 | 248 | # Determine whether the table is a rectangle 249 | expected_columns = len(new_table[0]) 250 | flag = True 251 | for row in new_table: 252 | if len(row) != expected_columns: 253 | flag = False 254 | break 255 | 256 | if not flag: 257 | problematic_indices.append(idx) 258 | logging.error(f"new_table at index {idx}") 259 | continue 260 | 261 | # Part I Encoded 262 | input_ids = [tokenizer.bos_token_id] + tokenizer.encode(text=part1, add_special_tokens=False) 263 | l_part1 = len(input_ids) 264 | tx = list(range(l_part1)) 265 | ty = list(range(l_part1)) 266 | 267 | px = list(range(l_part1)) 268 | py = list(range(l_part1)) 269 | 270 | substart = input_ids[-4:] 271 | 272 | 273 | # Table Encoded 274 | height = len(new_table) 275 | for i, row in enumerate(new_table): 276 | width = len(row) 277 | row_x = l_part1 - 1 + (width + 1) * (i + 1) 278 | for j, item in enumerate(row): 279 | row_y = l_part1 - 1 + (height + 1) * (j + 1) 280 | item_en = item 281 | px.extend([row_x] * len(item_en)) 282 | py.extend([row_y] * len(item_en)) 283 | input_ids.extend(item_en) 284 | 285 | for i, row in enumerate(new_table): 286 | for j, item in enumerate(row): 287 | tx_count = len(tx) 288 | tx.extend(list(range(tx_count, tx_count + len(item)))) 289 | transpose_new_table = np.transpose(new_table).tolist() 290 | ty_list, count = [], len(ty) 291 | for i, row in enumerate(transpose_new_table): 292 | ty_list.append([]) 293 | for j, item in enumerate(row): 294 | ty_list[-1].append(list(range(count, count + len(item)))) 295 | count += len(item) 296 | transpose_ty_list = np.transpose(ty_list).tolist() 297 | for i, row in enumerate(transpose_ty_list): 298 | for j, item in enumerate(row): 299 | ty.extend(item) 300 | 301 | k_part3_start = l_part1 - 1 + (width + 1) * (height + 1) 302 | 303 | part3_target = part3 + target 304 | part3_target_en = tokenizer.encode(text=part3_target, add_special_tokens=False) + [tokenizer.eos_token_id] 305 | input_ids.extend(part3_target_en) 306 | 307 | tx_count = len(tx) 308 | ty_count = len(ty) 309 | assert tx_count == ty_count 310 | tx.extend(list(range(tx_count, tx_count + len(part3_target_en)))) 311 | ty.extend(list(range(ty_count, ty_count + len(part3_target_en)))) 312 | 313 | 314 | # Part III Encoded 315 | 316 | k_part3_end = k_part3_start + len(part3_target_en) 317 | px.extend(list(range(k_part3_start, k_part3_end))) 318 | py.extend(list(range(k_part3_start, k_part3_end))) 319 | 320 | subend = part3_target_en[:4] 321 | 322 | target_en = tokenizer.encode(text=target, add_special_tokens=False) + [tokenizer.eos_token_id] 323 | target_len = len(target_en) 324 | labels = copy.deepcopy(input_ids) 325 | labels[:-target_len] = [IGNORE_INDEX] * (len(input_ids) - target_len) 326 | 327 | 328 | if len(input_ids) > tokenizer.model_max_length: 329 | problematic_indices.append(idx) 330 | continue 331 | input_ids = input_ids[-tokenizer.model_max_length:] 332 | labels = labels[-tokenizer.model_max_length:] 333 | px = px[-tokenizer.model_max_length:] 334 | py = py[-tokenizer.model_max_length:] 335 | tx = tx[-tokenizer.model_max_length:] 336 | ty = ty[-tokenizer.model_max_length:] 337 | 338 | pi = np.concatenate([px, py]) 339 | ti = np.concatenate([tx, ty]) 340 | 341 | input_ids_all.append(torch.tensor(input_ids)) 342 | labels_all.append(torch.tensor(labels)) 343 | token_ids_all.append(torch.tensor(ti)) 344 | position_ids_all.append(torch.tensor(pi)) 345 | substart_all.append(torch.tensor(substart)) 346 | subend_all.append(torch.tensor(subend)) 347 | 348 | except Exception as e: 349 | problematic_indices.append(idx) 350 | logging.error(f"Error processing example at index {idx}: {str(e)}") 351 | 352 | 353 | 354 | print(len(input_ids_all)) 355 | self.input_ids = input_ids_all 356 | self.labels = labels_all 357 | self.token_ids = token_ids_all 358 | self.position_ids = position_ids_all 359 | self.substart = substart_all 360 | self.subend = subend_all 361 | 362 | def __len__(self): 363 | return len(self.input_ids) 364 | 365 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 366 | return dict(input_ids=self.input_ids[i], labels=self.labels[i], token_ids=self.token_ids[i], position_ids=self.position_ids[i], substart=self.substart[i], subend=self.subend[i]) 367 | 368 | 369 | 370 | @dataclass 371 | class DataCollatorForSupervisedDataset(object): 372 | """Collate examples for supervised fine-tuning.""" 373 | 374 | tokenizer: transformers.PreTrainedTokenizer 375 | 376 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 377 | # logging.warning(f"instances: {instances}") 378 | input_ids, labels, token_ids, position_ids, substart, subend = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "token_ids", "position_ids", "substart", "subend")) 379 | input_ids = torch.nn.utils.rnn.pad_sequence( 380 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 381 | ) 382 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 383 | substart = torch.nn.utils.rnn.pad_sequence(substart, batch_first=True, padding_value=IGNORE_INDEX) 384 | subend = torch.nn.utils.rnn.pad_sequence(subend, batch_first=True, padding_value=IGNORE_INDEX) 385 | # token_ids = torch.nn.utils.rnn.pad_sequence(token_ids, batch_first=True, padding_value=0) 386 | 387 | px_list = [] 388 | py_list = [] 389 | 390 | for pid in position_ids: 391 | s = len(pid) // 2 392 | px = pid[:s] 393 | py = pid[s:] 394 | px_list.append(px) 395 | py_list.append(py) 396 | 397 | px_padded = self.efficient_custom_pad_sequences(px_list) 398 | py_padded = self.efficient_custom_pad_sequences(py_list) 399 | position_ids = torch.cat((px_padded, py_padded), dim=-1) 400 | 401 | 402 | 403 | tx_list = [] 404 | ty_list = [] 405 | 406 | for tid in token_ids: 407 | s = len(tid) // 2 408 | tx = tid[:s] 409 | ty = tid[s:] 410 | tx_list.append(tx) 411 | ty_list.append(ty) 412 | 413 | tx_padded = self.efficient_custom_pad_sequences(tx_list) 414 | ty_padded = self.efficient_custom_pad_sequences(ty_list) 415 | 416 | token_ids = torch.cat((tx_padded, ty_padded), dim=-1) 417 | 418 | 419 | return dict( 420 | input_ids=input_ids, 421 | labels=labels, 422 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 423 | token_ids=token_ids, 424 | position_ids=position_ids, 425 | substart=substart, 426 | subend=subend, 427 | ) 428 | 429 | def efficient_custom_pad_sequences(self, sequence_list): 430 | tensors = [torch.tensor(seq) for seq in sequence_list] 431 | max_len = max(t.size(0) for t in tensors) 432 | 433 | pad_sizes = [max_len - t.size(0) for t in tensors] 434 | 435 | max_pad_size = max(pad_sizes) 436 | increment_ranges = torch.arange(1, max_pad_size + 1).unsqueeze(0) 437 | 438 | padded_tensors = [] 439 | for tensor, pad_size in zip(tensors, pad_sizes): 440 | if pad_size > 0: 441 | padded_tensor = torch.cat([tensor, tensor[-1] + increment_ranges[:, :pad_size].squeeze(0)]) 442 | else: 443 | padded_tensor = tensor 444 | padded_tensors.append(padded_tensor) 445 | 446 | padded_tensor_batch = torch.stack(padded_tensors) 447 | 448 | return padded_tensor_batch 449 | 450 | 451 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 452 | """Make dataset and collator for supervised fine-tuning.""" 453 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) 454 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 455 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 456 | 457 | 458 | def train(): 459 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 460 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 461 | 462 | # Set RoPE scaling factor 463 | config = transformers.AutoConfig.from_pretrained( 464 | model_args.model_name_or_path, 465 | cache_dir=training_args.cache_dir, 466 | ) 467 | 468 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 469 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: 470 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 471 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 472 | 473 | # Load model and tokenizer 474 | tokenizer = transformers.AutoTokenizer.from_pretrained( 475 | model_args.model_name_or_path, 476 | cache_dir=training_args.cache_dir, 477 | model_max_length=training_args.model_max_length, 478 | padding_side="right", 479 | use_fast=False, 480 | ) 481 | 482 | special_tokens_dict = dict() 483 | if tokenizer.pad_token is None: 484 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 485 | 486 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 487 | 488 | training_args.remove_unused_columns = False 489 | config.use_cache = False 490 | 491 | config._flash_attn_2_enabled = True 492 | config.output_loss = True 493 | config.pad_token_id = 0 494 | 495 | config.lamda = 1 496 | 497 | model = LlamaForCausalLM.from_pretrained( 498 | model_args.model_name_or_path, 499 | config=config, 500 | torch_dtype=torch.bfloat16, 501 | cache_dir=training_args.cache_dir, 502 | ) 503 | 504 | smart_tokenizer_and_embedding_resize( 505 | special_tokens_dict=special_tokens_dict, 506 | tokenizer=tokenizer, 507 | model=model, 508 | ) 509 | 510 | if training_args.low_rank_training: 511 | config = LoraConfig( 512 | r=8, 513 | lora_alpha=16, 514 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], 515 | lora_dropout=0.01, 516 | bias="none", 517 | task_type="CAUSAL_LM", 518 | ) 519 | model = get_peft_model(model, config) 520 | # enable trainable params 521 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] 522 | 523 | model.enable_input_require_grads() # required for gradient checkpointing 524 | model.gradient_checkpointing_enable() # enable gradient checkpointing 525 | 526 | logging.warning(f"data_module: {data_module}") 527 | 528 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 529 | trainer.train() 530 | trainer.save_state() 531 | trainer.save_model(output_dir=training_args.output_dir) 532 | 533 | 534 | if __name__ == "__main__": 535 | train() -------------------------------------------------------------------------------- /TPE-Llama/modeling_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | import math 22 | import copy 23 | from typing import List, Optional, Tuple, Union 24 | 25 | import torch 26 | import torch.nn.functional as F 27 | import torch.utils.checkpoint 28 | from torch import nn 29 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 30 | 31 | from transformers.activations import ACT2FN 32 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 33 | from transformers.modeling_utils import PreTrainedModel 34 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 35 | from transformers.utils import ( 36 | add_start_docstrings, 37 | add_start_docstrings_to_model_forward, 38 | is_flash_attn_available, 39 | logging, 40 | replace_return_docstrings, 41 | ) 42 | from .configuration_llama import LlamaConfig 43 | 44 | 45 | if is_flash_attn_available(): 46 | from flash_attn import flash_attn_func, flash_attn_varlen_func 47 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 48 | 49 | 50 | logger = logging.get_logger(__name__) 51 | 52 | _CONFIG_FOR_DOC = "LlamaConfig" 53 | 54 | 55 | def _get_unpad_data(padding_mask): 56 | seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) 57 | indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() 58 | max_seqlen_in_batch = seqlens_in_batch.max().item() 59 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 60 | return ( 61 | indices, 62 | cu_seqlens, 63 | max_seqlen_in_batch, 64 | ) 65 | 66 | 67 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 68 | def _make_causal_mask( 69 | # input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 70 | input_ids, position_ids, substart, subend, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 71 | ): 72 | """ 73 | Make causal mask used for bi-directional self-attention. 74 | """ 75 | bsz, tgt_len = input_ids.shape 76 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 77 | mask_cond = torch.arange(mask.size(-1), device=device) 78 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 79 | mask = mask.to(dtype) 80 | 81 | if past_key_values_length > 0: 82 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 83 | 84 | pos_A = None 85 | pos_B = None 86 | 87 | if substart is not None: 88 | sub_len = subend.size(1) 89 | 90 | windows = input_ids.unfold(dimension=1, size=sub_len, step=1) 91 | 92 | substart = substart[:, None, :] # [bsz, 1, sub_len] 93 | subend = subend[:, None, :] # [bsz, 1, sub_len] 94 | 95 | matches_start = (windows == substart).all(dim=2) # [bsz, seq_len - sub_len + 1] 96 | matches_end = (windows == subend).all(dim=2) # [bsz, seq_len - sub_len + 1] 97 | 98 | pos_A = matches_start.long().argmax(dim=1) # [bsz] 99 | pos_B = matches_end.long().argmax(dim=1) # [bsz] 100 | 101 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length), pos_A, pos_B 102 | 103 | 104 | # Copied from transformers.models.bart.modeling_bart._expand_mask 105 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 106 | """ 107 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 108 | """ 109 | bsz, src_len = mask.size() 110 | tgt_len = tgt_len if tgt_len is not None else src_len 111 | 112 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 113 | 114 | inverted_mask = 1.0 - expanded_mask 115 | 116 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 117 | 118 | 119 | class LlamaRMSNorm(nn.Module): 120 | def __init__(self, hidden_size, eps=1e-6): 121 | """ 122 | LlamaRMSNorm is equivalent to T5LayerNorm 123 | """ 124 | super().__init__() 125 | self.weight = nn.Parameter(torch.ones(hidden_size)) 126 | self.variance_epsilon = eps 127 | 128 | def forward(self, hidden_states): 129 | input_dtype = hidden_states.dtype 130 | hidden_states = hidden_states.to(torch.float32) 131 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 132 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 133 | return self.weight * hidden_states.to(input_dtype) 134 | 135 | 136 | ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) 137 | 138 | 139 | class LlamaRotaryEmbedding(nn.Module): 140 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 141 | super().__init__() 142 | 143 | self.dim = dim 144 | self.max_position_embeddings = max_position_embeddings 145 | self.base = base 146 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 147 | self.register_buffer("inv_freq", inv_freq, persistent=False) 148 | 149 | # Build here to make `torch.jit.trace` work. 150 | self._set_cos_sin_cache( 151 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 152 | ) 153 | 154 | def _set_cos_sin_cache(self, seq_len, device, dtype): 155 | self.max_seq_len_cached = seq_len 156 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 157 | 158 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 159 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 160 | emb = torch.cat((freqs, freqs), dim=-1) 161 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 162 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 163 | 164 | def forward(self, x, seq_len=None): 165 | # x: [bs, num_attention_heads, seq_len, head_size] 166 | if seq_len > self.max_seq_len_cached: 167 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 168 | 169 | return ( 170 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 171 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 172 | ) 173 | 174 | 175 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 176 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 177 | 178 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 179 | self.scaling_factor = scaling_factor 180 | super().__init__(dim, max_position_embeddings, base, device) 181 | 182 | def _set_cos_sin_cache(self, seq_len, device, dtype): 183 | self.max_seq_len_cached = seq_len 184 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 185 | t = t / self.scaling_factor 186 | 187 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 188 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 189 | emb = torch.cat((freqs, freqs), dim=-1) 190 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 191 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 192 | 193 | 194 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 195 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 196 | 197 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 198 | self.scaling_factor = scaling_factor 199 | super().__init__(dim, max_position_embeddings, base, device) 200 | 201 | def _set_cos_sin_cache(self, seq_len, device, dtype): 202 | self.max_seq_len_cached = seq_len 203 | 204 | if seq_len > self.max_position_embeddings: 205 | base = self.base * ( 206 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 207 | ) ** (self.dim / (self.dim - 2)) 208 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 209 | self.register_buffer("inv_freq", inv_freq, persistent=False) 210 | 211 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 212 | 213 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 214 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 215 | emb = torch.cat((freqs, freqs), dim=-1) 216 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 217 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 218 | 219 | 220 | def rotate_half(x): 221 | """Rotates half the hidden dims of the input.""" 222 | x1 = x[..., : x.shape[-1] // 2] 223 | x2 = x[..., x.shape[-1] // 2 :] 224 | return torch.cat((-x2, x1), dim=-1) 225 | 226 | 227 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 228 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 229 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 230 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 231 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 232 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 233 | 234 | q_embed = (q * cos) + (rotate_half(q) * sin) 235 | k_embed = (k * cos) + (rotate_half(k) * sin) 236 | return q_embed, k_embed 237 | 238 | class LlamaMLP(nn.Module): 239 | def __init__(self, config): 240 | super().__init__() 241 | self.config = config 242 | self.hidden_size = config.hidden_size 243 | self.intermediate_size = config.intermediate_size 244 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 245 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 246 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 247 | self.act_fn = ACT2FN[config.hidden_act] 248 | 249 | def forward(self, x): 250 | if self.config.pretraining_tp > 1: 251 | slice = self.intermediate_size // self.config.pretraining_tp 252 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 253 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 254 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 255 | 256 | gate_proj = torch.cat( 257 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 258 | ) 259 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) 260 | 261 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 262 | down_proj = [ 263 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 264 | ] 265 | down_proj = sum(down_proj) 266 | else: 267 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 268 | 269 | return down_proj 270 | 271 | 272 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 273 | """ 274 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 275 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 276 | """ 277 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 278 | if n_rep == 1: 279 | return hidden_states 280 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 281 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 282 | 283 | 284 | class LlamaAttention(nn.Module): 285 | """Multi-headed attention from 'Attention Is All You Need' paper""" 286 | 287 | def __init__(self, config: LlamaConfig): 288 | super().__init__() 289 | self.config = config 290 | self.hidden_size = config.hidden_size 291 | self.num_heads = config.num_attention_heads 292 | self.head_dim = self.hidden_size // self.num_heads 293 | self.num_key_value_heads = config.num_key_value_heads 294 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 295 | self.max_position_embeddings = config.max_position_embeddings 296 | self.rope_theta = config.rope_theta 297 | 298 | if (self.head_dim * self.num_heads) != self.hidden_size: 299 | raise ValueError( 300 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 301 | f" and `num_heads`: {self.num_heads})." 302 | ) 303 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) 304 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 305 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 306 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) 307 | self._init_rope() 308 | 309 | self.expert_nums = 2 310 | self.gate_1 = nn.Linear(self.head_dim, 4 * self.head_dim, bias=True) 311 | self.gate_2 = nn.Linear(self.head_dim, 4 * self.head_dim, bias=True) 312 | self.gate_3 = nn.Linear(4 * self.head_dim, self.expert_nums, bias=True) 313 | self.act_fn = ACT2FN[config.hidden_act] 314 | 315 | def _init_rope(self): 316 | if self.config.rope_scaling is None: 317 | self.rotary_emb = LlamaRotaryEmbedding( 318 | self.head_dim, 319 | max_position_embeddings=self.max_position_embeddings, 320 | base=self.rope_theta, 321 | ) 322 | else: 323 | scaling_type = self.config.rope_scaling["type"] 324 | scaling_factor = self.config.rope_scaling["factor"] 325 | if scaling_type == "linear": 326 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding( 327 | self.head_dim, 328 | max_position_embeddings=self.max_position_embeddings, 329 | scaling_factor=scaling_factor, 330 | base=self.rope_theta, 331 | ) 332 | elif scaling_type == "dynamic": 333 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( 334 | self.head_dim, 335 | max_position_embeddings=self.max_position_embeddings, 336 | scaling_factor=scaling_factor, 337 | base=self.rope_theta, 338 | ) 339 | else: 340 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 341 | 342 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 343 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 344 | 345 | def forward( 346 | self, 347 | hidden_states: torch.Tensor, 348 | attention_mask: Optional[torch.Tensor] = None, 349 | position_ids: Optional[torch.LongTensor] = None, 350 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 351 | output_attentions: bool = False, 352 | use_cache: bool = False, 353 | padding_mask: Optional[torch.LongTensor] = None, 354 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 355 | bsz, q_len, _ = hidden_states.size() 356 | 357 | if self.config.pretraining_tp > 1: 358 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 359 | query_slices = self.q_proj.weight.split( 360 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 361 | ) 362 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 363 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 364 | 365 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 366 | query_states = torch.cat(query_states, dim=-1) 367 | 368 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 369 | key_states = torch.cat(key_states, dim=-1) 370 | 371 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 372 | value_states = torch.cat(value_states, dim=-1) 373 | 374 | else: 375 | query_states = self.q_proj(hidden_states) 376 | key_states = self.k_proj(hidden_states) 377 | value_states = self.v_proj(hidden_states) 378 | 379 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 380 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 381 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 382 | 383 | kv_seq_len = key_states.shape[-2] 384 | if past_key_value is not None: 385 | kv_seq_len += past_key_value[0].shape[-2] 386 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 387 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 388 | 389 | if past_key_value is not None: 390 | # reuse k, v, self_attention 391 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 392 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 393 | 394 | past_key_value = (key_states, value_states) if use_cache else None 395 | 396 | key_states = repeat_kv(key_states, self.num_key_value_groups) 397 | value_states = repeat_kv(value_states, self.num_key_value_groups) 398 | 399 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 400 | 401 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 402 | raise ValueError( 403 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 404 | f" {attn_weights.size()}" 405 | ) 406 | 407 | if attention_mask is not None: 408 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 409 | raise ValueError( 410 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 411 | ) 412 | attn_weights = attn_weights + attention_mask 413 | 414 | # upcast attention to fp32 415 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 416 | attn_output = torch.matmul(attn_weights, value_states) 417 | 418 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 419 | raise ValueError( 420 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 421 | f" {attn_output.size()}" 422 | ) 423 | 424 | attn_output = attn_output.transpose(1, 2).contiguous() 425 | 426 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 427 | 428 | if self.config.pretraining_tp > 1: 429 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 430 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 431 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) 432 | else: 433 | attn_output = self.o_proj(attn_output) 434 | 435 | if not output_attentions: 436 | attn_weights = None 437 | 438 | return attn_output, attn_weights, past_key_value 439 | 440 | 441 | class LlamaFlashAttention2Ours(LlamaAttention): 442 | """ 443 | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays 444 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 445 | flash attention and deal with padding tokens in case the input contains any of them. 446 | """ 447 | 448 | def forward( 449 | self, 450 | hidden_states: torch.Tensor, 451 | input_ids: torch.LongTensor = None, 452 | attention_mask: Optional[torch.Tensor] = None, 453 | position_ids: Optional[torch.LongTensor] = None, 454 | token_ids: Optional[torch.LongTensor] = None, 455 | pos_A: Optional[torch.Tensor] = None, 456 | pos_B: Optional[torch.Tensor] = None, 457 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 458 | output_attentions: bool = False, 459 | use_cache: bool = False, 460 | padding_mask: Optional[torch.LongTensor] = None, 461 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 462 | # LlamaFlashAttention2 attention does not support output_attentions 463 | output_attentions = False 464 | 465 | bsz, q_len, _ = hidden_states.size() 466 | 467 | query_states = self.q_proj(hidden_states) 468 | key_states = self.k_proj(hidden_states) 469 | value_states = self.v_proj(hidden_states) 470 | 471 | # Flash attention requires the input to have the shape 472 | # batch_size x seq_length x head_dime x hidden_dim 473 | # therefore we just need to keep the original shape 474 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 475 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 476 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 477 | 478 | kv_seq_len = key_states.shape[-2] 479 | if past_key_value is not None: 480 | kv_seq_len += past_key_value[0].shape[-2] 481 | 482 | cos, sin = self.rotary_emb(value_states, seq_len=max(position_ids.max()+1,token_ids.max()+1)) 483 | 484 | input_dtype = query_states.dtype 485 | if input_dtype == torch.float32: 486 | query_states = query_states.to(torch.float16) 487 | key_states = key_states.to(torch.float16) 488 | value_states = value_states.to(torch.float16) 489 | 490 | 491 | assert token_ids.size()[1] == 2*q_len 492 | q_x, k_x = apply_rotary_pos_emb(query_states, key_states, cos, sin, token_ids[..., :q_len]) 493 | q_y, k_y = apply_rotary_pos_emb(query_states, key_states, cos, sin, token_ids[..., q_len:]) 494 | 495 | 496 | if past_key_value is not None: 497 | k_x = torch.cat([past_key_value[0], k_x], dim=2) 498 | k_y = torch.cat([past_key_value[1], k_y], dim=2) 499 | value_states = torch.cat([past_key_value[2], value_states], dim=2) 500 | 501 | 502 | past_key_value = (k_x, k_y, value_states) if use_cache else None 503 | k_x = repeat_kv(k_x, self.num_key_value_groups) 504 | k_y = repeat_kv(k_y, self.num_key_value_groups) 505 | value_states = repeat_kv(value_states, self.num_key_value_groups) 506 | 507 | 508 | q_x, q_y = q_x.transpose(1, 2), q_y.transpose(1, 2) # [bsz, q_len, heads, head_dim] 509 | k_x, k_y = k_x.transpose(1, 2), k_y.transpose(1, 2) # [bsz, q_len, heads, head_dim] 510 | value_states = value_states.transpose(1, 2) # [bsz, q_len, heads, head_dim] 511 | 512 | v_x, v_y = value_states.clone(), value_states.clone() # [bsz, q_len, heads, head_dim] 513 | 514 | #sort 515 | if q_len > 1: 516 | _, token_col = torch.sort(token_ids[..., q_len:], descending=False, dim=-1) 517 | token_col = token_col.unsqueeze(2).unsqueeze(3).expand(-1, -1, self.num_heads, self.head_dim) 518 | q_y, k_y, v_y = torch.gather(q_y, 1, token_col), torch.gather(k_y, 1, token_col), torch.gather(value_states, 1, token_col) 519 | inverse_indices = token_ids[..., q_len:].unsqueeze(2).unsqueeze(3).expand(-1, -1, self.num_heads, self.head_dim) 520 | 521 | padding_mask = padding_mask.long() 522 | dropout_rate = 0.0 # if not self.training else self.attn_dropout 523 | attn_output_x = self._flash_attention_forward(q_x, k_x, v_x, padding_mask, q_len, dropout=dropout_rate) # [bsz, q_len, heads, head_dim] 524 | attn_output_y = self._flash_attention_forward(q_y, k_y, v_y, padding_mask, q_len, dropout=dropout_rate) # [bsz, q_len, heads, head_dim] 525 | 526 | if q_len > 1: 527 | attn_output_y = torch.gather(attn_output_y, 1, inverse_indices) # [bsz, q_len, heads, head_dim] 528 | 529 | #router 530 | hidden_states_for_router = hidden_states.view(bsz, q_len, self.num_heads, self.head_dim) 531 | router_logits = self.gate_3(self.act_fn(self.gate_1(hidden_states_for_router)) * self.gate_2(hidden_states_for_router)) 532 | routing_weights = nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) # [bsz, q_len, heads, expert_nums] 533 | routing_weights = routing_weights.to(hidden_states.dtype) 534 | routing_weights = routing_weights.permute(3,0,1,2).contiguous() # [expert_nums,bsz, q_len, heads] 535 | 536 | attn_output = torch.stack([attn_output_x, attn_output_y], dim=0) # [expert_nums, bsz, q_len, heads, head_dim] 537 | 538 | attn_output = routing_weights.unsqueeze(-1) * attn_output.view(self.expert_nums,bsz,q_len,self.num_heads,self.head_dim) 539 | attn_output = torch.sum(attn_output,dim=0) #[batch_size,q_len, heads,head_dim] 540 | 541 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 542 | attn_output = self.o_proj(attn_output) 543 | 544 | if not output_attentions: 545 | attn_weights = None 546 | 547 | #loss 548 | epsilon = 1e-10 549 | entropy_avg = 0.0 550 | entropy_list = [] 551 | if self.config.output_loss: 552 | for i in range(bsz): 553 | routing_back = routing_weights.clone()[:, i, pos_A[i]+4:pos_B[i], ...] 554 | routing_back = torch.clamp(routing_back, min=epsilon, max=1.0) 555 | entropy = -torch.sum(routing_back * torch.log(routing_back), dim=0) 556 | entropy = entropy.mean(dim=(0,1)) 557 | entropy_list.append(entropy) 558 | 559 | entropy_avg = torch.mean(torch.stack(entropy_list)) 560 | 561 | return attn_output, attn_weights, past_key_value, entropy_avg 562 | 563 | def _flash_attention_forward( 564 | self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None 565 | ): 566 | """ 567 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 568 | first unpad the input, then computes the attention scores and pad the final attention scores. 569 | 570 | Args: 571 | query_states (`torch.Tensor`): 572 | Input query states to be passed to Flash Attention API 573 | key_states (`torch.Tensor`): 574 | Input key states to be passed to Flash Attention API 575 | value_states (`torch.Tensor`): 576 | Input value states to be passed to Flash Attention API 577 | padding_mask (`torch.Tensor`): 578 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 579 | position of padding tokens and 1 for the position of non-padding tokens. 580 | dropout (`int`, *optional*): 581 | Attention dropout 582 | softmax_scale (`float`, *optional*): 583 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 584 | """ 585 | # Contains at least one padding token in the sequence 586 | if padding_mask is not None: 587 | batch_size = query_states.shape[0] 588 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 589 | query_states, key_states, value_states, padding_mask, query_length 590 | ) 591 | 592 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 593 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 594 | 595 | attn_output_unpad = flash_attn_varlen_func( 596 | query_states, 597 | key_states, 598 | value_states, 599 | cu_seqlens_q=cu_seqlens_q, 600 | cu_seqlens_k=cu_seqlens_k, 601 | max_seqlen_q=max_seqlen_in_batch_q, 602 | max_seqlen_k=max_seqlen_in_batch_k, 603 | dropout_p=dropout, 604 | softmax_scale=softmax_scale, 605 | causal=True, 606 | ) 607 | 608 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 609 | else: 610 | attn_output = flash_attn_func( 611 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True 612 | ) 613 | 614 | return attn_output 615 | 616 | def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): 617 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) 618 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 619 | 620 | key_layer = index_first_axis( 621 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 622 | ) 623 | value_layer = index_first_axis( 624 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 625 | ) 626 | num_heads = query_layer.size()[-2] 627 | if query_length == kv_seq_len: 628 | query_layer = index_first_axis( 629 | query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k 630 | ) 631 | cu_seqlens_q = cu_seqlens_k 632 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 633 | indices_q = indices_k 634 | elif query_length == 1: 635 | max_seqlen_in_batch_q = 1 636 | cu_seqlens_q = torch.arange( 637 | batch_size + 1, dtype=torch.int32, device=query_layer.device 638 | ) # There is a memcpy here, that is very bad. 639 | indices_q = cu_seqlens_q[:-1] 640 | query_layer = query_layer.squeeze(1) 641 | else: 642 | # The -q_len: slice assumes left padding. 643 | padding_mask = padding_mask[:, -query_length:] 644 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) 645 | 646 | return ( 647 | query_layer, 648 | key_layer, 649 | value_layer, 650 | indices_q, 651 | (cu_seqlens_q, cu_seqlens_k), 652 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 653 | ) 654 | 655 | 656 | class LlamaFlashAttention2(LlamaAttention): 657 | """ 658 | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays 659 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 660 | flash attention and deal with padding tokens in case the input contains any of them. 661 | """ 662 | 663 | def forward( 664 | self, 665 | hidden_states: torch.Tensor, 666 | attention_mask: Optional[torch.Tensor] = None, 667 | position_ids: Optional[torch.LongTensor] = None, 668 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 669 | output_attentions: bool = False, 670 | use_cache: bool = False, 671 | padding_mask: Optional[torch.LongTensor] = None, 672 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 673 | # LlamaFlashAttention2 attention does not support output_attentions 674 | output_attentions = False 675 | 676 | bsz, q_len, _ = hidden_states.size() 677 | 678 | query_states = self.q_proj(hidden_states) 679 | key_states = self.k_proj(hidden_states) 680 | value_states = self.v_proj(hidden_states) 681 | 682 | # Flash attention requires the input to have the shape 683 | # batch_size x seq_length x head_dime x hidden_dim 684 | # therefore we just need to keep the original shape 685 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 686 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 687 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 688 | 689 | kv_seq_len = key_states.shape[-2] 690 | if past_key_value is not None: 691 | kv_seq_len += past_key_value[0].shape[-2] 692 | 693 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 694 | 695 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 696 | 697 | if past_key_value is not None: 698 | # reuse k, v, self_attention 699 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 700 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 701 | 702 | past_key_value = (key_states, value_states) if use_cache else None 703 | 704 | query_states = query_states.transpose(1, 2) 705 | key_states = key_states.transpose(1, 2) 706 | value_states = value_states.transpose(1, 2) 707 | 708 | # TODO: llama does not have dropout in the config?? 709 | # It is recommended to use dropout with FA according to the docs 710 | # when training. 711 | dropout_rate = 0.0 # if not self.training else self.attn_dropout 712 | 713 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 714 | # therefore the input hidden states gets silently casted in float32. Hence, we need 715 | # cast them back in float16 just to be sure everything works as expected. 716 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 717 | # in fp32. (LlamaRMSNorm handles it correctly) 718 | input_dtype = query_states.dtype 719 | if input_dtype == torch.float32: 720 | logger.warning_once( 721 | "The input hidden states seems to be silently casted in float32, this might be related to" 722 | " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 723 | " float16." 724 | ) 725 | 726 | query_states = query_states.to(torch.float16) 727 | key_states = key_states.to(torch.float16) 728 | value_states = value_states.to(torch.float16) 729 | 730 | attn_output = self._flash_attention_forward( 731 | query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate 732 | ) 733 | 734 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 735 | attn_output = self.o_proj(attn_output) 736 | 737 | if not output_attentions: 738 | attn_weights = None 739 | 740 | return attn_output, attn_weights, past_key_value 741 | 742 | def _flash_attention_forward( 743 | self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None 744 | ): 745 | """ 746 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 747 | first unpad the input, then computes the attention scores and pad the final attention scores. 748 | 749 | Args: 750 | query_states (`torch.Tensor`): 751 | Input query states to be passed to Flash Attention API 752 | key_states (`torch.Tensor`): 753 | Input key states to be passed to Flash Attention API 754 | value_states (`torch.Tensor`): 755 | Input value states to be passed to Flash Attention API 756 | padding_mask (`torch.Tensor`): 757 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 758 | position of padding tokens and 1 for the position of non-padding tokens. 759 | dropout (`int`, *optional*): 760 | Attention dropout 761 | softmax_scale (`float`, *optional*): 762 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 763 | """ 764 | # Contains at least one padding token in the sequence 765 | if padding_mask is not None: 766 | batch_size = query_states.shape[0] 767 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 768 | query_states, key_states, value_states, padding_mask, query_length 769 | ) 770 | 771 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 772 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 773 | 774 | attn_output_unpad = flash_attn_varlen_func( 775 | query_states, 776 | key_states, 777 | value_states, 778 | cu_seqlens_q=cu_seqlens_q, 779 | cu_seqlens_k=cu_seqlens_k, 780 | max_seqlen_q=max_seqlen_in_batch_q, 781 | max_seqlen_k=max_seqlen_in_batch_k, 782 | dropout_p=dropout, 783 | softmax_scale=softmax_scale, 784 | causal=True, 785 | ) 786 | 787 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 788 | else: 789 | attn_output = flash_attn_func( 790 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True 791 | ) 792 | 793 | return attn_output 794 | 795 | def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): 796 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) 797 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 798 | 799 | key_layer = index_first_axis( 800 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 801 | ) 802 | value_layer = index_first_axis( 803 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 804 | ) 805 | if query_length == kv_seq_len: 806 | query_layer = index_first_axis( 807 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k 808 | ) 809 | cu_seqlens_q = cu_seqlens_k 810 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 811 | indices_q = indices_k 812 | elif query_length == 1: 813 | max_seqlen_in_batch_q = 1 814 | cu_seqlens_q = torch.arange( 815 | batch_size + 1, dtype=torch.int32, device=query_layer.device 816 | ) # There is a memcpy here, that is very bad. 817 | indices_q = cu_seqlens_q[:-1] 818 | query_layer = query_layer.squeeze(1) 819 | else: 820 | # The -q_len: slice assumes left padding. 821 | padding_mask = padding_mask[:, -query_length:] 822 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) 823 | 824 | return ( 825 | query_layer, 826 | key_layer, 827 | value_layer, 828 | indices_q, 829 | (cu_seqlens_q, cu_seqlens_k), 830 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 831 | ) 832 | 833 | 834 | class LlamaDecoderLayer(nn.Module): 835 | def __init__(self, config: LlamaConfig): 836 | super().__init__() 837 | self.hidden_size = config.hidden_size 838 | self.self_attn = ( 839 | LlamaAttention(config=config) 840 | if not getattr(config, "_flash_attn_2_enabled", False) 841 | else LlamaFlashAttention2Ours(config=config) 842 | ) 843 | self.mlp = LlamaMLP(config) 844 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 845 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 846 | 847 | def forward( 848 | self, 849 | hidden_states: torch.Tensor, 850 | input_ids: torch.LongTensor = None, 851 | token_ids: Optional[torch.LongTensor] = None, 852 | pos_A: Optional[torch.Tensor] = None, 853 | pos_B: Optional[torch.Tensor] = None, 854 | attention_mask: Optional[torch.Tensor] = None, 855 | position_ids: Optional[torch.LongTensor] = None, 856 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 857 | output_attentions: Optional[bool] = False, 858 | output_loss: Optional[bool] = False, 859 | use_cache: Optional[bool] = False, 860 | padding_mask: Optional[torch.LongTensor] = None, 861 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 862 | """ 863 | Args: 864 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 865 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 866 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 867 | output_attentions (`bool`, *optional*): 868 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 869 | returned tensors for more detail. 870 | use_cache (`bool`, *optional*): 871 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 872 | (see `past_key_values`). 873 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 874 | """ 875 | 876 | residual = hidden_states 877 | 878 | hidden_states = self.input_layernorm(hidden_states) 879 | 880 | # Self Attention 881 | hidden_states, self_attn_weights, present_key_value, entropy = self.self_attn( 882 | hidden_states=hidden_states, 883 | input_ids=input_ids, 884 | token_ids=token_ids, 885 | pos_A=pos_A, 886 | pos_B=pos_B, 887 | attention_mask=attention_mask, 888 | position_ids=position_ids, 889 | past_key_value=past_key_value, 890 | output_attentions=output_attentions, 891 | use_cache=use_cache, 892 | padding_mask=padding_mask, 893 | ) 894 | hidden_states = residual + hidden_states 895 | 896 | # Fully Connected 897 | residual = hidden_states 898 | hidden_states = self.post_attention_layernorm(hidden_states) 899 | hidden_states = self.mlp(hidden_states) 900 | hidden_states = residual + hidden_states 901 | 902 | outputs = (hidden_states, ) 903 | 904 | if output_loss: 905 | outputs += (entropy,) 906 | 907 | if output_attentions: 908 | outputs += (self_attn_weights,) 909 | 910 | if use_cache: 911 | outputs += (present_key_value,) 912 | 913 | return outputs 914 | 915 | 916 | LLAMA_START_DOCSTRING = r""" 917 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 918 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 919 | etc.) 920 | 921 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 922 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 923 | and behavior. 924 | 925 | Parameters: 926 | config ([`LlamaConfig`]): 927 | Model configuration class with all the parameters of the model. Initializing with a config file does not 928 | load the weights associated with the model, only the configuration. Check out the 929 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 930 | """ 931 | 932 | 933 | @add_start_docstrings( 934 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 935 | LLAMA_START_DOCSTRING, 936 | ) 937 | class LlamaPreTrainedModel(PreTrainedModel): 938 | config_class = LlamaConfig 939 | base_model_prefix = "model" 940 | supports_gradient_checkpointing = True 941 | _no_split_modules = ["LlamaDecoderLayer"] 942 | _skip_keys_device_placement = "past_key_values" 943 | _supports_flash_attn_2 = True 944 | 945 | def _init_weights(self, module): 946 | std = self.config.initializer_range 947 | if isinstance(module, nn.Linear): 948 | module.weight.data.normal_(mean=0.0, std=std) 949 | if module.bias is not None: 950 | module.bias.data.zero_() 951 | elif isinstance(module, nn.Embedding): 952 | module.weight.data.normal_(mean=0.0, std=std) 953 | if module.padding_idx is not None: 954 | module.weight.data[module.padding_idx].zero_() 955 | 956 | def _set_gradient_checkpointing(self, module, value=False): 957 | if isinstance(module, LlamaModel): 958 | module.gradient_checkpointing = value 959 | 960 | 961 | LLAMA_INPUTS_DOCSTRING = r""" 962 | Args: 963 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 964 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 965 | it. 966 | 967 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 968 | [`PreTrainedTokenizer.__call__`] for details. 969 | 970 | [What are input IDs?](../glossary#input-ids) 971 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 972 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 973 | 974 | - 1 for tokens that are **not masked**, 975 | - 0 for tokens that are **masked**. 976 | 977 | [What are attention masks?](../glossary#attention-mask) 978 | 979 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 980 | [`PreTrainedTokenizer.__call__`] for details. 981 | 982 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 983 | `past_key_values`). 984 | 985 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 986 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 987 | information on the default strategy. 988 | 989 | - 1 indicates the head is **not masked**, 990 | - 0 indicates the head is **masked**. 991 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 992 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 993 | config.n_positions - 1]`. 994 | 995 | [What are position IDs?](../glossary#position-ids) 996 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 997 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 998 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 999 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 1000 | 1001 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 1002 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 1003 | 1004 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 1005 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 1006 | of shape `(batch_size, sequence_length)`. 1007 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1008 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 1009 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 1010 | model's internal embedding lookup matrix. 1011 | use_cache (`bool`, *optional*): 1012 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 1013 | `past_key_values`). 1014 | output_attentions (`bool`, *optional*): 1015 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 1016 | tensors for more detail. 1017 | output_hidden_states (`bool`, *optional*): 1018 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 1019 | more detail. 1020 | return_dict (`bool`, *optional*): 1021 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1022 | """ 1023 | 1024 | 1025 | @add_start_docstrings( 1026 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 1027 | LLAMA_START_DOCSTRING, 1028 | ) 1029 | class LlamaModel(LlamaPreTrainedModel): 1030 | """ 1031 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 1032 | 1033 | Args: 1034 | config: LlamaConfig 1035 | """ 1036 | 1037 | def __init__(self, config: LlamaConfig): 1038 | super().__init__(config) 1039 | self.padding_idx = config.pad_token_id 1040 | self.vocab_size = config.vocab_size 1041 | 1042 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 1043 | self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 1044 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 1045 | 1046 | self.gradient_checkpointing = False 1047 | # Initialize weights and apply final processing 1048 | self.post_init() 1049 | 1050 | def get_input_embeddings(self): 1051 | return self.embed_tokens 1052 | 1053 | def set_input_embeddings(self, value): 1054 | self.embed_tokens = value 1055 | 1056 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 1057 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, input_ids, position_ids, substart, subend, inputs_embeds, past_key_values_length): 1058 | # create causal mask 1059 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1060 | combined_attention_mask = None 1061 | pos_A = None 1062 | pos_B = None 1063 | if input_shape[-1] > 1: 1064 | # combined_attention_mask = _make_causal_mask( 1065 | # input_shape, 1066 | # inputs_embeds.dtype, 1067 | # device=inputs_embeds.device, 1068 | # past_key_values_length=past_key_values_length, 1069 | # ) 1070 | combined_attention_mask, pos_A, pos_B = _make_causal_mask( 1071 | input_ids, 1072 | position_ids, 1073 | substart, 1074 | subend, 1075 | inputs_embeds.dtype, 1076 | device=inputs_embeds.device, 1077 | past_key_values_length=past_key_values_length, 1078 | ) 1079 | 1080 | if attention_mask is not None: 1081 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1082 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 1083 | inputs_embeds.device 1084 | ) 1085 | combined_attention_mask = ( 1086 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 1087 | ) 1088 | 1089 | return combined_attention_mask, pos_A, pos_B 1090 | 1091 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1092 | def forward( 1093 | self, 1094 | input_ids: torch.LongTensor = None, 1095 | attention_mask: Optional[torch.Tensor] = None, 1096 | position_ids: Optional[torch.LongTensor] = None, 1097 | token_ids: Optional[torch.LongTensor] = None, 1098 | substart: Optional[torch.LongTensor] = None, 1099 | subend: Optional[torch.LongTensor] = None, 1100 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1101 | inputs_embeds: Optional[torch.FloatTensor] = None, 1102 | use_cache: Optional[bool] = None, 1103 | output_attentions: Optional[bool] = None, 1104 | output_loss: Optional[bool] = None, 1105 | output_hidden_states: Optional[bool] = None, 1106 | return_dict: Optional[bool] = None, 1107 | ) -> Union[Tuple, BaseModelOutputWithPast]: 1108 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1109 | output_loss = output_loss if output_loss is not None else self.config.output_loss 1110 | output_hidden_states = ( 1111 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1112 | ) 1113 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1114 | 1115 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1116 | 1117 | # retrieve input_ids and inputs_embeds 1118 | if input_ids is not None and inputs_embeds is not None: 1119 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 1120 | elif input_ids is not None: 1121 | batch_size, seq_length = input_ids.shape 1122 | elif inputs_embeds is not None: 1123 | batch_size, seq_length, _ = inputs_embeds.shape 1124 | else: 1125 | raise ValueError("You have to specify either input_ids or inputs_embeds") 1126 | 1127 | seq_length_with_past = seq_length 1128 | past_key_values_length = 0 1129 | 1130 | if past_key_values is not None: 1131 | past_key_values_length = past_key_values[0][0].shape[2] 1132 | seq_length_with_past = seq_length_with_past + past_key_values_length 1133 | 1134 | if position_ids is None: 1135 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1136 | position_ids = torch.arange( 1137 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 1138 | ) 1139 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 1140 | else: 1141 | position_ids = position_ids.view(-1, seq_length*2).long() 1142 | 1143 | if inputs_embeds is None: 1144 | inputs_embeds = self.embed_tokens(input_ids) 1145 | 1146 | # embed positions 1147 | if attention_mask is None: 1148 | attention_mask = torch.ones( 1149 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device 1150 | ) 1151 | padding_mask = None 1152 | else: 1153 | if 0 in attention_mask: 1154 | padding_mask = attention_mask 1155 | else: 1156 | padding_mask = None 1157 | padding_mask = attention_mask 1158 | 1159 | attention_mask, pos_A, pos_B = self._prepare_decoder_attention_mask( 1160 | attention_mask, (batch_size, seq_length), input_ids, position_ids, substart, subend, inputs_embeds, past_key_values_length 1161 | ) 1162 | 1163 | hidden_states = inputs_embeds 1164 | 1165 | if self.gradient_checkpointing and self.training: 1166 | if use_cache: 1167 | logger.warning_once( 1168 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1169 | ) 1170 | use_cache = False 1171 | 1172 | # decoder layers 1173 | all_hidden_states = () if output_hidden_states else None 1174 | all_self_attns = () if output_attentions else None 1175 | next_decoder_cache = () if use_cache else None 1176 | 1177 | entropy_list = [] 1178 | 1179 | for idx, decoder_layer in enumerate(self.layers): 1180 | if output_hidden_states: 1181 | all_hidden_states += (hidden_states,) 1182 | 1183 | past_key_value = past_key_values[idx] if past_key_values is not None else None 1184 | 1185 | if self.gradient_checkpointing and self.training: 1186 | 1187 | def create_custom_forward(module): 1188 | def custom_forward(*inputs): 1189 | # None for past_key_value 1190 | return module(*inputs, past_key_value, output_attentions, output_loss=output_loss, padding_mask=padding_mask) 1191 | 1192 | return custom_forward 1193 | 1194 | layer_outputs = torch.utils.checkpoint.checkpoint( 1195 | create_custom_forward(decoder_layer), hidden_states, input_ids, token_ids, pos_A, pos_B, attention_mask, position_ids 1196 | ) 1197 | else: 1198 | layer_outputs = decoder_layer( 1199 | hidden_states, 1200 | input_ids=input_ids, 1201 | token_ids=token_ids, 1202 | pos_A=pos_A, 1203 | pos_B=pos_B, 1204 | attention_mask=attention_mask, 1205 | position_ids=position_ids, 1206 | past_key_value=past_key_value, 1207 | output_attentions=output_attentions, 1208 | output_loss=output_loss, 1209 | use_cache=use_cache, 1210 | padding_mask=padding_mask, 1211 | ) 1212 | 1213 | hidden_states = layer_outputs[0] 1214 | 1215 | if output_loss: 1216 | entropy_list.append(layer_outputs[1]) 1217 | 1218 | if use_cache: 1219 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 1220 | 1221 | if output_attentions: 1222 | all_self_attns += (layer_outputs[1],) 1223 | 1224 | if output_loss: 1225 | entropy_all = torch.mean(torch.stack(entropy_list)) 1226 | 1227 | hidden_states = self.norm(hidden_states) 1228 | 1229 | # add hidden states from the last decoder layer 1230 | if output_hidden_states: 1231 | all_hidden_states += (hidden_states,) 1232 | 1233 | next_cache = next_decoder_cache if use_cache else None 1234 | if not return_dict: 1235 | return tuple(v for v in [hidden_states, entropy_all, next_cache, all_hidden_states, all_self_attns] if v is not None) 1236 | return BaseModelOutputWithPast( 1237 | last_hidden_state=hidden_states, 1238 | past_key_values=next_cache, 1239 | hidden_states=all_hidden_states, 1240 | attentions=all_self_attns, 1241 | ) 1242 | 1243 | 1244 | class LlamaForCausalLM(LlamaPreTrainedModel): 1245 | _tied_weights_keys = ["lm_head.weight"] 1246 | 1247 | def __init__(self, config): 1248 | super().__init__(config) 1249 | self.model = LlamaModel(config) 1250 | self.vocab_size = config.vocab_size 1251 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1252 | 1253 | # Initialize weights and apply final processing 1254 | self.post_init() 1255 | 1256 | def get_input_embeddings(self): 1257 | return self.model.embed_tokens 1258 | 1259 | def set_input_embeddings(self, value): 1260 | self.model.embed_tokens = value 1261 | 1262 | def get_output_embeddings(self): 1263 | return self.lm_head 1264 | 1265 | def set_output_embeddings(self, new_embeddings): 1266 | self.lm_head = new_embeddings 1267 | 1268 | def set_decoder(self, decoder): 1269 | self.model = decoder 1270 | 1271 | def get_decoder(self): 1272 | return self.model 1273 | 1274 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1275 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1276 | def forward( 1277 | self, 1278 | input_ids: torch.LongTensor = None, 1279 | attention_mask: Optional[torch.Tensor] = None, 1280 | position_ids: Optional[torch.LongTensor] = None, 1281 | token_ids: Optional[torch.LongTensor] = None, 1282 | substart: Optional[torch.LongTensor] = None, 1283 | subend: Optional[torch.LongTensor] = None, 1284 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1285 | inputs_embeds: Optional[torch.FloatTensor] = None, 1286 | labels: Optional[torch.LongTensor] = None, 1287 | use_cache: Optional[bool] = None, 1288 | output_attentions: Optional[bool] = None, 1289 | output_loss: Optional[bool] = None, 1290 | output_hidden_states: Optional[bool] = None, 1291 | return_dict: Optional[bool] = None, 1292 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1293 | r""" 1294 | Args: 1295 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1296 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1297 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1298 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1299 | 1300 | Returns: 1301 | 1302 | Example: 1303 | 1304 | ```python 1305 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 1306 | 1307 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1308 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1309 | 1310 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1311 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1312 | 1313 | >>> # Generate 1314 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1315 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1316 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1317 | ```""" 1318 | 1319 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1320 | output_loss = output_loss if output_loss is not None else self.config.output_loss 1321 | output_hidden_states = ( 1322 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1323 | ) 1324 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1325 | 1326 | if output_loss: 1327 | return_dict = False 1328 | 1329 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1330 | outputs = self.model( 1331 | input_ids=input_ids, 1332 | attention_mask=attention_mask, 1333 | position_ids=position_ids, 1334 | token_ids=token_ids, 1335 | substart=substart, 1336 | subend=subend, 1337 | past_key_values=past_key_values, 1338 | inputs_embeds=inputs_embeds, 1339 | use_cache=use_cache, 1340 | output_attentions=output_attentions, 1341 | output_hidden_states=output_hidden_states, 1342 | return_dict=return_dict, 1343 | ) 1344 | 1345 | hidden_states = outputs[0] 1346 | 1347 | if output_loss: 1348 | entropy = outputs[1] 1349 | 1350 | 1351 | if self.config.lamda: 1352 | lamda = self.config.lamda 1353 | else: 1354 | lamda = 1 1355 | 1356 | if self.config.pretraining_tp > 1: 1357 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 1358 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 1359 | logits = torch.cat(logits, dim=-1) 1360 | else: 1361 | logits = self.lm_head(hidden_states) 1362 | logits = logits.float() 1363 | 1364 | loss = None 1365 | if labels is not None: 1366 | # Shift so that tokens < n predict n 1367 | shift_logits = logits[..., :-1, :].contiguous() 1368 | shift_labels = labels[..., 1:].contiguous() 1369 | # Flatten the tokens 1370 | loss_fct = CrossEntropyLoss() 1371 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1372 | shift_labels = shift_labels.view(-1) 1373 | # Enable model parallelism 1374 | shift_labels = shift_labels.to(shift_logits.device) 1375 | loss = loss_fct(shift_logits, shift_labels) 1376 | 1377 | loss = loss + lamda * entropy 1378 | 1379 | if not return_dict: 1380 | output = (logits,) + outputs[1:] 1381 | return (loss,) + output if loss is not None else output 1382 | 1383 | return CausalLMOutputWithPast( 1384 | loss=loss, 1385 | logits=logits, 1386 | past_key_values=outputs.past_key_values, 1387 | hidden_states=outputs.hidden_states, 1388 | attentions=outputs.attentions, 1389 | ) 1390 | 1391 | def prepare_inputs_for_generation( 1392 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1393 | ): 1394 | if past_key_values: 1395 | input_ids = input_ids[:, -1:] 1396 | 1397 | position_ids = kwargs.get("position_ids", None) 1398 | token_ids = kwargs.get("token_ids", None) 1399 | substart = kwargs.get("substart", None) 1400 | subend = kwargs.get("subend", None) 1401 | if attention_mask is not None and position_ids is None: 1402 | # create position_ids on the fly for batch generation 1403 | position_ids = attention_mask.long().cumsum(-1) - 1 1404 | position_ids.masked_fill_(attention_mask == 0, 1) 1405 | if past_key_values: 1406 | position_ids = position_ids[:, -1].unsqueeze(-1) 1407 | 1408 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1409 | if inputs_embeds is not None and past_key_values is None: 1410 | model_inputs = {"inputs_embeds": inputs_embeds} 1411 | else: 1412 | model_inputs = {"input_ids": input_ids} 1413 | 1414 | model_inputs.update( 1415 | { 1416 | "position_ids": position_ids, 1417 | "past_key_values": past_key_values, 1418 | "use_cache": kwargs.get("use_cache"), 1419 | "attention_mask": attention_mask, 1420 | "token_ids": token_ids, 1421 | "substart": substart, 1422 | "subend": subend, 1423 | } 1424 | ) 1425 | return model_inputs 1426 | 1427 | @staticmethod 1428 | def _reorder_cache(past_key_values, beam_idx): 1429 | reordered_past = () 1430 | for layer_past in past_key_values: 1431 | reordered_past += ( 1432 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1433 | ) 1434 | return reordered_past 1435 | 1436 | 1437 | @add_start_docstrings( 1438 | """ 1439 | The LLaMa Model transformer with a sequence classification head on top (linear layer). 1440 | 1441 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1442 | (e.g. GPT-2) do. 1443 | 1444 | Since it does classification on the last token, it requires to know the position of the last token. If a 1445 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1446 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1447 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1448 | each row of the batch). 1449 | """, 1450 | LLAMA_START_DOCSTRING, 1451 | ) 1452 | class LlamaForSequenceClassification(LlamaPreTrainedModel): 1453 | def __init__(self, config): 1454 | super().__init__(config) 1455 | self.num_labels = config.num_labels 1456 | self.model = LlamaModel(config) 1457 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1458 | 1459 | # Initialize weights and apply final processing 1460 | self.post_init() 1461 | 1462 | def get_input_embeddings(self): 1463 | return self.model.embed_tokens 1464 | 1465 | def set_input_embeddings(self, value): 1466 | self.model.embed_tokens = value 1467 | 1468 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1469 | def forward( 1470 | self, 1471 | input_ids: torch.LongTensor = None, 1472 | attention_mask: Optional[torch.Tensor] = None, 1473 | position_ids: Optional[torch.LongTensor] = None, 1474 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1475 | inputs_embeds: Optional[torch.FloatTensor] = None, 1476 | labels: Optional[torch.LongTensor] = None, 1477 | use_cache: Optional[bool] = None, 1478 | output_attentions: Optional[bool] = None, 1479 | output_hidden_states: Optional[bool] = None, 1480 | return_dict: Optional[bool] = None, 1481 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1482 | r""" 1483 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1484 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1485 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1486 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1487 | """ 1488 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1489 | 1490 | transformer_outputs = self.model( 1491 | input_ids, 1492 | attention_mask=attention_mask, 1493 | position_ids=position_ids, 1494 | past_key_values=past_key_values, 1495 | inputs_embeds=inputs_embeds, 1496 | use_cache=use_cache, 1497 | output_attentions=output_attentions, 1498 | output_hidden_states=output_hidden_states, 1499 | return_dict=return_dict, 1500 | ) 1501 | hidden_states = transformer_outputs[0] 1502 | logits = self.score(hidden_states) 1503 | 1504 | if input_ids is not None: 1505 | batch_size = input_ids.shape[0] 1506 | else: 1507 | batch_size = inputs_embeds.shape[0] 1508 | 1509 | if self.config.pad_token_id is None and batch_size != 1: 1510 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1511 | if self.config.pad_token_id is None: 1512 | sequence_lengths = -1 1513 | else: 1514 | if input_ids is not None: 1515 | sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( 1516 | logits.device 1517 | ) 1518 | else: 1519 | sequence_lengths = -1 1520 | 1521 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1522 | 1523 | loss = None 1524 | if labels is not None: 1525 | labels = labels.to(logits.device) 1526 | if self.config.problem_type is None: 1527 | if self.num_labels == 1: 1528 | self.config.problem_type = "regression" 1529 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1530 | self.config.problem_type = "single_label_classification" 1531 | else: 1532 | self.config.problem_type = "multi_label_classification" 1533 | 1534 | if self.config.problem_type == "regression": 1535 | loss_fct = MSELoss() 1536 | if self.num_labels == 1: 1537 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1538 | else: 1539 | loss = loss_fct(pooled_logits, labels) 1540 | elif self.config.problem_type == "single_label_classification": 1541 | loss_fct = CrossEntropyLoss() 1542 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1543 | elif self.config.problem_type == "multi_label_classification": 1544 | loss_fct = BCEWithLogitsLoss() 1545 | loss = loss_fct(pooled_logits, labels) 1546 | if not return_dict: 1547 | output = (pooled_logits,) + transformer_outputs[1:] 1548 | return ((loss,) + output) if loss is not None else output 1549 | 1550 | return SequenceClassifierOutputWithPast( 1551 | loss=loss, 1552 | logits=pooled_logits, 1553 | past_key_values=transformer_outputs.past_key_values, 1554 | hidden_states=transformer_outputs.hidden_states, 1555 | attentions=transformer_outputs.attentions, 1556 | ) 1557 | --------------------------------------------------------------------------------