├── .gitignore ├── README.md ├── evaluation ├── equal.py ├── eval.py ├── eval_llama3.py ├── eval_vicuna.py ├── inference_baseline.py ├── inference_eagle.py ├── inference_eagle2.py ├── inference_pld.py ├── inference_sam_only.py ├── inference_sam_only.py.bak ├── inference_samd.py ├── inference_token_recycle.py ├── model │ ├── eagle │ │ ├── __init__.py │ │ ├── choices.py │ │ ├── cnets.py │ │ ├── config.json │ │ ├── configs.py │ │ ├── ea_model.py │ │ ├── kv_cache.py │ │ ├── modeling_Mixtral_kv.py │ │ ├── modeling_llama_kv.py │ │ ├── utils.py │ │ ├── utils_alpha.py │ │ └── utils_c.py │ ├── eagle2 │ │ ├── __init__.py │ │ ├── choices.py │ │ ├── cnets.py │ │ ├── configs.py │ │ ├── ea_model.py │ │ ├── kv_cache.py │ │ ├── modeling_llama_kv.py │ │ ├── modeling_mixtral_kv.py │ │ ├── modeling_qwen2_kv.py │ │ ├── utils.py │ │ ├── utils_alpha.py │ │ └── utils_c.py │ ├── pld │ │ └── pld.py │ ├── sam_only │ │ ├── __init__.py │ │ ├── cache.py │ │ ├── cache.py.bak │ │ ├── config │ │ │ ├── default_tree.json │ │ │ ├── default_tree.json.bak0 │ │ │ ├── default_tree.json.bak1 │ │ │ ├── default_tree.json.bak2 │ │ │ ├── default_tree.json.bak3 │ │ │ ├── default_tree.json.bak4 │ │ │ ├── default_tree.json.bak5 │ │ │ ├── default_tree_1_1.json │ │ │ ├── default_tree_6_60.json │ │ │ ├── eagle.json │ │ │ └── token_recycle.json │ │ ├── draft.py │ │ ├── model_patch │ │ │ ├── __init__.py │ │ │ └── llama.py │ │ ├── sam │ │ │ ├── __init__.py │ │ │ ├── sam.py │ │ │ ├── sam.py.bak │ │ │ └── utils.py │ │ ├── samd_config.py │ │ ├── samd_model.py │ │ └── utils.py │ └── token_recycle │ │ ├── __init__.py │ │ ├── attn_patch │ │ ├── __init__.py │ │ └── llama.py │ │ ├── cache.py │ │ ├── config │ │ ├── default_tree.json │ │ └── default_tree_80.json │ │ ├── draft.py │ │ ├── token_recycle.py │ │ ├── token_recycle_config.py │ │ ├── token_recycle_model.py │ │ └── utils.py ├── profile_entry.py ├── profile_sam_only.py ├── profile_samd.py └── speed.py ├── profile_utils.py ├── sam_data └── list.txt ├── samd ├── __init__.py ├── cache.py ├── cache.py.bak ├── config │ ├── default_tree.json │ ├── default_tree.json.bak0 │ ├── default_tree.json.bak1 │ ├── default_tree.json.bak2 │ ├── default_tree.json.bak3 │ ├── default_tree.json.bak4 │ ├── default_tree.json.bak5 │ ├── default_tree_1_1.json │ ├── default_tree_6_60.json │ ├── eagle.json │ └── token_recycle.json ├── draft.py ├── inference │ ├── __init__.py │ └── cli.py ├── model_patch │ ├── __init__.py │ └── llama.py ├── sam │ ├── __init__.py │ ├── dyn_sam.py │ ├── sam.py.bak │ ├── static_sam.py │ └── utils.py ├── samd_config.py ├── samd_model.py ├── tree_model │ ├── __init__.py │ ├── eagle │ │ ├── __init__.py │ │ ├── eagle.py │ │ ├── eagle_config.py │ │ ├── eagle_model.py │ │ ├── eagle_utils.py │ │ └── utils.py │ ├── eagle2 │ │ ├── __init__.py │ │ ├── eagle2.py │ │ ├── eagle2_config.py │ │ ├── eagle2_model.py │ │ ├── eagle2_utils.py │ │ └── utils.py │ ├── token_recycle │ │ ├── __init__.py │ │ ├── token_recycle.py │ │ └── utils.py │ └── tree.py └── utils.py ├── samd_sam_only ├── __init__.py ├── cache.py ├── config │ ├── default_tree.json │ ├── default_tree.json.bak0 │ ├── default_tree.json.bak1 │ ├── default_tree.json.bak2 │ ├── default_tree.json.bak3 │ ├── default_tree.json.bak4 │ ├── default_tree.json.bak5 │ ├── default_tree_1_1.json │ ├── default_tree_6_60.json │ ├── eagle.json │ └── token_recycle.json ├── draft.py ├── inference │ ├── __init__.py │ ├── cli.py │ └── cli_baseline.py ├── model_patch │ ├── __init__.py │ └── llama.py ├── sam │ ├── __init__.py │ ├── dyn_sam.py │ ├── static_sam.py │ └── utils.py ├── samd_config.py ├── samd_model.py └── utils.py ├── scripts ├── equal.sh ├── inference_baseline.sh ├── inference_eagle.sh ├── inference_eagle2.sh ├── inference_pld.sh ├── inference_samd.sh ├── inference_samd_sam_only.sh ├── inference_token_recycle.sh ├── speed.sh ├── test_samd.sh └── test_samd_sam_only.sh ├── tests ├── test_samd.py ├── test_samd_sam_only.py └── test_token_recycle.py └── tools ├── data_utils.py ├── gen_default_tree.py ├── gen_response.py ├── gen_sam_alpaca.py ├── gen_sam_alpaca_sam_only.py ├── gen_sam_none.py ├── gen_sam_none_sam_only.py ├── prepare_prompts.py └── prompter.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | sam_data/* 3 | local_cache/* 4 | third_party/ 5 | evaluation/data 6 | test.py 7 | rsync.txt 8 | misc 9 | evaluation/model/pia 10 | -------------------------------------------------------------------------------- /evaluation/equal.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | 5 | 6 | def get_content(jsonfile_path="../data/mt_bench/model_answer/vicuna-7b-v1.3-pld-float32.jsonl", 7 | output_path="../data/mt_bench/model_answer/txt/vicuna-7b-v1.3-pld-float32.txt"): 8 | data = [] 9 | with open(jsonfile_path, 'r', encoding='utf-8') as file: 10 | for line in file: 11 | json_obj = json.loads(line) 12 | data.append(json_obj) 13 | 14 | contents=[] 15 | for datapoint in data: 16 | turns=datapoint["choices"][0]['turns'] 17 | contents.append(str(turns)) 18 | 19 | with open(output_path, 'w') as file: 20 | for content in contents: 21 | file.write(content) 22 | file.write('\n') 23 | 24 | 25 | def txt_compare(file_path1, file_path2): 26 | cnt_neq = 0 27 | cnt = 0 28 | with open(file_path1, 'r', encoding='utf-8') as f1: 29 | lines1 = f1.readlines() 30 | with open(file_path2, 'r', encoding='utf-8') as f2: 31 | lines2 = f2.readlines() 32 | for l1, l2 in zip(lines1, lines2): 33 | if l1 != l2: 34 | cnt_neq += 1 35 | print(l1, "\n", l2) 36 | cnt += 1 37 | print(f"neq: {cnt_neq}, all: {cnt}, ratio: {cnt_neq / cnt}") 38 | return cnt_neq == 0 39 | 40 | 41 | def run_compare(file_path, jsonfile1, jsonfile2): 42 | jsonfile_path1 = os.path.join(file_path, jsonfile1) 43 | jsonfile_path2 = os.path.join(file_path, jsonfile2) 44 | output_path1 = file_path + "txt/" + jsonfile1.replace("jsonl", "txt") 45 | output_path2 = file_path + "txt/" + jsonfile2.replace("jsonl", "txt") 46 | if not os.path.exists(file_path + "txt/"): 47 | os.makedirs(file_path + "txt/") 48 | get_content(jsonfile_path1, output_path1) 49 | get_content(jsonfile_path2, output_path2) 50 | if txt_compare(output_path1, output_path2): 51 | print("Result totally Equal!") 52 | else: 53 | print("Not Equal!") 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | 59 | parser.add_argument( 60 | "--file-path", 61 | default='evaluation/data/spec_bench/model_answer/', 62 | type=str, 63 | help="The file path of model answers.", 64 | ) 65 | parser.add_argument( 66 | "--jsonfile1", 67 | default='vicuna-7b-v1.3-vanilla-float32-temp-0.0.jsonl', 68 | type=str, 69 | help="The file name of the first evaluated method.", 70 | ) 71 | parser.add_argument( 72 | "--jsonfile2", 73 | default='vicuna-7b-v1.3-sps-68m-float32-temp-0.0.jsonl', 74 | type=str, 75 | help="The file name of the second evaluated method.", 76 | ) 77 | args = parser.parse_args() 78 | run_compare(args.file_path, args.jsonfile1, args.jsonfile2) -------------------------------------------------------------------------------- /evaluation/eval.py: -------------------------------------------------------------------------------- 1 | from evaluation.eval_vicuna import ( 2 | run_eval as run_eval_vicuna, 3 | reorg_answer_file as reorg_answer_file_vicuna 4 | ) 5 | from evaluation.eval_llama3 import ( 6 | run_eval as run_eval_llama3, 7 | reorg_answer_file as reorg_answer_file_llama3, 8 | ) 9 | 10 | run_evals = { 11 | "vicuna": run_eval_vicuna, 12 | "llama3": run_eval_llama3 13 | } 14 | reorg_answer_files = { 15 | "vicuna": reorg_answer_file_vicuna, 16 | "llama3": reorg_answer_file_llama3 17 | } -------------------------------------------------------------------------------- /evaluation/inference_baseline.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | from fastchat.utils import str_to_torch_dtype 8 | 9 | from evaluation.eval import run_evals, reorg_answer_files 10 | 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | 13 | 14 | def baseline_forward(inputs, model, tokenizer, max_new_tokens, temperature=0.0, do_sample=False): 15 | input_ids = inputs.input_ids 16 | output_ids = model.generate( 17 | input_ids, 18 | do_sample=do_sample, 19 | temperature=temperature, 20 | max_new_tokens=max_new_tokens, 21 | ) 22 | new_token = len(output_ids[0][len(input_ids[0]):]) 23 | step = new_token 24 | accept_length_list = [1] * new_token 25 | return output_ids, new_token, step, accept_length_list 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--template", 32 | type=str, 33 | default="vicuna", 34 | choices=["vicuna", "llama3"] 35 | ) 36 | parser.add_argument( 37 | "--model-type", 38 | type=str, 39 | required=True, 40 | choices=["vicuna", "llama3"] 41 | ) 42 | parser.add_argument( 43 | "--model-path", 44 | type=str, 45 | required=True, 46 | ) 47 | parser.add_argument("--model-id", type=str, required=True) 48 | parser.add_argument( 49 | "--bench-name", 50 | type=str, 51 | default="mt_bench", 52 | help="The name of the benchmark question set.", 53 | ) 54 | parser.add_argument( 55 | "--question-begin", 56 | type=int, 57 | help="A debug option. The begin index of questions.", 58 | ) 59 | parser.add_argument( 60 | "--question-end", 61 | type=int, 62 | help="A debug option. The end index of questions." 63 | ) 64 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 65 | parser.add_argument( 66 | "--max-new-tokens", 67 | type=int, 68 | default=1024, 69 | help="The maximum number of new generated tokens.", 70 | ) 71 | parser.add_argument( 72 | "--num-choices", 73 | type=int, 74 | default=1, 75 | help="How many completion choices to generate.", 76 | ) 77 | parser.add_argument( 78 | "--num-gpus-per-model", 79 | type=int, 80 | default=1, 81 | help="The number of GPUs per model.", 82 | ) 83 | parser.add_argument( 84 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 85 | ) 86 | parser.add_argument( 87 | "--temperature", 88 | type=float, 89 | default=0.0, 90 | help="The temperature for medusa sampling.", 91 | ) 92 | parser.add_argument( 93 | "--dtype", 94 | type=str, 95 | default="float16", 96 | choices=["float32", "float64", "float16", "bfloat16"], 97 | help="Override the default dtype. If not set, it will use float16 on GPU.", 98 | ) 99 | 100 | args = parser.parse_args() 101 | 102 | question_file = f"evaluation/data/{args.bench_name}/question.jsonl" 103 | 104 | if args.answer_file: 105 | answer_file = args.answer_file 106 | else: 107 | answer_file = f"evaluation/data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 108 | 109 | print(f"Output to {answer_file}") 110 | 111 | model = AutoModelForCausalLM.from_pretrained( 112 | args.model_path, 113 | torch_dtype=str_to_torch_dtype(args.dtype), 114 | low_cpu_mem_usage=True, 115 | device_map="auto" 116 | ) 117 | 118 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 119 | 120 | if args.temperature > 0: 121 | do_sample = True 122 | else: 123 | do_sample = False 124 | 125 | if args.model_type == "llama3": 126 | tokenizer.pad_token_id = tokenizer.eos_token_id 127 | 128 | run_evals[args.template]( 129 | model=model, 130 | tokenizer=tokenizer, 131 | forward_func=baseline_forward, 132 | model_id=args.model_id, 133 | question_file=question_file, 134 | question_begin=args.question_begin, 135 | question_end=args.question_end, 136 | answer_file=answer_file, 137 | max_new_tokens=args.max_new_tokens, 138 | num_choices=args.num_choices, 139 | num_gpus_per_model=args.num_gpus_per_model, 140 | num_gpus_total=args.num_gpus_total, 141 | temperature=args.temperature, 142 | do_sample=do_sample, 143 | ) 144 | 145 | reorg_answer_files[args.template](answer_file) 146 | -------------------------------------------------------------------------------- /evaluation/inference_eagle2.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import torch 7 | import argparse 8 | from fastchat.utils import str_to_torch_dtype 9 | 10 | from evaluation.eval import run_evals, reorg_answer_files 11 | from evaluation.model.eagle2.ea_model import EaModel 12 | from functools import partial 13 | 14 | def ea_forward(inputs, model, tokenizer, max_new_tokens, temperature=0.0, is_llama3=False): 15 | input_ids = inputs.input_ids 16 | max_length=model.config.max_position_embeddings 17 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 18 | input_ids, new_token, step, accept_length_list = model.eagenerate( 19 | torch.as_tensor(input_ids).cuda(), 20 | temperature=temperature, 21 | max_new_tokens=max_new_tokens, 22 | max_length=max_length, 23 | log=True, 24 | is_llama3=is_llama3, 25 | ) 26 | return input_ids, new_token, step, accept_length_list 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | "--template", 33 | type=str, 34 | default="vicuna", 35 | choices=["vicuna", "llama3"] 36 | ) 37 | parser.add_argument( 38 | "--model-type", 39 | type=str, 40 | required=True, 41 | choices=["vicuna", "llama3"] 42 | ) 43 | parser.add_argument( 44 | "--ea-model-path", 45 | type=str, 46 | default="down_checkpoints/LC70B", 47 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 48 | ) 49 | parser.add_argument("--base-model-path", type=str, default="/home/lyh/weights/hf/llama2chat/70B/", 50 | help="1") 51 | parser.add_argument( 52 | "--load-in-8bit", action="store_false", help="Use 8-bit quantization" 53 | ) 54 | parser.add_argument("--model-id", type=str, default="ess-vicuna-70b-fp16") 55 | parser.add_argument( 56 | "--bench-name", 57 | type=str, 58 | default="mt_bench", 59 | help="The name of the benchmark question set.", 60 | ) 61 | parser.add_argument( 62 | "--question-begin", 63 | type=int, 64 | help="A debug option. The begin index of questions.", 65 | ) 66 | parser.add_argument( 67 | "--question-end", type=int, help="A debug option. The end index of questions." 68 | ) 69 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 70 | parser.add_argument( 71 | "--max-new-tokens", 72 | type=int, 73 | default=1024, 74 | help="The maximum number of new generated tokens.", 75 | ) 76 | parser.add_argument( 77 | "--total-token", 78 | type=int, 79 | default=60, 80 | help="The maximum number of new generated tokens.", 81 | ) 82 | parser.add_argument( 83 | "--depth", 84 | type=int, 85 | default=5, 86 | help="The maximum number of new generated tokens.", 87 | ) 88 | parser.add_argument( 89 | "--top-k", 90 | type=int, 91 | default=10, 92 | help="The maximum number of new generated tokens.", 93 | ) 94 | parser.add_argument( 95 | "--num-choices", 96 | type=int, 97 | default=1, 98 | help="How many completion choices to generate.", 99 | ) 100 | parser.add_argument( 101 | "--num-gpus-per-model", 102 | type=int, 103 | default=1, 104 | help="The number of GPUs per model.", 105 | ) 106 | parser.add_argument( 107 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 108 | ) 109 | parser.add_argument( 110 | "--temperature", 111 | type=float, 112 | default=0.0, 113 | ) 114 | parser.add_argument( 115 | "--tree-choices", 116 | type=str, 117 | default="mc_sim_7b_63", 118 | ) 119 | parser.add_argument( 120 | "--dtype", 121 | type=str, 122 | default="float16", 123 | choices=["float32", "float64", "float16", "bfloat16"], 124 | help="Override the default dtype. If not set, it will use float16 on GPU.", 125 | ) 126 | parser.add_argument("--is_llama3", action="store_true") 127 | 128 | args = parser.parse_args() 129 | 130 | args.model_id = args.model_id + "-temperature-" + str(args.temperature) 131 | 132 | question_file = f"evaluation/data/{args.bench_name}/question.jsonl" 133 | 134 | if args.answer_file: 135 | answer_file = args.answer_file 136 | else: 137 | answer_file = f"evaluation/data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 138 | 139 | print(f"Output to {answer_file}") 140 | 141 | model = EaModel.from_pretrained( 142 | base_model_path=args.base_model_path, 143 | ea_model_path=args.ea_model_path, 144 | total_token=args.total_token, 145 | depth=args.depth, 146 | top_k=args.top_k, 147 | torch_dtype=str_to_torch_dtype(args.dtype), 148 | low_cpu_mem_usage=True, 149 | # load_in_8bit=True, 150 | device_map="auto" 151 | ) 152 | 153 | tokenizer = model.get_tokenizer() 154 | 155 | if args.model_type == "llama3": 156 | ea_forward = partial(ea_forward, is_llama3=True) 157 | 158 | run_evals[args.template]( 159 | model=model, 160 | tokenizer=tokenizer, 161 | forward_func=ea_forward, 162 | model_id=args.model_id, 163 | question_file=question_file, 164 | question_begin=args.question_begin, 165 | question_end=args.question_end, 166 | answer_file=answer_file, 167 | max_new_tokens=args.max_new_tokens, 168 | num_choices=args.num_choices, 169 | num_gpus_per_model=args.num_gpus_per_model, 170 | num_gpus_total=args.num_gpus_total, 171 | temperature=args.temperature, 172 | ) 173 | 174 | reorg_answer_files[args.template](answer_file) 175 | -------------------------------------------------------------------------------- /evaluation/inference_pld.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | 8 | from evaluation.eval import run_evals, reorg_answer_files 9 | 10 | from fastchat.utils import str_to_torch_dtype 11 | 12 | from transformers import StoppingCriteriaList, MaxLengthCriteria 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | from evaluation.model.pld.pld import greedy_search_pld 15 | 16 | 17 | def pld_forward(inputs, model, tokenizer, max_new_tokens): 18 | input_ids = inputs.input_ids 19 | output_ids, idx, accept_length_list = model.greedy_search_pld( 20 | inputs.input_ids, 21 | attention_mask=inputs.attention_mask, 22 | stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=len(inputs.input_ids[0]) + max_new_tokens)]), 23 | draft_matching_window_size=3, 24 | draft_num_candidate_tokens=10, 25 | use_cache=True, 26 | pad_token_id=tokenizer.pad_token_id, 27 | eos_token_id=tokenizer.eos_token_id, 28 | return_dict_in_generate=False) 29 | input_len = len(input_ids[0]) 30 | new_token = len(output_ids[0][input_len:]) 31 | if tokenizer.eos_token_id in output_ids[0, input_len:].tolist(): 32 | for i, id in enumerate(output_ids[0, input_len:]): 33 | if id == tokenizer.eos_token_id: 34 | eos_token_ids_index = i 35 | invalid_len = len(output_ids[0, input_len:]) - eos_token_ids_index - 1 36 | if invalid_len > 0: 37 | accept_length_list[-1] -= invalid_len 38 | new_token -= invalid_len 39 | return output_ids, new_token, idx+1, accept_length_list 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument( 45 | "--template", 46 | type=str, 47 | default="vicuna", 48 | choices=["vicuna", "llama3"] 49 | ) 50 | parser.add_argument( 51 | "--model-type", 52 | type=str, 53 | required=True, 54 | choices=["vicuna", "llama3"] 55 | ) 56 | parser.add_argument( 57 | "--model-path", 58 | type=str, 59 | required=True, 60 | ) 61 | parser.add_argument("--model-id", type=str, required=True) 62 | parser.add_argument( 63 | "--bench-name", 64 | type=str, 65 | default="mt_bench", 66 | help="The name of the benchmark question set.", 67 | ) 68 | parser.add_argument( 69 | "--question-begin", 70 | type=int, 71 | help="A debug option. The begin index of questions.", 72 | ) 73 | parser.add_argument( 74 | "--question-end", 75 | type=int, 76 | help="A debug option. The end index of questions." 77 | ) 78 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 79 | parser.add_argument( 80 | "--max-new-tokens", 81 | type=int, 82 | default=1024, 83 | help="The maximum number of new generated tokens.", 84 | ) 85 | parser.add_argument( 86 | "--num-choices", 87 | type=int, 88 | default=1, 89 | help="How many completion choices to generate.", 90 | ) 91 | parser.add_argument( 92 | "--num-gpus-per-model", 93 | type=int, 94 | default=1, 95 | help="The number of GPUs per model.", 96 | ) 97 | parser.add_argument( 98 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 99 | ) 100 | parser.add_argument( 101 | "--dtype", 102 | type=str, 103 | default="float16", 104 | choices=["float32", "float64", "float16", "bfloat16"], 105 | help="Override the default dtype. If not set, it will use float16 on GPU.", 106 | ) 107 | 108 | args = parser.parse_args() 109 | 110 | question_file = f"evaluation/data/{args.bench_name}/question.jsonl" 111 | 112 | if args.answer_file: 113 | answer_file = args.answer_file 114 | else: 115 | answer_file = f"evaluation/data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 116 | 117 | print(f"Output to {answer_file}") 118 | 119 | 120 | model = AutoModelForCausalLM.from_pretrained( 121 | args.model_path, 122 | torch_dtype=str_to_torch_dtype(args.dtype), 123 | low_cpu_mem_usage=True, 124 | device_map="auto" 125 | ) 126 | 127 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 128 | 129 | model.greedy_search_pld = greedy_search_pld.__get__(model, type(model)) 130 | 131 | run_evals[args.template]( 132 | model=model, 133 | tokenizer=tokenizer, 134 | forward_func=pld_forward, 135 | model_id=args.model_id, 136 | question_file=question_file, 137 | question_begin=args.question_begin, 138 | question_end=args.question_end, 139 | answer_file=answer_file, 140 | max_new_tokens=args.max_new_tokens, 141 | num_choices=args.num_choices, 142 | num_gpus_per_model=args.num_gpus_per_model, 143 | num_gpus_total=args.num_gpus_total, 144 | ) 145 | 146 | reorg_answer_files[args.template](answer_file) 147 | -------------------------------------------------------------------------------- /evaluation/inference_sam_only.py.bak: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | from fastchat.utils import str_to_torch_dtype 8 | from evaluation.eval import run_evals, reorg_answer_files 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer 10 | from evaluation.model.sam_only import SamdConfig, SamdModel, SamdGenerationConfig, DraftModel, load_sam 11 | 12 | def sam_only_forward( 13 | inputs, 14 | model: SamdModel, 15 | tokenizer: PreTrainedTokenizer, 16 | max_new_tokens: int, 17 | temperature: float = 0.0, 18 | do_sample: bool = False 19 | ): 20 | max_cache_len = model.lm.config.max_position_embeddings 21 | input_ids = inputs.input_ids 22 | outputs = model.generate( 23 | input_ids, 24 | generation_config=SamdGenerationConfig( 25 | max_new_tokens=max_new_tokens, 26 | max_cache_len=max_cache_len, 27 | ), 28 | ) 29 | output_ids = outputs.output_ids 30 | new_token = outputs.decode_tokens 31 | step = outputs.decode_steps 32 | accept_length_list = outputs.accepet_length_per_step 33 | return output_ids, new_token, step, accept_length_list 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument( 39 | "--template", 40 | type=str, 41 | default="vicuna", 42 | choices=["vicuna", "llama3"] 43 | ) 44 | parser.add_argument( 45 | "--model-type", 46 | type=str, 47 | required=True, 48 | choices=["vicuna", "llama3"] 49 | ) 50 | parser.add_argument( 51 | "--model-path", 52 | type=str, 53 | required=True, 54 | ) 55 | parser.add_argument("--model-id", type=str, required=True) 56 | parser.add_argument( 57 | "--bench-name", 58 | type=str, 59 | default="mt_bench", 60 | help="The name of the benchmark question set.", 61 | ) 62 | parser.add_argument( 63 | "--question-begin", 64 | type=int, 65 | help="A debug option. The begin index of questions.", 66 | ) 67 | parser.add_argument( 68 | "--question-end", 69 | type=int, 70 | help="A debug option. The end index of questions." 71 | ) 72 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 73 | parser.add_argument( 74 | "--max-new-tokens", 75 | type=int, 76 | default=1024, 77 | help="The maximum number of new generated tokens.", 78 | ) 79 | parser.add_argument( 80 | "--num-choices", 81 | type=int, 82 | default=1, 83 | help="How many completion choices to generate.", 84 | ) 85 | parser.add_argument( 86 | "--num-gpus-per-model", 87 | type=int, 88 | default=1, 89 | help="The number of GPUs per model.", 90 | ) 91 | parser.add_argument( 92 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 93 | ) 94 | parser.add_argument( 95 | "--temperature", 96 | type=float, 97 | default=0.0, 98 | help="The temperature for medusa sampling.", 99 | ) 100 | parser.add_argument( 101 | "--dtype", 102 | type=str, 103 | default="float16", 104 | choices=["float32", "float64", "float16", "bfloat16"], 105 | help="Override the default dtype. If not set, it will use float16 on GPU.", 106 | ) 107 | parser.add_argument( 108 | "--sam_path", 109 | type=str, 110 | default=None 111 | ) 112 | parser.add_argument( 113 | "--samd_max_predicts", 114 | type=int, 115 | default=15 116 | ) 117 | parser.add_argument( 118 | "--samd_len_bias", 119 | type=int, 120 | default=5 121 | ) 122 | args = parser.parse_args() 123 | 124 | question_file = f"evaluation/data/{args.bench_name}/question.jsonl" 125 | 126 | if args.answer_file: 127 | answer_file = args.answer_file 128 | else: 129 | answer_file = f"evaluation/data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 130 | 131 | print(f"Output to {answer_file}") 132 | 133 | print("len_bias:", args.samd_len_bias) 134 | 135 | model = AutoModelForCausalLM.from_pretrained( 136 | args.model_path, 137 | torch_dtype=str_to_torch_dtype(args.dtype), 138 | low_cpu_mem_usage=True, 139 | device_map="cuda", 140 | # attn_implementation="eager", 141 | ) 142 | 143 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 144 | 145 | samd_config = SamdConfig( 146 | max_predicts=args.samd_max_predicts, 147 | len_bias=args.samd_len_bias, 148 | ) 149 | if args.sam_path is not None: 150 | sam = load_sam(args.sam_path) 151 | else: 152 | sam = None 153 | draft = DraftModel( 154 | samd_config, 155 | sam_dyn=None, 156 | sam_static=sam, 157 | lm=model, 158 | dtype=str_to_torch_dtype(args.dtype), 159 | device="cuda" 160 | ) 161 | samd_model = SamdModel( 162 | samd_config, 163 | model, 164 | draft, 165 | tokenizer.eos_token_id, 166 | str_to_torch_dtype(args.dtype), 167 | "cuda", 168 | ) 169 | 170 | if args.temperature > 0: 171 | do_sample = True 172 | else: 173 | do_sample = False 174 | 175 | run_evals[args.template]( 176 | model=samd_model, 177 | tokenizer=tokenizer, 178 | forward_func=sam_only_forward, 179 | model_id=args.model_id, 180 | question_file=question_file, 181 | question_begin=args.question_begin, 182 | question_end=args.question_end, 183 | answer_file=answer_file, 184 | max_new_tokens=args.max_new_tokens, 185 | num_choices=args.num_choices, 186 | num_gpus_per_model=args.num_gpus_per_model, 187 | num_gpus_total=args.num_gpus_total, 188 | temperature=args.temperature, 189 | do_sample=do_sample, 190 | ) 191 | 192 | reorg_answer_files[args.template](answer_file) 193 | -------------------------------------------------------------------------------- /evaluation/inference_token_recycle.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | from fastchat.utils import str_to_torch_dtype 8 | from evaluation.eval import run_evals, reorg_answer_files 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer 10 | from evaluation.model.token_recycle import ( 11 | TokenRecycleConfig, 12 | TokenRecycleModel, 13 | TokenRecycleGenerationConfig, 14 | DraftModel 15 | ) 16 | 17 | def token_recycle_forward( 18 | inputs, 19 | model: TokenRecycleModel, 20 | tokenizer: PreTrainedTokenizer, 21 | max_new_tokens: int, 22 | temperature: float = 0.0, 23 | do_sample: bool = False 24 | ): 25 | max_cache_len = model.lm.config.max_position_embeddings 26 | input_ids = inputs.input_ids 27 | outputs = model.generate( 28 | input_ids, 29 | generation_config=TokenRecycleGenerationConfig( 30 | max_new_tokens=max_new_tokens, 31 | max_cache_len=max_cache_len, 32 | temperature=temperature 33 | ), 34 | ) 35 | output_ids = outputs.output_ids 36 | new_token = outputs.decode_tokens 37 | step = outputs.decode_steps 38 | accept_length_list = outputs.accepet_length_per_step 39 | return output_ids, new_token, step, accept_length_list 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument( 45 | "--template", 46 | type=str, 47 | default="vicuna", 48 | choices=["vicuna", "llama3"] 49 | ) 50 | parser.add_argument( 51 | "--model-type", 52 | type=str, 53 | required=True, 54 | choices=["vicuna", "llama3"] 55 | ) 56 | parser.add_argument( 57 | "--model-path", 58 | type=str, 59 | required=True, 60 | ) 61 | parser.add_argument("--model-id", type=str, required=True) 62 | parser.add_argument( 63 | "--bench-name", 64 | type=str, 65 | default="mt_bench", 66 | help="The name of the benchmark question set.", 67 | ) 68 | parser.add_argument( 69 | "--question-begin", 70 | type=int, 71 | help="A debug option. The begin index of questions.", 72 | ) 73 | parser.add_argument( 74 | "--question-end", 75 | type=int, 76 | help="A debug option. The end index of questions." 77 | ) 78 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 79 | parser.add_argument( 80 | "--max-new-tokens", 81 | type=int, 82 | default=1024, 83 | help="The maximum number of new generated tokens.", 84 | ) 85 | parser.add_argument( 86 | "--num-choices", 87 | type=int, 88 | default=1, 89 | help="How many completion choices to generate.", 90 | ) 91 | parser.add_argument( 92 | "--num-gpus-per-model", 93 | type=int, 94 | default=1, 95 | help="The number of GPUs per model.", 96 | ) 97 | parser.add_argument( 98 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 99 | ) 100 | parser.add_argument( 101 | "--temperature", 102 | type=float, 103 | default=0.0, 104 | help="The temperature for medusa sampling.", 105 | ) 106 | parser.add_argument( 107 | "--dtype", 108 | type=str, 109 | default="float16", 110 | choices=["float32", "float64", "float16", "bfloat16"], 111 | help="Override the default dtype. If not set, it will use float16 on GPU.", 112 | ) 113 | args = parser.parse_args() 114 | 115 | question_file = f"evaluation/data/{args.bench_name}/question.jsonl" 116 | 117 | if args.answer_file: 118 | answer_file = args.answer_file 119 | else: 120 | answer_file = f"evaluation/data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 121 | 122 | print(f"Output to {answer_file}") 123 | 124 | if args.num_gpus_total == 1: 125 | device_map = "cuda" 126 | else: 127 | device_map = "auto" 128 | 129 | model = AutoModelForCausalLM.from_pretrained( 130 | args.model_path, 131 | torch_dtype=str_to_torch_dtype(args.dtype), 132 | low_cpu_mem_usage=True, 133 | device_map=device_map 134 | ) 135 | 136 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 137 | 138 | token_recycle_config = TokenRecycleConfig() 139 | draft = DraftModel(token_recycle_config) 140 | token_recycle_model = TokenRecycleModel( 141 | token_recycle_config, 142 | model, 143 | draft, 144 | tokenizer.eos_token_id, 145 | str_to_torch_dtype(args.dtype), 146 | "cuda", 147 | ) 148 | 149 | if args.temperature > 0: 150 | do_sample = True 151 | else: 152 | do_sample = False 153 | 154 | run_evals[args.template]( 155 | model=token_recycle_model, 156 | tokenizer=tokenizer, 157 | forward_func=token_recycle_forward, 158 | model_id=args.model_id, 159 | question_file=question_file, 160 | question_begin=args.question_begin, 161 | question_end=args.question_end, 162 | answer_file=answer_file, 163 | max_new_tokens=args.max_new_tokens, 164 | num_choices=args.num_choices, 165 | num_gpus_per_model=args.num_gpus_per_model, 166 | num_gpus_total=args.num_gpus_total, 167 | temperature=args.temperature, 168 | do_sample=do_sample, 169 | ) 170 | 171 | reorg_answer_files[args.template](answer_file) 172 | -------------------------------------------------------------------------------- /evaluation/model/eagle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyx1999/SAM-Decoding/18c41f055b424fa3fa0bac41a8953d34cea1ed77/evaluation/model/eagle/__init__.py -------------------------------------------------------------------------------- /evaluation/model/eagle/choices.py: -------------------------------------------------------------------------------- 1 | mc_sim_7b_63 = [[0],[1],[2],[3],[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0] 2 | ,[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,2,0],[0,2,1],[1,0,0], 3 | [0,0,0,0],[0,0,0,1],[0,0,0,2],[0,0,0,0,0],[0,0,0,0,1]] -------------------------------------------------------------------------------- /evaluation/model/eagle/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "bos_token_id": 1, 6 | "eos_token_id": 2, 7 | "hidden_act": "silu", 8 | "hidden_size": 6656, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 17920, 11 | "max_sequence_length": 2048, 12 | "model_type": "llama", 13 | "num_attention_heads": 52, 14 | "num_key_value_heads": 13, 15 | "num_hidden_layers": 1, 16 | "pad_token_id": 0, 17 | "rms_norm_eps": 1e-06, 18 | "tie_word_embeddings": false, 19 | "torch_dtype": "float16", 20 | "transformers_version": "4.28.0.dev0", 21 | "use_cache": true, 22 | "vocab_size": 32000 23 | } 24 | -------------------------------------------------------------------------------- /evaluation/model/eagle2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyx1999/SAM-Decoding/18c41f055b424fa3fa0bac41a8953d34cea1ed77/evaluation/model/eagle2/__init__.py -------------------------------------------------------------------------------- /evaluation/model/eagle2/choices.py: -------------------------------------------------------------------------------- 1 | mc_sim_7b_63 = [[0],[1],[2],[3],[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0] 2 | ,[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,2,0],[0,2,1],[1,0,0], 3 | [0,0,0,0],[0,0,0,1],[0,0,0,2],[0,0,0,0,0],[0,0,0,0,1]] 4 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/__init__.py: -------------------------------------------------------------------------------- 1 | from .samd_config import SamdConfig 2 | from .samd_model import SamdModel 3 | from .utils import SamdGenerationConfig 4 | from .sam import build_sam, load_sam, dump_sam 5 | from .draft import DraftModel -------------------------------------------------------------------------------- /evaluation/model/sam_only/cache.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers import PretrainedConfig 4 | from transformers.cache_utils import DynamicCache, Cache 5 | from typing import Optional, Dict, Any, Tuple 6 | from profile_utils import profile_decorator 7 | 8 | class SamdCache(DynamicCache): 9 | 10 | def __init__(self, num_hidden_layers: int | None = None) -> None: 11 | super().__init__(num_hidden_layers) 12 | self.cache_length = 0 13 | 14 | def set_length(self): 15 | self.cache_length = self.get_seq_length() 16 | 17 | @profile_decorator("SamdCache.select_indices") 18 | def select_indices(self, 19 | indices: torch.Tensor | None = None, 20 | accept_length: int = 1, 21 | ): 22 | start = self.cache_length 23 | if indices is not None: 24 | select_indices = start + indices 25 | else: 26 | select_indices = None 27 | for data in self.key_cache + self.value_cache: 28 | if select_indices is not None: 29 | select_indices = select_indices.to(data.device) 30 | tgt = data.index_select(-2, select_indices) 31 | dst = data.narrow(-2, start, accept_length) 32 | dst.copy_(tgt) 33 | self.cache_length += accept_length 34 | self.crop(self.cache_length) 35 | 36 | 37 | class SamdStaticCache(Cache): 38 | 39 | def __init__(self, 40 | num_hidden_layers: int, 41 | num_attention_heads: int, 42 | num_key_value_heads: int, 43 | hidden_size: int, 44 | max_batch_size: int, 45 | max_cache_len: int, 46 | device=None, 47 | dtype=None 48 | ) -> None: 49 | super().__init__() 50 | self.max_batch_size = max_batch_size 51 | self.max_cache_len = max_cache_len 52 | self.cur_length = 0 53 | self.new_length = 0 54 | self.kv_data = torch.zeros( 55 | num_hidden_layers * 2, 56 | max_batch_size, 57 | num_key_value_heads, 58 | max_cache_len, 59 | hidden_size // num_attention_heads, 60 | device=device, 61 | dtype=dtype 62 | ) 63 | self.devcie = device 64 | self.dtype = dtype 65 | 66 | def get_seq_length(self, layer_idx: int | None = 0) -> int: 67 | return self.cur_length 68 | 69 | def get_max_cache_shape(self) -> int: 70 | return self.max_cache_len 71 | 72 | def reorder_cache(self, beam_idx): 73 | raise NotImplementedError 74 | 75 | def reset(self): 76 | self.cur_length = 0 77 | self.new_length = 0 78 | 79 | def update( 80 | self, 81 | key_states: torch.Tensor, 82 | value_states: torch.Tensor, 83 | layer_idx: int, 84 | cache_kwargs: Optional[Dict[str, Any]] = None, 85 | ) -> Tuple[torch.Tensor, torch.Tensor]: 86 | start = self.cur_length 87 | length = key_states.shape[-2] 88 | self.new_length = start + length 89 | self.kv_data[2 * layer_idx + 0]\ 90 | .narrow(-2, start, length)\ 91 | .copy_(key_states) 92 | self.kv_data[2 * layer_idx + 1]\ 93 | .narrow(-2, start, length)\ 94 | .copy_(value_states) 95 | 96 | k_out = self.kv_data[2 * layer_idx + 0].narrow(-2, 0, start + length) 97 | v_out = self.kv_data[2 * layer_idx + 1].narrow(-2, 0, start + length) 98 | return k_out, v_out 99 | 100 | def select_indices(self, 101 | indices: torch.Tensor | None = None, 102 | accept_length: int = 1, 103 | ): 104 | start = self.cur_length 105 | if indices is not None: 106 | select_indices = start + indices 107 | else: 108 | select_indices = None 109 | if select_indices is not None: 110 | tgt = self.kv_data.index_select(-2, select_indices) 111 | dst = self.kv_data.narrow(-2, start, accept_length) 112 | dst.copy_(tgt) 113 | self.cur_length += accept_length 114 | 115 | def set_length(self): 116 | self.cur_length = self.new_length 117 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/cache.py.bak: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers.cache_utils import Cache, StaticCache 4 | from transformers.configuration_utils import PretrainedConfig 5 | from typing import List, Optional, Dict, Any, Tuple 6 | from enum import Enum 7 | 8 | 9 | class SamdCache(Cache): 10 | 11 | def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: 12 | super().__init__() 13 | self.config = config 14 | self.max_batch_size = max_batch_size 15 | self.max_cache_len = max_cache_len 16 | self.cur_length = torch.tensor(0, dtype=torch.long, device=device) 17 | self.kv_data = torch.zeros( 18 | config.num_hidden_layers * 2, 19 | max_batch_size, 20 | config.num_key_value_heads, 21 | max_cache_len, 22 | config.hidden_size // config.num_attention_heads, 23 | device=device, 24 | dtype=dtype 25 | ) 26 | self.devcie = device 27 | self.dtype = dtype 28 | 29 | def get_seq_length(self, layer_idx: int | None = 0) -> int: 30 | return self.cur_length.item() 31 | 32 | def get_max_cache_shape(self) -> int: 33 | return self.max_cache_len 34 | 35 | def reorder_cache(self, beam_idx): 36 | raise NotImplementedError 37 | 38 | def reset(self): 39 | self.kv_data.fill_(0) 40 | self.cur_length.fill_(0) 41 | 42 | def update( 43 | self, 44 | key_states: torch.Tensor, 45 | value_states: torch.Tensor, 46 | layer_idx: int, 47 | cache_kwargs: Optional[Dict[str, Any]] = None, 48 | ) -> Tuple[torch.Tensor, torch.Tensor]: 49 | start = self.cur_length 50 | length = key_states.shape[-2] 51 | self.kv_data[2 * layer_idx + 0]\ 52 | .narrow(-2, start, length)\ 53 | .copy_(key_states) 54 | self.kv_data[2 * layer_idx + 1]\ 55 | .narrow(-2, start, length)\ 56 | .copy_(value_states) 57 | 58 | k_out = self.kv_data[2 * layer_idx + 0].narrow(-2, 0, start + length) 59 | v_out = self.kv_data[2 * layer_idx + 1].narrow(-2, 0, start + length) 60 | return k_out, v_out 61 | 62 | def post_update(self, indices: torch.Tensor): 63 | start = self.cur_length 64 | select_indices = start + indices 65 | accept_length = indices.shape[-1] 66 | tgt = self.kv_data.index_select(-2, select_indices) 67 | dst = self.kv_data.narrow(-2, start, accept_length) 68 | dst.copy_(tgt) 69 | self.cur_length += accept_length 70 | 71 | def set_cache_positions(self, length): 72 | self.cur_length.fill_(length) 73 | 74 | 75 | class SamdStaticCache(StaticCache): 76 | 77 | def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: 78 | super().__init__(config, max_batch_size, max_cache_len, device, dtype) 79 | self.cur_length = torch.tensor(0, dtype=torch.long, device=device) 80 | self.devcie = device 81 | 82 | def get_seq_length(self, layer_idx: int | None = 0) -> int: 83 | return self.cur_length.item() 84 | 85 | def reset(self): 86 | super().reset() 87 | self.cur_length.fill_(0) 88 | 89 | def update( 90 | self, 91 | key_states: torch.Tensor, 92 | value_states: torch.Tensor, 93 | layer_idx: int, 94 | cache_kwargs: Optional[Dict[str, Any]] = None, 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | start = self.cur_length 97 | self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) 98 | self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) 99 | k_out = self.key_cache[layer_idx] 100 | v_out = self.value_cache[layer_idx] 101 | 102 | key_length = value_length = key_states.shape[2] 103 | k_out[:, :, start:start + key_length] = key_states 104 | v_out[:, :, start:start + value_length] = value_states 105 | 106 | return k_out, v_out 107 | 108 | @torch.no_grad() 109 | def post_update(self, indices: torch.Tensor): 110 | start = self.cur_length 111 | select_positions = self.cur_length + indices 112 | accept_length = indices.shape[-1] 113 | for layer_idx in range(len(self.key_cache)): 114 | self.key_cache[layer_idx][:, :, start:start + accept_length] \ 115 | = self.key_cache[layer_idx][:, :, select_positions] 116 | self.value_cache[layer_idx][:, :, start:start + accept_length] \ 117 | = self.value_cache[layer_idx][:, :, select_positions] 118 | self.cur_length += accept_length 119 | 120 | @torch.no_grad() 121 | def set_cache_positions(self, length): 122 | self.cur_length.fill_(length) -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [], 46 | "9": [], 47 | "10": [], 48 | "11": [], 49 | "12": [], 50 | "14": [], 51 | "15": [], 52 | "16": [], 53 | "17": [], 54 | "18": [], 55 | "19": [], 56 | "20": [], 57 | "22": [], 58 | "23": [] 59 | } 60 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree.json.bak0: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8 11 | ], 12 | "1": [ 13 | 9, 14 | 10, 15 | 11, 16 | 12, 17 | 13, 18 | 14, 19 | 15, 20 | 16 21 | ], 22 | "2": [ 23 | 17, 24 | 18, 25 | 19, 26 | 20 27 | ], 28 | "3": [ 29 | 21, 30 | 22, 31 | 23 32 | ], 33 | "4": [ 34 | 24, 35 | 25 36 | ], 37 | "5": [ 38 | 26 39 | ], 40 | "6": [ 41 | 27 42 | ], 43 | "7": [ 44 | 28 45 | ], 46 | "8": [ 47 | 29 48 | ], 49 | "9": [ 50 | 30, 51 | 31, 52 | 32, 53 | 33, 54 | 34, 55 | 35, 56 | 36, 57 | 37 58 | ], 59 | "10": [ 60 | 38, 61 | 39, 62 | 40 63 | ], 64 | "11": [ 65 | 41, 66 | 42 67 | ], 68 | "12": [ 69 | 43 70 | ], 71 | "13": [ 72 | 44 73 | ], 74 | "14": [ 75 | 45 76 | ], 77 | "15": [ 78 | 46 79 | ], 80 | "16": [ 81 | 47 82 | ], 83 | "17": [ 84 | 48, 85 | 49 86 | ], 87 | "18": [ 88 | 50 89 | ], 90 | "21": [ 91 | 51 92 | ], 93 | "24": [ 94 | 52 95 | ], 96 | "26": [ 97 | 53 98 | ], 99 | "27": [ 100 | 54 101 | ], 102 | "30": [ 103 | 55, 104 | 56, 105 | 57, 106 | 58, 107 | 59 108 | ], 109 | "31": [ 110 | 60, 111 | 61 112 | ], 113 | "32": [ 114 | 62 115 | ], 116 | "19": [], 117 | "20": [], 118 | "22": [], 119 | "23": [], 120 | "25": [], 121 | "28": [], 122 | "29": [], 123 | "33": [ 124 | 63 125 | ], 126 | "34": [], 127 | "35": [], 128 | "36": [], 129 | "37": [], 130 | "38": [], 131 | "39": [], 132 | "40": [], 133 | "41": [], 134 | "42": [], 135 | "43": [], 136 | "44": [], 137 | "45": [], 138 | "46": [], 139 | "47": [], 140 | "48": [], 141 | "49": [], 142 | "50": [], 143 | "51": [], 144 | "52": [], 145 | "53": [], 146 | "54": [], 147 | "55": [], 148 | "56": [], 149 | "57": [], 150 | "58": [], 151 | "59": [], 152 | "60": [], 153 | "61": [], 154 | "62": [], 155 | "63": [] 156 | } -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree.json.bak1: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 2, 3, 4 4 | ], 5 | "1": [ 6 | 5 7 | ], 8 | "2": [], 9 | "3": [], 10 | "4": [], 11 | "5": [ 12 | 6 13 | ], 14 | "6": [ 15 | 7 16 | ], 17 | "7": [] 18 | } 19 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree.json.bak2: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1 4 | ], 5 | "1": [ 6 | 2 7 | ], 8 | "2": [ 9 | 3 10 | ], 11 | "3": [ 12 | 4 13 | ], 14 | "4": [] 15 | } 16 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree.json.bak3: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [], 46 | "9": [], 47 | "10": [], 48 | "11": [], 49 | "12": [], 50 | "14": [], 51 | "15": [], 52 | "16": [], 53 | "17": [], 54 | "18": [], 55 | "19": [], 56 | "20": [], 57 | "22": [], 58 | "23": [] 59 | } 60 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree.json.bak4: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [ 46 | 24, 47 | 25 48 | ], 49 | "9": [], 50 | "10": [], 51 | "11": [], 52 | "12": [], 53 | "14": [], 54 | "15": [], 55 | "16": [], 56 | "17": [], 57 | "18": [], 58 | "19": [], 59 | "20": [], 60 | "22": [], 61 | "23": [], 62 | "24": [ 63 | 25 64 | ], 65 | "25": [ 66 | 26 67 | ], 68 | "26": [ 69 | 27 70 | ], 71 | "27": [ 72 | 28 73 | ], 74 | "28": [] 75 | } 76 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree.json.bak5: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [ 46 | 24, 47 | 25 48 | ], 49 | "9": [], 50 | "10": [], 51 | "11": [], 52 | "12": [], 53 | "14": [], 54 | "15": [], 55 | "16": [], 56 | "17": [], 57 | "18": [], 58 | "19": [], 59 | "20": [], 60 | "22": [], 61 | "23": [], 62 | "24": [ 63 | 25 64 | ], 65 | "25": [ 66 | 26 67 | ], 68 | "26": [ 69 | 27 70 | ], 71 | "27": [ 72 | 28 73 | ], 74 | "28": [] 75 | } 76 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree_1_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [] 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/default_tree_6_60.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [ 4 | 1, 5 | 2, 6 | 3, 7 | 4, 8 | 5, 9 | 6, 10 | 7 11 | ], 12 | "1": [ 13 | 8, 14 | 9, 15 | 10, 16 | 11, 17 | 12, 18 | 13 19 | ], 20 | "2": [ 21 | 14, 22 | 15, 23 | 16, 24 | 17, 25 | 18 26 | ], 27 | "3": [ 28 | 19, 29 | 20, 30 | 21 31 | ], 32 | "4": [ 33 | 22, 34 | 23 35 | ], 36 | "5": [ 37 | 24, 38 | 25 39 | ], 40 | "6": [ 41 | 26 42 | ], 43 | "7": [ 44 | 27 45 | ], 46 | "8": [ 47 | 28, 48 | 29, 49 | 30 50 | ], 51 | "13": [ 52 | ], 53 | "9": [ 54 | 31, 55 | 32 56 | ], 57 | "10": [ 58 | 33, 59 | 34 60 | ], 61 | "11": [ 62 | 35 63 | ], 64 | "12": [ 65 | 36 66 | ], 67 | "14": [ 68 | 37, 69 | 38, 70 | 39 71 | ], 72 | "15": [ 73 | 40, 74 | 41 75 | ], 76 | "16": [ 77 | 42 78 | ], 79 | "17": [ 80 | 43 81 | ], 82 | "18": [ 83 | ], 84 | "19": [ 85 | 44 86 | ], 87 | "20": [ 88 | 45 89 | ], 90 | "21": [ 91 | ], 92 | "22": [ 93 | 46 94 | ], 95 | "23": [ 96 | ], 97 | "24": [ 98 | 47 99 | ], 100 | "25": [ 101 | ], 102 | "26": [ 103 | 48 104 | ], 105 | "27": [ 106 | ], 107 | "28": [ 108 | 49, 109 | 50 110 | ], 111 | "29": [ 112 | 51 113 | ], 114 | "30": [ 115 | ], 116 | "31": [ 117 | 52 118 | ], 119 | "32": [], 120 | "33": [ 121 | ], 122 | "34": [ 123 | ], 124 | "35": [ 125 | ], 126 | "36": [], 127 | "37": [ 128 | 53, 129 | 54 130 | ], 131 | "38": [ 132 | ], 133 | "39": [], 134 | "40": [ 135 | 55 136 | ], 137 | "41": [], 138 | "42": [], 139 | "43": [ 140 | ], 141 | "44": [ 142 | 56 143 | ], 144 | "45": [], 145 | "46": [ 146 | ], 147 | "47": [], 148 | "48": [], 149 | "49": [ 150 | 57, 151 | 58 152 | ], 153 | "50": [], 154 | "51": [ 155 | 59 156 | ], 157 | "52": [ 158 | ], 159 | "53": [ 160 | 60 161 | ], 162 | "54": [ 163 | ], 164 | "55": [ 165 | ], 166 | "56": [], 167 | "57": [], 168 | "58": [], 169 | "59": [], 170 | "60": [] 171 | } 172 | } -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/eagle.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_choices": [ 3 | [ 4 | 0 5 | ], 6 | [ 7 | 1 8 | ], 9 | [ 10 | 2 11 | ], 12 | [ 13 | 3 14 | ], 15 | [ 16 | 0, 17 | 0 18 | ], 19 | [ 20 | 0, 21 | 1 22 | ], 23 | [ 24 | 0, 25 | 2 26 | ], 27 | [ 28 | 1, 29 | 0 30 | ], 31 | [ 32 | 1, 33 | 1 34 | ], 35 | [ 36 | 2, 37 | 0 38 | ], 39 | [ 40 | 2, 41 | 1 42 | ], 43 | [ 44 | 3, 45 | 0 46 | ], 47 | [ 48 | 0, 49 | 0, 50 | 0 51 | ], 52 | [ 53 | 0, 54 | 0, 55 | 1 56 | ], 57 | [ 58 | 0, 59 | 0, 60 | 2 61 | ], 62 | [ 63 | 0, 64 | 1, 65 | 0 66 | ], 67 | [ 68 | 0, 69 | 1, 70 | 1 71 | ], 72 | [ 73 | 0, 74 | 2, 75 | 0 76 | ], 77 | [ 78 | 0, 79 | 2, 80 | 1 81 | ], 82 | [ 83 | 1, 84 | 0, 85 | 0 86 | ], 87 | [ 88 | 0, 89 | 0, 90 | 0, 91 | 0 92 | ], 93 | [ 94 | 0, 95 | 0, 96 | 0, 97 | 1 98 | ], 99 | [ 100 | 0, 101 | 0, 102 | 0, 103 | 2 104 | ], 105 | [ 106 | 0, 107 | 0, 108 | 0, 109 | 0, 110 | 0 111 | ], 112 | [ 113 | 0, 114 | 0, 115 | 0, 116 | 0, 117 | 1 118 | ] 119 | ] 120 | } -------------------------------------------------------------------------------- /evaluation/model/sam_only/config/token_recycle.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [ 4 | 1, 5 | 2, 6 | 3, 7 | 4, 8 | 5, 9 | 6, 10 | 7 11 | ], 12 | "1": [ 13 | 8, 14 | 9, 15 | 10, 16 | 11, 17 | 12, 18 | 13 19 | ], 20 | "2": [ 21 | 14, 22 | 15, 23 | 16, 24 | 17, 25 | 18 26 | ], 27 | "3": [ 28 | 19, 29 | 20, 30 | 21 31 | ], 32 | "4": [ 33 | 22, 34 | 23 35 | ], 36 | "5": [ 37 | 24, 38 | 25 39 | ], 40 | "6": [ 41 | 26 42 | ], 43 | "7": [ 44 | 27 45 | ], 46 | "8": [ 47 | 28, 48 | 29, 49 | 30 50 | ], 51 | "13": [ 52 | ], 53 | "9": [ 54 | 31, 55 | 32 56 | ], 57 | "10": [ 58 | 33, 59 | 34 60 | ], 61 | "11": [ 62 | 35 63 | ], 64 | "12": [ 65 | 36 66 | ], 67 | "14": [ 68 | 37, 69 | 38, 70 | 39 71 | ], 72 | "15": [ 73 | 40, 74 | 41 75 | ], 76 | "16": [ 77 | 42 78 | ], 79 | "17": [ 80 | 43 81 | ], 82 | "18": [ 83 | ], 84 | "19": [ 85 | 44 86 | ], 87 | "20": [ 88 | 45 89 | ], 90 | "21": [ 91 | ], 92 | "22": [ 93 | 46 94 | ], 95 | "23": [ 96 | ], 97 | "24": [ 98 | 47 99 | ], 100 | "25": [ 101 | ], 102 | "26": [ 103 | 48 104 | ], 105 | "27": [ 106 | ], 107 | "28": [ 108 | 49, 109 | 50 110 | ], 111 | "29": [ 112 | 51 113 | ], 114 | "30": [ 115 | ], 116 | "31": [ 117 | 52 118 | ], 119 | "32": [], 120 | "33": [ 121 | ], 122 | "34": [ 123 | ], 124 | "35": [ 125 | ], 126 | "36": [], 127 | "37": [ 128 | 53, 129 | 54 130 | ], 131 | "38": [ 132 | ], 133 | "39": [], 134 | "40": [ 135 | 55 136 | ], 137 | "41": [], 138 | "42": [], 139 | "43": [ 140 | ], 141 | "44": [ 142 | 56 143 | ], 144 | "45": [], 145 | "46": [ 146 | ], 147 | "47": [], 148 | "48": [], 149 | "49": [ 150 | 57, 151 | 58 152 | ], 153 | "50": [], 154 | "51": [ 155 | 59 156 | ], 157 | "52": [ 158 | ], 159 | "53": [ 160 | 60 161 | ], 162 | "54": [ 163 | ], 164 | "55": [ 165 | ], 166 | "56": [], 167 | "57": [], 168 | "58": [], 169 | "59": [], 170 | "60": [] 171 | } 172 | } -------------------------------------------------------------------------------- /evaluation/model/sam_only/draft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict, Optional 3 | from enum import Enum 4 | from collections import namedtuple 5 | 6 | from .samd_config import SamdConfig 7 | from .sam import DynSAM, StaticSAM 8 | from profile_utils import profile_decorator 9 | from transformers import LlamaConfig, LlamaForCausalLM 10 | 11 | # from transformers import LlamaTokenizer 12 | # tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained('/data/models/vicuna-7b-v1.3') 13 | 14 | class CandidateType(str, Enum): 15 | sequence = "sequence" 16 | tree = "tree" 17 | 18 | Candidates = namedtuple('Candidates', ['type', 'tokens', 'candidate_tokens', 'buffers_kwargs']) 19 | 20 | TOPK = 8 21 | 22 | class DraftModel(torch.nn.Module): 23 | 24 | def __init__(self, 25 | config: SamdConfig, 26 | sam_dyn: DynSAM = None, 27 | sam_static: StaticSAM = None, 28 | lm: LlamaForCausalLM = None, 29 | dtype: torch.dtype = torch.float16, 30 | device: str = "cuda", 31 | ) -> None: 32 | super().__init__() 33 | self.config = config 34 | self.device = device 35 | self.sam_dyn = sam_dyn if sam_dyn is not None else DynSAM(config.max_predicts) 36 | self.sam_static = sam_static if sam_static is not None else StaticSAM(config.max_predicts) 37 | 38 | self.sam_dyn.max_predicts = config.max_predicts 39 | self.sam_static.max_predicts = config.max_predicts 40 | self.len_bias = config.len_bias 41 | 42 | @profile_decorator("DraftModel.reset") 43 | def reset(self): 44 | self.sam_dyn.reset() 45 | if self.sam_static is not None: 46 | self.sam_static.reset() 47 | 48 | @profile_decorator("DraftModel.lookup") 49 | def lookup(self, start_token: int): 50 | pred_dyn, match_dyn = self.sam_dyn.lookup(start_token) 51 | pred_static, match_static = self.sam_static.lookup(start_token) 52 | match_static -= self.len_bias 53 | if match_dyn >= match_static: 54 | pred, len = pred_dyn, match_dyn 55 | else: 56 | pred, len = pred_static, match_static 57 | pred = pred[:int(self.config.alpha * len)] 58 | position_ids = torch.arange(0, len(pred) + 1, dtype=torch.long, device=self.device).unsqueeze(0) 59 | return (CandidateType.sequence, [start_token] + pred, {"seq_position_ids": position_ids}) 60 | 61 | @profile_decorator("DraftModel.update") 62 | def update(self, 63 | tokens: Optional[torch.Tensor] = None, 64 | ): 65 | tokens_list = tokens.tolist() 66 | self.sam_dyn.add_tokens(tokens_list) 67 | self.sam_static.transfer_tokens(tokens_list) 68 | 69 | @profile_decorator("DraftModel.prefill_update") 70 | def prefill_update(self, 71 | tokens: Optional[torch.Tensor] = None, 72 | ): 73 | self.update(tokens) 74 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/model_patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import llama_patch_dict, llama_attn_patch_dict 2 | 3 | patch_dict = {} 4 | attn_patch_dict = {} 5 | 6 | patch_dict.update(llama_patch_dict) 7 | attn_patch_dict.update(llama_attn_patch_dict) 8 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/sam/__init__.py: -------------------------------------------------------------------------------- 1 | from .sam import DynSAM, StaticSAM 2 | from .utils import build_sam, dump_sam, load_sam 3 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/sam/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict 3 | from dataclasses import dataclass 4 | from copy import deepcopy 5 | from collections import deque 6 | from tqdm import tqdm 7 | 8 | class SAM: 9 | 10 | @dataclass 11 | class SAMState: 12 | next: dict[int, int] 13 | link: int 14 | length: int 15 | min_endpos: int 16 | 17 | def __init__(self, n_predicts: int = 40): 18 | self.max_predicts = n_predicts 19 | self.states: List[SAM.SAMState] = [SAM.SAMState(next={}, link=-1, length=0, min_endpos=0)] 20 | self.input_ids: List[int] = [-1] 21 | self.last = 0 22 | self.max_length = 0 23 | 24 | # params needed to be reset for each query 25 | self.cur_index = 0 26 | self.cur_length = 0 27 | 28 | def reset(self): 29 | raise NotImplementedError 30 | 31 | def expand_state(self, state: SAMState): 32 | new_index = len(self.states) 33 | self.states.append(state) 34 | return new_index 35 | 36 | def add_state(self, token: int): 37 | self.max_length += 1 38 | cur = self.expand_state( 39 | SAM.SAMState( 40 | next={}, link=-1, 41 | length=self.max_length, 42 | min_endpos=self.max_length 43 | ) 44 | ) 45 | p = self.last 46 | while p != -1 and token not in self.states[p].next: 47 | self.states[p].next[token] = cur 48 | p = self.states[p].link 49 | if p == -1: 50 | self.states[cur].link = 0 51 | else: 52 | q = self.states[p].next[token] 53 | if self.states[p].length + 1 == self.states[q].length: 54 | self.states[cur].link = q 55 | else: 56 | clone = self.expand_state(deepcopy(self.states[q])) 57 | self.states[clone].length = self.states[p].length + 1 58 | while p != -1 and self.states[p].next[token] == q: 59 | self.states[p].next[token] = clone 60 | p = self.states[p].link 61 | self.states[q].link = self.states[cur].link = clone 62 | self.last = cur 63 | 64 | def transfer_state(self, index: int, length: int, token: int): 65 | while index != 0 and token not in self.states[index].next: 66 | index = self.states[index].link 67 | length = self.states[index].length 68 | if token in self.states[index].next: 69 | index = self.states[index].next[token] 70 | length += 1 71 | else: 72 | index = length = 0 73 | return index, length 74 | 75 | def transfer_cur_state(self, token: int): 76 | self.cur_index, self.cur_length = \ 77 | self.transfer_state(self.cur_index, self.cur_length, token) 78 | 79 | def to_anc(self, index: int, length: int): 80 | length_to_end = self.max_length - self.states[index].min_endpos 81 | while index != 0 and self.max_predicts > length_to_end: 82 | index = self.states[index].link 83 | length = self.states[index].length 84 | length_to_end = self.max_length - self.states[index].min_endpos 85 | return index, length 86 | 87 | def add_tokens(self, tokens: List[int]): 88 | for token in tokens: 89 | self.add_state(token) 90 | self.transfer_cur_state(token) 91 | self.input_ids.extend(tokens) 92 | self.cur_index, self.cur_length = \ 93 | self.to_anc(self.cur_index, self.cur_length) 94 | 95 | def transfer_tokens(self, tokens: List[int]): 96 | for token in tokens: 97 | self.transfer_cur_state(token) 98 | self.cur_index, self.cur_length = \ 99 | self.to_anc(self.cur_index, self.cur_length) 100 | 101 | def lookup(self, token: int): 102 | index, length = \ 103 | self.transfer_state(self.cur_index, self.cur_length, token) 104 | index, length = \ 105 | self.to_anc(index, length) 106 | endpos = self.states[index].min_endpos 107 | pred_ids = self.input_ids[endpos + 1:endpos + self.max_predicts + 1] 108 | if len(pred_ids) < self.max_predicts: 109 | pred_ids.extend([0] * (self.max_predicts - len(pred_ids))) 110 | return pred_ids, length 111 | 112 | 113 | class DynSAM(SAM): 114 | 115 | def reset(self): 116 | self.states: List[SAM.SAMState] = [SAM.SAMState(next={}, link=-1, length=0, min_endpos=0)] 117 | self.input_ids: List[int] = [-1] 118 | self.last = 0 119 | self.max_length = 0 120 | self.cur_index = 0 121 | self.cur_length = 0 122 | 123 | 124 | class StaticSAM(SAM): 125 | 126 | def reset(self): 127 | self.cur_index = 0 128 | self.cur_length = 0 129 | 130 | def add_batch_tokens(self, batch_tokens: List[List[int]], eos_token: int, verbose: bool): 131 | for tokens in tqdm(batch_tokens, desc="build sam...", disable=not verbose): 132 | self.add_tokens(tokens) 133 | if tokens[-1] != eos_token: 134 | self.add_tokens([eos_token]) 135 | 136 | @staticmethod 137 | def build( 138 | batch_tokens: List[List[int]], 139 | eos_token: int, 140 | n_predict: int, 141 | verbose: bool =True 142 | ): 143 | sam = StaticSAM(n_predict) 144 | sam.add_batch_tokens(batch_tokens, eos_token, verbose) 145 | return sam 146 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/sam/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | from datasets import Dataset 4 | from transformers import PreTrainedTokenizerFast 5 | from typing import List 6 | 7 | from .sam import StaticSAM 8 | from ..samd_config import SamdConfig 9 | 10 | def build_sam( 11 | config: SamdConfig, 12 | batch_tokens: List[List[int]], 13 | eos_token: int, 14 | ): 15 | sam = StaticSAM.build( 16 | batch_tokens, 17 | eos_token, 18 | config.max_predicts 19 | ) 20 | return sam 21 | 22 | def dump_sam(path: str, sam: StaticSAM): 23 | with open(path, "wb") as f: 24 | pickle.dump(sam, f) 25 | 26 | def load_sam(path: str): 27 | print("load sam...") 28 | start = time.perf_counter() 29 | with open(path, "rb") as f: 30 | _sam = pickle.load(f) 31 | sam = StaticSAM() 32 | for key, value in vars(_sam).items(): 33 | if hasattr(sam, key): 34 | setattr(sam, key, value) 35 | print("load [{}]".format(key)) 36 | end = time.perf_counter() 37 | assert type(sam) is StaticSAM 38 | print("loading ended in {} seconds.".format(end - start)) 39 | return sam 40 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/samd_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from dataclasses import dataclass, field 5 | from typing import Optional, Union, List, Literal, Dict, Any 6 | from enum import Enum 7 | 8 | 9 | @dataclass 10 | class SamdConfig: 11 | max_predicts: int = field(default=40) 12 | alpha: float = field(default=4.0) 13 | len_bias: int = field(default=5) 14 | 15 | 16 | class ForwardType(str, Enum): 17 | prefill = "prefill" 18 | seq_decode = "seq_decode" 19 | tree_decode = "tree_decode" 20 | 21 | 22 | class ForwardState: 23 | 24 | def __init__(self, forward_type: ForwardType | None) -> None: 25 | self.forward_type = forward_type 26 | 27 | 28 | class MaskState: 29 | 30 | def __init__(self, mask: Optional[torch.Tensor]) -> None: 31 | self.mask = mask 32 | 33 | def set_state(self, mask: Optional[torch.Tensor]) -> None: 34 | self.mask = mask 35 | 36 | 37 | def load_token_recycle(tree_path: Optional[str] = None): 38 | if tree_path is None: 39 | tree_path = "token_recycle.json" 40 | samd_path = os.path.dirname(__file__) 41 | with open(os.path.join(samd_path, "config", tree_path), "r") as f: 42 | tree_adj: dict = json.load(f)["tree_adj"] 43 | num_node = len(tree_adj) 44 | tree: List[List[int]] = [] 45 | for i in range(num_node): 46 | tree.append(tree_adj[str(i)]) 47 | print("tree_path:", tree_path) 48 | print("len_tree:", len(tree)) 49 | return tree 50 | 51 | 52 | def load_eagle(tree_model_path: str, tree_path: Optional[str] = None): 53 | if tree_path is None: 54 | tree_path = "eagle.json" 55 | samd_path = os.path.dirname(__file__) 56 | with open(os.path.join(samd_path, "config", tree_path), "r") as f: 57 | tree = json.load(f)["tree_choices"] 58 | with open(os.path.join(tree_model_path, "config.json")) as f: 59 | tree_config = json.load(f) 60 | return tree, tree_config 61 | 62 | 63 | def load_eagle2(tree_model_path: str): 64 | with open(os.path.join(tree_model_path, "config.json")) as f: 65 | tree_config = json.load(f) 66 | return tree_config 67 | -------------------------------------------------------------------------------- /evaluation/model/sam_only/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from enum import Enum 4 | from typing import List, Dict, Optional, Callable 5 | from dataclasses import dataclass, field 6 | from collections import namedtuple 7 | 8 | from .samd_config import SamdConfig 9 | from .draft import DraftModel, Candidates, CandidateType 10 | 11 | from profile_utils import profile_decorator 12 | 13 | 14 | class SamplingMethods(str, Enum): 15 | typical = "typical" 16 | nucleus = "nucleus" 17 | 18 | 19 | class OptionalTensor: 20 | 21 | def __init__(self, data: Optional[torch.Tensor] = None): 22 | self.data = data 23 | 24 | def apply(self, fn: Callable) -> 'OptionalTensor': 25 | if self.data is None: 26 | return OptionalTensor(None) 27 | else: 28 | return OptionalTensor(fn(self.data)) 29 | 30 | @dataclass 31 | class SamdGenerationConfig: 32 | max_new_tokens: int = field(default=512) 33 | max_cache_len: int = field(default=2048) 34 | 35 | 36 | @profile_decorator("gen_candidates") 37 | def gen_candidates( 38 | logits: torch.Tensor, 39 | tree_retrieve_indices: torch.Tensor, 40 | draft: DraftModel, 41 | samd_config: SamdConfig, 42 | gen_config: SamdGenerationConfig, 43 | device: torch.device, 44 | ): 45 | """ 46 | Generate candidates based on provided logits and indices. 47 | 48 | Parameters: 49 | - ... 50 | 51 | Returns: 52 | - tuple (torch.Tensor, List[int]): ... 53 | """ 54 | # Greedy decoding: Select the most probable candidate from the original logits. 55 | start_token = torch.argmax(logits[:, -1]).item() 56 | candidate_type, tokens, buffers_kwargs = draft.lookup(start_token) 57 | tree_retrieve_indices = buffers_kwargs.get("tree_retrieve_indices", tree_retrieve_indices) 58 | tokens = torch.tensor([tokens], dtype=torch.long, device=device) 59 | candidate_tokens = tokens 60 | 61 | return Candidates( 62 | candidate_type, 63 | tokens, 64 | candidate_tokens, 65 | buffers_kwargs, 66 | ) 67 | 68 | 69 | @profile_decorator("eval_posterior") 70 | def eval_posterior( 71 | logits: torch.Tensor, 72 | candidates: torch.Tensor, 73 | config: SamdGenerationConfig, 74 | ): 75 | """ 76 | Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. 77 | 78 | Depending on the temperature value, the function either uses greedy decoding or evaluates posterior 79 | probabilities to select the best candidate. 80 | 81 | Args: 82 | - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). 83 | - candidates (torch.Tensor): Candidate token sequences. 84 | 85 | Returns: 86 | - best_candidate (torch.Tensor): Index of the chosen best candidate. 87 | - accept_length (int): Length of the accepted candidate sequence. 88 | """ 89 | # Greedy decoding based on temperature value 90 | # Find the tokens that match the maximum logits for each position in the sequence 91 | # posterior_mask = ( 92 | # candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1) 93 | # ).int() 94 | # candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) 95 | accept_length = ((~(candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1))).cumsum(dim=-1) < 1).sum() 96 | best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) 97 | return best_candidate, accept_length + 1 98 | -------------------------------------------------------------------------------- /evaluation/model/token_recycle/__init__.py: -------------------------------------------------------------------------------- 1 | from .token_recycle_config import TokenRecycleConfig 2 | from .token_recycle_model import TokenRecycleModel 3 | from .utils import TokenRecycleGenerationConfig 4 | from .draft import DraftModel -------------------------------------------------------------------------------- /evaluation/model/token_recycle/attn_patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import llama_patch_dict 2 | 3 | attn_patch_dict = {} 4 | 5 | attn_patch_dict.update(llama_patch_dict) 6 | -------------------------------------------------------------------------------- /evaluation/model/token_recycle/attn_patch/llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.models.llama.modeling_llama import ( 3 | LlamaModel, 4 | Cache, 5 | StaticCache, 6 | AttentionMaskConverter 7 | ) 8 | from ..token_recycle_config import ForwardType 9 | 10 | try: 11 | from transformers.models.llama.modeling_llama import ( 12 | _prepare_4d_causal_attention_mask_with_cache_position 13 | ) 14 | except: 15 | _prepare_4d_causal_attention_mask_with_cache_position = LlamaModel._prepare_4d_causal_attention_mask_with_cache_position 16 | 17 | 18 | def _update_causal_mask( 19 | self, 20 | attention_mask: torch.Tensor, 21 | input_tensor: torch.Tensor, 22 | cache_position: torch.Tensor, 23 | past_key_values: Cache, 24 | output_attentions: bool, 25 | ): 26 | # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static 27 | # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. 28 | # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using 29 | # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 30 | 31 | if self.config._attn_implementation == "flash_attention_2": 32 | if attention_mask is not None and 0.0 in attention_mask: 33 | return attention_mask 34 | return None 35 | 36 | # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in 37 | # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail 38 | # to infer the attention mask. 39 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 40 | using_static_cache = isinstance(past_key_values, StaticCache) 41 | 42 | # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward 43 | if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: 44 | if AttentionMaskConverter._ignore_causal_mask_sdpa( 45 | attention_mask, 46 | inputs_embeds=input_tensor, 47 | past_key_values_length=past_seen_tokens, 48 | is_training=self.training, 49 | ): 50 | return None 51 | 52 | dtype, device = input_tensor.dtype, input_tensor.device 53 | min_dtype = torch.finfo(dtype).min 54 | sequence_length = input_tensor.shape[1] 55 | if using_static_cache: 56 | target_length = past_key_values.get_max_length() 57 | else: 58 | target_length = ( 59 | attention_mask.shape[-1] 60 | if isinstance(attention_mask, torch.Tensor) 61 | else past_seen_tokens + sequence_length + 1 62 | ) 63 | 64 | # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). 65 | causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( 66 | attention_mask, 67 | sequence_length=sequence_length, 68 | target_length=target_length, 69 | dtype=dtype, 70 | device=device, 71 | min_dtype=min_dtype, 72 | cache_position=cache_position, 73 | batch_size=input_tensor.shape[0], 74 | ) 75 | 76 | # assert hasattr(self, "samd_attn_mask") and hasattr(self, "forward_state") 77 | if self.forward_state.forward_type == ForwardType.tree_decode: 78 | samd_attn_mask: torch.Tensor = self.tree_attn_mask.to(causal_mask.device) 79 | causal_mask[:, :, :, cache_position] = causal_mask.min() * (samd_attn_mask == 0) 80 | # if self.forward_state.forward_type == ForwardType.seq_decode: 81 | # # do nothing for seq_decode 82 | 83 | if ( 84 | self.config._attn_implementation == "sdpa" 85 | and attention_mask is not None 86 | and attention_mask.device.type == "cuda" 87 | and not output_attentions 88 | ): 89 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when 90 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 91 | # Details: https://github.com/pytorch/pytorch/issues/110213 92 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) 93 | 94 | return causal_mask 95 | 96 | 97 | llama_patch_dict = { 98 | LlamaModel: [("_update_causal_mask", _update_causal_mask)] 99 | } 100 | -------------------------------------------------------------------------------- /evaluation/model/token_recycle/cache.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers import PretrainedConfig 4 | from transformers.cache_utils import DynamicCache, Cache 5 | from typing import Optional, Dict, Any, Tuple 6 | from profile_utils import profile_decorator 7 | 8 | class SamdCache(DynamicCache): 9 | 10 | def __init__(self, num_hidden_layers: int | None = None) -> None: 11 | super().__init__(num_hidden_layers) 12 | self.cache_length = 0 13 | 14 | def set_length(self): 15 | self.cache_length = self.get_seq_length() 16 | 17 | @profile_decorator("SamdCache.select_indices") 18 | def select_indices(self, 19 | indices: torch.Tensor | None = None, 20 | accept_length: int = 1, 21 | ): 22 | start = self.cache_length 23 | if indices is not None: 24 | select_indices = start + indices 25 | else: 26 | select_indices = None 27 | for data in self.key_cache + self.value_cache: 28 | if select_indices is not None: 29 | select_indices = select_indices.to(data.device) 30 | tgt = data.index_select(-2, select_indices) 31 | dst = data.narrow(-2, start, accept_length) 32 | dst.copy_(tgt) 33 | self.cache_length += accept_length 34 | self.crop(self.cache_length) 35 | 36 | 37 | class SamdStaticCache(Cache): 38 | 39 | def __init__(self, 40 | num_hidden_layers: int, 41 | num_attention_heads: int, 42 | num_key_value_heads: int, 43 | hidden_size: int, 44 | max_batch_size: int, 45 | max_cache_len: int, 46 | device=None, 47 | dtype=None 48 | ) -> None: 49 | super().__init__() 50 | self.max_batch_size = max_batch_size 51 | self.max_cache_len = max_cache_len 52 | self.cur_length = 0 53 | self.new_length = 0 54 | self.kv_data = torch.zeros( 55 | num_hidden_layers * 2, 56 | max_batch_size, 57 | num_key_value_heads, 58 | max_cache_len, 59 | hidden_size // num_attention_heads, 60 | device=device, 61 | dtype=dtype 62 | ) 63 | self.devcie = device 64 | self.dtype = dtype 65 | 66 | def get_seq_length(self, layer_idx: int | None = 0) -> int: 67 | return self.cur_length 68 | 69 | def get_max_cache_shape(self) -> int: 70 | return self.max_cache_len 71 | 72 | def reorder_cache(self, beam_idx): 73 | raise NotImplementedError 74 | 75 | def reset(self): 76 | self.cur_length = 0 77 | self.new_length = 0 78 | 79 | def update( 80 | self, 81 | key_states: torch.Tensor, 82 | value_states: torch.Tensor, 83 | layer_idx: int, 84 | cache_kwargs: Optional[Dict[str, Any]] = None, 85 | ) -> Tuple[torch.Tensor, torch.Tensor]: 86 | start = self.cur_length 87 | length = key_states.shape[-2] 88 | self.new_length = start + length 89 | self.kv_data[2 * layer_idx + 0]\ 90 | .narrow(-2, start, length)\ 91 | .copy_(key_states) 92 | self.kv_data[2 * layer_idx + 1]\ 93 | .narrow(-2, start, length)\ 94 | .copy_(value_states) 95 | 96 | k_out = self.kv_data[2 * layer_idx + 0].narrow(-2, 0, start + length) 97 | v_out = self.kv_data[2 * layer_idx + 1].narrow(-2, 0, start + length) 98 | return k_out, v_out 99 | 100 | def select_indices(self, 101 | indices: torch.Tensor | None = None, 102 | accept_length: int = 1, 103 | ): 104 | start = self.cur_length 105 | if indices is not None: 106 | select_indices = start + indices 107 | else: 108 | select_indices = None 109 | if select_indices is not None: 110 | tgt = self.kv_data.index_select(-2, select_indices) 111 | dst = self.kv_data.narrow(-2, start, accept_length) 112 | dst.copy_(tgt) 113 | self.cur_length += accept_length 114 | 115 | def set_length(self): 116 | self.cur_length = self.new_length 117 | -------------------------------------------------------------------------------- /evaluation/model/token_recycle/config/default_tree.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [], 46 | "9": [], 47 | "10": [], 48 | "11": [], 49 | "12": [], 50 | "14": [], 51 | "15": [], 52 | "16": [], 53 | "17": [], 54 | "18": [], 55 | "19": [], 56 | "20": [], 57 | "22": [], 58 | "23": [] 59 | } 60 | -------------------------------------------------------------------------------- /evaluation/model/token_recycle/config/default_tree_80.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8 11 | ], 12 | "1": [ 13 | 9, 14 | 10, 15 | 11, 16 | 12, 17 | 13, 18 | 14, 19 | 15, 20 | 16 21 | ], 22 | "2": [ 23 | 17, 24 | 18, 25 | 19, 26 | 20 27 | ], 28 | "3": [ 29 | 21, 30 | 22, 31 | 23 32 | ], 33 | "4": [ 34 | 24, 35 | 25 36 | ], 37 | "5": [ 38 | 26 39 | ], 40 | "6": [ 41 | 27 42 | ], 43 | "7": [ 44 | 28 45 | ], 46 | "8": [ 47 | 29 48 | ], 49 | "9": [ 50 | 30, 51 | 31, 52 | 32, 53 | 33, 54 | 34, 55 | 35, 56 | 36, 57 | 37 58 | ], 59 | "10": [ 60 | 38, 61 | 39, 62 | 40 63 | ], 64 | "11": [ 65 | 41, 66 | 42 67 | ], 68 | "12": [ 69 | 43 70 | ], 71 | "13": [ 72 | 44 73 | ], 74 | "14": [ 75 | 45 76 | ], 77 | "15": [ 78 | 46 79 | ], 80 | "16": [ 81 | 47 82 | ], 83 | "17": [ 84 | 48, 85 | 49 86 | ], 87 | "18": [ 88 | 50 89 | ], 90 | "21": [ 91 | 51 92 | ], 93 | "24": [ 94 | 52 95 | ], 96 | "26": [ 97 | 53 98 | ], 99 | "27": [ 100 | 54 101 | ], 102 | "30": [ 103 | 55, 104 | 56, 105 | 57, 106 | 58, 107 | 59 108 | ], 109 | "31": [ 110 | 60, 111 | 61 112 | ], 113 | "32": [ 114 | 62 115 | ], 116 | "19": [], 117 | "20": [], 118 | "22": [], 119 | "23": [], 120 | "25": [], 121 | "28": [], 122 | "29": [], 123 | "33": [ 124 | 63 125 | ], 126 | "34": [ 127 | 64 128 | ], 129 | "35": [], 130 | "36": [], 131 | "37": [], 132 | "38": [ 133 | 65 134 | ], 135 | "39": [], 136 | "40": [], 137 | "41": [ 138 | 66 139 | ], 140 | "42": [], 141 | "43": [], 142 | "44": [], 143 | "45": [], 144 | "46": [], 145 | "47": [], 146 | "48": [ 147 | 67 148 | ], 149 | "49": [], 150 | "50": [], 151 | "51": [ 152 | 68 153 | ], 154 | "52": [ 155 | 69 156 | ], 157 | "53": [], 158 | "54": [], 159 | "55": [ 160 | 70, 161 | 71, 162 | 72 163 | ], 164 | "56": [ 165 | 73 166 | ], 167 | "57": [ 168 | 74 169 | ], 170 | "58": [], 171 | "59": [], 172 | "60": [ 173 | 75 174 | ], 175 | "61": [], 176 | "62": [], 177 | "63": [], 178 | "64": [], 179 | "65": [ 180 | 76 181 | ], 182 | "66": [], 183 | "67": [ 184 | 77 185 | ], 186 | "68": [], 187 | "69": [], 188 | "70": [ 189 | 78, 190 | 79 191 | ], 192 | "71": [], 193 | "72": [], 194 | "73": [ 195 | 80 196 | ], 197 | "74": [], 198 | "75": [], 199 | "76": [], 200 | "77": [], 201 | "78": [], 202 | "79": [], 203 | "80": [] 204 | } -------------------------------------------------------------------------------- /evaluation/model/token_recycle/draft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict 3 | from enum import Enum 4 | from collections import namedtuple 5 | 6 | from .token_recycle_config import TokenRecycleConfig 7 | from .token_recycle import TokenRecycle 8 | from profile_utils import profile_decorator 9 | 10 | # from transformers import LlamaTokenizer 11 | # tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained('/data/models/vicuna-7b-v1.3') 12 | 13 | class CandidateType(str, Enum): 14 | sequence = "sequence" 15 | tree = "tree" 16 | 17 | Candidates = namedtuple('Candidates', ['type', 'tokens', 'candidate_tokens']) 18 | 19 | class DraftModel: 20 | 21 | def __init__(self, 22 | config: TokenRecycleConfig, 23 | tree_model: TokenRecycle | None = None 24 | ) -> None: 25 | self.config = config 26 | self.tree_model = tree_model if tree_model is not None else TokenRecycle(config.tree) 27 | 28 | def reset(self): 29 | self.tree_model.reset() 30 | 31 | def lookup(self, start_token: int): 32 | return CandidateType.tree, self.tree_model.lookup(start_token) 33 | 34 | @profile_decorator("draft.update") 35 | def update(self, tokens: List[int], topk_nest: List[List[int]]): 36 | self.tree_model.update(tokens, topk_nest) 37 | -------------------------------------------------------------------------------- /evaluation/model/token_recycle/token_recycle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict 3 | from dataclasses import dataclass 4 | from copy import deepcopy 5 | from collections import deque 6 | from tqdm import tqdm 7 | 8 | from .token_recycle_config import TokenRecycleConfig 9 | 10 | class TokenRecycle: 11 | 12 | def __init__(self, 13 | tree: List[List[int]] 14 | ) -> None: 15 | self.tree = tree 16 | self.cache = {} 17 | 18 | def reset(self): 19 | pass # do nothting 20 | 21 | def update(self, tokens: List[int], topk_nest: List[List[int]]): 22 | for token, topk in zip(tokens, topk_nest): 23 | self.cache[token] = topk 24 | 25 | def lookup(self, start_token: int) -> List[int]: 26 | tree_tokens = [start_token] + [0] * (len(self.tree) - 1) 27 | for node_id, childs in enumerate(self.tree): 28 | token = tree_tokens[node_id] 29 | if token not in self.cache: 30 | continue 31 | topk = self.cache[token] 32 | for child_id, child in enumerate(childs): 33 | tree_tokens[child] = topk[child_id] 34 | return tree_tokens 35 | -------------------------------------------------------------------------------- /evaluation/model/token_recycle/token_recycle_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, Union, List, Literal 3 | from enum import Enum 4 | 5 | @dataclass 6 | class TokenRecycleConfig: 7 | n_predicts: int = field(default=10) 8 | tree: Optional[List[List[int]]] = field(default=None) 9 | 10 | def __post_init__(self): 11 | if self.tree is None: 12 | self.tree = load_default_tree() 13 | 14 | 15 | class ForwardType(str, Enum): 16 | prefill = "prefill" 17 | seq_decode = "seq_decode" 18 | tree_decode = "tree_decode" 19 | 20 | class ForwardState: 21 | 22 | def __init__(self, forward_type: ForwardType | None) -> None: 23 | self.forward_type = forward_type 24 | 25 | 26 | def load_default_tree(): 27 | import os 28 | import json 29 | samd_path = os.path.dirname(__file__) 30 | with open(os.path.join(samd_path, "config", "default_tree_80.json"), "r") as f: 31 | tree_dict: dict = json.load(f) 32 | num_node = len(tree_dict) 33 | tree: List[List[int]] = [] 34 | for i in range(num_node): 35 | tree.append(tree_dict[str(i)]) 36 | return tree 37 | -------------------------------------------------------------------------------- /profile_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | from collections import defaultdict 3 | from functools import wraps 4 | import time 5 | import pandas as pd 6 | 7 | fn_dict: Dict[str, List[float]] = defaultdict(lambda: list()) 8 | lookup_dict: Dict[str, List[str]] = defaultdict(lambda: list()) 9 | accept_dict: Dict[str, List[int]] = defaultdict(lambda: list()) 10 | 11 | decorator_flag: bool = False 12 | 13 | def enable_decorator(mode: bool): 14 | global decorator_flag 15 | decorator_flag = mode 16 | 17 | def clear_dict(): 18 | fn_dict.clear() 19 | lookup_dict.clear() 20 | 21 | def profile_decorator(fn_name: str): 22 | def decorator(fn): 23 | @wraps(fn) 24 | def wrapper(*args, **kwargs): 25 | if decorator_flag: 26 | start_time = time.perf_counter() 27 | result = fn(*args, **kwargs) 28 | end_time = time.perf_counter() 29 | fn_dict[fn_name].append(end_time - start_time) 30 | else: 31 | result = fn(*args, **kwargs) 32 | return result 33 | return wrapper 34 | return decorator 35 | 36 | 37 | def profile_lookup_decorator(fn_name: str): 38 | def decorator(fn): 39 | @wraps(fn) 40 | def wrapper(*args, **kwargs): 41 | if decorator_flag: 42 | result = fn(*args, **kwargs) 43 | lookup_dict[fn_name].append(result[0]) 44 | else: 45 | result = fn(*args, **kwargs) 46 | return result 47 | return wrapper 48 | return decorator 49 | 50 | def profile_accept_length(name: str, length: int): 51 | if decorator_flag: 52 | accept_dict[name].append(length) 53 | 54 | def export_result(root_name: str = "forward"): 55 | result = [] 56 | if len(fn_dict) == 0: 57 | return None 58 | for name, value in fn_dict.items(): 59 | print("name: {}, len(value): {}".format(name, len(value))) 60 | result.append(( 61 | name, sum(value) 62 | )) 63 | result_dict = dict(result) 64 | sum_time = result_dict.get(root_name, max(result_dict.values())) 65 | if sum_time is None: 66 | sum_time = max(result_dict.values()) 67 | df = pd.DataFrame(result, columns=["name", "time"]) 68 | df["ratio"] = df["time"] / sum_time 69 | return df.to_string() 70 | 71 | def export_lookup_result(): 72 | import json 73 | result1 = {} 74 | result2 = {} 75 | for name, type_names in lookup_dict.items(): 76 | result1[name] = {} 77 | result2[name] = {} 78 | for type_name in type_names: 79 | if type_name not in result1[name]: 80 | result1[name][type_name] = 0 81 | result1[name][type_name] += 1 82 | for type_name, length in zip(type_names, accept_dict[name]): 83 | if type_name not in result2[name]: 84 | result2[name][type_name] = 0 85 | result2[name][type_name] += length 86 | for key in result1[name].keys(): 87 | result2[name][key] /= result1[name][key] 88 | return json.dumps({"result-1": result1, "result-2": result2}, indent=4, ensure_ascii=False) 89 | -------------------------------------------------------------------------------- /sam_data/list.txt: -------------------------------------------------------------------------------- 1 | alpaca-cleand 2 | python_code_instructions_18k_alpaca 3 | gsm8k 4 | -------------------------------------------------------------------------------- /samd/__init__.py: -------------------------------------------------------------------------------- 1 | from .samd_config import SamdConfig 2 | from .samd_model import SamdModel 3 | from .utils import SamdGenerationConfig 4 | from .sam import build_sam, load_sam, dump_sam 5 | from .draft import DraftModel -------------------------------------------------------------------------------- /samd/cache.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers import PretrainedConfig, LlamaForCausalLM 4 | from transformers.cache_utils import DynamicCache, StaticCache, Cache 5 | from typing import Optional, Dict, Any, Tuple, List 6 | from profile_utils import profile_decorator 7 | 8 | class SamdCache(DynamicCache): 9 | 10 | def __init__(self, num_hidden_layers: int | None = None) -> None: 11 | super().__init__(num_hidden_layers) 12 | self.cache_length = 0 13 | 14 | def set_length(self): 15 | self.cache_length = self.get_seq_length() 16 | 17 | # @profile_decorator("SamdCache.select_indices") 18 | def select_indices(self, 19 | indices: torch.Tensor | None = None, 20 | accept_length: int = 1, 21 | ): 22 | start = self.cache_length 23 | if indices is not None: 24 | select_indices = start + indices 25 | else: 26 | select_indices = None 27 | for data in self.key_cache + self.value_cache: 28 | if select_indices is not None: 29 | select_indices = select_indices.to(data.device) 30 | tgt = data.index_select(-2, select_indices) 31 | dst = data.narrow(-2, start, accept_length) 32 | dst.copy_(tgt) 33 | self.cache_length += accept_length 34 | self.crop(self.cache_length) 35 | 36 | 37 | class SamdStaticCache(Cache): 38 | 39 | def __init__(self, 40 | config, 41 | batch_size = None, 42 | max_cache_len = None, 43 | device = None, 44 | dtype = torch.float32, 45 | max_batch_size = None, 46 | hf_device_map = None, 47 | ): 48 | super().__init__() 49 | if len(hf_device_map) <= 1: 50 | device = device 51 | layer_device_map = None 52 | else: 53 | device = None 54 | layer_device_map = {} 55 | for i in range(config.num_hidden_layers): 56 | layer_device_map[i] = hf_device_map[f"model.layers.{i}"] 57 | self.batch_size = batch_size or max_batch_size 58 | self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len 59 | 60 | # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads 61 | self.head_dim = ( 62 | config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads 63 | ) 64 | 65 | self.dtype = dtype 66 | self.num_key_value_heads = ( 67 | config.num_attention_heads 68 | if getattr(config, "num_key_value_heads", None) is None 69 | else config.num_key_value_heads 70 | ) 71 | 72 | self.key_cache: List[torch.Tensor] = [] 73 | self.value_cache: List[torch.Tensor] = [] 74 | # Note: There will be significant perf decrease if switching to use 5D tensors instead. 75 | cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) 76 | for idx in range(config.num_hidden_layers): 77 | if layer_device_map is not None: 78 | layer_device = layer_device_map[idx] 79 | else: 80 | layer_device = device 81 | new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) 82 | new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) 83 | self.key_cache.append(new_layer_key_cache) 84 | self.value_cache.append(new_layer_value_cache) 85 | 86 | self.last_length = 0 87 | self.cache_length = 0 88 | 89 | def reset(self): 90 | self.cache_length = 0 91 | self.last_length = 0 92 | 93 | def set_length(self): 94 | self.cache_length = self.last_length 95 | 96 | def get_seq_length(self, layer_idx = 0): 97 | return self.cache_length 98 | 99 | def get_max_cache_shape(self) -> Optional[int]: 100 | """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" 101 | return self.max_cache_len 102 | 103 | def update(self, key_states, value_states, layer_idx, cache_kwargs = None): 104 | k_out = self.key_cache[layer_idx] 105 | v_out = self.value_cache[layer_idx] 106 | k_dst = k_out.narrow(2, self.cache_length, key_states.shape[2]) 107 | v_dst = v_out.narrow(2, self.cache_length, value_states.shape[2]) 108 | k_dst.copy_(key_states) 109 | v_dst.copy_(value_states) 110 | if layer_idx == 0: 111 | self.last_length = self.cache_length + key_states.shape[2] 112 | return ( 113 | k_out.narrow(2, 0, self.last_length), 114 | v_out.narrow(2, 0, self.last_length), 115 | ) 116 | 117 | # @profile_decorator("SamdCache.select_indices") 118 | def select_indices(self, 119 | indices: torch.Tensor | None = None, 120 | accept_length: int = 1, 121 | ): 122 | start = self.cache_length 123 | if indices is not None: 124 | select_indices = start + indices 125 | else: 126 | select_indices = None 127 | for data in self.key_cache + self.value_cache: 128 | if select_indices is not None: 129 | select_indices = select_indices.to(data.device) 130 | tgt = data.index_select(-2, select_indices) 131 | dst = data.narrow(-2, start, accept_length) 132 | dst.copy_(tgt, non_blocking=True) 133 | self.cache_length += accept_length 134 | -------------------------------------------------------------------------------- /samd/cache.py.bak: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers.cache_utils import Cache, StaticCache 4 | from transformers.configuration_utils import PretrainedConfig 5 | from typing import List, Optional, Dict, Any, Tuple 6 | from enum import Enum 7 | 8 | 9 | class SamdCache(Cache): 10 | 11 | def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: 12 | super().__init__() 13 | self.config = config 14 | self.max_batch_size = max_batch_size 15 | self.max_cache_len = max_cache_len 16 | self.cur_length = torch.tensor(0, dtype=torch.long, device=device) 17 | self.kv_data = torch.zeros( 18 | config.num_hidden_layers * 2, 19 | max_batch_size, 20 | config.num_key_value_heads, 21 | max_cache_len, 22 | config.hidden_size // config.num_attention_heads, 23 | device=device, 24 | dtype=dtype 25 | ) 26 | self.devcie = device 27 | self.dtype = dtype 28 | 29 | def get_seq_length(self, layer_idx: int | None = 0) -> int: 30 | return self.cur_length.item() 31 | 32 | def get_max_cache_shape(self) -> int: 33 | return self.max_cache_len 34 | 35 | def reorder_cache(self, beam_idx): 36 | raise NotImplementedError 37 | 38 | def reset(self): 39 | self.kv_data.fill_(0) 40 | self.cur_length.fill_(0) 41 | 42 | def update( 43 | self, 44 | key_states: torch.Tensor, 45 | value_states: torch.Tensor, 46 | layer_idx: int, 47 | cache_kwargs: Optional[Dict[str, Any]] = None, 48 | ) -> Tuple[torch.Tensor, torch.Tensor]: 49 | start = self.cur_length 50 | length = key_states.shape[-2] 51 | self.kv_data[2 * layer_idx + 0]\ 52 | .narrow(-2, start, length)\ 53 | .copy_(key_states) 54 | self.kv_data[2 * layer_idx + 1]\ 55 | .narrow(-2, start, length)\ 56 | .copy_(value_states) 57 | 58 | k_out = self.kv_data[2 * layer_idx + 0].narrow(-2, 0, start + length) 59 | v_out = self.kv_data[2 * layer_idx + 1].narrow(-2, 0, start + length) 60 | return k_out, v_out 61 | 62 | def post_update(self, indices: torch.Tensor): 63 | start = self.cur_length 64 | select_indices = start + indices 65 | accept_length = indices.shape[-1] 66 | tgt = self.kv_data.index_select(-2, select_indices) 67 | dst = self.kv_data.narrow(-2, start, accept_length) 68 | dst.copy_(tgt) 69 | self.cur_length += accept_length 70 | 71 | def set_cache_positions(self, length): 72 | self.cur_length.fill_(length) 73 | 74 | 75 | class SamdStaticCache(StaticCache): 76 | 77 | def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: 78 | super().__init__(config, max_batch_size, max_cache_len, device, dtype) 79 | self.cur_length = torch.tensor(0, dtype=torch.long, device=device) 80 | self.devcie = device 81 | 82 | def get_seq_length(self, layer_idx: int | None = 0) -> int: 83 | return self.cur_length.item() 84 | 85 | def reset(self): 86 | super().reset() 87 | self.cur_length.fill_(0) 88 | 89 | def update( 90 | self, 91 | key_states: torch.Tensor, 92 | value_states: torch.Tensor, 93 | layer_idx: int, 94 | cache_kwargs: Optional[Dict[str, Any]] = None, 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | start = self.cur_length 97 | self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) 98 | self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) 99 | k_out = self.key_cache[layer_idx] 100 | v_out = self.value_cache[layer_idx] 101 | 102 | key_length = value_length = key_states.shape[2] 103 | k_out[:, :, start:start + key_length] = key_states 104 | v_out[:, :, start:start + value_length] = value_states 105 | 106 | return k_out, v_out 107 | 108 | @torch.no_grad() 109 | def post_update(self, indices: torch.Tensor): 110 | start = self.cur_length 111 | select_positions = self.cur_length + indices 112 | accept_length = indices.shape[-1] 113 | for layer_idx in range(len(self.key_cache)): 114 | self.key_cache[layer_idx][:, :, start:start + accept_length] \ 115 | = self.key_cache[layer_idx][:, :, select_positions] 116 | self.value_cache[layer_idx][:, :, start:start + accept_length] \ 117 | = self.value_cache[layer_idx][:, :, select_positions] 118 | self.cur_length += accept_length 119 | 120 | @torch.no_grad() 121 | def set_cache_positions(self, length): 122 | self.cur_length.fill_(length) -------------------------------------------------------------------------------- /samd/config/default_tree.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [], 46 | "9": [], 47 | "10": [], 48 | "11": [], 49 | "12": [], 50 | "14": [], 51 | "15": [], 52 | "16": [], 53 | "17": [], 54 | "18": [], 55 | "19": [], 56 | "20": [], 57 | "22": [], 58 | "23": [] 59 | } 60 | -------------------------------------------------------------------------------- /samd/config/default_tree.json.bak0: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8 11 | ], 12 | "1": [ 13 | 9, 14 | 10, 15 | 11, 16 | 12, 17 | 13, 18 | 14, 19 | 15, 20 | 16 21 | ], 22 | "2": [ 23 | 17, 24 | 18, 25 | 19, 26 | 20 27 | ], 28 | "3": [ 29 | 21, 30 | 22, 31 | 23 32 | ], 33 | "4": [ 34 | 24, 35 | 25 36 | ], 37 | "5": [ 38 | 26 39 | ], 40 | "6": [ 41 | 27 42 | ], 43 | "7": [ 44 | 28 45 | ], 46 | "8": [ 47 | 29 48 | ], 49 | "9": [ 50 | 30, 51 | 31, 52 | 32, 53 | 33, 54 | 34, 55 | 35, 56 | 36, 57 | 37 58 | ], 59 | "10": [ 60 | 38, 61 | 39, 62 | 40 63 | ], 64 | "11": [ 65 | 41, 66 | 42 67 | ], 68 | "12": [ 69 | 43 70 | ], 71 | "13": [ 72 | 44 73 | ], 74 | "14": [ 75 | 45 76 | ], 77 | "15": [ 78 | 46 79 | ], 80 | "16": [ 81 | 47 82 | ], 83 | "17": [ 84 | 48, 85 | 49 86 | ], 87 | "18": [ 88 | 50 89 | ], 90 | "21": [ 91 | 51 92 | ], 93 | "24": [ 94 | 52 95 | ], 96 | "26": [ 97 | 53 98 | ], 99 | "27": [ 100 | 54 101 | ], 102 | "30": [ 103 | 55, 104 | 56, 105 | 57, 106 | 58, 107 | 59 108 | ], 109 | "31": [ 110 | 60, 111 | 61 112 | ], 113 | "32": [ 114 | 62 115 | ], 116 | "19": [], 117 | "20": [], 118 | "22": [], 119 | "23": [], 120 | "25": [], 121 | "28": [], 122 | "29": [], 123 | "33": [ 124 | 63 125 | ], 126 | "34": [], 127 | "35": [], 128 | "36": [], 129 | "37": [], 130 | "38": [], 131 | "39": [], 132 | "40": [], 133 | "41": [], 134 | "42": [], 135 | "43": [], 136 | "44": [], 137 | "45": [], 138 | "46": [], 139 | "47": [], 140 | "48": [], 141 | "49": [], 142 | "50": [], 143 | "51": [], 144 | "52": [], 145 | "53": [], 146 | "54": [], 147 | "55": [], 148 | "56": [], 149 | "57": [], 150 | "58": [], 151 | "59": [], 152 | "60": [], 153 | "61": [], 154 | "62": [], 155 | "63": [] 156 | } -------------------------------------------------------------------------------- /samd/config/default_tree.json.bak1: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 2, 3, 4 4 | ], 5 | "1": [ 6 | 5 7 | ], 8 | "2": [], 9 | "3": [], 10 | "4": [], 11 | "5": [ 12 | 6 13 | ], 14 | "6": [ 15 | 7 16 | ], 17 | "7": [] 18 | } 19 | -------------------------------------------------------------------------------- /samd/config/default_tree.json.bak2: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1 4 | ], 5 | "1": [ 6 | 2 7 | ], 8 | "2": [ 9 | 3 10 | ], 11 | "3": [ 12 | 4 13 | ], 14 | "4": [] 15 | } 16 | -------------------------------------------------------------------------------- /samd/config/default_tree.json.bak3: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [], 46 | "9": [], 47 | "10": [], 48 | "11": [], 49 | "12": [], 50 | "14": [], 51 | "15": [], 52 | "16": [], 53 | "17": [], 54 | "18": [], 55 | "19": [], 56 | "20": [], 57 | "22": [], 58 | "23": [] 59 | } 60 | -------------------------------------------------------------------------------- /samd/config/default_tree.json.bak4: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [ 46 | 24, 47 | 25 48 | ], 49 | "9": [], 50 | "10": [], 51 | "11": [], 52 | "12": [], 53 | "14": [], 54 | "15": [], 55 | "16": [], 56 | "17": [], 57 | "18": [], 58 | "19": [], 59 | "20": [], 60 | "22": [], 61 | "23": [], 62 | "24": [ 63 | 25 64 | ], 65 | "25": [ 66 | 26 67 | ], 68 | "26": [ 69 | 27 70 | ], 71 | "27": [ 72 | 28 73 | ], 74 | "28": [] 75 | } 76 | -------------------------------------------------------------------------------- /samd/config/default_tree.json.bak5: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [ 46 | 24, 47 | 25 48 | ], 49 | "9": [], 50 | "10": [], 51 | "11": [], 52 | "12": [], 53 | "14": [], 54 | "15": [], 55 | "16": [], 56 | "17": [], 57 | "18": [], 58 | "19": [], 59 | "20": [], 60 | "22": [], 61 | "23": [], 62 | "24": [ 63 | 25 64 | ], 65 | "25": [ 66 | 26 67 | ], 68 | "26": [ 69 | 27 70 | ], 71 | "27": [ 72 | 28 73 | ], 74 | "28": [] 75 | } 76 | -------------------------------------------------------------------------------- /samd/config/default_tree_1_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [] 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /samd/config/default_tree_6_60.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [ 4 | 1, 5 | 2, 6 | 3, 7 | 4, 8 | 5, 9 | 6, 10 | 7 11 | ], 12 | "1": [ 13 | 8, 14 | 9, 15 | 10, 16 | 11, 17 | 12, 18 | 13 19 | ], 20 | "2": [ 21 | 14, 22 | 15, 23 | 16, 24 | 17, 25 | 18 26 | ], 27 | "3": [ 28 | 19, 29 | 20, 30 | 21 31 | ], 32 | "4": [ 33 | 22, 34 | 23 35 | ], 36 | "5": [ 37 | 24, 38 | 25 39 | ], 40 | "6": [ 41 | 26 42 | ], 43 | "7": [ 44 | 27 45 | ], 46 | "8": [ 47 | 28, 48 | 29, 49 | 30 50 | ], 51 | "13": [ 52 | ], 53 | "9": [ 54 | 31, 55 | 32 56 | ], 57 | "10": [ 58 | 33, 59 | 34 60 | ], 61 | "11": [ 62 | 35 63 | ], 64 | "12": [ 65 | 36 66 | ], 67 | "14": [ 68 | 37, 69 | 38, 70 | 39 71 | ], 72 | "15": [ 73 | 40, 74 | 41 75 | ], 76 | "16": [ 77 | 42 78 | ], 79 | "17": [ 80 | 43 81 | ], 82 | "18": [ 83 | ], 84 | "19": [ 85 | 44 86 | ], 87 | "20": [ 88 | 45 89 | ], 90 | "21": [ 91 | ], 92 | "22": [ 93 | 46 94 | ], 95 | "23": [ 96 | ], 97 | "24": [ 98 | 47 99 | ], 100 | "25": [ 101 | ], 102 | "26": [ 103 | 48 104 | ], 105 | "27": [ 106 | ], 107 | "28": [ 108 | 49, 109 | 50 110 | ], 111 | "29": [ 112 | 51 113 | ], 114 | "30": [ 115 | ], 116 | "31": [ 117 | 52 118 | ], 119 | "32": [], 120 | "33": [ 121 | ], 122 | "34": [ 123 | ], 124 | "35": [ 125 | ], 126 | "36": [], 127 | "37": [ 128 | 53, 129 | 54 130 | ], 131 | "38": [ 132 | ], 133 | "39": [], 134 | "40": [ 135 | 55 136 | ], 137 | "41": [], 138 | "42": [], 139 | "43": [ 140 | ], 141 | "44": [ 142 | 56 143 | ], 144 | "45": [], 145 | "46": [ 146 | ], 147 | "47": [], 148 | "48": [], 149 | "49": [ 150 | 57, 151 | 58 152 | ], 153 | "50": [], 154 | "51": [ 155 | 59 156 | ], 157 | "52": [ 158 | ], 159 | "53": [ 160 | 60 161 | ], 162 | "54": [ 163 | ], 164 | "55": [ 165 | ], 166 | "56": [], 167 | "57": [], 168 | "58": [], 169 | "59": [], 170 | "60": [] 171 | } 172 | } -------------------------------------------------------------------------------- /samd/config/eagle.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_choices": [ 3 | [ 4 | 0 5 | ], 6 | [ 7 | 1 8 | ], 9 | [ 10 | 2 11 | ], 12 | [ 13 | 3 14 | ], 15 | [ 16 | 0, 17 | 0 18 | ], 19 | [ 20 | 0, 21 | 1 22 | ], 23 | [ 24 | 0, 25 | 2 26 | ], 27 | [ 28 | 1, 29 | 0 30 | ], 31 | [ 32 | 1, 33 | 1 34 | ], 35 | [ 36 | 2, 37 | 0 38 | ], 39 | [ 40 | 2, 41 | 1 42 | ], 43 | [ 44 | 3, 45 | 0 46 | ], 47 | [ 48 | 0, 49 | 0, 50 | 0 51 | ], 52 | [ 53 | 0, 54 | 0, 55 | 1 56 | ], 57 | [ 58 | 0, 59 | 0, 60 | 2 61 | ], 62 | [ 63 | 0, 64 | 1, 65 | 0 66 | ], 67 | [ 68 | 0, 69 | 1, 70 | 1 71 | ], 72 | [ 73 | 0, 74 | 2, 75 | 0 76 | ], 77 | [ 78 | 0, 79 | 2, 80 | 1 81 | ], 82 | [ 83 | 1, 84 | 0, 85 | 0 86 | ], 87 | [ 88 | 0, 89 | 0, 90 | 0, 91 | 0 92 | ], 93 | [ 94 | 0, 95 | 0, 96 | 0, 97 | 1 98 | ], 99 | [ 100 | 0, 101 | 0, 102 | 0, 103 | 2 104 | ], 105 | [ 106 | 0, 107 | 0, 108 | 0, 109 | 0, 110 | 0 111 | ], 112 | [ 113 | 0, 114 | 0, 115 | 0, 116 | 0, 117 | 1 118 | ] 119 | ] 120 | } -------------------------------------------------------------------------------- /samd/config/token_recycle.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [ 4 | 1, 5 | 2, 6 | 3, 7 | 4, 8 | 5, 9 | 6, 10 | 7 11 | ], 12 | "1": [ 13 | 8, 14 | 9, 15 | 10, 16 | 11, 17 | 12, 18 | 13 19 | ], 20 | "2": [ 21 | 14, 22 | 15, 23 | 16, 24 | 17, 25 | 18 26 | ], 27 | "3": [ 28 | 19, 29 | 20, 30 | 21 31 | ], 32 | "4": [ 33 | 22, 34 | 23 35 | ], 36 | "5": [ 37 | 24, 38 | 25 39 | ], 40 | "6": [ 41 | 26 42 | ], 43 | "7": [ 44 | 27 45 | ], 46 | "8": [ 47 | 28, 48 | 29, 49 | 30 50 | ], 51 | "13": [ 52 | ], 53 | "9": [ 54 | 31, 55 | 32 56 | ], 57 | "10": [ 58 | 33, 59 | 34 60 | ], 61 | "11": [ 62 | 35 63 | ], 64 | "12": [ 65 | 36 66 | ], 67 | "14": [ 68 | 37, 69 | 38, 70 | 39 71 | ], 72 | "15": [ 73 | 40, 74 | 41 75 | ], 76 | "16": [ 77 | 42 78 | ], 79 | "17": [ 80 | 43 81 | ], 82 | "18": [ 83 | ], 84 | "19": [ 85 | 44 86 | ], 87 | "20": [ 88 | 45 89 | ], 90 | "21": [ 91 | ], 92 | "22": [ 93 | 46 94 | ], 95 | "23": [ 96 | ], 97 | "24": [ 98 | 47 99 | ], 100 | "25": [ 101 | ], 102 | "26": [ 103 | 48 104 | ], 105 | "27": [ 106 | ], 107 | "28": [ 108 | 49, 109 | 50 110 | ], 111 | "29": [ 112 | 51 113 | ], 114 | "30": [ 115 | ], 116 | "31": [ 117 | 52 118 | ], 119 | "32": [], 120 | "33": [ 121 | ], 122 | "34": [ 123 | ], 124 | "35": [ 125 | ], 126 | "36": [], 127 | "37": [ 128 | 53, 129 | 54 130 | ], 131 | "38": [ 132 | ], 133 | "39": [], 134 | "40": [ 135 | 55 136 | ], 137 | "41": [], 138 | "42": [], 139 | "43": [ 140 | ], 141 | "44": [ 142 | 56 143 | ], 144 | "45": [], 145 | "46": [ 146 | ], 147 | "47": [], 148 | "48": [], 149 | "49": [ 150 | 57, 151 | 58 152 | ], 153 | "50": [], 154 | "51": [ 155 | 59 156 | ], 157 | "52": [ 158 | ], 159 | "53": [ 160 | 60 161 | ], 162 | "54": [ 163 | ], 164 | "55": [ 165 | ], 166 | "56": [], 167 | "57": [], 168 | "58": [], 169 | "59": [], 170 | "60": [] 171 | } 172 | } -------------------------------------------------------------------------------- /samd/draft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict, Optional 3 | from enum import Enum 4 | from collections import namedtuple 5 | 6 | from .samd_config import SamdConfig 7 | from .sam import DynSAM, StaticSAM, NullStaticSAM 8 | from .tree_model import TreeModel, tree_model_cls 9 | from transformers import LlamaConfig, LlamaForCausalLM 10 | 11 | from profile_utils import profile_decorator, profile_lookup_decorator 12 | 13 | # from transformers import LlamaTokenizer 14 | # tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained('/data/models/vicuna-7b-v1.3') 15 | 16 | class CandidateType(str, Enum): 17 | sequence = "sequence" 18 | tree = "tree" 19 | 20 | Candidates = namedtuple('Candidates', ['type', 'tokens', 'candidate_tokens', 'buffers_kwargs']) 21 | 22 | TOPK = 8 23 | 24 | class DraftModel(torch.nn.Module): 25 | 26 | def __init__(self, 27 | config: SamdConfig, 28 | sam_dyn: DynSAM = None, 29 | sam_static: StaticSAM = None, 30 | tree_model: TreeModel = None, 31 | lm: LlamaForCausalLM = None, 32 | dtype: torch.dtype = torch.float16, 33 | device: str = "cuda", 34 | ) -> None: 35 | super().__init__() 36 | tree_cls = tree_model_cls[config.tree_method] 37 | self.config = config 38 | self.sam_dyn = sam_dyn if sam_dyn is not None else DynSAM(config.n_predicts) 39 | self.sam_static = sam_static if sam_static is not None else NullStaticSAM(config.n_predicts) 40 | self.tree_model = tree_model if tree_model is not None else tree_cls(config, lm, dtype, device) 41 | 42 | self.sam_dyn.n_predicts = config.n_predicts 43 | self.sam_static.n_predicts = config.n_predicts 44 | self.len_bias = config.len_bias 45 | self.len_threshold = config.len_threshold 46 | 47 | def reset(self): 48 | self.sam_dyn.reset() 49 | self.sam_static.reset() 50 | self.tree_model.reset() 51 | 52 | def lookup(self, start_token: int): 53 | index_dyn, match_dyn = self.sam_dyn.lookup(start_token) 54 | index_static, match_static = self.sam_static.lookup(start_token) 55 | match_static -= self.len_bias 56 | if max(match_dyn, match_static) >= self.len_threshold: 57 | if match_dyn >= match_static: 58 | seq = self.sam_dyn.gen_draft(index_dyn, start_token) 59 | else: 60 | seq = self.sam_static.gen_draft(index_static, start_token) 61 | return (CandidateType.sequence, seq, {}) 62 | else: 63 | return (CandidateType.tree,) + self.tree_model.gen_draft(start_token) 64 | 65 | def update(self, 66 | tokens: Optional[torch.Tensor] = None, 67 | last_hidden_states: Optional[torch.Tensor] = None, 68 | tree_tokens: Optional[torch.Tensor] = None, 69 | tree_logits: Optional[torch.Tensor] = None, 70 | ): 71 | tokens_list = tokens.tolist() 72 | self.sam_dyn.add_tokens(tokens_list) 73 | self.sam_static.transfer_tokens(tokens_list) 74 | self.tree_model.update( 75 | tokens=tokens, 76 | last_hidden_states=last_hidden_states, 77 | tree_tokens=tree_tokens, 78 | tree_logits=tree_logits, 79 | ) 80 | -------------------------------------------------------------------------------- /samd/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyx1999/SAM-Decoding/18c41f055b424fa3fa0bac41a8953d34cea1ed77/samd/inference/__init__.py -------------------------------------------------------------------------------- /samd/model_patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import llama_patch_dict, llama_attn_patch_dict 2 | 3 | patch_dict = {} 4 | attn_patch_dict = {} 5 | 6 | patch_dict.update(llama_patch_dict) 7 | attn_patch_dict.update(llama_attn_patch_dict) 8 | -------------------------------------------------------------------------------- /samd/sam/__init__.py: -------------------------------------------------------------------------------- 1 | from .dyn_sam import DynSAM 2 | from .static_sam import StaticSAM, NullStaticSAM 3 | from .utils import build_sam, dump_sam, load_sam 4 | -------------------------------------------------------------------------------- /samd/sam/dyn_sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict 3 | from dataclasses import dataclass 4 | from copy import deepcopy 5 | from collections import deque 6 | from tqdm import tqdm 7 | 8 | class DynSAM: 9 | 10 | @dataclass 11 | class SAMState: 12 | next: dict[int, int] 13 | link: int 14 | length: int 15 | min_endpos: int 16 | 17 | def __init__(self, n_predicts: int = 40): 18 | self.n_predicts = n_predicts 19 | self.states: List[DynSAM.SAMState] = [DynSAM.SAMState(next={}, link=-1, length=0, min_endpos=0)] 20 | self.input_ids: List[int] = [-1] 21 | self.last = 0 22 | self.max_length = 0 23 | 24 | # params needed to be reset for each query 25 | self.cur_index = 0 26 | self.cur_length = 0 27 | 28 | def reset(self): 29 | self.states: List[DynSAM.SAMState] = [DynSAM.SAMState(next={}, link=-1, length=0, min_endpos=0)] 30 | self.input_ids: List[int] = [-1] 31 | self.last = 0 32 | self.max_length = 0 33 | self.cur_index = 0 34 | self.cur_length = 0 35 | 36 | def expand_state(self, state: SAMState): 37 | new_index = len(self.states) 38 | self.states.append(state) 39 | return new_index 40 | 41 | def add_state(self, token: int): 42 | self.max_length += 1 43 | cur = self.expand_state( 44 | DynSAM.SAMState( 45 | next={}, link=-1, 46 | length=self.max_length, 47 | min_endpos=self.max_length 48 | ) 49 | ) 50 | p = self.last 51 | while p != -1 and token not in self.states[p].next: 52 | self.states[p].next[token] = cur 53 | p = self.states[p].link 54 | if p == -1: 55 | self.states[cur].link = 0 56 | else: 57 | q = self.states[p].next[token] 58 | if self.states[p].length + 1 == self.states[q].length: 59 | self.states[cur].link = q 60 | else: 61 | clone = self.expand_state(deepcopy(self.states[q])) 62 | self.states[clone].length = self.states[p].length + 1 63 | while p != -1 and self.states[p].next[token] == q: 64 | self.states[p].next[token] = clone 65 | p = self.states[p].link 66 | self.states[q].link = self.states[cur].link = clone 67 | self.last = cur 68 | 69 | def transfer_state(self, index: int, length: int, token: int): 70 | while index != 0 and token not in self.states[index].next: 71 | index = self.states[index].link 72 | length = self.states[index].length 73 | if token in self.states[index].next: 74 | index = self.states[index].next[token] 75 | length += 1 76 | else: 77 | index = length = 0 78 | return index, length 79 | 80 | def transfer_cur_state(self, token: int): 81 | self.cur_index, self.cur_length = \ 82 | self.transfer_state(self.cur_index, self.cur_length, token) 83 | 84 | def add_tokens(self, tokens: List[int]): 85 | for token in tokens: 86 | self.transfer_cur_state(token) 87 | self.add_state(token) 88 | self.input_ids.extend(tokens) 89 | 90 | def transfer_tokens(self, tokens: List[int]): 91 | for token in tokens: 92 | self.transfer_cur_state(token) 93 | 94 | def lookup(self, token: int): 95 | index, length = \ 96 | self.transfer_state(self.cur_index, self.cur_length, token) 97 | return index, length 98 | 99 | def to_anc(self, index: int): 100 | if index != 0: 101 | length_to_end = self.max_length - self.states[index].min_endpos 102 | while self.states[index].link != 0 and self.n_predicts > length_to_end: 103 | index = self.states[index].link 104 | length_to_end = self.max_length - self.states[index].min_endpos 105 | return index 106 | 107 | def gen_draft(self, index: int, start_token: int): 108 | index = self.to_anc(index) 109 | endpos = self.states[index].min_endpos 110 | pred_ids = [start_token] + self.input_ids[endpos + 1:endpos + self.n_predicts] 111 | if len(pred_ids) < self.n_predicts: 112 | pred_ids.extend([0] * (self.n_predicts - len(pred_ids))) 113 | return pred_ids 114 | -------------------------------------------------------------------------------- /samd/sam/static_sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict 3 | from dataclasses import dataclass 4 | from copy import deepcopy 5 | from collections import deque 6 | from tqdm import tqdm 7 | 8 | class StaticSAM: 9 | 10 | @dataclass 11 | class SAMState: 12 | next: dict[int, int] 13 | link: int 14 | length: int 15 | min_endpos: int 16 | 17 | def __init__(self, n_predicts: int = 40): 18 | self.n_predicts = n_predicts 19 | self.states: List[StaticSAM.SAMState] = [StaticSAM.SAMState(next={}, link=-1, length=0, min_endpos=0)] 20 | self.input_ids: List[int] = [-1] 21 | self.last = 0 22 | self.max_length = 0 23 | 24 | # params needed to be reset for each query 25 | self.cur_index = 0 26 | self.cur_length = 0 27 | 28 | def reset(self): 29 | self.cur_index = 0 30 | self.cur_length = 0 31 | 32 | def add_batch_tokens(self, batch_tokens: List[List[int]], eos_token: int, verbose: bool): 33 | for tokens in tqdm(batch_tokens, desc="build sam...", disable=not verbose): 34 | self.add_tokens(tokens) 35 | if tokens[-1] != eos_token: 36 | self.add_tokens([eos_token]) 37 | 38 | @staticmethod 39 | def build( 40 | batch_tokens: List[List[int]], 41 | eos_token: int, 42 | verbose: bool =True 43 | ): 44 | sam = StaticSAM() 45 | sam.add_batch_tokens(batch_tokens, eos_token, verbose) 46 | return sam 47 | 48 | def expand_state(self, state: SAMState): 49 | new_index = len(self.states) 50 | self.states.append(state) 51 | return new_index 52 | 53 | def add_state(self, token: int): 54 | self.max_length += 1 55 | cur = self.expand_state( 56 | StaticSAM.SAMState( 57 | next={}, link=-1, 58 | length=self.max_length, 59 | min_endpos=self.max_length 60 | ) 61 | ) 62 | p = self.last 63 | while p != -1 and token not in self.states[p].next: 64 | self.states[p].next[token] = cur 65 | p = self.states[p].link 66 | if p == -1: 67 | self.states[cur].link = 0 68 | else: 69 | q = self.states[p].next[token] 70 | if self.states[p].length + 1 == self.states[q].length: 71 | self.states[cur].link = q 72 | else: 73 | clone = self.expand_state(deepcopy(self.states[q])) 74 | self.states[clone].length = self.states[p].length + 1 75 | while p != -1 and self.states[p].next[token] == q: 76 | self.states[p].next[token] = clone 77 | p = self.states[p].link 78 | self.states[q].link = self.states[cur].link = clone 79 | self.last = cur 80 | 81 | def transfer_state(self, index: int, length: int, token: int): 82 | while index != 0 and token not in self.states[index].next: 83 | index = self.states[index].link 84 | length = self.states[index].length 85 | if token in self.states[index].next: 86 | index = self.states[index].next[token] 87 | length += 1 88 | else: 89 | index = length = 0 90 | return index, length 91 | 92 | def transfer_cur_state(self, token: int): 93 | self.cur_index, self.cur_length = \ 94 | self.transfer_state(self.cur_index, self.cur_length, token) 95 | 96 | def add_tokens(self, tokens: List[int]): 97 | for token in tokens: 98 | self.transfer_cur_state(token) 99 | self.add_state(token) 100 | self.input_ids.extend(tokens) 101 | 102 | def transfer_tokens(self, tokens: List[int]): 103 | for token in tokens: 104 | self.transfer_cur_state(token) 105 | 106 | def lookup(self, token: int): 107 | index, length = \ 108 | self.transfer_state(self.cur_index, self.cur_length, token) 109 | return index, length 110 | 111 | def to_anc(self, index: int): 112 | if index != 0: 113 | length_to_end = self.max_length - self.states[index].min_endpos 114 | while self.states[index].link != 0 and self.n_predicts > length_to_end: 115 | index = self.states[index].link 116 | length_to_end = self.max_length - self.states[index].min_endpos 117 | return index 118 | 119 | def gen_draft(self, index: int, start_token: int): 120 | # index = self.to_anc(index) 121 | endpos = self.states[index].min_endpos 122 | pred_ids = [start_token] + self.input_ids[endpos + 1:endpos + self.n_predicts] 123 | if len(pred_ids) < self.n_predicts: 124 | pred_ids.extend([0] * (self.n_predicts - len(pred_ids))) 125 | return pred_ids 126 | 127 | 128 | class NullStaticSAM(StaticSAM): 129 | 130 | def __init__(self, n_predicts = 40): 131 | super().__init__(n_predicts) 132 | 133 | def transfer_tokens(self, tokens): 134 | pass 135 | 136 | def gen_draft(self, index, start_token): 137 | return -1, -1 138 | -------------------------------------------------------------------------------- /samd/sam/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | from datasets import Dataset 4 | from transformers import PreTrainedTokenizerFast 5 | from typing import List 6 | 7 | from .static_sam import StaticSAM 8 | from ..samd_config import SamdConfig 9 | 10 | def build_sam( 11 | batch_tokens: List[List[int]], 12 | eos_token: int, 13 | ): 14 | sam = StaticSAM.build( 15 | batch_tokens, 16 | eos_token 17 | ) 18 | return sam 19 | 20 | def dump_sam(path: str, sam: StaticSAM): 21 | with open(path, "wb") as f: 22 | pickle.dump(sam, f) 23 | 24 | def load_sam(path: str): 25 | print("load sam...") 26 | start = time.perf_counter() 27 | with open(path, "rb") as f: 28 | _sam = pickle.load(f) 29 | sam = StaticSAM() 30 | for key, value in vars(_sam).items(): 31 | if hasattr(sam, key): 32 | setattr(sam, key, value) 33 | print("load [{}]".format(key)) 34 | end = time.perf_counter() 35 | assert type(sam) is StaticSAM 36 | print("loading ended in {} seconds.".format(end - start)) 37 | return sam 38 | -------------------------------------------------------------------------------- /samd/samd_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from dataclasses import dataclass, field 5 | from typing import Optional, Union, List, Literal, Dict, Any 6 | from enum import Enum 7 | 8 | 9 | @dataclass 10 | class SamdConfig: 11 | n_predicts: int = field(default=40) 12 | max_predicts: int = field(default=70) 13 | len_threshold: int = field(default=5) 14 | len_bias: int = field(default=5) 15 | 16 | cache_type: Literal["dynamic", "static"] = field( 17 | default="static" 18 | ) 19 | use_last_hidden_states: bool = field(default=False) 20 | 21 | tree_method: Literal["token_recycle", "eagle", "eagle2"] = field( 22 | default="token_recycle" 23 | ) 24 | tree_model_path: Optional[str] = field(default=None) 25 | tree_path: Optional[str] = field(default=None) 26 | tree: Optional[List[List[int]]] = field(default=None) 27 | tree_config: Optional[Dict[str, Any]] = field(default=None) 28 | 29 | def __post_init__(self): 30 | if self.tree is None: 31 | if self.tree_method == "token_recycle": 32 | self.tree = load_token_recycle(self.tree_path) 33 | elif self.tree_method == "eagle": 34 | tree, tree_config = load_eagle(self.tree_model_path, self.tree_path) 35 | self.tree = tree 36 | self.tree_config = tree_config 37 | self.use_last_hidden_states = True 38 | elif self.tree_method == "eagle2": 39 | tree_config = load_eagle2(self.tree_model_path) 40 | self.tree_config = tree_config 41 | self.use_last_hidden_states = True 42 | else: 43 | raise ValueError 44 | 45 | 46 | class ForwardType(str, Enum): 47 | prefill = "prefill" 48 | seq_decode = "seq_decode" 49 | tree_decode = "tree_decode" 50 | 51 | 52 | class ForwardState: 53 | 54 | def __init__(self, forward_type: ForwardType | None) -> None: 55 | self.forward_type = forward_type 56 | 57 | 58 | class MaskState: 59 | 60 | def __init__(self, mask: Optional[torch.Tensor]) -> None: 61 | self.mask = mask 62 | 63 | def set_state(self, mask: Optional[torch.Tensor]) -> None: 64 | self.mask = mask 65 | 66 | 67 | def load_token_recycle(tree_path: Optional[str] = None): 68 | if tree_path is None: 69 | tree_path = "token_recycle.json" 70 | samd_path = os.path.dirname(__file__) 71 | with open(os.path.join(samd_path, "config", tree_path), "r") as f: 72 | tree_adj: dict = json.load(f)["tree_adj"] 73 | num_node = len(tree_adj) 74 | tree: List[List[int]] = [] 75 | for i in range(num_node): 76 | tree.append(tree_adj[str(i)]) 77 | print("tree_path:", tree_path) 78 | print("len_tree:", len(tree)) 79 | return tree 80 | 81 | 82 | def load_eagle(tree_model_path: str, tree_path: Optional[str] = None): 83 | if tree_path is None: 84 | tree_path = "eagle.json" 85 | samd_path = os.path.dirname(__file__) 86 | with open(os.path.join(samd_path, "config", tree_path), "r") as f: 87 | tree = json.load(f)["tree_choices"] 88 | with open(os.path.join(tree_model_path, "config.json")) as f: 89 | tree_config = json.load(f) 90 | return tree, tree_config 91 | 92 | 93 | def load_eagle2(tree_model_path: str): 94 | with open(os.path.join(tree_model_path, "config.json")) as f: 95 | tree_config = json.load(f) 96 | return tree_config 97 | -------------------------------------------------------------------------------- /samd/tree_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .tree import TreeModel 2 | from .token_recycle import TokenRecycle 3 | from .eagle import Eagle 4 | from .eagle2 import Eagle2 5 | from typing import Dict, Union 6 | 7 | tree_model_cls: Dict[ 8 | str, 9 | Union[TokenRecycle, Eagle] 10 | ] = { 11 | "token_recycle": TokenRecycle, 12 | "eagle": Eagle, 13 | "eagle2": Eagle2, 14 | } 15 | -------------------------------------------------------------------------------- /samd/tree_model/eagle/__init__.py: -------------------------------------------------------------------------------- 1 | from .eagle import Eagle -------------------------------------------------------------------------------- /samd/tree_model/eagle/eagle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from transformers import LlamaConfig, LlamaForCausalLM 4 | from typing import List, Tuple, Dict 5 | 6 | from ...samd_config import SamdConfig 7 | from ..tree import TreeModel 8 | from .eagle_config import EagleConfig 9 | from .eagle_model import EagleModel 10 | 11 | from .utils import gen_buffers, TOPK 12 | 13 | 14 | class Eagle(TreeModel): 15 | 16 | def __init__(self, 17 | config: SamdConfig, 18 | lm: LlamaForCausalLM, 19 | dtype: torch.dtype, 20 | device: str, 21 | ) -> None: 22 | super().__init__() 23 | self.tree = config.tree 24 | self.dtype = dtype 25 | self.device = device 26 | self.head: torch.nn.Linear = lm.lm_head 27 | self.model: EagleModel = EagleModel( 28 | config=EagleConfig(**config.tree_config), 29 | bias=config.tree_config.get("bias", True) 30 | ).to(device=device, dtype=dtype) 31 | self.model.gen_buffers(config.tree, device) 32 | self.model.load_weight(config.tree_model_path) 33 | 34 | self.accpet_tokens: torch.Tensor = None 35 | self.accept_hidden_states: torch.Tensor = None 36 | self.tree_indices: torch.Tensor = None 37 | 38 | def reset(self): 39 | self.model.stable_kv = None 40 | 41 | def update(self, 42 | tokens: torch.Tensor, 43 | last_hidden_states: torch.Tensor, 44 | **kwargs, 45 | ): 46 | tokens = tokens.to(self.device) 47 | if self.accpet_tokens is None: 48 | self.accpet_tokens = tokens 49 | else: 50 | self.accpet_tokens = torch.cat([self.accpet_tokens, tokens], dim=-1) 51 | if self.accept_hidden_states is None: 52 | self.accept_hidden_states = last_hidden_states 53 | else: 54 | self.accept_hidden_states = torch.cat([self.accept_hidden_states, last_hidden_states], dim=-2) 55 | 56 | def gen_draft(self, start_token: int) -> List[int]: 57 | start_token = torch.tensor([start_token], dtype=torch.long, device=self.device) 58 | accpet_tokens = torch.cat((self.accpet_tokens, start_token), dim=-1) 59 | accept_hidden_states = self.accept_hidden_states 60 | self.accpet_tokens = self.accept_hidden_states = None 61 | pred_ids: torch.Tensor = self.model.topk_genrate( 62 | accept_hidden_states, 63 | accpet_tokens, 64 | self.head, 65 | top_k=TOPK 66 | )[0] 67 | pred_ids = torch.cat([start_token, pred_ids.view(-1)]) 68 | pred_ids = pred_ids[self.tree_indices].tolist() 69 | buffers_kwargs = {} 70 | return pred_ids, buffers_kwargs 71 | 72 | def gen_buffers(self): 73 | buffers = gen_buffers(self.tree, self.device) 74 | self.tree_indices = buffers["tree_indices"] 75 | return buffers 76 | -------------------------------------------------------------------------------- /samd/tree_model/eagle/eagle_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # typing 4 | from typing import List 5 | from .utils import TOPK 6 | 7 | 8 | class EagleNode: 9 | def __init__(self, parent=None, value=None, dict_key=None): 10 | self.parent = parent 11 | self.value = value 12 | if parent: 13 | self.depth = parent.depth + 1 14 | parent.children.append(self) 15 | else: 16 | self.depth = 0 17 | self.children = [] 18 | self.dict_key = dict_key 19 | 20 | def is_leaf(self): 21 | return len(self.children) == 0 22 | 23 | def all_index(self): 24 | if not self.parent.parent: 25 | return [self.index] 26 | else: 27 | return self.parent.all_index() + [self.index] 28 | 29 | 30 | class EagleTree: 31 | 32 | def __init__(self, tree_list): 33 | sorted_tree_list = sorted(tree_list, key=lambda x: (len(x), x)) 34 | self.root = EagleNode() 35 | self.node_dic = {} 36 | for tree_node in sorted_tree_list: 37 | cur_value = tree_node[-1] 38 | if len(tree_node) == 1: 39 | cur_node = EagleNode( 40 | parent=self.root, value=cur_value, dict_key=tuple(tree_node) 41 | ) 42 | else: 43 | cur_parent = self.node_dic[tuple(tree_node[:-1])] 44 | cur_node = EagleNode( 45 | parent=cur_parent, value=cur_value, dict_key=tuple(tree_node) 46 | ) 47 | self.node_dic[tuple(tree_node)] = cur_node 48 | self.indexnode() 49 | 50 | def max_depth(self): 51 | return max([item.depth for item in self.node_dic.values()]) 52 | 53 | def num_node_wchild(self): 54 | num_c = 0 55 | for item in self.node_dic.values(): 56 | if not item.is_leaf(): 57 | num_c += 1 58 | return num_c 59 | 60 | def get_node_wchild(self): 61 | ns = [] 62 | for item in self.node_dic.values(): 63 | if not item.is_leaf(): 64 | ns.append(item) 65 | return ns 66 | 67 | def indexnode(self): 68 | cur_index = 0 69 | for key in self.node_dic: 70 | cur_node = self.node_dic[key] 71 | if not cur_node.is_leaf(): 72 | cur_node.index = cur_index 73 | cur_index += 1 74 | 75 | 76 | def gen_buffers_eagle(tree_choices, device="cuda"): 77 | tree = EagleTree(tree_choices) 78 | tree_len = tree.num_node_wchild() 79 | 80 | max_depth = tree.max_depth() 81 | nodes_wc = tree.get_node_wchild() 82 | 83 | depth_counts = [0 for _ in range(max_depth - 1)] 84 | for x in nodes_wc: 85 | depth_counts[x.depth - 1] += 1 86 | depth_counts_sum = [sum(depth_counts[: i + 1]) for i in range(len(depth_counts))] 87 | 88 | tree_attn_mask = torch.eye(tree_len, tree_len) 89 | 90 | for id, x in enumerate(nodes_wc): 91 | tree_attn_mask[id, x.all_index()] = 1 92 | 93 | tree_attn_mask_list0 = [tree_attn_mask[:ml, :ml] for ml in depth_counts_sum] 94 | tree_attn_mask_list = [] 95 | for id, x in enumerate(tree_attn_mask_list0): 96 | x = x[-depth_counts[id] :] 97 | tree_attn_mask_list.append(x) 98 | 99 | tree_indices_list = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] 100 | repeat_nums = [[] for _ in depth_counts] 101 | start = 0 102 | bias = 0 103 | for i in range(len(depth_counts)): 104 | bias = 0 105 | repeat_j = 0 106 | for j in range(depth_counts[i]): 107 | cur_node = nodes_wc[start + j] 108 | cur_parent = cur_node.parent 109 | 110 | if j != 0: 111 | if cur_parent != parent: 112 | bias += 1 113 | parent = cur_parent 114 | repeat_nums[i].append(j - repeat_j) 115 | repeat_j = j 116 | else: 117 | parent = cur_parent 118 | tree_indices_list[i][j] = cur_node.value + TOPK * (bias) 119 | repeat_nums[i].append(j - repeat_j + 1) 120 | start += depth_counts[i] 121 | 122 | position_ids = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] 123 | 124 | tree_buffers = { 125 | "attn_mask": [i.unsqueeze(0).unsqueeze(0) for i in tree_attn_mask_list], 126 | "tree_indices": tree_indices_list, 127 | "position_ids": position_ids, 128 | "repeat_nums": repeat_nums, 129 | } 130 | 131 | # Move the tensors in the dictionary to the specified device 132 | tree_buffers = { 133 | k: ( 134 | [i.clone().to(device) for i in v] 135 | if isinstance(v[0], torch.Tensor) 136 | else (torch.tensor(v, device=device) if isinstance(v, torch.Tensor) else v) 137 | ) 138 | for k, v in tree_buffers.items() 139 | } 140 | return tree_buffers 141 | -------------------------------------------------------------------------------- /samd/tree_model/eagle2/__init__.py: -------------------------------------------------------------------------------- 1 | from .eagle2 import Eagle2 -------------------------------------------------------------------------------- /samd/tree_model/eagle2/eagle2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from transformers import LlamaConfig, LlamaForCausalLM 4 | from typing import List, Tuple, Dict 5 | 6 | from ...samd_config import SamdConfig 7 | from ..tree import TreeModel 8 | from .eagle2_config import Eagle2Config 9 | from .eagle2_model import Eagle2Model 10 | 11 | 12 | class Eagle2(TreeModel): 13 | 14 | def __init__(self, 15 | config: SamdConfig, 16 | lm: LlamaForCausalLM, 17 | dtype: torch.dtype, 18 | device: str, 19 | ) -> None: 20 | super().__init__() 21 | self.dtype = dtype 22 | self.device = device 23 | self.head: torch.nn.Linear = lm.lm_head 24 | self.model: Eagle2Model = Eagle2Model( 25 | config=Eagle2Config(**config.tree_config), 26 | bias=config.tree_config.get("bias", True) 27 | ).to(device=device, dtype=dtype) 28 | self.model.load_weight(config.tree_model_path) 29 | self.model.init_tree() 30 | 31 | self.accpet_tokens: torch.Tensor = None 32 | self.accept_hidden_states: torch.Tensor = None 33 | 34 | def reset(self): 35 | self.model.stable_kv = None 36 | 37 | def update(self, 38 | tokens: torch.Tensor, 39 | last_hidden_states: torch.Tensor, 40 | **kwargs, 41 | ): 42 | tokens = tokens.to(self.device) 43 | if self.accpet_tokens is None: 44 | self.accpet_tokens = tokens 45 | else: 46 | self.accpet_tokens = torch.cat([self.accpet_tokens, tokens], dim=-1) 47 | if self.accept_hidden_states is None: 48 | self.accept_hidden_states = last_hidden_states 49 | else: 50 | self.accept_hidden_states = torch.cat([self.accept_hidden_states, last_hidden_states], dim=-2) 51 | 52 | def gen_draft(self, start_token: int) -> List[int]: 53 | start_token = torch.tensor([start_token], dtype=torch.long, device=self.device) 54 | accpet_tokens = torch.cat((self.accpet_tokens, start_token), dim=-1) 55 | accept_hidden_states = self.accept_hidden_states 56 | self.accpet_tokens = self.accept_hidden_states = None 57 | pred_ids, buffers_kwargs = self.model.topk_genrate( 58 | accept_hidden_states, 59 | accpet_tokens, 60 | self.head, 61 | ) 62 | pred_ids = pred_ids.view(-1).tolist() 63 | return pred_ids, buffers_kwargs 64 | 65 | def gen_buffers(self): 66 | return { 67 | "tree_attn_mask": None, 68 | "tree_position_ids": None, 69 | "tree_retrieve_indices": None, 70 | } 71 | -------------------------------------------------------------------------------- /samd/tree_model/token_recycle/__init__.py: -------------------------------------------------------------------------------- 1 | from .token_recycle import TokenRecycle -------------------------------------------------------------------------------- /samd/tree_model/token_recycle/token_recycle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict 3 | from dataclasses import dataclass 4 | from copy import deepcopy 5 | from collections import deque 6 | from tqdm import tqdm 7 | from transformers import LlamaConfig, LlamaForCausalLM 8 | 9 | from ...samd_config import SamdConfig 10 | from ..tree import TreeModel 11 | from .utils import ( 12 | pad_path, 13 | gen_buffers 14 | ) 15 | 16 | TOPK = 8 17 | 18 | class TokenRecycle(TreeModel): 19 | 20 | def __init__(self, 21 | config: SamdConfig, 22 | lm: LlamaForCausalLM, 23 | dtype: torch.dtype, 24 | device: str, 25 | ) -> None: 26 | super().__init__() 27 | self.samd_config = config 28 | self.dtype = dtype 29 | self.device = device 30 | self.tree = config.tree 31 | self.cache = {} 32 | 33 | def reset(self): 34 | pass # do nothting 35 | 36 | def logits_to_topk(self, logits: torch.Tensor) -> List[List[int]]: 37 | topk_nest = logits.topk(k=TOPK).indices.tolist() 38 | return topk_nest 39 | 40 | def update(self, 41 | tree_tokens: torch.Tensor, 42 | tree_logits: torch.Tensor, 43 | **kwargs 44 | ): 45 | tree_tokens = tree_tokens.tolist() 46 | topk_nest = self.logits_to_topk(tree_logits) 47 | for token, topk in zip(tree_tokens, topk_nest): 48 | self.cache[token] = topk 49 | 50 | def gen_draft(self, start_token: int) -> List[int]: 51 | tree_tokens = [start_token] + [0] * (len(self.tree) - 1) 52 | for node_id, childs in enumerate(self.tree): 53 | token = tree_tokens[node_id] 54 | if token not in self.cache: 55 | continue 56 | topk = self.cache[token] 57 | for child_id, child in enumerate(childs): 58 | tree_tokens[child] = topk[child_id] 59 | buffers_kwargs = {} 60 | return tree_tokens, buffers_kwargs 61 | 62 | def gen_buffers(self) -> Dict[str, torch.Tensor]: 63 | return gen_buffers(self.samd_config.tree, self.device) 64 | -------------------------------------------------------------------------------- /samd/tree_model/token_recycle/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict 3 | from dataclasses import dataclass 4 | from copy import deepcopy 5 | from collections import deque 6 | from tqdm import tqdm 7 | 8 | from ...samd_config import SamdConfig 9 | 10 | def pad_path(path, length, pad_value=-1): 11 | """ 12 | Pad the given path list with a specific value up to a specified length. 13 | 14 | Parameters: 15 | - path (list): The original list that needs padding. 16 | - length (int): The desired length of the padded list. 17 | - pad_value (optional, default=-1): The value to use for padding. 18 | 19 | Returns: 20 | - list: A new list based on the original path but padded to the desired length. 21 | 22 | Example: 23 | >>> pad_path([1,2,3], 5) 24 | [1, 2, 3, -1, -1] 25 | 26 | Note: 27 | If the given path is already longer than the specified length, 28 | then no padding occurs, and the original path is returned. 29 | """ 30 | 31 | # Calculate the number of padding values needed by subtracting the length 32 | # of the path from the desired length. 33 | # Append the padding values to the original path and return the new list. 34 | return path + [pad_value] * (length - len(path)) 35 | 36 | 37 | def gen_buffers( 38 | tree: List[List[int]], 39 | device: torch.device 40 | ) -> Dict[str, torch.Tensor]: 41 | """ 42 | Generate buffers for the SD based on the provided bfs tree. 43 | 44 | Parameters: 45 | - tree (List[List[int]]): A nested list represent the SD tree structure. 46 | - device (torch.device): The device to save the tensors. 47 | 48 | Returns: 49 | - dict: A dictionary containing buffers related to the SD structure. 50 | """ 51 | num_nodes = len(tree) 52 | 53 | anc_dict = {0: -1} 54 | for node_id, childs in enumerate(tree): 55 | for child in childs: 56 | anc_dict[child] = node_id 57 | 58 | level_dict = {0: 0} 59 | for node_id in range(1, num_nodes): 60 | level_dict[node_id] = level_dict[anc_dict[node_id]] + 1 61 | 62 | # Create the attention mask for Medusa 63 | tree_attn_mask = torch.eye(num_nodes, num_nodes) 64 | for node_id in range(num_nodes): 65 | ancs = [node_id] 66 | x = node_id 67 | while x != -1: 68 | ancs.append(x) 69 | x = anc_dict[x] 70 | ancs = torch.tensor(ancs, dtype=torch.long) 71 | tree_attn_mask[node_id, ancs] = True 72 | tree_attn_mask = tree_attn_mask.view(1, 1, num_nodes, num_nodes) 73 | 74 | tree_position_ids = torch.zeros((1, num_nodes), dtype=torch.long) 75 | for i in range(num_nodes): 76 | tree_position_ids[:, i] = level_dict[i] 77 | 78 | max_level = max(level_dict.values()) + 1 79 | retrieve_indices_nest = [] 80 | for node_id, childs in enumerate(tree): 81 | if len(childs) != 0: 82 | continue 83 | retrieve_indices = [node_id] 84 | while retrieve_indices[-1] != 0: 85 | retrieve_indices.append(anc_dict[retrieve_indices[-1]]) 86 | retrieve_indices_nest.append(list(reversed(retrieve_indices))) 87 | 88 | retrieve_indices_nest = reversed(retrieve_indices_nest) 89 | retrieve_indices_nest = [pad_path(x, max_level) for x in retrieve_indices_nest] 90 | tree_retrieve_indices = torch.tensor(retrieve_indices_nest, dtype=torch.long) 91 | 92 | tree_buffers = { 93 | "tree_attn_mask": tree_attn_mask, 94 | "tree_position_ids": tree_position_ids, 95 | "tree_retrieve_indices": tree_retrieve_indices, 96 | } 97 | 98 | tree_buffers = {k: (v.to(device) if v is not None else v) for k, v in tree_buffers.items()} 99 | return tree_buffers 100 | -------------------------------------------------------------------------------- /samd/tree_model/tree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict 3 | from dataclasses import dataclass 4 | from copy import deepcopy 5 | from collections import deque 6 | from tqdm import tqdm 7 | 8 | 9 | class TreeModel(torch.nn.Module): 10 | 11 | def __init__(self, 12 | samd_config=None, 13 | lm_config=None, 14 | lm=None, 15 | dtype: torch.dtype=None, 16 | device: str=None, 17 | ) -> None: 18 | super().__init__() 19 | 20 | def reset(self): 21 | raise NotImplementedError 22 | 23 | def update(self, tokens: List[int], topk_nest: List[List[int]]): 24 | raise NotImplementedError 25 | 26 | def gen_draft(self, start_token: int) -> Tuple[List[int], Dict[str, torch.Tensor]]: 27 | raise NotImplementedError 28 | 29 | def gen_buffers(self): 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /samd_sam_only/__init__.py: -------------------------------------------------------------------------------- 1 | from .samd_config import SamdConfig 2 | from .samd_model import SamdModel 3 | from .utils import SamdGenerationConfig 4 | from .sam import build_sam, load_sam, dump_sam 5 | from .draft import DraftModel -------------------------------------------------------------------------------- /samd_sam_only/cache.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers import PretrainedConfig 4 | from transformers.cache_utils import DynamicCache, Cache 5 | from typing import Optional, Dict, Any, Tuple, List 6 | from profile_utils import profile_decorator 7 | 8 | class SamdCache(DynamicCache): 9 | 10 | def __init__(self, num_hidden_layers: int | None = None) -> None: 11 | super().__init__(num_hidden_layers) 12 | self.cache_length = 0 13 | 14 | def set_length(self): 15 | self.cache_length = self.get_seq_length() 16 | 17 | @profile_decorator("SamdCache.select_indices") 18 | def select_indices(self, 19 | indices: torch.Tensor | None = None, 20 | accept_length: int = 1, 21 | ): 22 | start = self.cache_length 23 | if indices is not None: 24 | select_indices = start + indices 25 | else: 26 | select_indices = None 27 | for data in self.key_cache + self.value_cache: 28 | if select_indices is not None: 29 | select_indices = select_indices.to(data.device) 30 | tgt = data.index_select(-2, select_indices) 31 | dst = data.narrow(-2, start, accept_length) 32 | dst.copy_(tgt) 33 | self.cache_length += accept_length 34 | self.crop(self.cache_length) 35 | 36 | 37 | class SamdStaticCache(Cache): 38 | 39 | def __init__(self, 40 | config, 41 | batch_size = None, 42 | max_cache_len = None, 43 | device = None, 44 | dtype = torch.float32, 45 | max_batch_size = None, 46 | hf_device_map = None, 47 | ): 48 | super().__init__() 49 | if len(hf_device_map) <= 1: 50 | device = device 51 | layer_device_map = None 52 | else: 53 | device = None 54 | layer_device_map = {} 55 | for i in range(config.num_hidden_layers): 56 | layer_device_map[i] = hf_device_map[f"model.layers.{i}"] 57 | self.batch_size = batch_size or max_batch_size 58 | self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len 59 | 60 | # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads 61 | self.head_dim = ( 62 | config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads 63 | ) 64 | 65 | self.dtype = dtype 66 | self.num_key_value_heads = ( 67 | config.num_attention_heads 68 | if getattr(config, "num_key_value_heads", None) is None 69 | else config.num_key_value_heads 70 | ) 71 | 72 | self.key_cache: List[torch.Tensor] = [] 73 | self.value_cache: List[torch.Tensor] = [] 74 | # Note: There will be significant perf decrease if switching to use 5D tensors instead. 75 | cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) 76 | for idx in range(config.num_hidden_layers): 77 | if layer_device_map is not None: 78 | layer_device = layer_device_map[idx] 79 | else: 80 | layer_device = device 81 | new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) 82 | new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) 83 | self.key_cache.append(new_layer_key_cache) 84 | self.value_cache.append(new_layer_value_cache) 85 | 86 | self.last_length = 0 87 | self.cache_length = 0 88 | 89 | def reset(self): 90 | self.cache_length = 0 91 | self.last_length = 0 92 | 93 | def set_length(self): 94 | self.cache_length = self.last_length 95 | 96 | def get_seq_length(self, layer_idx = 0): 97 | return self.cache_length 98 | 99 | def get_max_cache_shape(self) -> Optional[int]: 100 | """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" 101 | return self.max_cache_len 102 | 103 | def update(self, key_states, value_states, layer_idx, cache_kwargs = None): 104 | k_out = self.key_cache[layer_idx] 105 | v_out = self.value_cache[layer_idx] 106 | k_dst = k_out.narrow(2, self.cache_length, key_states.shape[2]) 107 | v_dst = v_out.narrow(2, self.cache_length, value_states.shape[2]) 108 | k_dst.copy_(key_states) 109 | v_dst.copy_(value_states) 110 | if layer_idx == 0: 111 | self.last_length = self.cache_length + key_states.shape[2] 112 | return ( 113 | k_out.narrow(2, 0, self.last_length), 114 | v_out.narrow(2, 0, self.last_length), 115 | ) 116 | 117 | # @profile_decorator("SamdCache.select_indices") 118 | def select_indices(self, 119 | indices: torch.Tensor | None = None, 120 | accept_length: int = 1, 121 | ): 122 | start = self.cache_length 123 | if indices is not None: 124 | select_indices = start + indices 125 | else: 126 | select_indices = None 127 | for data in self.key_cache + self.value_cache: 128 | if select_indices is not None: 129 | select_indices = select_indices.to(data.device) 130 | tgt = data.index_select(-2, select_indices) 131 | dst = data.narrow(-2, start, accept_length) 132 | dst.copy_(tgt, non_blocking=True) 133 | self.cache_length += accept_length 134 | -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [], 46 | "9": [], 47 | "10": [], 48 | "11": [], 49 | "12": [], 50 | "14": [], 51 | "15": [], 52 | "16": [], 53 | "17": [], 54 | "18": [], 55 | "19": [], 56 | "20": [], 57 | "22": [], 58 | "23": [] 59 | } 60 | -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree.json.bak0: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8 11 | ], 12 | "1": [ 13 | 9, 14 | 10, 15 | 11, 16 | 12, 17 | 13, 18 | 14, 19 | 15, 20 | 16 21 | ], 22 | "2": [ 23 | 17, 24 | 18, 25 | 19, 26 | 20 27 | ], 28 | "3": [ 29 | 21, 30 | 22, 31 | 23 32 | ], 33 | "4": [ 34 | 24, 35 | 25 36 | ], 37 | "5": [ 38 | 26 39 | ], 40 | "6": [ 41 | 27 42 | ], 43 | "7": [ 44 | 28 45 | ], 46 | "8": [ 47 | 29 48 | ], 49 | "9": [ 50 | 30, 51 | 31, 52 | 32, 53 | 33, 54 | 34, 55 | 35, 56 | 36, 57 | 37 58 | ], 59 | "10": [ 60 | 38, 61 | 39, 62 | 40 63 | ], 64 | "11": [ 65 | 41, 66 | 42 67 | ], 68 | "12": [ 69 | 43 70 | ], 71 | "13": [ 72 | 44 73 | ], 74 | "14": [ 75 | 45 76 | ], 77 | "15": [ 78 | 46 79 | ], 80 | "16": [ 81 | 47 82 | ], 83 | "17": [ 84 | 48, 85 | 49 86 | ], 87 | "18": [ 88 | 50 89 | ], 90 | "21": [ 91 | 51 92 | ], 93 | "24": [ 94 | 52 95 | ], 96 | "26": [ 97 | 53 98 | ], 99 | "27": [ 100 | 54 101 | ], 102 | "30": [ 103 | 55, 104 | 56, 105 | 57, 106 | 58, 107 | 59 108 | ], 109 | "31": [ 110 | 60, 111 | 61 112 | ], 113 | "32": [ 114 | 62 115 | ], 116 | "19": [], 117 | "20": [], 118 | "22": [], 119 | "23": [], 120 | "25": [], 121 | "28": [], 122 | "29": [], 123 | "33": [ 124 | 63 125 | ], 126 | "34": [], 127 | "35": [], 128 | "36": [], 129 | "37": [], 130 | "38": [], 131 | "39": [], 132 | "40": [], 133 | "41": [], 134 | "42": [], 135 | "43": [], 136 | "44": [], 137 | "45": [], 138 | "46": [], 139 | "47": [], 140 | "48": [], 141 | "49": [], 142 | "50": [], 143 | "51": [], 144 | "52": [], 145 | "53": [], 146 | "54": [], 147 | "55": [], 148 | "56": [], 149 | "57": [], 150 | "58": [], 151 | "59": [], 152 | "60": [], 153 | "61": [], 154 | "62": [], 155 | "63": [] 156 | } -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree.json.bak1: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 2, 3, 4 4 | ], 5 | "1": [ 6 | 5 7 | ], 8 | "2": [], 9 | "3": [], 10 | "4": [], 11 | "5": [ 12 | 6 13 | ], 14 | "6": [ 15 | 7 16 | ], 17 | "7": [] 18 | } 19 | -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree.json.bak2: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1 4 | ], 5 | "1": [ 6 | 2 7 | ], 8 | "2": [ 9 | 3 10 | ], 11 | "3": [ 12 | 4 13 | ], 14 | "4": [] 15 | } 16 | -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree.json.bak3: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [], 46 | "9": [], 47 | "10": [], 48 | "11": [], 49 | "12": [], 50 | "14": [], 51 | "15": [], 52 | "16": [], 53 | "17": [], 54 | "18": [], 55 | "19": [], 56 | "20": [], 57 | "22": [], 58 | "23": [] 59 | } 60 | -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree.json.bak4: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [ 46 | 24, 47 | 25 48 | ], 49 | "9": [], 50 | "10": [], 51 | "11": [], 52 | "12": [], 53 | "14": [], 54 | "15": [], 55 | "16": [], 56 | "17": [], 57 | "18": [], 58 | "19": [], 59 | "20": [], 60 | "22": [], 61 | "23": [], 62 | "24": [ 63 | 25 64 | ], 65 | "25": [ 66 | 26 67 | ], 68 | "26": [ 69 | 27 70 | ], 71 | "27": [ 72 | 28 73 | ], 74 | "28": [] 75 | } 76 | -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree.json.bak5: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4 7 | ], 8 | "1": [ 9 | 5, 10 | 6, 11 | 7 12 | ], 13 | "2": [ 14 | 8, 15 | 9 16 | ], 17 | "3": [ 18 | 10, 19 | 11 20 | ], 21 | "4": [ 22 | 12 23 | ], 24 | "5": [ 25 | 13, 26 | 14, 27 | 15 28 | ], 29 | "6": [ 30 | 16, 31 | 17 32 | ], 33 | "7": [ 34 | 18, 35 | 19 36 | ], 37 | "8": [ 38 | 20 39 | ], 40 | "13": [ 41 | 21, 42 | 22, 43 | 23 44 | ], 45 | "21": [ 46 | 24, 47 | 25 48 | ], 49 | "9": [], 50 | "10": [], 51 | "11": [], 52 | "12": [], 53 | "14": [], 54 | "15": [], 55 | "16": [], 56 | "17": [], 57 | "18": [], 58 | "19": [], 59 | "20": [], 60 | "22": [], 61 | "23": [], 62 | "24": [ 63 | 25 64 | ], 65 | "25": [ 66 | 26 67 | ], 68 | "26": [ 69 | 27 70 | ], 71 | "27": [ 72 | 28 73 | ], 74 | "28": [] 75 | } 76 | -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree_1_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [] 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /samd_sam_only/config/default_tree_6_60.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [ 4 | 1, 5 | 2, 6 | 3, 7 | 4, 8 | 5, 9 | 6, 10 | 7 11 | ], 12 | "1": [ 13 | 8, 14 | 9, 15 | 10, 16 | 11, 17 | 12, 18 | 13 19 | ], 20 | "2": [ 21 | 14, 22 | 15, 23 | 16, 24 | 17, 25 | 18 26 | ], 27 | "3": [ 28 | 19, 29 | 20, 30 | 21 31 | ], 32 | "4": [ 33 | 22, 34 | 23 35 | ], 36 | "5": [ 37 | 24, 38 | 25 39 | ], 40 | "6": [ 41 | 26 42 | ], 43 | "7": [ 44 | 27 45 | ], 46 | "8": [ 47 | 28, 48 | 29, 49 | 30 50 | ], 51 | "13": [ 52 | ], 53 | "9": [ 54 | 31, 55 | 32 56 | ], 57 | "10": [ 58 | 33, 59 | 34 60 | ], 61 | "11": [ 62 | 35 63 | ], 64 | "12": [ 65 | 36 66 | ], 67 | "14": [ 68 | 37, 69 | 38, 70 | 39 71 | ], 72 | "15": [ 73 | 40, 74 | 41 75 | ], 76 | "16": [ 77 | 42 78 | ], 79 | "17": [ 80 | 43 81 | ], 82 | "18": [ 83 | ], 84 | "19": [ 85 | 44 86 | ], 87 | "20": [ 88 | 45 89 | ], 90 | "21": [ 91 | ], 92 | "22": [ 93 | 46 94 | ], 95 | "23": [ 96 | ], 97 | "24": [ 98 | 47 99 | ], 100 | "25": [ 101 | ], 102 | "26": [ 103 | 48 104 | ], 105 | "27": [ 106 | ], 107 | "28": [ 108 | 49, 109 | 50 110 | ], 111 | "29": [ 112 | 51 113 | ], 114 | "30": [ 115 | ], 116 | "31": [ 117 | 52 118 | ], 119 | "32": [], 120 | "33": [ 121 | ], 122 | "34": [ 123 | ], 124 | "35": [ 125 | ], 126 | "36": [], 127 | "37": [ 128 | 53, 129 | 54 130 | ], 131 | "38": [ 132 | ], 133 | "39": [], 134 | "40": [ 135 | 55 136 | ], 137 | "41": [], 138 | "42": [], 139 | "43": [ 140 | ], 141 | "44": [ 142 | 56 143 | ], 144 | "45": [], 145 | "46": [ 146 | ], 147 | "47": [], 148 | "48": [], 149 | "49": [ 150 | 57, 151 | 58 152 | ], 153 | "50": [], 154 | "51": [ 155 | 59 156 | ], 157 | "52": [ 158 | ], 159 | "53": [ 160 | 60 161 | ], 162 | "54": [ 163 | ], 164 | "55": [ 165 | ], 166 | "56": [], 167 | "57": [], 168 | "58": [], 169 | "59": [], 170 | "60": [] 171 | } 172 | } -------------------------------------------------------------------------------- /samd_sam_only/config/eagle.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_choices": [ 3 | [ 4 | 0 5 | ], 6 | [ 7 | 1 8 | ], 9 | [ 10 | 2 11 | ], 12 | [ 13 | 3 14 | ], 15 | [ 16 | 0, 17 | 0 18 | ], 19 | [ 20 | 0, 21 | 1 22 | ], 23 | [ 24 | 0, 25 | 2 26 | ], 27 | [ 28 | 1, 29 | 0 30 | ], 31 | [ 32 | 1, 33 | 1 34 | ], 35 | [ 36 | 2, 37 | 0 38 | ], 39 | [ 40 | 2, 41 | 1 42 | ], 43 | [ 44 | 3, 45 | 0 46 | ], 47 | [ 48 | 0, 49 | 0, 50 | 0 51 | ], 52 | [ 53 | 0, 54 | 0, 55 | 1 56 | ], 57 | [ 58 | 0, 59 | 0, 60 | 2 61 | ], 62 | [ 63 | 0, 64 | 1, 65 | 0 66 | ], 67 | [ 68 | 0, 69 | 1, 70 | 1 71 | ], 72 | [ 73 | 0, 74 | 2, 75 | 0 76 | ], 77 | [ 78 | 0, 79 | 2, 80 | 1 81 | ], 82 | [ 83 | 1, 84 | 0, 85 | 0 86 | ], 87 | [ 88 | 0, 89 | 0, 90 | 0, 91 | 0 92 | ], 93 | [ 94 | 0, 95 | 0, 96 | 0, 97 | 1 98 | ], 99 | [ 100 | 0, 101 | 0, 102 | 0, 103 | 2 104 | ], 105 | [ 106 | 0, 107 | 0, 108 | 0, 109 | 0, 110 | 0 111 | ], 112 | [ 113 | 0, 114 | 0, 115 | 0, 116 | 0, 117 | 1 118 | ] 119 | ] 120 | } -------------------------------------------------------------------------------- /samd_sam_only/config/token_recycle.json: -------------------------------------------------------------------------------- 1 | { 2 | "tree_adj": { 3 | "0": [ 4 | 1, 5 | 2, 6 | 3, 7 | 4, 8 | 5, 9 | 6, 10 | 7 11 | ], 12 | "1": [ 13 | 8, 14 | 9, 15 | 10, 16 | 11, 17 | 12, 18 | 13 19 | ], 20 | "2": [ 21 | 14, 22 | 15, 23 | 16, 24 | 17, 25 | 18 26 | ], 27 | "3": [ 28 | 19, 29 | 20, 30 | 21 31 | ], 32 | "4": [ 33 | 22, 34 | 23 35 | ], 36 | "5": [ 37 | 24, 38 | 25 39 | ], 40 | "6": [ 41 | 26 42 | ], 43 | "7": [ 44 | 27 45 | ], 46 | "8": [ 47 | 28, 48 | 29, 49 | 30 50 | ], 51 | "13": [ 52 | ], 53 | "9": [ 54 | 31, 55 | 32 56 | ], 57 | "10": [ 58 | 33, 59 | 34 60 | ], 61 | "11": [ 62 | 35 63 | ], 64 | "12": [ 65 | 36 66 | ], 67 | "14": [ 68 | 37, 69 | 38, 70 | 39 71 | ], 72 | "15": [ 73 | 40, 74 | 41 75 | ], 76 | "16": [ 77 | 42 78 | ], 79 | "17": [ 80 | 43 81 | ], 82 | "18": [ 83 | ], 84 | "19": [ 85 | 44 86 | ], 87 | "20": [ 88 | 45 89 | ], 90 | "21": [ 91 | ], 92 | "22": [ 93 | 46 94 | ], 95 | "23": [ 96 | ], 97 | "24": [ 98 | 47 99 | ], 100 | "25": [ 101 | ], 102 | "26": [ 103 | 48 104 | ], 105 | "27": [ 106 | ], 107 | "28": [ 108 | 49, 109 | 50 110 | ], 111 | "29": [ 112 | 51 113 | ], 114 | "30": [ 115 | ], 116 | "31": [ 117 | 52 118 | ], 119 | "32": [], 120 | "33": [ 121 | ], 122 | "34": [ 123 | ], 124 | "35": [ 125 | ], 126 | "36": [], 127 | "37": [ 128 | 53, 129 | 54 130 | ], 131 | "38": [ 132 | ], 133 | "39": [], 134 | "40": [ 135 | 55 136 | ], 137 | "41": [], 138 | "42": [], 139 | "43": [ 140 | ], 141 | "44": [ 142 | 56 143 | ], 144 | "45": [], 145 | "46": [ 146 | ], 147 | "47": [], 148 | "48": [], 149 | "49": [ 150 | 57, 151 | 58 152 | ], 153 | "50": [], 154 | "51": [ 155 | 59 156 | ], 157 | "52": [ 158 | ], 159 | "53": [ 160 | 60 161 | ], 162 | "54": [ 163 | ], 164 | "55": [ 165 | ], 166 | "56": [], 167 | "57": [], 168 | "58": [], 169 | "59": [], 170 | "60": [] 171 | } 172 | } -------------------------------------------------------------------------------- /samd_sam_only/draft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple, Dict, Optional 3 | from enum import Enum 4 | from collections import namedtuple 5 | 6 | from .samd_config import SamdConfig 7 | from .sam import DynSAM, StaticSAM 8 | from profile_utils import profile_decorator 9 | from transformers import LlamaConfig, LlamaForCausalLM 10 | 11 | # from transformers import LlamaTokenizer 12 | # tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained('/data/models/vicuna-7b-v1.3') 13 | 14 | class CandidateType(str, Enum): 15 | sequence = "sequence" 16 | tree = "tree" 17 | 18 | Candidates = namedtuple('Candidates', ['type', 'tokens', 'candidate_tokens', 'buffers_kwargs']) 19 | 20 | TOPK = 8 21 | 22 | class DraftModel(torch.nn.Module): 23 | 24 | def __init__(self, 25 | config: SamdConfig, 26 | sam_dyn: DynSAM = None, 27 | sam_static: StaticSAM = None, 28 | lm: LlamaForCausalLM = None, 29 | dtype: torch.dtype = torch.float16, 30 | device: str = "cuda", 31 | ) -> None: 32 | super().__init__() 33 | self.config = config 34 | self.sam_dyn = sam_dyn if sam_dyn is not None else DynSAM(config.max_predicts, config.alpha, device) 35 | self.sam_static = sam_static if sam_static is not None else StaticSAM(config.max_predicts, config.alpha, device) 36 | self.sam_dyn.max_predicts = config.max_predicts 37 | self.sam_dyn.alpha = config.alpha 38 | self.sam_static.max_predicts = config.max_predicts 39 | self.sam_static.alpha = config.alpha 40 | self.sam_static.K = config.K 41 | self.sam_static.device = device 42 | self.len_bias = config.len_bias 43 | 44 | @profile_decorator("DraftModel.reset") 45 | def reset(self): 46 | self.sam_dyn.reset() 47 | self.sam_static.reset() 48 | 49 | @profile_decorator("DraftModel.lookup") 50 | def lookup(self, start_token: int): 51 | index_dyn, match_dyn = self.sam_dyn.lookup(start_token) 52 | index_static, match_static = self.sam_static.lookup(start_token) 53 | match_static -= self.len_bias 54 | if match_dyn >= match_static: 55 | seq, buffers_kwargs = self.sam_dyn.gen_draft(index_dyn, match_dyn, start_token) 56 | return (CandidateType.sequence, seq, buffers_kwargs) 57 | else: 58 | tree, buffers_kwargs = self.sam_static.gen_draft(index_static, match_static, start_token) 59 | return (CandidateType.tree, tree, buffers_kwargs) 60 | 61 | @profile_decorator("DraftModel.update") 62 | def update(self, 63 | tokens: Optional[torch.Tensor] = None, 64 | ): 65 | tokens_list = tokens.tolist() 66 | self.sam_dyn.add_tokens(tokens_list) 67 | self.sam_static.transfer_tokens(tokens_list) 68 | 69 | @profile_decorator("DraftModel.prefill_update") 70 | def prefill_update(self, 71 | tokens: Optional[torch.Tensor] = None, 72 | ): 73 | self.update(tokens) 74 | -------------------------------------------------------------------------------- /samd_sam_only/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyx1999/SAM-Decoding/18c41f055b424fa3fa0bac41a8953d34cea1ed77/samd_sam_only/inference/__init__.py -------------------------------------------------------------------------------- /samd_sam_only/model_patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import llama_patch_dict, llama_attn_patch_dict 2 | 3 | patch_dict = {} 4 | attn_patch_dict = {} 5 | 6 | patch_dict.update(llama_patch_dict) 7 | attn_patch_dict.update(llama_attn_patch_dict) 8 | -------------------------------------------------------------------------------- /samd_sam_only/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # from .sam import DynSAM, StaticSAM 2 | from .dyn_sam import DynSAM 3 | from .static_sam import StaticSAM 4 | from .utils import build_sam, dump_sam, load_sam 5 | -------------------------------------------------------------------------------- /samd_sam_only/sam/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | from datasets import Dataset 4 | from transformers import PreTrainedTokenizerFast 5 | from typing import List 6 | 7 | from .static_sam import StaticSAM 8 | from ..samd_config import SamdConfig 9 | 10 | def build_sam( 11 | batch_tokens: List[List[int]], 12 | eos_token: int, 13 | ): 14 | sam = StaticSAM.build( 15 | batch_tokens, 16 | eos_token, 17 | ) 18 | return sam 19 | 20 | def dump_sam(path: str, sam: StaticSAM): 21 | with open(path, "wb") as f: 22 | pickle.dump(sam, f) 23 | 24 | def load_sam(path: str): 25 | print("load sam...") 26 | start = time.perf_counter() 27 | with open(path, "rb") as f: 28 | _sam = pickle.load(f) 29 | sam = StaticSAM() 30 | for key, value in vars(_sam).items(): 31 | if hasattr(sam, key): 32 | setattr(sam, key, value) 33 | print("load [{}]".format(key)) 34 | end = time.perf_counter() 35 | assert type(sam) is StaticSAM 36 | print("loading ended in {} seconds.".format(end - start)) 37 | if sam.states_topk_next is None: 38 | sam.init_topk_next() 39 | return sam 40 | -------------------------------------------------------------------------------- /samd_sam_only/samd_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from dataclasses import dataclass, field 5 | from typing import Optional, Union, List, Literal, Dict, Any 6 | from enum import Enum 7 | 8 | 9 | @dataclass 10 | class SamdConfig: 11 | max_predicts: int = field(default=60) 12 | alpha: float = field(default=4.0) 13 | K: int = field(default=8) 14 | len_bias: int = field(default=5) 15 | cache_type: Literal["dynamic", "static"] = field( 16 | default="static" 17 | ) 18 | 19 | class ForwardType(str, Enum): 20 | prefill = "prefill" 21 | seq_decode = "seq_decode" 22 | tree_decode = "tree_decode" 23 | 24 | 25 | class ForwardState: 26 | 27 | def __init__(self, forward_type: ForwardType | None) -> None: 28 | self.forward_type = forward_type 29 | 30 | 31 | class MaskState: 32 | 33 | def __init__(self, mask: Optional[torch.Tensor]) -> None: 34 | self.mask = mask 35 | 36 | def set_state(self, mask: Optional[torch.Tensor]) -> None: 37 | self.mask = mask 38 | 39 | 40 | def load_token_recycle(tree_path: Optional[str] = None): 41 | if tree_path is None: 42 | tree_path = "token_recycle.json" 43 | samd_path = os.path.dirname(__file__) 44 | with open(os.path.join(samd_path, "config", tree_path), "r") as f: 45 | tree_adj: dict = json.load(f)["tree_adj"] 46 | num_node = len(tree_adj) 47 | tree: List[List[int]] = [] 48 | for i in range(num_node): 49 | tree.append(tree_adj[str(i)]) 50 | print("tree_path:", tree_path) 51 | print("len_tree:", len(tree)) 52 | return tree 53 | 54 | 55 | def load_eagle(tree_model_path: str, tree_path: Optional[str] = None): 56 | if tree_path is None: 57 | tree_path = "eagle.json" 58 | samd_path = os.path.dirname(__file__) 59 | with open(os.path.join(samd_path, "config", tree_path), "r") as f: 60 | tree = json.load(f)["tree_choices"] 61 | with open(os.path.join(tree_model_path, "config.json")) as f: 62 | tree_config = json.load(f) 63 | return tree, tree_config 64 | 65 | 66 | def load_eagle2(tree_model_path: str): 67 | with open(os.path.join(tree_model_path, "config.json")) as f: 68 | tree_config = json.load(f) 69 | return tree_config 70 | -------------------------------------------------------------------------------- /scripts/equal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | CUDA_VISIBLE_DEVICES=${devices} \ 10 | python -m evaluation.equal \ 11 | --jsonfile1 vicuna-7b-v1.3.jsonl \ 12 | --jsonfile2 vicuna-7b-v1.3-sam_alpaca-v0.4.2.jsonl 13 | -------------------------------------------------------------------------------- /scripts/inference_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | # vicuna 10 | CUDA_VISIBLE_DEVICES=${devices} \ 11 | python -m evaluation.inference_baseline \ 12 | --model-type vicuna \ 13 | --bench-name spec_bench \ 14 | --model-path /data/models/vicuna-7b-v1.3 \ 15 | --model-id vicuna-7b-v1.3 16 | -------------------------------------------------------------------------------- /scripts/inference_eagle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | # vicuna-7b-v1.3 10 | CUDA_VISIBLE_DEVICES=${devices} python -m evaluation.inference_eagle \ 11 | --model-type vicuna \ 12 | --ea-model-path /data/models/EAGLE-Vicuna-7B-v1.3 \ 13 | --base-model-path /data/models/vicuna-7b-v1.3 \ 14 | --model-id vicuna-7b-v1.3-eagle \ 15 | --bench-name spec_bench \ 16 | --temperature 0 -------------------------------------------------------------------------------- /scripts/inference_eagle2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | # vicuna-7b-v1.3 10 | CUDA_VISIBLE_DEVICES=${devices} python -m evaluation.inference_eagle2 \ 11 | --model-type vicuna \ 12 | --ea-model-path /data/models/EAGLE-Vicuna-7B-v1.3 \ 13 | --base-model-path /data/models/vicuna-7b-v1.3 \ 14 | --model-id vicuna-7b-v1.3-eagle2 \ 15 | --bench-name spec_bench \ 16 | --temperature 0 17 | -------------------------------------------------------------------------------- /scripts/inference_pld.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | # vicuna-7b-v1.3 10 | CUDA_VISIBLE_DEVICES=${devices} \ 11 | python -m evaluation.inference_pld \ 12 | --model-type vicuna \ 13 | --bench-name spec_bench \ 14 | --model-path /data/models/vicuna-7b-v1.3 \ 15 | --model-id vicuna-7b-v1.3-pld 16 | -------------------------------------------------------------------------------- /scripts/inference_samd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | # vicuna-7b-v1.3 10 | CUDA_VISIBLE_DEVICES=${devices} \ 11 | python -m evaluation.inference_samd \ 12 | --model-type vicuna \ 13 | --bench-name spec_bench \ 14 | --model-path /data/models/vicuna-7b-v1.3 \ 15 | --model-id vicuna-7b-v1.3-samd-eagle2 \ 16 | --sam_path local_cache/sam_alpaca_vicuna-7b-v1.3_min-endpos.pkl \ 17 | --tree_method eagle2 \ 18 | --samd_n_predicts 40 \ 19 | --samd_len_threshold 5 \ 20 | --samd_len_bias 5 21 | -------------------------------------------------------------------------------- /scripts/inference_samd_sam_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | # vicuna-7b-v1.3 10 | CUDA_VISIBLE_DEVICES=${devices} \ 11 | python -m evaluation.inference_sam_only \ 12 | --model-type vicuna \ 13 | --bench-name spec_bench \ 14 | --model-path /data/models/vicuna-7b-v1.3 \ 15 | --model-id vicuna-7b-v1.3-samd_sam_only \ 16 | --sam_path local_cache/sam_alpaca_vicuna-7b-v1.3.pkl \ 17 | --samd_max_predicts 60 \ 18 | --samd_alpha 4.0 \ 19 | --samd_len_bias 0 20 | -------------------------------------------------------------------------------- /scripts/inference_token_recycle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | CUDA_VISIBLE_DEVICES=${devices} \ 10 | python -m evaluation.inference_token_recycle \ 11 | --model-type vicuna \ 12 | --bench-name spec_bench \ 13 | --model-path /data/models/vicuna-7b-v1.3 \ 14 | --model-id vicuna-7b-v1.3-token_recycle 15 | -------------------------------------------------------------------------------- /scripts/speed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | python -m evaluation.speed \ 8 | --file-path evaluation/data/spec_bench/model_answer/vicuna-7b-v1.3-samd_sam_only.jsonl 9 | 10 | python -m evaluation.speed \ 11 | --file-path evaluation/data/spec_bench/model_answer/vicuna-7b-v1.3-samd-token_recycle.jsonl 12 | 13 | python -m evaluation.speed \ 14 | --file-path evaluation/data/spec_bench/model_answer/vicuna-7b-v1.3-samd-eagle2.jsonl 15 | -------------------------------------------------------------------------------- /scripts/test_samd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | CUDA_VISIBLE_DEVICES=${devices} \ 10 | python -m tests.test_samd \ 11 | --sam_path local_cache/sam_alpaca_vicuna-7b-v1.3_min-endpos.pkl \ 12 | --model_path /data/models/vicuna-7b-v1.3 \ 13 | --device "cuda" \ 14 | --tree_method eagle2 15 | -------------------------------------------------------------------------------- /scripts/test_samd_sam_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | cd $(dirname $0)/.. 6 | 7 | devices=0 8 | 9 | CUDA_VISIBLE_DEVICES=${devices} \ 10 | python -m tests.test_samd_sam_only \ 11 | --sam_path local_cache/sam_alpaca_vicuna-7b-v1.3.pkl \ 12 | --model_path /data/models/vicuna-7b-v1.3 \ 13 | --device "cuda" 14 | -------------------------------------------------------------------------------- /tests/test_samd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from transformers import ( 4 | AutoModelForCausalLM, 5 | AutoTokenizer, 6 | GenerationConfig, 7 | GenerationMixin, 8 | LlamaConfig, 9 | LlamaTokenizer 10 | ) 11 | from samd import ( 12 | SamdConfig, 13 | SamdModel, 14 | SamdGenerationConfig, 15 | DraftModel, 16 | load_sam 17 | ) 18 | import time 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--model_path', type=str, required=True) 23 | parser.add_argument('--sam_path', type=str, default=None) 24 | parser.add_argument('--samd_n_predicts', type=int, default=15) 25 | parser.add_argument('--max_new_tokens', type=int, default=512) 26 | parser.add_argument('--max_cache_len', type=int, default=2048) 27 | parser.add_argument("--tree_method", type=str, default="token_recycle") 28 | parser.add_argument("--tree_model_path", type=str, default="/data/models/EAGLE-Vicuna-7B-v1.3") 29 | parser.add_argument('--dtype', type=str, default='float16', choices=['float16', 'float32']) 30 | parser.add_argument('--device', type=str, default="cuda", choices=['cuda', 'cpu']) 31 | args = parser.parse_args() 32 | args.dtype = { 33 | 'float16': torch.float16, 34 | 'float32': torch.float32, 35 | }[args.dtype] 36 | return args 37 | 38 | @torch.inference_mode() 39 | def generate(args, inputs, model, tokenizer): 40 | model.eval() 41 | assert inputs.input_ids.shape[-1] + args.max_new_tokens <= args.max_cache_len 42 | gen_config = SamdGenerationConfig( 43 | max_new_tokens=args.max_new_tokens, 44 | max_cache_len=args.max_cache_len, 45 | greedy=True, 46 | temperature=0.0 47 | ) 48 | st = time.perf_counter() 49 | tokens = model.generate(**inputs, generation_config=gen_config)[0] 50 | ed = time.perf_counter() 51 | response = tokenizer.decode(tokens) 52 | print("model inference time use: {} seconds".format(ed - st)) 53 | print("model response:\n{}".format(repr(response))) 54 | 55 | 56 | @torch.inference_mode() 57 | def samd_generate(args, inputs, model, tokenizer): 58 | assert inputs.input_ids.shape[-1] + args.max_new_tokens <= args.max_cache_len 59 | sam = load_sam(args.sam_path) if args.sam_path is not None else None 60 | samd_config = SamdConfig( 61 | n_predicts=args.samd_n_predicts, 62 | tree_method=args.tree_method, 63 | tree_model_path=args.tree_model_path, 64 | ) 65 | draft = DraftModel( 66 | samd_config, 67 | sam_static=sam, 68 | lm=model, 69 | dtype=args.dtype, 70 | device=args.device 71 | ) 72 | samd_model = SamdModel( 73 | samd_config, 74 | model, 75 | draft, 76 | tokenizer.eos_token_id, 77 | args.dtype, 78 | args.device, 79 | ) 80 | samd_model.eval() 81 | 82 | gen_config = SamdGenerationConfig( 83 | max_new_tokens=args.max_new_tokens, 84 | max_cache_len=args.max_cache_len, 85 | ) 86 | 87 | st = time.perf_counter() 88 | outputs = samd_model.generate(**inputs, generation_config=gen_config) 89 | ed = time.perf_counter() 90 | response = tokenizer.decode(outputs.output_ids[0]) 91 | print("model inference time use: {} seconds".format(ed - st)) 92 | print("samd_model response:\n{}".format(repr(response))) 93 | print("decode_steps: {}".format(outputs.decode_steps)) 94 | print("decode_tokens: {}".format(outputs.decode_tokens)) 95 | print("accepect_length_per_step: {}".format(outputs.accepet_length_per_step)) 96 | 97 | def main(): 98 | args = parse_args() 99 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 100 | model = AutoModelForCausalLM.from_pretrained( 101 | args.model_path, 102 | torch_dtype=args.dtype, 103 | device_map=args.device, 104 | ) 105 | model.eval() 106 | 107 | # prompts = ["A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: Give three tips for staying healthy.\n\nASSISTANT: "] 108 | 109 | # prompts = ['A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\nUSER: Please generate the following: "1, 2, 3, 4, 5, 6, 7, 8, 9, 10".\n\nASSISTANT: '] 110 | 111 | prompts = ["A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\nUSER: Embrace the role of Sheldon from \"The Big Bang Theory\" as we delve into our conversation. Don\u2019t start with phrases like \"As Sheldon\". Let's kick things off with the following question: \"What is your opinion on hand dryers?\"\n\nASSISTANT: "] 112 | 113 | inputs = tokenizer( 114 | prompts, 115 | padding=True, 116 | return_tensors="pt" 117 | ).to(args.device) 118 | 119 | # generate(args, inputs, model, tokenizer) 120 | 121 | samd_generate(args, inputs, model, tokenizer) 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /tests/test_samd_sam_only.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from transformers import ( 4 | AutoModelForCausalLM, 5 | AutoTokenizer, 6 | GenerationConfig, 7 | GenerationMixin, 8 | LlamaConfig, 9 | LlamaTokenizer 10 | ) 11 | from samd_sam_only import ( 12 | SamdConfig, 13 | SamdModel, 14 | SamdGenerationConfig, 15 | DraftModel, 16 | load_sam 17 | ) 18 | import time 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--model_path', type=str, required=True) 23 | parser.add_argument('--sam_path', type=str, default=None) 24 | parser.add_argument("--samd_max_predicts", type=int, default=40) 25 | parser.add_argument("--samd_alpha", type=float, default=4.0) 26 | parser.add_argument("--samd_len_bias", type=int, default=5) 27 | parser.add_argument('--max_new_tokens', type=int, default=512) 28 | parser.add_argument('--max_cache_len', type=int, default=2048) 29 | parser.add_argument('--dtype', type=str, default='float16', choices=['float16', 'float32']) 30 | parser.add_argument('--device', type=str, default="cuda", choices=['cuda', 'cpu']) 31 | args = parser.parse_args() 32 | args.dtype = { 33 | 'float16': torch.float16, 34 | 'float32': torch.float32, 35 | }[args.dtype] 36 | return args 37 | 38 | @torch.inference_mode() 39 | def generate(args, inputs, model, tokenizer): 40 | model.eval() 41 | assert inputs.input_ids.shape[-1] + args.max_new_tokens <= args.max_cache_len 42 | gen_config = SamdGenerationConfig( 43 | max_new_tokens=args.max_new_tokens, 44 | max_cache_len=args.max_cache_len, 45 | greedy=True, 46 | temperature=0.0 47 | ) 48 | st = time.perf_counter() 49 | tokens = model.generate(**inputs, generation_config=gen_config)[0] 50 | ed = time.perf_counter() 51 | response = tokenizer.decode(tokens) 52 | print("model inference time use: {} seconds".format(ed - st)) 53 | print("model response:\n{}".format(repr(response))) 54 | 55 | 56 | @torch.inference_mode() 57 | def samd_generate(args, inputs, model, tokenizer): 58 | assert inputs.input_ids.shape[-1] + args.max_new_tokens <= args.max_cache_len 59 | sam = load_sam(args.sam_path) if args.sam_path is not None else None 60 | samd_config = SamdConfig( 61 | max_predicts=args.samd_max_predicts, 62 | alpha=args.samd_alpha, 63 | len_bias=args.samd_len_bias, 64 | ) 65 | draft = DraftModel( 66 | samd_config, 67 | sam_static=sam, 68 | lm=model, 69 | dtype=args.dtype, 70 | device=args.device 71 | ) 72 | samd_model = SamdModel( 73 | samd_config, 74 | model, 75 | draft, 76 | tokenizer.eos_token_id, 77 | args.dtype, 78 | args.device, 79 | ) 80 | samd_model.eval() 81 | 82 | gen_config = SamdGenerationConfig( 83 | max_new_tokens=args.max_new_tokens, 84 | max_cache_len=args.max_cache_len, 85 | ) 86 | 87 | st = time.perf_counter() 88 | outputs = samd_model.generate(**inputs, generation_config=gen_config) 89 | ed = time.perf_counter() 90 | response = tokenizer.decode(outputs.output_ids[0]) 91 | print("model inference time use: {} seconds".format(ed - st)) 92 | print("samd_model response:\n{}".format(repr(response))) 93 | print("decode_steps: {}".format(outputs.decode_steps)) 94 | print("decode_tokens: {}".format(outputs.decode_tokens)) 95 | print("accepect_length_per_step: {}".format(outputs.accepet_length_per_step)) 96 | 97 | def main(): 98 | args = parse_args() 99 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 100 | model = AutoModelForCausalLM.from_pretrained( 101 | args.model_path, 102 | torch_dtype=args.dtype, 103 | device_map=args.device, 104 | ) 105 | model.eval() 106 | 107 | # prompts = ["A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: Give three tips for staying healthy.\n\nASSISTANT: "] 108 | 109 | # prompts = ['A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\nUSER: Please generate the following: "1, 2, 3, 4, 5, 6, 7, 8, 9, 10".\n\nASSISTANT: '] 110 | 111 | prompts = ["A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\nUSER: Embrace the role of Sheldon from \"The Big Bang Theory\" as we delve into our conversation. Don\u2019t start with phrases like \"As Sheldon\". Let's kick things off with the following question: \"What is your opinion on hand dryers?\"\n\nASSISTANT: "] 112 | 113 | inputs = tokenizer( 114 | prompts, 115 | padding=True, 116 | return_tensors="pt" 117 | ).to(args.device) 118 | 119 | # generate(args, inputs, model, tokenizer) 120 | 121 | samd_generate(args, inputs, model, tokenizer) 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /tests/test_token_recycle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from transformers import ( 4 | AutoModelForCausalLM, 5 | AutoTokenizer, 6 | GenerationConfig, 7 | GenerationMixin, 8 | LlamaConfig, 9 | LlamaTokenizer 10 | ) 11 | from evaluation.model.token_recycle import ( 12 | TokenRecycleConfig, 13 | TokenRecycleModel, 14 | TokenRecycleGenerationConfig, 15 | DraftModel 16 | ) 17 | import time 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--model_path', type=str, required=True) 22 | parser.add_argument('--samd_n_predicts', type=int, default=15) 23 | parser.add_argument('--max_new_tokens', type=int, default=512) 24 | parser.add_argument('--max_cache_len', type=int, default=2048) 25 | parser.add_argument('--dtype', type=str, default='float16', choices=['float16', 'float32']) 26 | parser.add_argument('--device', type=str, default="cuda", choices=['cuda', 'cpu']) 27 | args = parser.parse_args() 28 | args.dtype = { 29 | 'float16': torch.float16, 30 | 'float32': torch.float32, 31 | }[args.dtype] 32 | return args 33 | 34 | @torch.inference_mode() 35 | def generate(args, inputs, model, tokenizer): 36 | assert inputs.input_ids.shape[-1] + args.max_new_tokens <= args.max_cache_len 37 | gen_config = GenerationConfig( 38 | max_new_tokens=args.max_new_tokens, cache_implementation="static", 39 | cache_config = { 40 | "batch_size": 1, 41 | "max_cache_len": args.max_cache_len, 42 | } 43 | ) 44 | st = time.perf_counter() 45 | tokens = model.generate(**inputs, generation_config=gen_config)[0] 46 | ed = time.perf_counter() 47 | response = tokenizer.decode(tokens) 48 | print("model inference time use: {} seconds".format(ed - st)) 49 | print("model response:\n{}".format(repr(response))) 50 | 51 | 52 | @torch.inference_mode() 53 | def token_recycle_generate(args, inputs, model, tokenizer): 54 | assert inputs.input_ids.shape[-1] + args.max_new_tokens <= args.max_cache_len 55 | token_recycle_config = TokenRecycleConfig(n_predicts=args.samd_n_predicts) 56 | draft = DraftModel(token_recycle_config) 57 | token_recycle_model = TokenRecycleModel( 58 | token_recycle_config, 59 | model, 60 | draft, 61 | tokenizer.eos_token_id, 62 | args.dtype, 63 | args.device, 64 | ) 65 | 66 | gen_config = TokenRecycleGenerationConfig( 67 | max_new_tokens=args.max_new_tokens, 68 | max_cache_len=args.max_cache_len, 69 | ) 70 | 71 | st = time.perf_counter() 72 | outputs = token_recycle_model.generate(**inputs, generation_config=gen_config) 73 | ed = time.perf_counter() 74 | response = tokenizer.decode(outputs.output_ids[0]) 75 | print("model inference time use: {} seconds".format(ed - st)) 76 | print("token_recycle_model response:\n{}".format(repr(response))) 77 | print("decode_steps: {}".format(outputs.decode_steps)) 78 | print("decode_tokens: {}".format(outputs.decode_tokens)) 79 | print("accepect_length_per_step: {}".format(outputs.accepet_length_per_step)) 80 | 81 | def main(): 82 | args = parse_args() 83 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 84 | model = AutoModelForCausalLM.from_pretrained( 85 | args.model_path, 86 | torch_dtype=args.dtype, 87 | device_map=args.device, 88 | ) 89 | model.eval() 90 | 91 | prompts = ["A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: Give three tips for staying healthy.\n\nASSISTANT: "] 92 | 93 | # prompts = ['A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\nUSER: Please generate the following: "1, 2, 3, 4, 5, 6, 7, 8, 9, 10".\n\nASSISTANT: '] 94 | 95 | # prompts = ["A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\nUSER: Embrace the role of Sheldon from \"The Big Bang Theory\" as we delve into our conversation. Don\u2019t start with phrases like \"As Sheldon\". Let's kick things off with the following question: \"What is your opinion on hand dryers?\"\n\nASSISTANT: "] 96 | 97 | inputs = tokenizer( 98 | prompts, 99 | padding=True, 100 | return_tensors="pt" 101 | ).to(args.device) 102 | 103 | generate(args, inputs, model, tokenizer) 104 | 105 | token_recycle_generate(args, inputs, model, tokenizer) 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /tools/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import PreTrainedTokenizer, AutoTokenizer 4 | from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets 5 | from accelerate.logging import get_logger 6 | from itertools import chain 7 | from .prompter import Prompter 8 | 9 | logger = get_logger(__name__) 10 | 11 | def process_gsm8k( 12 | args, 13 | raw_dataset: Dataset 14 | ): 15 | column_names = raw_dataset.column_names 16 | prompter = Prompter(args.prompt_template_name) 17 | 18 | def generate_and_tokenize_prompt(data_point): 19 | full_prompt = prompter.generate_prompt( 20 | data_point["question"], 21 | ) 22 | return {"prompt": full_prompt} 23 | 24 | lm_dataset = raw_dataset.map( 25 | generate_and_tokenize_prompt, 26 | remove_columns=column_names, 27 | desc=f"Processing gsm8k datasets", 28 | ) 29 | 30 | return lm_dataset 31 | 32 | def process_alpaca( 33 | args, 34 | raw_dataset: Dataset 35 | ): 36 | column_names = raw_dataset.column_names 37 | prompter = Prompter(args.prompt_template_name) 38 | 39 | def generate_and_tokenize_prompt(data_point): 40 | full_prompt = prompter.generate_prompt( 41 | data_point["instruction"], 42 | data_point["input"], 43 | ) 44 | return {"prompt": full_prompt} 45 | 46 | lm_dataset = raw_dataset.map( 47 | generate_and_tokenize_prompt, 48 | remove_columns=column_names, 49 | desc=f"Processing alpaca like datasets", 50 | ) 51 | 52 | return lm_dataset 53 | -------------------------------------------------------------------------------- /tools/gen_default_tree.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | edges = [ 4 | [0, 1], [0, 2], [0, 3], [0, 4], 5 | [1, 5], [1, 6], [1, 7], 6 | [2, 8], [2, 9], 7 | [3, 10], [3, 11], 8 | [4, 12], 9 | [5, 13], 10 | [13, 14], 11 | [14, 15], 12 | [15, 16], 13 | [16, 17], 14 | [17, 18], 15 | [18, 19] 16 | ] 17 | 18 | N = max(sum(edges, [])) + 1 19 | 20 | print("N = {}".format(N)) 21 | 22 | childs = {} 23 | 24 | for i in range(len(edges)): 25 | u, v = tuple(edges[i]) 26 | if u not in childs: 27 | childs[u] = [] 28 | childs[u].append(v) 29 | 30 | for u in range(N): 31 | if u not in childs: 32 | childs[u] = [] 33 | 34 | print(json.dumps(childs, indent=4)) 35 | -------------------------------------------------------------------------------- /tools/gen_response.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datasets import load_from_disk, Dataset 4 | from vllm import LLM, SamplingParams 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--model_name', type=str, default='/data/models/vicuna-7b-v1.3') 8 | parser.add_argument('--sam_data_path', type=str, default='sam_data/sam_prompts') 9 | args = parser.parse_args() 10 | 11 | sam_dataset = load_from_disk(args.sam_data_path) 12 | 13 | prompts = sam_dataset["prompt"] 14 | print("number of prompts: {}".format(len(prompts))) 15 | 16 | llm = LLM(model=args.model_name, enable_prefix_caching=True) 17 | sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1024) 18 | 19 | outputs = llm.generate(prompts, sampling_params) 20 | 21 | sam_dialogues = [] 22 | # Print the outputs. 23 | for output in outputs: 24 | prompt = output.prompt 25 | generated_text = output.outputs[0].text 26 | # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") 27 | sam_dialogues.append({ 28 | "prompt": prompt, 29 | "response": generated_text, 30 | }) 31 | 32 | sam_dialogues = Dataset.from_list(sam_dialogues) 33 | sam_dialogues.save_to_disk("sam_data/sam_dialogues") 34 | -------------------------------------------------------------------------------- /tools/gen_sam_alpaca.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import AutoTokenizer 4 | from datasets import load_from_disk, Dataset 5 | from samd import SamdConfig, build_sam, dump_sam 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--model_name', type=str, default='/data/models/vicuna-7b-v1.3') 9 | parser.add_argument('--sam_data_path', type=str, default='sam_data/sam_dialogues') 10 | parser.add_argument('--cutoff_len', type=int, default=2048) 11 | parser.add_argument('--n_predicts', type=int, default=10) 12 | parser.add_argument('--sam_path', type=str, default="local_cache/sam_alpaca_vicuna-7b-v1.3_min-endpos.pkl") 13 | args = parser.parse_args() 14 | 15 | sam_data = load_from_disk(args.sam_data_path) 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 18 | 19 | def tokenize_fn(data_point, add_eos_token=False): 20 | text = data_point["prompt"] + data_point["response"] 21 | # there's probably a way to do this with the tokenizer settings 22 | # but again, gotta move fast 23 | result = tokenizer( 24 | text, 25 | padding=False, 26 | return_tensors=None, 27 | ) 28 | if ( 29 | result["input_ids"][-1] != tokenizer.eos_token_id 30 | and len(result["input_ids"]) < args.cutoff_len 31 | and add_eos_token 32 | ): 33 | result["input_ids"].append(tokenizer.eos_token_id) 34 | result["attention_mask"].append(1) 35 | return result 36 | 37 | column_names = sam_data.column_names 38 | 39 | batch_tokens = sam_data.map( 40 | tokenize_fn, 41 | desc=f"Processing sam dialogue datasets", 42 | )["input_ids"] 43 | for i in range(len(tokenizer)): 44 | batch_tokens.append([i]) 45 | 46 | sam = build_sam(batch_tokens, tokenizer.eos_token_id) 47 | 48 | model_name = args.model_name.split("/")[-1] 49 | dump_sam(args.sam_path, sam) 50 | -------------------------------------------------------------------------------- /tools/gen_sam_alpaca_sam_only.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import AutoTokenizer 4 | from datasets import load_from_disk, Dataset 5 | from samd_sam_only import SamdConfig, build_sam, dump_sam 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--model_name', type=str, default='/data/models/vicuna-7b-v1.3') 9 | parser.add_argument('--sam_data_path', type=str, default='sam_data/sam_dialogues') 10 | parser.add_argument('--cutoff_len', type=int, default=2048) 11 | parser.add_argument('--n_predicts', type=int, default=10) 12 | parser.add_argument('--sam_path', type=str, default="local_cache/sam_alpaca_vicuna-7b-v1.3.pkl") 13 | args = parser.parse_args() 14 | 15 | sam_data = load_from_disk(args.sam_data_path) 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 18 | 19 | def tokenize_fn(data_point, add_eos_token=False): 20 | text = data_point["prompt"] + data_point["response"] 21 | # there's probably a way to do this with the tokenizer settings 22 | # but again, gotta move fast 23 | result = tokenizer( 24 | text, 25 | padding=False, 26 | return_tensors=None, 27 | ) 28 | if ( 29 | result["input_ids"][-1] != tokenizer.eos_token_id 30 | and len(result["input_ids"]) < args.cutoff_len 31 | and add_eos_token 32 | ): 33 | result["input_ids"].append(tokenizer.eos_token_id) 34 | result["attention_mask"].append(1) 35 | return result 36 | 37 | column_names = sam_data.column_names 38 | 39 | batch_tokens = sam_data.map( 40 | tokenize_fn, 41 | desc=f"Processing sam dialogue datasets", 42 | )["input_ids"] 43 | for i in range(len(tokenizer)): 44 | batch_tokens.append([i]) 45 | 46 | sam = build_sam(batch_tokens, tokenizer.eos_token_id) 47 | 48 | model_name = args.model_name.split("/")[-1] 49 | dump_sam(args.sam_path, sam) 50 | -------------------------------------------------------------------------------- /tools/gen_sam_none.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import AutoTokenizer 4 | from datasets import load_from_disk, Dataset 5 | from samd import SamdConfig, build_sam, dump_sam 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--model_name', type=str, default='/data/models/vicuna-7b-v1.3') 9 | parser.add_argument('--cutoff_len', type=int, default=2048) 10 | parser.add_argument('--n_predicts', type=int, default=10) 11 | parser.add_argument('--sam_path', type=str, default="local_cache/sam_none_min-endpos.pkl") 12 | args = parser.parse_args() 13 | 14 | 15 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 16 | 17 | batch_tokens = [] 18 | for i in range(len(tokenizer)): 19 | batch_tokens.append([i]) 20 | 21 | sam = build_sam(batch_tokens, tokenizer.eos_token_id) 22 | 23 | model_name = args.model_name.split("/")[-1] 24 | dump_sam(args.sam_path, sam) 25 | -------------------------------------------------------------------------------- /tools/gen_sam_none_sam_only.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import AutoTokenizer 4 | from datasets import load_from_disk, Dataset 5 | from samd_sam_only import SamdConfig, build_sam, dump_sam 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--model_name', type=str, default='/data/models/vicuna-7b-v1.3') 9 | parser.add_argument('--cutoff_len', type=int, default=2048) 10 | parser.add_argument('--n_predicts', type=int, default=10) 11 | parser.add_argument('--sam_path', type=str, default="local_cache/sam_none.pkl") 12 | args = parser.parse_args() 13 | 14 | 15 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 16 | 17 | batch_tokens = [] 18 | for i in range(len(tokenizer)): 19 | batch_tokens.append([i]) 20 | 21 | sam = build_sam(batch_tokens, tokenizer.eos_token_id) 22 | 23 | model_name = args.model_name.split("/")[-1] 24 | dump_sam(args.sam_path, sam) 25 | -------------------------------------------------------------------------------- /tools/prepare_prompts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import PreTrainedTokenizer, AutoTokenizer 4 | from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets 5 | from .data_utils import process_alpaca, process_gsm8k 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--model_name', type=str, default='/data/models/vicuna-7b-v1.3') 9 | parser.add_argument('--cutoff_len', type=int, default=1024) 10 | parser.add_argument('--prompt_template_name', type=str, default='vicuna') 11 | args = parser.parse_args() 12 | 13 | model_name = args.model_name.split("/")[-1] 14 | 15 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True) 16 | 17 | alpaca_data = load_dataset('json', data_files='sam_data/alpaca-cleaned/alpaca_data_cleaned.json', split="train") 18 | code_data = load_dataset('parquet', data_files='sam_data/python_code_instructions_18k_alpaca/data/*.parquet', split="train") 19 | math_data = load_dataset('parquet', data_files='sam_data/gsm8k/main/train-*.parquet', split="train") 20 | 21 | alpaca_data = process_alpaca(args, alpaca_data) 22 | code_data = process_alpaca(args, code_data) 23 | math_data = process_gsm8k(args, math_data) 24 | 25 | sam_data: Dataset = concatenate_datasets([alpaca_data, code_data, math_data]) 26 | sam_data.save_to_disk(f'sam_data/sam_prompts') 27 | 28 | print(sam_data) 29 | print(sam_data[:10]) 30 | -------------------------------------------------------------------------------- /tools/prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dedicated helper to manage templates and prompt building. 3 | """ 4 | 5 | import json 6 | import os.path as osp 7 | from typing import Union 8 | 9 | 10 | class Prompter(object): 11 | __slots__ = ("template", "_verbose") 12 | 13 | def __init__(self, template_name: str = "", verbose: bool = False): 14 | self._verbose = verbose 15 | if not template_name: 16 | # Enforce the default here, so the constructor can be called with '' and will not break. 17 | template_name = "alpaca" 18 | file_name = osp.join("sam_data", "templates", f"{template_name}.json") 19 | if not osp.exists(file_name): 20 | raise ValueError(f"Can't read {file_name}") 21 | with open(file_name) as fp: 22 | self.template = json.load(fp) 23 | if self._verbose: 24 | print( 25 | f"Using prompt template {template_name}: {self.template['description']}" 26 | ) 27 | 28 | def generate_prompt( 29 | self, 30 | instruction: str, 31 | input: Union[None, str] = None, 32 | label: Union[None, str] = None, 33 | ) -> str: 34 | # returns the full prompt from instruction and optional input 35 | # if a label (=response, =output) is provided, it's also appended. 36 | if input is not None and len(input) > 0: 37 | res = self.template["prompt_input"].format( 38 | instruction=instruction, input=input 39 | ) 40 | else: 41 | res = self.template["prompt_no_input"].format( 42 | instruction=instruction 43 | ) 44 | if label: 45 | res = f"{res}{label}" 46 | if self._verbose: 47 | print(res) 48 | return res 49 | 50 | def get_response(self, output: str) -> str: 51 | return output.split(self.template["response_split"])[1].strip() 52 | --------------------------------------------------------------------------------