├── assets ├── main.png ├── pareto.png ├── theorem.png ├── specdec1.png ├── specdec2.png ├── tradeoff.png ├── teaser_result.png └── specdec-illustration.png ├── requirements.txt ├── data ├── gen_data.sh ├── split_dataset.py ├── readme.md ├── gen_acceptance.py ├── util.py ├── gen_log_p.py ├── gen_assistant.py └── gen_dataset.py ├── LICENSE ├── specdec_pp ├── wrap_model.py ├── sample.py ├── evaluate.py ├── train.py └── hf_generation.py └── README.md /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaffaljidhmah2/SpecDec_pp/HEAD/assets/main.png -------------------------------------------------------------------------------- /assets/pareto.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaffaljidhmah2/SpecDec_pp/HEAD/assets/pareto.png -------------------------------------------------------------------------------- /assets/theorem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaffaljidhmah2/SpecDec_pp/HEAD/assets/theorem.png -------------------------------------------------------------------------------- /assets/specdec1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaffaljidhmah2/SpecDec_pp/HEAD/assets/specdec1.png -------------------------------------------------------------------------------- /assets/specdec2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaffaljidhmah2/SpecDec_pp/HEAD/assets/specdec2.png -------------------------------------------------------------------------------- /assets/tradeoff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaffaljidhmah2/SpecDec_pp/HEAD/assets/tradeoff.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.34.1 2 | huggingface_hub 3 | datasets 4 | accelerate 5 | wandb 6 | scipy -------------------------------------------------------------------------------- /assets/teaser_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaffaljidhmah2/SpecDec_pp/HEAD/assets/teaser_result.png -------------------------------------------------------------------------------- /assets/specdec-illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaffaljidhmah2/SpecDec_pp/HEAD/assets/specdec-illustration.png -------------------------------------------------------------------------------- /data/gen_data.sh: -------------------------------------------------------------------------------- 1 | OUT=$1 2 | FINAL=$2 3 | DATASET=$3 4 | 5 | TARGET=70b 6 | DRAFT=7b 7 | 8 | mkdir -p ${OUT} 9 | mkdir -p ${OUT}/tmp 10 | 11 | TMP1=${OUT}/tmp/tmp1.json 12 | TMP2=${OUT}/tmp/tmp2.json 13 | TMP3=${OUT}/tmp/tmp3.json 14 | TMP4=${OUT}/tmp/tmp4.json 15 | 16 | python3 gen_dataset.py --dataset_name ${DATASET} --model_name ${TARGET} --mode hf --do_sample --output_file $TMP1 17 | python3 gen_assistant.py --model_name ${DRAFT} --do_sample --input_file $TMP1 --output_file $TMP2 18 | python3 gen_log_p.py --model_name ${DRAFT} --input_file $TMP2 --output_file $TMP3 19 | python3 gen_log_p.py --model_name ${TARGET} --input_file $TMP3 --output_file $TMP4 20 | python3 gen_acceptance.py --target_name ${TARGET} --draft_name ${DRAFT} --input_file $TMP4 --output_file ${OUT}/${FINAL} 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Kaixuan Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/split_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | def process(filename, num_train, num_dev, num_test): 6 | data = json.load(open(filename, 'r')) 7 | if num_train + num_dev + num_test != len(data): 8 | print(f"warning: the number of total datapoints {len(data)} does not match {num_train} + {num_dev} + {num_test}.") 9 | 10 | random.shuffle(data) 11 | 12 | train_data = data[:num_train] 13 | dev_data = data[num_train:num_train + num_dev] 14 | test_data = data[num_train + num_dev:num_train + num_dev + num_test] 15 | 16 | # Save the splits into separate files 17 | base_dir = base_dir = os.path.dirname(filename) or '.' 18 | 19 | with open(f'{base_dir}/train.json', 'w') as train_file: 20 | json.dump(train_data, train_file, indent=2) 21 | 22 | with open(f'{base_dir}/dev.json', 'w') as dev_file: 23 | json.dump(dev_data, dev_file, indent=2) 24 | 25 | with open(f'{base_dir}/test.json', 'w') as test_file: 26 | json.dump(test_data, test_file, indent=2) 27 | 28 | if __name__ == "__main__": 29 | import sys 30 | 31 | all_file = sys.argv[1] 32 | num_train = int(sys.argv[2]) 33 | num_dev = int(sys.argv[3]) 34 | num_test = int(sys.argv[4]) 35 | process(all_file, num_train, num_dev, num_test) 36 | 37 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | ### Dataset Preparation 2 | 3 | First `cd data` and run the following commands: 4 | 5 | 1. Run the following script to generate the alpaca dataset: 6 | 7 | ``` 8 | sh gen_data.sh alpaca_data all.json tatsu-lab/alpaca 9 | ``` 10 | 11 | Next, run the following python script to split the alpaca_data `all.json` into `train.json` (40k), `dev.json` (10k), `test.json` (2k). 12 | 13 | ``` 14 | python split_dataset.py all.json 40000 10000 2000 15 | ``` 16 | 17 | 1. Run the following script to generate the humanEval dataset (test-only): 18 | 19 | ``` 20 | sh gen_data.sh humaneval_data test.json openai_humaneval 21 | ``` 22 | 23 | 24 | 3. Run the following script to generate the GSM8K dataset (test-only): 25 | 26 | ``` 27 | sh gen_data.sh gsm8k_test_data test.json gsm8k_test 28 | ``` 29 | 30 | ### Data Format 31 | 32 | Each `json` file is a `list` of `dict` containing the following keys: 33 | 34 | ``` 35 | prompt: str, the prompt. 36 | prefix: list[int], tokenized prompt. 37 | continuation: str, the response generated by the target model. 38 | tokens: list[int], tokenized continuation. 39 | draft: list[int], next tokens generated from the draft model conditioned on target model's generation. 40 | log_p_7b: list[float], the log probabilities of the draft tokens predicted by the draft model (7b). 41 | log_p_70b: list[float], the log probabilities of the draft tokens predicted by the target model (70b). 42 | p_acc: list[float], the acceptance probabilities of the draft tokens. 43 | ``` -------------------------------------------------------------------------------- /data/gen_acceptance.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 2 | import time 3 | import json 4 | import torch 5 | from datasets import load_dataset, load_from_disk 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | from ast import literal_eval as eval 10 | from util import CKPT, get_model, pretty_format 11 | 12 | def read_data(filename): 13 | data = json.load(open(filename,'r')) 14 | return data 15 | 16 | 17 | def get_acc_prob(data, target_name, draft_name): 18 | for item in data: 19 | target_log_p = eval(item['log_p_' + target_name]) 20 | target_log_p = torch.tensor(target_log_p) 21 | 22 | draft_log_p = eval(item['log_p_' + draft_name]) 23 | draft_log_p = torch.tensor(draft_log_p) 24 | 25 | diff = target_log_p - draft_log_p 26 | diff[diff>0] = 0 27 | acc_prob = torch.exp(diff) 28 | item[f'p_acc'] = acc_prob.tolist() 29 | 30 | return data 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description='data generator') 34 | 35 | parser.add_argument('--target_name', type=str, choices=["7b", "13b", "70b"], default='70b') 36 | parser.add_argument('--draft_name', type=str, choices=["7b", "13b", "70b"], default='7b') 37 | parser.add_argument('--input_file', type=str) 38 | parser.add_argument('--output_file', type=str, default=None) 39 | 40 | args = parser.parse_args() 41 | 42 | return args 43 | if __name__ == "__main__": 44 | args = parse_args() 45 | data = read_data(args.input_file) 46 | data = get_acc_prob(data, target_name = args.target_name, draft_name = args.draft_name) 47 | 48 | if args.output_file is None or len(args.output_file) == 0: 49 | args.output_file = args.input_file.rstrip('.json') + '_acc.json' 50 | 51 | data = pretty_format(data) 52 | with open(args.output_file, 'w') as f: 53 | f.write(json.dumps(data, indent=2)) 54 | 55 | -------------------------------------------------------------------------------- /specdec_pp/wrap_model.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from transformers import AutoTokenizer, EsmForMaskedLM, AutoModelForCausalLM, Trainer, TrainingArguments 3 | from tokenizers import Tokenizer 4 | from dataclasses import dataclass, field 5 | from typing import Optional 6 | import torch 7 | import torch.nn as nn 8 | from transformers.modeling_utils import PreTrainedModel 9 | from transformers.configuration_utils import PretrainedConfig 10 | from transformers.utils import logging 11 | from huggingface_hub import PyTorchModelHubMixin 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | 16 | class ResBlock(nn.Module): 17 | def __init__(self, hidden_size): 18 | super().__init__() 19 | self.linear = nn.Linear(hidden_size, hidden_size) 20 | #torch.nn.init.zeros_(self.linear.weight) 21 | self.act = nn.SiLU() 22 | 23 | def forward(self, x): 24 | return x + self.act(self.linear(x)) 25 | 26 | class AcceptancePredictionHead(nn.Module, PyTorchModelHubMixin): 27 | def __init__(self, config): 28 | self.config=config 29 | hidden_size = config['hidden_size'] 30 | num_layers = config.get('num_layers', 0) 31 | super().__init__() 32 | self.model = nn.Sequential( *([ResBlock(hidden_size)] * num_layers), nn.Linear(hidden_size, 2) ) 33 | 34 | def forward(self, x): 35 | return self.model(x) 36 | 37 | class WrapModel(PreTrainedModel): 38 | def __init__(self, model, head): 39 | super().__init__(model.config) 40 | self.model = model 41 | self.assist_acc_head = head 42 | 43 | def forward(self, input_ids = None, labels = None, **kwargs): 44 | return self.model(input_ids = input_ids, labels = labels, **kwargs) 45 | 46 | 47 | if __name__ == "__main__": 48 | #input_ids = labels = torch.LongTensor([[1,2,3]]) 49 | #model = transformers.AutoModelForCausalLM.from_pretrained("ckpt/hf-llama2-7b-chat") 50 | #wrapped = WrapModel(model, num_layers=2) 51 | #AcceptancePredictionHead.from_pretrained('../exp-weight6-layer3') 52 | 53 | import pdb;pdb.set_trace() 54 | -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 2 | import time 3 | import json 4 | import torch 5 | from datasets import load_dataset, load_from_disk 6 | import os 7 | import argparse 8 | 9 | 10 | CKPT = { 11 | '7b': "meta-llama/Llama-2-7b-chat-hf", 12 | '13b': "meta-llama/Llama-2-13b-chat-hf", 13 | '70b': "meta-llama/Llama-2-70b-chat-hf" 14 | } 15 | 16 | 17 | def get_model(model_name): 18 | checkpoint = CKPT[model_name] 19 | dtype = torch.bfloat16 20 | print('model checkpoint: ', checkpoint) 21 | print('model dtype: ', dtype) 22 | model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=dtype, device_map='auto') 23 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 24 | return tokenizer, model 25 | 26 | def get_dataset(name): 27 | if name=="tatsu-lab/alpaca": 28 | dataset_file='alpaca' 29 | if not os.path.exists(dataset_file): 30 | dataset = load_dataset("tatsu-lab/alpaca")['train'] 31 | dataset.save_to_disk(dataset_file) 32 | else: 33 | dataset = load_from_disk(dataset_file) 34 | elif name=='openai_humaneval': 35 | dataset_file='humaneval' 36 | if not os.path.exists(dataset_file): 37 | dataset = load_dataset('openai_humaneval')['test'] 38 | dataset.save_to_disk(dataset_file) 39 | else: 40 | dataset = load_from_disk(dataset_file) 41 | elif name=='gsm8k_test': 42 | dataset_file='gsm8k' 43 | if not os.path.exists(dataset_file): 44 | dataset = load_dataset('gsm8k', 'main')['test'] 45 | dataset.save_to_disk(dataset_file) 46 | else: 47 | dataset = load_from_disk(dataset_file) 48 | else: 49 | raise NotImplementedError 50 | return dataset 51 | 52 | def pretty_format(data): 53 | for item in data: 54 | for key, value in item.items(): 55 | if isinstance(value, list) and isinstance(value[0], int): 56 | item[key] = str(value) 57 | if isinstance(value, list) and isinstance(value[0], float): 58 | item[key] = str(value) 59 | return data 60 | 61 | if __name__ == "__main__": 62 | dataset = get_dataset('gsm8k_test') 63 | print(len(dataset), dataset[0]) -------------------------------------------------------------------------------- /data/gen_log_p.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 2 | import time 3 | import json 4 | import torch 5 | from datasets import load_dataset, load_from_disk 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | from ast import literal_eval as eval 10 | from util import CKPT, get_model, pretty_format 11 | 12 | def read_data(filename): 13 | data = json.load(open(filename,'r')) 14 | for item in data: 15 | item['prefix'] = eval(item['prefix']) 16 | item['tokens'] = eval(item['tokens']) 17 | item['draft'] = eval(item['draft']) 18 | return data 19 | 20 | 21 | @torch.no_grad() 22 | def get_log_prob(data, model, model_name): 23 | for item in data: 24 | joint = item['prefix'] + item['tokens'] 25 | joint = torch.LongTensor(joint).to(model.device) 26 | joint = joint.unsqueeze(0) 27 | 28 | index = item['draft'] 29 | index = torch.LongTensor(index).to(model.device) 30 | index = index.unsqueeze(-1) 31 | 32 | logits = model(input_ids = joint).logits 33 | 34 | log_probs = logits.log_softmax(dim=-1) # bs * seq_len * vocab_size 35 | log_probs_shifted = log_probs[0, len(item['prefix'])-1 : -1] # next_token_log_prob for the continuation 36 | 37 | # take along "item['draft']" 38 | log_p = torch.take_along_dim(log_probs_shifted, index, dim=-1) # seq_len * 1 39 | item[f'log_p_{model_name}'] = log_p[:, 0].tolist() 40 | 41 | return data 42 | 43 | def parse_args(): 44 | parser = argparse.ArgumentParser(description='data generator') 45 | 46 | parser.add_argument('--model_name', type=str, choices=["7b", "13b", "70b"], default='7b') 47 | parser.add_argument('--input_file', type=str) 48 | parser.add_argument('--output_file', type=str, default=None) 49 | 50 | args = parser.parse_args() 51 | 52 | return args 53 | if __name__ == "__main__": 54 | args = parse_args() 55 | data = read_data(args.input_file) 56 | 57 | tokenizer, model = get_model(args.model_name) 58 | data = get_log_prob(data, model, args.model_name) 59 | suffix = 'logP' 60 | if args.output_file is None or len(args.output_file) == 0: 61 | args.output_file = args.input_file.rstrip('.json') + '_' + args.model_name + suffix + '.json' 62 | 63 | data = pretty_format(data) 64 | 65 | 66 | with open(args.output_file, 'w') as f: 67 | f.write(json.dumps(data, indent=2)) 68 | 69 | -------------------------------------------------------------------------------- /specdec_pp/sample.py: -------------------------------------------------------------------------------- 1 | """example code for sampling using SpecDec++""" 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 4 | import time 5 | import torch 6 | from hf_generation import my_generate 7 | from wrap_model import AcceptancePredictionHead 8 | 9 | device = "cuda" 10 | 11 | 12 | 13 | def set_up(): 14 | checkpoint = "meta-llama/Llama-2-70b-chat-hf" 15 | assistant_checkpoint = "meta-llama/Llama-2-7b-chat-hf" 16 | assist_acc_head_dir = "hacky/acchead-llama2-chat-7bx70b" 17 | 18 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 19 | 20 | assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, torch_dtype=torch.bfloat16, device_map='cuda:0') 21 | assist_acc_head = AcceptancePredictionHead.from_pretrained(assist_acc_head_dir).to('cuda:0') 22 | model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map='auto') 23 | 24 | 25 | return tokenizer, model, assistant_model, assist_acc_head 26 | 27 | 28 | def format_prompt(prompt): 29 | """ 30 | wrap the prompt in llama-2-chat format. 31 | """ 32 | 33 | B_INST, E_INST = "[INST]", "[/INST]" 34 | return f"{B_INST} {prompt.strip()} {E_INST}" 35 | 36 | 37 | 38 | def main(prompt): 39 | 40 | ### load target/draft/Acceptance Head and set generation config 41 | tokenizer, model, assistant_model, assist_acc_head = set_up() 42 | stop_threshold = 0.7 43 | bound = (2, 20) 44 | max_length = 512 45 | 46 | before=time.time() 47 | 48 | ### format and tokenize prompt 49 | prompt = format_prompt(prompt) 50 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 51 | 52 | outputs, mismatched_tokens, LM_call = my_generate(model=model, **inputs, assistant_model=assistant_model, \ 53 | max_length=max_length, num_assistant_tokens_schedule='ada', \ 54 | do_sample=True, \ 55 | assist_acc_head=assist_acc_head, \ 56 | stop_threshold=stop_threshold, bound=bound) 57 | 58 | after = time.time() 59 | assisted_time = (after - before) 60 | 61 | print(tokenizer.decode(outputs[0])) 62 | 63 | print("assisted time: {:.2f}".format(assisted_time)) 64 | print("# mismatched_tokens: {:.2f}".format(mismatched_tokens)) 65 | print("# LM_call: {:.2f}".format(LM_call)) 66 | 67 | return outputs 68 | 69 | 70 | 71 | if __name__ == "__main__": 72 | prompt = "List 10 methods to be a successful PHD." 73 | main(prompt) 74 | 75 | -------------------------------------------------------------------------------- /data/gen_assistant.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 2 | import time 3 | import json 4 | import torch 5 | from datasets import load_dataset, load_from_disk 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | from ast import literal_eval as eval 10 | from util import CKPT, get_model, pretty_format 11 | 12 | 13 | def read_data(filename): 14 | data = json.load(open(filename,'r')) 15 | for item in data: 16 | item['prefix'] = eval(item['prefix']) 17 | item['tokens'] = eval(item['tokens']) 18 | 19 | return data 20 | 21 | 22 | @torch.no_grad() 23 | def get_assistant_result(data, assistant_model, model_name, do_sample): 24 | for item in data: 25 | joint = item['prefix'] + item['tokens'] 26 | joint = torch.LongTensor(joint).to(assistant_model.device) 27 | joint = joint.unsqueeze(0) 28 | sm_logits = assistant_model(input_ids = joint).logits 29 | if do_sample: 30 | probs = sm_logits.softmax(dim=-1) # bs * seq_len * vocab_size 31 | new_token = torch.multinomial(probs[0], num_samples=1).squeeze(-1) 32 | item['draft'] = new_token[len(item['prefix'])-1 : -1].tolist() 33 | else: 34 | new_token = sm_logits.argmax(dim=-1) # bs * seq_len 35 | item['draft'] = new_token[0, len(item['prefix'])-1 : -1].tolist() 36 | return data 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser(description='data generator') 40 | 41 | parser.add_argument('--model_name', type=str, choices=["7b"], default='7b') 42 | parser.add_argument('--input_file', type=str) 43 | parser.add_argument('--output_file', type=str, default=None) 44 | parser.add_argument('--do_sample', action='store_true') 45 | 46 | args = parser.parse_args() 47 | 48 | return args 49 | if __name__ == "__main__": 50 | args = parse_args() 51 | data = read_data(args.input_file) 52 | 53 | tokenizer, model = get_model(args.model_name) 54 | data = get_assistant_result(data, model, args.model_name, args.do_sample) 55 | 56 | if args.output_file is None or len(args.output_file) == 0: 57 | if args.do_sample: 58 | suffix = 'stochastic' 59 | else: 60 | suffix = 'greedy' 61 | args.output_file = args.input_file.rstrip('.json') + '_' + args.model_name + suffix + '.json' 62 | 63 | data = pretty_format(data) 64 | 65 | with open(args.output_file, 'w') as f: 66 | f.write(json.dumps(data, indent=2)) 67 | 68 | -------------------------------------------------------------------------------- /data/gen_dataset.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 2 | import time 3 | import json 4 | import torch 5 | from datasets import load_dataset, load_from_disk 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | from util import CKPT, get_model, get_dataset, pretty_format 10 | 11 | B_INST, E_INST = "[INST]", "[/INST]" 12 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 13 | 14 | SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] 15 | UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." 16 | 17 | def get_prompt(sample, dataset_name): 18 | """ 19 | wrap the prompt in llama-2-chat format. 20 | """ 21 | if dataset_name == 'tatsu-lab/alpaca': 22 | prompt = get_prompt_alpaca(sample) 23 | elif dataset_name == 'openai_humaneval': 24 | prompt = get_prompt_humaneval(sample) 25 | elif dataset_name == 'gsm8k_test': 26 | prompt = get_prompt_gsm8k(sample) 27 | 28 | return f"{B_INST} {prompt.strip()} {E_INST}" 29 | 30 | 31 | def get_prompt_alpaca(sample): 32 | """ 33 | for alpaca format only 34 | """ 35 | if sample['input'] is None or len(sample['input'].strip()) == 0: 36 | prompt = sample['instruction'] 37 | else: 38 | prompt = sample['instruction'] + '\nInput: ' + sample['input'] 39 | return prompt 40 | 41 | 42 | def get_prompt_humaneval(sample): 43 | """ 44 | OpenAI HumanEval 45 | prompt format https://github.com/nlpxucan/WizardLM/blob/main/WizardCoder/src/humaneval_gen.py 46 | """ 47 | INSTRUCTION = """Below is an instruction that describes a task. Write a response that appropriately completes the request. 48 | 49 | 50 | ### Instruction: 51 | Create a Python script for this problem: 52 | {prompt} 53 | 54 | ### Response:""" 55 | return INSTRUCTION.format(prompt=sample['prompt']) 56 | 57 | def get_prompt_gsm8k(sample): 58 | """ 59 | gsm8K 60 | prompt format https://github.com/meta-math/MetaMath/blob/main/eval_gsm8k.py 61 | """ 62 | problem_prompt = ( 63 | "Below is an instruction that describes a task. " 64 | "Write a response that appropriately completes the request.\n\n" 65 | "### Instruction:\n{instruction}\n\n### Response: Let's think step by step." 66 | ) 67 | return problem_prompt.format(instruction=sample['question']) 68 | 69 | 70 | def sanity_check(sample, tokenizer): 71 | inputs = tokenizer(get_prompt(sample), return_tensors='pt') 72 | input_ids = inputs['input_ids'] 73 | assert input_ids[0][0] == tokenizer.bos_token_id, 'the first should be ' 74 | assert input_ids[0][-1] != tokenizer.eos_token_id, 'the last should not be ' 75 | print("sanity check passed") 76 | 77 | 78 | def infer(prompt, tokenizer, model, max_length = 32, do_sample=False): 79 | 80 | device = "cuda" 81 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 82 | max_length += len(inputs['input_ids'][0]) 83 | model_output = model.generate(**inputs, do_sample=do_sample, max_length = max_length)[0] 84 | 85 | ret = tokenizer.decode(model_output, skip_special_tokens=True) 86 | 87 | 88 | # Is it possible that this is not reversible? after decoding, the tokens changed? Yes! 89 | 90 | #re_token = tokenizer.encode(ret, return_tensors='pt').to(device) 91 | #if re_token.shape != model_output.shape or (re_token - model_output).sum().item() != 0: 92 | #print("mismatch!") 93 | 94 | prefix_token_id = model_output[:len(inputs['input_ids'][0])] 95 | gen_token_id = model_output[len(inputs['input_ids'][0]) :] 96 | ret = ret[len(prompt):] # remove prefix 97 | 98 | return prefix_token_id.cpu().tolist() , gen_token_id.cpu().tolist(), ret 99 | 100 | 101 | 102 | def parse_args(): 103 | parser = argparse.ArgumentParser(description='data generator') 104 | 105 | parser.add_argument('--dataset_name', type=str) 106 | parser.add_argument('--model_name', type=str, choices=["7b", "13b", "70b"]) 107 | parser.add_argument('--mode', type=str, choices=['hf']) 108 | parser.add_argument('--do_sample', action='store_true') 109 | parser.add_argument('--n_begin', type=int, default=0) 110 | parser.add_argument('--n_end', type=int, default=-1) 111 | parser.add_argument('--max_length', type=int, default=512) 112 | parser.add_argument('--output_file', type=str, default=None) 113 | 114 | 115 | 116 | args = parser.parse_args() 117 | 118 | return args 119 | 120 | 121 | 122 | def main(args): 123 | #sanity_check(dataset[0], tokenizer) 124 | 125 | print(f"we are using do sample = {args.do_sample}") 126 | 127 | tokenizer, model = get_model(args.model_name) 128 | dataset = get_dataset(args.dataset_name) 129 | 130 | if args.n_end == -1: 131 | args.n_end = len(dataset) 132 | args.n_end = min(args.n_end, len(dataset)) 133 | 134 | res_dict = [] 135 | for i in tqdm(range(args.n_begin, args.n_end)): 136 | sample = dataset[i] 137 | prompt = get_prompt(sample, args.dataset_name) 138 | prefix_token, gen_token, s = infer(prompt, tokenizer, model, max_length=args.max_length, do_sample=args.do_sample) 139 | res_dict.append( 140 | { 141 | 'prompt': prompt, 142 | 'continuation': s, 143 | 'prefix': str(prefix_token) if prefix_token is not None else "", 144 | 'tokens': str(gen_token) if gen_token is not None else "" 145 | } 146 | ) 147 | 148 | if args.output_file is None: 149 | args.output_file = f'dataset{args.n_begin}to{args.n_end}_{args.mode}{args.model_name}.json' 150 | with open(args.output_file, 'w') as f: 151 | f.write(json.dumps(res_dict, indent=2)) 152 | 153 | 154 | if __name__ == "__main__": 155 | args = parse_args() 156 | main(args) 157 | -------------------------------------------------------------------------------- /specdec_pp/evaluate.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig 2 | import time 3 | from datetime import datetime 4 | import torch 5 | from hf_generation import my_generate 6 | import argparse 7 | import json 8 | import os 9 | import numpy as np 10 | from ast import literal_eval as eval 11 | 12 | device = "cuda" 13 | 14 | def set_up(args): 15 | if args.do_sample: 16 | print("do_sample for SpeculativeDecoding") 17 | 18 | np.random.seed(args.random_seed) 19 | torch.manual_seed(args.random_seed) 20 | 21 | checkpoint = args.model_name 22 | assistant_checkpoint = args.assistant_name 23 | 24 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 25 | 26 | assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, torch_dtype=torch.bfloat16, device_map='cuda:0') 27 | 28 | if args.num_assistant_tokens_schedule == 'ada': 29 | from wrap_model import AcceptancePredictionHead 30 | print("Loading from acc_head checkpoint:", args.assist_acc_head_dir) 31 | 32 | assist_acc_head = AcceptancePredictionHead.from_pretrained(args.assist_acc_head_dir).to('cuda:0') 33 | 34 | else: 35 | assist_acc_head = None 36 | 37 | model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map='auto') 38 | 39 | # print(model.hf_device_map) 40 | # print(assistant_model.hf_device_map) 41 | return model, assistant_model, tokenizer, assist_acc_head 42 | 43 | def assist(model, assistant_model, tokenizer, assist_acc_head, inputs, max_length, num_assistant_tokens=None): 44 | # outputs = model.generate(**inputs, generation_config=generation_config, assistant_model=assistant_model, max_length=max_length) 45 | before=time.time() 46 | assistant_model.max_assistant_tokens = None 47 | outputs, mismatched_tokens, LM_call = my_generate(model=model, **inputs, assistant_model=assistant_model, \ 48 | max_length=max_length, num_assistant_tokens_schedule=args.num_assistant_tokens_schedule, \ 49 | num_assistant_tokens=num_assistant_tokens, do_sample=args.do_sample, \ 50 | assist_acc_head=assist_acc_head, \ 51 | stop_threshold=args.stop_threshold, bound=args.bound) 52 | 53 | after = time.time() 54 | assisted_time = (after - before) 55 | 56 | print("assisted time: {:.2f}".format(assisted_time)) 57 | print("mismatched_tokens: {:.2f}".format(mismatched_tokens)) 58 | print("LM_call: {:.2f}".format(LM_call)) 59 | 60 | return outputs, mismatched_tokens, LM_call, assisted_time 61 | 62 | def target(model, tokenizer, inputs, max_length): 63 | before=time.time() 64 | outputs = model.generate(**inputs, max_length=max_length, do_sample=args.do_sample) 65 | after = time.time() 66 | target_time = (after - before) 67 | print("target time {:.2f}".format(target_time)) 68 | return outputs, target_time 69 | 70 | def draft(assistant_model, tokenizer, inputs, max_length): 71 | before=time.time() 72 | outputs = assistant_model.generate(**inputs, max_length=max_length, do_sample=args.do_sample) 73 | after = time.time() 74 | draft_time = (after - before) 75 | print("draft time {:.2f}".format(draft_time)) 76 | return outputs, draft_time 77 | 78 | 79 | def run(model, assistant_model, tokenizer, assist_acc_head, args, item): 80 | len_prefix = len(eval(item['prefix'])) 81 | inputs = {'input_ids': torch.LongTensor([eval(item['prefix'])]).to(device)} 82 | max_length = args.max_length 83 | print("max_length:", max_length) 84 | 85 | 86 | if args.num_assistant_tokens_schedule in ['constant', 'heuristic', 'ada']: 87 | 88 | if args.num_assistant_tokens_schedule == 'ada': 89 | num_assistant_tokens = None 90 | else: 91 | num_assistant_tokens = args.num_assistant_tokens 92 | print("num_assistant_tokens:", num_assistant_tokens) 93 | 94 | res_a, num_mismatched_tokens, num_LM_call, assisted_time = assist(model, assistant_model, tokenizer, assist_acc_head, inputs, max_length, num_assistant_tokens=num_assistant_tokens) 95 | elif args.num_assistant_tokens_schedule == 'none': 96 | res_a = [[-1]] 97 | num_mismatched_tokens = -1 98 | num_LM_call = -1 99 | assisted_time = -1 100 | else: 101 | raise ValueError(f"{args.num_assistant_tokens_schedule} not supported") 102 | 103 | 104 | if args.num_assistant_tokens_schedule == 'none': 105 | res_b, target_time = target(model, tokenizer, inputs, max_length) 106 | generated_length_target = len(res_b[0]) - len_prefix 107 | 108 | res_c, draft_time = draft(assistant_model, tokenizer, inputs, max_length) 109 | generated_length_draft = len(res_c[0]) - len_prefix 110 | else: 111 | target_time = -1 112 | generated_length_target = -1 113 | draft_time = -1 114 | generated_length_draft = -1 115 | 116 | 117 | 118 | generated_length = len(res_a[0]) - len_prefix 119 | print("generated_length: {:.2f}".format(generated_length)) 120 | 121 | return assisted_time, target_time, draft_time, num_mismatched_tokens, num_LM_call, generated_length, generated_length_target, generated_length_draft 122 | 123 | def parse_args(): 124 | parser = argparse.ArgumentParser(description='benchmark performance') 125 | 126 | parser.add_argument('--model_name', type=str, default=None) 127 | parser.add_argument('--assistant_name', type=str, default=None) 128 | parser.add_argument('--max_length', type=int, default=512) 129 | parser.add_argument('--do_sample', action='store_true') 130 | 131 | parser.add_argument('--num_assistant_tokens', type=int, default=5) 132 | parser.add_argument('--num_assistant_tokens_schedule', type=str, default="constant", choices=['constant', 'heuristic', 'ada', 'none']) 133 | parser.add_argument('--assist_acc_head_dir', type=str, default=None) 134 | parser.add_argument('--data_path', type=str, default='data/alpaca_data/test.json') 135 | parser.add_argument('--save_path', type=str, default='./test_results') 136 | parser.add_argument('--random_seed', type=int, default=47) 137 | parser.add_argument('--stop_threshold', type=float, default=None) 138 | parser.add_argument('--bound', nargs='+', type=int, default=None) 139 | 140 | parser.add_argument('--n_begin', type=int, default=0) 141 | parser.add_argument('--n_end', type=int, default=None) 142 | 143 | args = parser.parse_args() 144 | print(args) 145 | 146 | return args 147 | 148 | 149 | if __name__ == "__main__": 150 | args = parse_args() 151 | data = json.load(open(args.data_path,'r')) 152 | if args.n_end is None: 153 | args.n_end = len(data) 154 | args.n_end = min(len(data), args.n_end) 155 | 156 | 157 | os.makedirs(args.save_path, exist_ok=True) 158 | 159 | model, assistant_model, tokenizer, assist_acc_head = set_up(args) 160 | 161 | results = [] 162 | 163 | for i, item in enumerate(data[args.n_begin:args.n_end]): 164 | print("---------------------------------") 165 | print(f"data {i + args.n_begin}") 166 | before=time.time() 167 | 168 | assisted_time, target_time, draft_time, num_mismatched_tokens, num_LM_call, generated_length, generated_length_target, generated_length_draft = run(model, assistant_model, tokenizer, assist_acc_head, args, item) 169 | item.update({ 170 | 'id': i+args.n_begin, 171 | 'spec_time': assisted_time, 172 | 'target_time': target_time, 173 | 'draft_time': draft_time, 174 | 'num_mismatched_tokens': num_mismatched_tokens, 175 | 'num_LM_call': num_LM_call, 176 | 'generated_length': generated_length, 177 | 'generated_length_target': generated_length_target, 178 | 'generated_length_draft': generated_length_draft, 179 | }) 180 | results.append(item) 181 | 182 | after=time.time() 183 | print("total time: {:.2f}".format(after-before)) 184 | save_file = f"{args.save_path}/results_{args.n_begin}to{args.n_end}.json" 185 | with open(save_file, 'w') as f: 186 | f.write(json.dumps(results, indent=2)) 187 | -------------------------------------------------------------------------------- /specdec_pp/train.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py 2 | 3 | import copy 4 | import logging 5 | from dataclasses import dataclass, field 6 | from typing import Dict, Optional, Sequence, List, TYPE_CHECKING, Any, Callable, Tuple, Union 7 | 8 | import torch 9 | import transformers 10 | from torch.nn import CrossEntropyLoss 11 | from torch.utils.data import Dataset 12 | from transformers import Trainer 13 | import json 14 | import numpy 15 | import scipy.special 16 | from ast import literal_eval as eval 17 | 18 | from wrap_model import WrapModel, AcceptancePredictionHead 19 | from transformers import EvalPrediction 20 | 21 | IGNORE_INDEX = -100 22 | DEFAULT_PAD_TOKEN = "[PAD]" 23 | 24 | def compute_metrics(eval_pred: "EvalPrediction") -> Dict: 25 | num_class = 2 26 | logits= eval_pred[0] 27 | soft_labels = eval_pred[1] 28 | 29 | logits = logits.reshape(-1, num_class) 30 | soft_labels = soft_labels.reshape(-1) 31 | 32 | not_ignore = (soft_labels - IGNORE_INDEX) > 0.1 33 | 34 | target_prob = soft_labels[not_ignore] 35 | logits = logits[not_ignore] 36 | predicted_log_prob = scipy.special.log_softmax(logits, axis=-1) 37 | 38 | # KL divergence: 39 | CrossEnt = target_prob * ( - predicted_log_prob[:,1]) + (1-target_prob) * ( - predicted_log_prob[:,0]) 40 | Ent = target_prob * numpy.log(target_prob) + (1-target_prob) * numpy.log(1-target_prob) 41 | Ent[numpy.isnan(Ent)] = 0. # hack for binary entropy 42 | KL_binary = CrossEnt - Ent 43 | KL_binary = numpy.mean(KL_binary) 44 | 45 | return {'KL': KL_binary} 46 | 47 | 48 | class MyTrainer(Trainer): 49 | 50 | def compute_loss(self, model, inputs, return_outputs=False): 51 | soft_labels = inputs.pop('soft_labels') 52 | mask = (soft_labels - IGNORE_INDEX).abs() > 0.1 53 | 54 | soft_labels_1 = soft_labels 55 | soft_labels_0 = soft_labels_1.clone() 56 | soft_labels_0[mask] = 1 - soft_labels_1[mask] 57 | 58 | label_0 = torch.ones_like(soft_labels, dtype=torch.long).to(soft_labels.device) * IGNORE_INDEX 59 | label_0[mask] = 0 60 | label_1 = torch.ones_like(soft_labels, dtype=torch.long).to(soft_labels.device) * IGNORE_INDEX 61 | label_1[mask] = 1 62 | 63 | outputs = model.model(**inputs, output_hidden_states = True, return_dict=True) 64 | hidden_states = outputs.get("hidden_states") 65 | orignal_logits = model.assist_acc_head(hidden_states[-1]) 66 | orignal_logits = orignal_logits.float() 67 | 68 | num_class = 2 69 | 70 | weight = torch.tensor([self.args.weight_mismatch, 1]).to(orignal_logits.device) 71 | loss_fct = CrossEntropyLoss(weight=weight, reduction='none') 72 | 73 | logits = orignal_logits.view(-1, num_class) 74 | label_0 = label_0.view(-1) 75 | label_1 = label_1.view(-1) 76 | soft_labels_0 = soft_labels_0.view(-1) 77 | soft_labels_1 = soft_labels_1.view(-1) 78 | mask = mask.view(-1) 79 | 80 | loss_0 = loss_fct(logits, label_0) # (bs * seq_len), num_class 81 | loss_1 = loss_fct(logits, label_1) # (bs * seq_len), num_class 82 | 83 | # reduce with soft labels, coresponding to BCELoss 84 | loss = (loss_0 * soft_labels_0 + loss_1 * soft_labels_1).sum() / (self.args.weight_mismatch * soft_labels_0[mask].sum() + soft_labels_1[mask].sum() ) 85 | 86 | if model.training: 87 | # KL divergence: 88 | target_prob = soft_labels_1[mask] 89 | predicted_logits = logits[mask, :] 90 | predicted_log_prob = torch.log_softmax(predicted_logits, dim=-1) 91 | 92 | #KL_binary = target_prob * (target_prob.log() - predicted_log_prob[:,1]) + (1-target_prob) * ( (1-target_prob).log() - predicted_log_prob[:,0]) 93 | 94 | CrossEnt = target_prob * ( - predicted_log_prob[:,1]) + (1-target_prob) * ( - predicted_log_prob[:,0]) 95 | Ent = target_prob * target_prob.log() + (1-target_prob) * (1-target_prob).log() 96 | Ent[Ent.isnan()] = 0. # hack for binary entropy 97 | KL_binary = CrossEnt - Ent 98 | KL_binary = KL_binary.mean().item() 99 | 100 | self.log({'KL': KL_binary}) 101 | 102 | 103 | if return_outputs: 104 | outputs = (loss, orignal_logits) 105 | return (loss, outputs) 106 | else: 107 | return loss 108 | 109 | @dataclass 110 | class TrainingArguments(transformers.TrainingArguments): 111 | bf16: bool = True 112 | model_name_or_path: Optional[str] = field(default=None) 113 | data_path: str = field(default=None) 114 | eval_data_path: str = field(default=None) 115 | remove_unused_columns: bool = False 116 | evaluate_only: bool = False 117 | label_names: Optional[List[str]] = field( 118 | default_factory=lambda: ['soft_labels'], metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."} 119 | ) 120 | 121 | weight_mismatch: Optional[float] = field(default = 1.) # 6 for balancing classes 122 | resnet_num_layers: Optional[int] = field(default = 1) 123 | mixing_ratio: Optional[float] = field(default = 0.15) 124 | 125 | 126 | def smart_tokenizer_and_embedding_resize( 127 | special_tokens_dict: Dict, 128 | tokenizer: transformers.PreTrainedTokenizer, 129 | model: transformers.PreTrainedModel, 130 | ): 131 | """Resize tokenizer and embedding. 132 | 133 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 134 | """ 135 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 136 | model.resize_token_embeddings(len(tokenizer)) 137 | 138 | if num_new_tokens > 0: 139 | input_embeddings = model.get_input_embeddings().weight.data 140 | output_embeddings = model.get_output_embeddings().weight.data 141 | 142 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 143 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 144 | 145 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 146 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 147 | 148 | 149 | 150 | class SupervisedDataset(Dataset): 151 | def __init__(self, data_path: str, r: float = 0.15): 152 | super(SupervisedDataset, self).__init__() 153 | logging.warning(f"Loading data... from {data_path}") 154 | data = json.load(open(data_path,'r')) 155 | self.input_ids = [] 156 | self.soft_labels = [] 157 | for item in data: 158 | item['prefix'] = eval(item['prefix']) 159 | item['tokens'] = eval(item['tokens']) 160 | item['draft'] = eval(item['draft']) 161 | 162 | # item['tokens'] are generated autoregressively from target model 163 | # item['draft'] are stochatic next-token predicted by the draft model 164 | 165 | item['p_acc'] = eval(item['p_acc']) 166 | 167 | prefix = torch.LongTensor(item['prefix']) 168 | Xs = torch.LongTensor(item['tokens']) 169 | # Ys = torch.LongTensor(item['draft']) 170 | 171 | # take r from Xs and (1-r) from Ys. 172 | mask = (torch.rand(*Xs.shape) < r) 173 | Zs = torch.LongTensor(item['draft']) 174 | Zs[mask] = Xs[mask] 175 | 176 | self.input_ids.append(torch.cat([prefix, Zs])) 177 | 178 | label_prefix = torch.tensor([IGNORE_INDEX] * len(item['prefix'])) 179 | p_acc = torch.tensor(item['p_acc']) 180 | 181 | # don't calculate loss on Xs. 182 | p_acc[mask] = IGNORE_INDEX 183 | 184 | self.soft_labels.append(torch.cat([label_prefix, p_acc])) 185 | 186 | def __len__(self): 187 | return len(self.input_ids) 188 | 189 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 190 | return dict(input_ids=self.input_ids[i], soft_labels=self.soft_labels[i]) 191 | 192 | 193 | @dataclass 194 | class DataCollatorForSupervisedDataset(object): 195 | """Collate examples for supervised fine-tuning.""" 196 | 197 | tokenizer: transformers.PreTrainedTokenizer 198 | 199 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 200 | input_ids, soft_labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "soft_labels")) 201 | input_ids = torch.nn.utils.rnn.pad_sequence( 202 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 203 | ) 204 | soft_labels = torch.nn.utils.rnn.pad_sequence(soft_labels, batch_first=True, padding_value=IGNORE_INDEX) 205 | return dict( 206 | input_ids=input_ids, 207 | soft_labels=soft_labels, 208 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 209 | ) 210 | 211 | 212 | 213 | 214 | if __name__ == "__main__": 215 | parser = transformers.HfArgumentParser((TrainingArguments)) 216 | training_args = parser.parse_args_into_dataclasses()[0] 217 | 218 | tokenizer = transformers.AutoTokenizer.from_pretrained(training_args.model_name_or_path) 219 | model = transformers.AutoModelForCausalLM.from_pretrained(training_args.model_name_or_path) 220 | special_tokens_dict = dict() 221 | if tokenizer.pad_token is None: 222 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 223 | 224 | smart_tokenizer_and_embedding_resize( 225 | special_tokens_dict=special_tokens_dict, 226 | tokenizer=tokenizer, 227 | model=model, 228 | ) 229 | 230 | train_dataset = SupervisedDataset(training_args.data_path, r=training_args.mixing_ratio) 231 | if training_args.eval_data_path is not None: 232 | eval_dataset = SupervisedDataset(training_args.eval_data_path, r=training_args.mixing_ratio) 233 | print("num eval example:", len(eval_dataset)) 234 | else: 235 | eval_dataset = None 236 | data_collator = DataCollatorForSupervisedDataset(tokenizer) 237 | 238 | acc_head_config = {'hidden_size': model.config.hidden_size, 'num_layers': training_args.resnet_num_layers} 239 | assist_acc_head = AcceptancePredictionHead(acc_head_config) 240 | wrapped = WrapModel(model, assist_acc_head) 241 | wrapped.model.requires_grad_(False) 242 | print('num training example:', len(train_dataset)) 243 | trainer = MyTrainer(model=wrapped, tokenizer=tokenizer, args=training_args, train_dataset = train_dataset, eval_dataset = eval_dataset, data_collator=data_collator, compute_metrics = compute_metrics) 244 | if training_args.evaluate_only: 245 | print("eval only. Loading from checkpoint:", training_args.output_dir) 246 | wrapped.assist_acc_head = AcceptancePredictionHead.from_pretrained(training_args.output_dir) 247 | trainer.evaluate() 248 | else: 249 | trainer.train() 250 | trainer.save_state() 251 | wrapped.assist_acc_head.save_pretrained(training_args.output_dir, config=acc_head_config) 252 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths

2 |

3 | Kaixuan Huang Xudong Guo,  4 | Mengdi Wang 5 |
6 | Princeton University 7 |

8 | 9 |

10 | 11 | COLM 2025 & ICML 2024 ES-FoMo workshop
12 |
13 |

14 | 15 |

16 | 17 | arXiv  18 | 19 |

20 | 21 | ----- 22 | We propose SpecDec++, an enhanced version of speculative decoding that adaptively determines the candidate length with the help of a trained acceptance prediction head. Our method can boost the performance of speculative decoding and can be combined with other tricks like fused kernel, quantization, and advanced KV cache management. 23 | 24 | ![./assets/teaser_result.png](./assets/teaser_result.png) 25 | 26 | *Tested with llama-2-chat 7B & 70B model pair (bfloat16) on 2 NVIDIA A100-80G GPUs. 27 | 28 | ---- 29 | 30 | ## Quick Links 31 | 32 | - [Quick Links](#quick-links) 33 | - [Overview of Speculative Decoding](#overview-of-speculative-decoding) 34 | - [Case I: There exists rejected tokens.](#case-i-there-exists-rejected-tokens) 35 | - [Case II: All tokens are accepted.](#case-ii-all-tokens-are-accepted) 36 | - [Problem: Determination of the candidate length $K$.](#problem-determination-of-the-candidate-length-k) 37 | - [Our approach](#our-approach) 38 | - [Performance](#performance) 39 | - [Using `SpecDec++`](#using-specdec) 40 | - [Checkpoint Release \& Sampling Code](#checkpoint-release--sampling-code) 41 | - [Training and Evaluation](#training-and-evaluation) 42 | - [Dataset Preparation](#dataset-preparation) 43 | - [Training the Acceptance Prediction Heads.](#training-the-acceptance-prediction-heads) 44 | - [Benchmarking Performances.](#benchmarking-performances) 45 | - [To benchmark the performance of SpecDec ++, modify and run the following command.](#to-benchmark-the-performance-of-specdec--modify-and-run-the-following-command) 46 | - [To benchmark the performance of SpecDec, modify and run the following command.](#to-benchmark-the-performance-of-specdec-modify-and-run-the-following-command) 47 | - [To benchmark the performance without speculative decoding, modify and run the following command.](#to-benchmark-the-performance-without-speculative-decoding-modify-and-run-the-following-command) 48 | - [Sample results](#sample-results) 49 | - [Bugs or Questions](#bugs-or-questions) 50 | - [Citation](#citation) 51 | 52 | 53 | 54 | ---- 55 | 56 | ## Overview of Speculative Decoding 57 | 58 | In speculative decoding, the draft model first generates $K$ tokens. The target model computes their log probabilities *in parallel* and then sequentially determines whether each token is accepted or not. 59 | 60 | ### Case I: There exists rejected tokens. 61 | 62 | Following the first rejected token, the algorithm discards the remaining tokens and corrects the rejected token with a fresh sample from a modified distribution. 63 | 64 |

65 | 66 |

67 | 68 | ### Case II: All tokens are accepted. 69 | 70 | If all tokens are accepted, a new token is sampled from the next-token probability given by the target model and appended to the sequence of accepted tokens, and then the process moves forward. 71 | 72 |

73 | 74 |

75 | 76 | ## Problem: Determination of the candidate length $K$. 77 | 78 | `SpecDec++` aims to find a *theoretically justifiable* approach towards the following problem: what is a proper candidate length that generates as many accepted tokens and wastes as few discarded tokens as possible? 79 | 80 | 81 | 82 | 83 | 84 | 85 | ### Our approach 86 | 87 | 88 | We formalize the dynamic choice of candidate length in speculative decoding as a Markov Decision 89 | Process (MDP). We theoretically show that when the probability that at least one token gets rejected 90 | exceeds a threshold, the optimal action is to stop the speculation and submit it for verification: 91 | 92 | 93 | 94 | 95 | We augment the draft model with a trained acceptance prediction head to predict the conditional acceptance probability of the candidate tokens. `SpecDec++` will stop the current speculation round when the predicted probability that at least one token gets rejected exceeds a threshold. 96 | 97 | 98 | ![./assets/main.png](./assets/main.png) 99 | 100 | ### Performance 101 | 102 | `SpecDec++` has better Pareto frontiers than `SpecDec` on both the in-distribution dataset Alpaca and the two out-of-distribution datasets HumanEval and GSM8K. Please check our paper for more details. 103 | 104 | ![./assets/pareto.png](./assets/pareto.png) 105 | 106 | ----- 107 | 108 | ## Using `SpecDec++` 109 | 110 | **Step 0 (Optional)**: To start with, prepare a conda environment with pytorch installed. If not, you can use the following command. 111 | 112 | ``` 113 | conda create -n specdecpp python=3.11 114 | conda activate specdecpp 115 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia 116 | ``` 117 | 118 | **Step 1**: Clone the repository and install the required packages. 119 | 120 | ``` 121 | git clone git@github.com:Kaffaljidhmah2/SpecDec_pp.git 122 | cd SpecDec_pp 123 | pip install -r requirements.txt 124 | ``` 125 | 126 | 127 | ### Checkpoint Release & Sampling Code 128 | 129 | The checkpoint of our best acceptance prediction head for llama-2-chat 7B & 70B model pair is available at [huggingface hub](https://huggingface.co/hacky/acchead-llama2-chat-7bx70b). 130 | 131 | Please take a look at [specdec_pp/sample.py](specdec_pp/sample.py) for how to use SpecDec++. 132 | 133 | ---- 134 | 135 | 136 | ## Training and Evaluation 137 | 138 | ### Dataset Preparation 139 | 140 | Follow the instructions in [data/readme.md](./data/readme.md) for dataset preparation. After running the code, you should be able to get the Alpaca dataset (`data/alpaca_data/train.json`, `data/alpaca_data/dev.json`, `data/alpaca_data/test.json`), HumanEval dataset (`data/humaneval_data/test.json`), and GSM8K test dataset (`data/gsm8k_test_data/test.json`) for llama-2-chat models. 141 | 142 | ### Training the Acceptance Prediction Heads. 143 | 144 | 145 | Please modify the following code for training. Here `layer` indicates the number of layers of the ResNet prediction head, `weight` is the loss weight for the mismatched tokens for the BCE loss (the weight for the matched tokens is `1`). The mixing ratio can be set via `--mixing_ratio` (default is 0.15). 146 | 147 | ```bash 148 | layer=3 149 | weight=6 150 | draft_model=meta-llama/Llama-2-7b-chat-hf 151 | 152 | WANDB_PROJECT=specdecpp python3 specdec_pp/train.py \ 153 | --data_path data/alpaca_data/train.json \ 154 | --eval_data_path data/alpaca_data/dev.json \ 155 | --output_dir exp-weight${weight}-layer${layer} \ 156 | --model_name_or_path ${draft_model} \ 157 | --bf16 True \ 158 | --per_device_train_batch_size 4 \ 159 | --num_train_epochs 3 \ 160 | --gradient_accumulation_steps 8 \ 161 | --logging_steps 5 \ 162 | --evaluation_strategy epoch \ 163 | --per_device_eval_batch_size 4 \ 164 | --weight_mismatch ${weight} \ 165 | --save_strategy no \ 166 | --warmup_ratio 0.03 \ 167 | --lr_scheduler_type cosine \ 168 | --resnet_num_layers ${layer} \ 169 | --mixing_ratio 0.15 170 | ``` 171 | 172 | ### Benchmarking Performances. 173 | 174 | #### To benchmark the performance of SpecDec ++, modify and run the following command. 175 | 176 | Note: `--num_assistant_tokens_schedule ada` indicates the proposed SpecDec++ method, where the ckeckpoint of the acceptance prediction head should be specified via `--assist_acc_head_dir`. `--stop_threshold` indicates the threshold value (between 0 and 1) used to stop the current speculation round. A larger `stop_threshold` indicates longer speculation rounds. `--bound MIN MAX` indicates the minimum number and the maximum number of candidate tokens for one speculation round. 177 | 178 | ```bash 179 | layer=3 180 | weight=6 181 | thres=0.3 182 | 183 | ckpt=exp-weight${weight}-layer${layer} 184 | 185 | target_model=meta-llama/Llama-2-70b-chat-hf 186 | draft_model=meta-llama/Llama-2-7b-chat-hf 187 | data=data/alpaca_data/test.json 188 | SAVEPATH=test-results-alpaca/weight${weight}-layer${layer}-thres${thres}-bound2_20/ 189 | 190 | python3 specdec_pp/evaluate.py \ 191 | --model_name ${target_model} \ 192 | --assistant_name ${draft_model} \ 193 | --num_assistant_tokens_schedule ada \ 194 | --data_path ${data} \ 195 | --assist_acc_head_dir $ckpt\ 196 | --do_sample \ 197 | --random_seed 42 \ 198 | --save_path ${SAVEPATH} \ 199 | --stop_threshold ${thres} \ 200 | --bound 2 20 201 | ``` 202 | 203 | The result will be stored under the folder `${SAVEPATH}`. 204 | 205 | 206 | 207 | #### To benchmark the performance of SpecDec, modify and run the following command. 208 | 209 | Note: `--num_assistant_tokens_schedule constant` indicates the baseline SpecDec method. `--num_assistant_tokens` means the constant number of candidate tokens generated per speculation round. 210 | 211 | ```bash 212 | target_model=meta-llama/Llama-2-70b-chat-hf 213 | draft_model=meta-llama/Llama-2-7b-chat-hf 214 | K=4 215 | data=data/alpaca_data/test.json 216 | SAVEPATH=test-results-alpaca/baseline-${K}/ 217 | 218 | python3 specdec_pp/evaluate.py \ 219 | --model_name ${target_model} \ 220 | --assistant_name ${draft_model} \ 221 | --num_assistant_tokens_schedule constant \ 222 | --num_assistant_tokens ${K} \ 223 | --data_path ${data} \ 224 | --do_sample \ 225 | --random_seed 42 \ 226 | --save_path ${SAVEPATH} \ 227 | ``` 228 | 229 | #### To benchmark the performance without speculative decoding, modify and run the following command. 230 | 231 | Note: `--num_assistant_tokens_schedule none` indicates the baseline SpecDec method. 232 | 233 | ```bash 234 | target_model=meta-llama/Llama-2-70b-chat-hf 235 | draft_model=meta-llama/Llama-2-7b-chat-hf 236 | data=data/alpaca_data/test.json 237 | SAVEPATH=test-results-alpaca/standalone/ 238 | 239 | python3 specdec_pp/evaluate.py \ 240 | --model_name ${target_model} \ 241 | --assistant_name ${draft_model} \ 242 | --num_assistant_tokens_schedule none \ 243 | --data_path ${data} \ 244 | --do_sample \ 245 | --random_seed 42 \ 246 | --save_path ${SAVEPATH} \ 247 | ``` 248 | 249 | 250 | #### Sample results 251 | 252 | ``` 253 | [ 254 | { 255 | ## key-value pairs for prompt, continuation, prefix, tokens, draft, p_acc, and id 256 | 257 | ## for SpecDec & SpecDec++ 258 | "spec_time": 15.580421447753906, 259 | "num_mismatched_tokens": 20, 260 | "num_LM_call": 67, 261 | "generated_length": 180, 262 | ## for standalone target model / draft model 263 | "target_time": 25.6504251956939, 264 | "draft_time": 2.795105218887329, 265 | "generated_length_target": 203, 266 | "generated_length_draft": 134 267 | } 268 | ] 269 | ``` 270 | 271 | ------ 272 | 273 | 274 | ### Bugs or Questions 275 | 276 | Feel free to send an email to `kaixuanh@princeton.edu` or create a GitHub Issue/Pull request. 277 | 278 | 279 | ### Citation 280 | 281 | If you find this useful in your research, please consider citing our paper. 282 | 283 | ```bibtex 284 | @article{huang2024specdec++, 285 | title={SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths}, 286 | author={Huang, Kaixuan and Guo, Xudong and Wang, Mengdi}, 287 | journal={arXiv preprint arXiv:2405.19715}, 288 | year={2024} 289 | } 290 | ``` 291 | -------------------------------------------------------------------------------- /specdec_pp/hf_generation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # modified from https://raw.githubusercontent.com/huggingface/transformers/v4.34.1/src/transformers/generation/utils.py 18 | 19 | import copy 20 | import inspect 21 | import warnings 22 | from dataclasses import dataclass 23 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 24 | 25 | import torch 26 | import torch.distributed as dist 27 | from torch import nn 28 | 29 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled 30 | from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput 31 | from transformers.models.auto import ( 32 | MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, 33 | MODEL_FOR_CAUSAL_LM_MAPPING, 34 | MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, 35 | MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, 36 | MODEL_FOR_VISION_2_SEQ_MAPPING, 37 | ) 38 | from transformers.utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging 39 | from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint 40 | from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer 41 | from transformers.generation.configuration_utils import GenerationConfig 42 | from transformers.generation.logits_process import ( 43 | EncoderNoRepeatNGramLogitsProcessor, 44 | EncoderRepetitionPenaltyLogitsProcessor, 45 | EpsilonLogitsWarper, 46 | EtaLogitsWarper, 47 | ExponentialDecayLengthPenalty, 48 | ForcedBOSTokenLogitsProcessor, 49 | ForcedEOSTokenLogitsProcessor, 50 | ForceTokensLogitsProcessor, 51 | HammingDiversityLogitsProcessor, 52 | InfNanRemoveLogitsProcessor, 53 | LogitNormalization, 54 | LogitsProcessorList, 55 | MinLengthLogitsProcessor, 56 | MinNewTokensLengthLogitsProcessor, 57 | NoBadWordsLogitsProcessor, 58 | NoRepeatNGramLogitsProcessor, 59 | PrefixConstrainedLogitsProcessor, 60 | RepetitionPenaltyLogitsProcessor, 61 | SequenceBiasLogitsProcessor, 62 | SuppressTokensAtBeginLogitsProcessor, 63 | SuppressTokensLogitsProcessor, 64 | TemperatureLogitsWarper, 65 | TopKLogitsWarper, 66 | TopPLogitsWarper, 67 | TypicalLogitsWarper, 68 | UnbatchedClassifierFreeGuidanceLogitsProcessor, 69 | ) 70 | from transformers.generation.stopping_criteria import ( 71 | MaxLengthCriteria, 72 | MaxTimeCriteria, 73 | StoppingCriteria, 74 | StoppingCriteriaList, 75 | validate_stopping_criteria, 76 | ) 77 | 78 | 79 | if TYPE_CHECKING: 80 | from transformers.modeling_utils import PreTrainedModel 81 | from transformers.streamers import BaseStreamer 82 | 83 | logger = logging.get_logger(__name__) 84 | 85 | if is_accelerate_available(): 86 | from accelerate.hooks import AlignDevicesHook, add_hook_to_module 87 | 88 | from transformers.generation.utils import GenerationMixin, GenerateOutput, GenerationMode, _crop_past_key_values, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput 89 | 90 | 91 | @torch.no_grad() 92 | def my_generate( 93 | model: "PreTrainedModel", 94 | inputs: Optional[torch.Tensor] = None, 95 | generation_config: Optional[GenerationConfig] = None, 96 | logits_processor: Optional[LogitsProcessorList] = None, 97 | stopping_criteria: Optional[StoppingCriteriaList] = None, 98 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 99 | synced_gpus: Optional[bool] = None, 100 | assistant_model: Optional["PreTrainedModel"] = None, 101 | streamer: Optional["BaseStreamer"] = None, 102 | negative_prompt_ids: Optional[torch.Tensor] = None, 103 | negative_prompt_attention_mask: Optional[torch.Tensor] = None, 104 | num_assistant_tokens_schedule: Optional[str] = 'heuristic', 105 | num_assistant_tokens: Optional[int] = None, 106 | oracle_token_num_list: Optional[List[int]] = None, 107 | assist_acc_head: Optional[nn.Module] = None, 108 | stop_threshold: Optional[float] = None, 109 | bound: Optional[List[int]] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | r""" 113 | 114 | Generates sequences of token ids for models with a language modeling head. 115 | 116 | 117 | 118 | Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the 119 | model's default generation configuration. You can override any `generation_config` by passing the corresponding 120 | parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. 121 | 122 | For an overview of generation strategies and code examples, check out the [following 123 | guide](../generation_strategies). 124 | 125 | 126 | 127 | Parameters: 128 | inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): 129 | The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the 130 | method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` 131 | should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of 132 | `input_ids`, `input_values`, `input_features`, or `pixel_values`. 133 | generation_config (`~generation.GenerationConfig`, *optional*): 134 | The generation configuration to be used as base parametrization for the generation call. `**kwargs` 135 | passed to generate matching the attributes of `generation_config` will override them. If 136 | `generation_config` is not provided, the default will be used, which had the following loading 137 | priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model 138 | configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s 139 | default values, whose documentation should be checked to parameterize generation. 140 | logits_processor (`LogitsProcessorList`, *optional*): 141 | Custom logits processors that complement the default logits processors built from arguments and 142 | generation config. If a logit processor is passed that is already created with the arguments or a 143 | generation config an error is thrown. This feature is intended for advanced users. 144 | stopping_criteria (`StoppingCriteriaList`, *optional*): 145 | Custom stopping criteria that complement the default stopping criteria built from arguments and a 146 | generation config. If a stopping criteria is passed that is already created with the arguments or a 147 | generation config an error is thrown. This feature is intended for advanced users. 148 | prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): 149 | If provided, this function constraints the beam search to allowed tokens only at each step. If not 150 | provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and 151 | `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned 152 | on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful 153 | for constrained generation conditioned on the prefix, as described in [Autoregressive Entity 154 | Retrieval](https://arxiv.org/abs/2010.00904). 155 | synced_gpus (`bool`, *optional*): 156 | Whether to continue running the while loop until max_length. Unless overridden this flag will be set to 157 | `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished 158 | generating before other GPUs. Otherwise it'll be set to `False`. 159 | assistant_model (`PreTrainedModel`, *optional*): 160 | An assistant model that can be used to accelerate generation. The assistant model must have the exact 161 | same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model 162 | is much faster than running generation with the model you're calling generate from. As such, the 163 | assistant model should be much smaller. 164 | streamer (`BaseStreamer`, *optional*): 165 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed 166 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing. 167 | negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 168 | The negative prompt needed for some processors such as CFG. The batch size must match the input batch 169 | size. This is an experimental feature, subject to breaking API changes in future versions. 170 | negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 171 | Attention_mask for `negative_prompt_ids`. 172 | kwargs (`Dict[str, Any]`, *optional*): 173 | Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be 174 | forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder 175 | specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. 176 | 177 | Return: 178 | [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` 179 | or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. 180 | 181 | If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible 182 | [`~utils.ModelOutput`] types are: 183 | 184 | - [`~generation.GreedySearchDecoderOnlyOutput`], 185 | - [`~generation.SampleDecoderOnlyOutput`], 186 | - [`~generation.BeamSearchDecoderOnlyOutput`], 187 | - [`~generation.BeamSampleDecoderOnlyOutput`] 188 | 189 | If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible 190 | [`~utils.ModelOutput`] types are: 191 | 192 | - [`~generation.GreedySearchEncoderDecoderOutput`], 193 | - [`~generation.SampleEncoderDecoderOutput`], 194 | - [`~generation.BeamSearchEncoderDecoderOutput`], 195 | - [`~generation.BeamSampleEncoderDecoderOutput`] 196 | """ 197 | 198 | if synced_gpus is None: 199 | if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: 200 | synced_gpus = True 201 | else: 202 | synced_gpus = False 203 | 204 | # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call 205 | model._validate_model_class() 206 | 207 | # priority: `generation_config` argument > `model.generation_config` (the default generation config) 208 | if generation_config is None: 209 | # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, 210 | # two conditions must be met 211 | # 1) the generation config must have been created from the model config (`_from_model_config` field); 212 | # 2) the generation config must have seen no modification since its creation (the hash is the same). 213 | if model.generation_config._from_model_config and model.generation_config._original_object_hash == hash( 214 | model.generation_config 215 | ): 216 | new_generation_config = GenerationConfig.from_model_config(model.config) 217 | if new_generation_config != model.generation_config: 218 | warnings.warn( 219 | "You have modified the pretrained model configuration to control generation. This is a" 220 | " deprecated strategy to control generation and will be removed soon, in a future version." 221 | " Please use and modify the model generation configuration (see" 222 | " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" 223 | ) 224 | model.generation_config = new_generation_config 225 | generation_config = model.generation_config 226 | 227 | generation_config = copy.deepcopy(generation_config) 228 | model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs 229 | generation_config.validate() 230 | model._validate_model_kwargs(model_kwargs.copy()) 231 | 232 | # 2. Set generation parameters if not already defined 233 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 234 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 235 | 236 | if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: 237 | if model_kwargs.get("attention_mask", None) is None: 238 | logger.warning( 239 | "The attention mask and the pad token id were not set. As a consequence, you may observe " 240 | "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." 241 | ) 242 | eos_token_id = generation_config.eos_token_id 243 | if isinstance(eos_token_id, list): 244 | eos_token_id = eos_token_id[0] 245 | logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") 246 | generation_config.pad_token_id = eos_token_id 247 | 248 | # 3. Define model inputs 249 | # inputs_tensor has to be defined 250 | # model_input_name is defined if model-specific keyword input is passed 251 | # otherwise model_input_name is None 252 | # all model-specific keyword inputs are removed from `model_kwargs` 253 | inputs_tensor, model_input_name, model_kwargs = model._prepare_model_inputs( 254 | inputs, generation_config.bos_token_id, model_kwargs 255 | ) 256 | batch_size = inputs_tensor.shape[0] 257 | 258 | # 4. Define other model kwargs 259 | model_kwargs["output_attentions"] = generation_config.output_attentions 260 | model_kwargs["output_hidden_states"] = generation_config.output_hidden_states 261 | # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are 262 | # generating the first new token or not, and we only want to use the embeddings for the first new token) 263 | if not model.config.is_encoder_decoder and model_input_name == "inputs_embeds": 264 | model_kwargs["use_cache"] = True 265 | else: 266 | model_kwargs["use_cache"] = generation_config.use_cache 267 | 268 | accepts_attention_mask = "attention_mask" in set(inspect.signature(model.forward).parameters.keys()) 269 | requires_attention_mask = "encoder_outputs" not in model_kwargs 270 | 271 | if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: 272 | model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation( 273 | inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id 274 | ) 275 | 276 | # decoder-only models should use left-padding for generation 277 | if not model.config.is_encoder_decoder: 278 | # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` 279 | # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. 280 | if ( 281 | generation_config.pad_token_id is not None 282 | and len(inputs_tensor.shape) == 2 283 | and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 284 | ): 285 | logger.warning( 286 | "A decoder-only architecture is being used, but right-padding was detected! For correct " 287 | "generation results, please set `padding_side='left'` when initializing the tokenizer." 288 | ) 289 | 290 | if model.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: 291 | # if model is encoder decoder encoder_outputs are created 292 | # and added to `model_kwargs` 293 | model_kwargs = model._prepare_encoder_decoder_kwargs_for_generation( 294 | inputs_tensor, model_kwargs, model_input_name 295 | ) 296 | 297 | # 5. Prepare `input_ids` which will be used for auto-regressive generation 298 | if model.config.is_encoder_decoder: 299 | input_ids, model_kwargs = model._prepare_decoder_input_ids_for_generation( 300 | batch_size=batch_size, 301 | model_input_name=model_input_name, 302 | model_kwargs=model_kwargs, 303 | decoder_start_token_id=generation_config.decoder_start_token_id, 304 | bos_token_id=generation_config.bos_token_id, 305 | device=inputs_tensor.device, 306 | ) 307 | else: 308 | input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") 309 | 310 | if streamer is not None: 311 | streamer.put(input_ids.cpu()) 312 | 313 | # 6. Prepare `max_length` depending on other stopping criteria. 314 | input_ids_length = input_ids.shape[-1] 315 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None 316 | if generation_config.max_new_tokens is not None: 317 | if not has_default_max_length and generation_config.max_length is not None: 318 | logger.warning( 319 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" 320 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " 321 | "Please refer to the documentation for more information. " 322 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" 323 | ) 324 | generation_config.max_length = generation_config.max_new_tokens + input_ids_length 325 | model._validate_generated_length(generation_config, input_ids_length, has_default_max_length) 326 | 327 | # 7. determine generation mode 328 | generation_mode = model._get_generation_mode(generation_config, assistant_model) 329 | 330 | if streamer is not None and (generation_config.num_beams > 1): 331 | raise ValueError( 332 | "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." 333 | ) 334 | 335 | if model.device.type != input_ids.device.type: 336 | warnings.warn( 337 | "You are calling .generate() with the `input_ids` being on a device type different" 338 | f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" 339 | f" is on {model.device.type}. You may experience unexpected behaviors or slower generation." 340 | " Please make sure that you have put `input_ids` to the" 341 | f" correct device by calling for example input_ids = input_ids.to('{model.device.type}') before" 342 | " running `.generate()`.", 343 | UserWarning, 344 | ) 345 | 346 | # 8. prepare distribution pre_processing samplers 347 | logits_processor = model._get_logits_processor( 348 | generation_config=generation_config, 349 | input_ids_seq_length=input_ids_length, 350 | encoder_input_ids=inputs_tensor, 351 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 352 | logits_processor=logits_processor, 353 | model_kwargs=model_kwargs, 354 | negative_prompt_ids=negative_prompt_ids, 355 | negative_prompt_attention_mask=negative_prompt_attention_mask, 356 | ) 357 | 358 | # 9. prepare stopping criteria 359 | stopping_criteria = model._get_stopping_criteria( 360 | generation_config=generation_config, stopping_criteria=stopping_criteria 361 | ) 362 | # 10. go into different generation modes 363 | if generation_mode == GenerationMode.ASSISTED_GENERATION: 364 | if generation_config.num_return_sequences > 1: 365 | raise ValueError( 366 | "num_return_sequences has to be 1 when doing assisted generate, " 367 | f"but is {generation_config.num_return_sequences}." 368 | ) 369 | if batch_size > 1: 370 | raise ValueError("assisted generate is only supported for batch_size = 1") 371 | if not model_kwargs["use_cache"]: 372 | raise ValueError("assisted generate requires `use_cache=True`") 373 | 374 | # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs 375 | if assistant_model.config.is_encoder_decoder: 376 | assistant_model_kwargs = copy.deepcopy(model_kwargs) 377 | inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( 378 | inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs 379 | ) 380 | assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( 381 | inputs_tensor, assistant_model_kwargs, model_input_name 382 | ) 383 | model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] 384 | 385 | # 12. run assisted generate 386 | return my_assisted_decoding( 387 | model, 388 | input_ids, 389 | assistant_model=assistant_model, 390 | do_sample=generation_config.do_sample, 391 | logits_processor=logits_processor, 392 | logits_warper=model._get_logits_warper(generation_config) if generation_config.do_sample else None, 393 | stopping_criteria=stopping_criteria, 394 | pad_token_id=generation_config.pad_token_id, 395 | eos_token_id=generation_config.eos_token_id, 396 | output_scores=generation_config.output_scores, 397 | return_dict_in_generate=generation_config.return_dict_in_generate, 398 | synced_gpus=synced_gpus, 399 | streamer=streamer, 400 | num_assistant_tokens_schedule=num_assistant_tokens_schedule, 401 | num_assistant_tokens=num_assistant_tokens, 402 | oracle_token_num_list=oracle_token_num_list, 403 | assist_acc_head=assist_acc_head, 404 | stop_threshold=stop_threshold, 405 | bound=bound, 406 | **model_kwargs, 407 | ) 408 | 409 | 410 | def my_assisted_decoding( 411 | model: "PreTrainedModel", 412 | input_ids: torch.LongTensor, 413 | assistant_model: "PreTrainedModel", 414 | do_sample: bool = False, 415 | logits_processor: Optional[LogitsProcessorList] = None, 416 | logits_warper: Optional[LogitsProcessorList] = None, 417 | stopping_criteria: Optional[StoppingCriteriaList] = None, 418 | pad_token_id: Optional[int] = None, 419 | eos_token_id: Optional[Union[int, List[int]]] = None, 420 | output_attentions: Optional[bool] = None, 421 | output_hidden_states: Optional[bool] = None, 422 | output_scores: Optional[bool] = None, 423 | return_dict_in_generate: Optional[bool] = None, 424 | synced_gpus: bool = False, 425 | streamer: Optional["BaseStreamer"] = None, 426 | num_assistant_tokens_schedule: Optional[str] = 'heuristic', 427 | num_assistant_tokens: Optional[int] = None, 428 | oracle_token_num_list: Optional[List[int]] = None, 429 | assist_acc_head: Optional[nn.Module] = None, 430 | stop_threshold: Optional[float] = None, 431 | bound: Optional[List[int]] = None, 432 | **model_kwargs, 433 | ): 434 | r""" 435 | Generates sequences of token ids for models with a language modeling head using **greedy decoding** or 436 | **sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text, 437 | speech-to-text, and vision-to-text models. 438 | 439 | 440 | 441 | In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use 442 | generate() instead. For an overview of generation strategies and code examples, check the [following 443 | guide](../generation_strategies). 444 | 445 | 446 | 447 | Parameters: 448 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 449 | The sequence used as a prompt for the generation. 450 | assistant_model (`PreTrainedModel`, *optional*): 451 | An assistant model that can be used to accelerate generation. The assistant model must have the exact 452 | same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model 453 | is much faster than running generation with the model you're calling generate from. As such, the 454 | assistant model should be much smaller. 455 | do_sample (`bool`, *optional*, defaults to `False`): 456 | Whether or not to use sampling ; use greedy decoding otherwise. 457 | logits_processor (`LogitsProcessorList`, *optional*): 458 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 459 | used to modify the prediction scores of the language modeling head applied at each generation step. 460 | logits_warper (`LogitsProcessorList`, *optional*): 461 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used 462 | to warp the prediction score distribution of the language modeling head applied before multinomial 463 | sampling at each generation step. 464 | stopping_criteria (`StoppingCriteriaList`, *optional*): 465 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 466 | used to tell if the generation loop should stop. 467 | pad_token_id (`int`, *optional*): 468 | The id of the *padding* token. 469 | eos_token_id (`Union[int, List[int]]`, *optional*): 470 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. 471 | output_attentions (`bool`, *optional*, defaults to `False`): 472 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 473 | returned tensors for more details. 474 | output_hidden_states (`bool`, *optional*, defaults to `False`): 475 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 476 | for more details. 477 | output_scores (`bool`, *optional*, defaults to `False`): 478 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. 479 | return_dict_in_generate (`bool`, *optional*, defaults to `False`): 480 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 481 | synced_gpus (`bool`, *optional*, defaults to `False`): 482 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 483 | streamer (`BaseStreamer`, *optional*): 484 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed 485 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing. 486 | model_kwargs: 487 | Additional model specific keyword arguments will be forwarded to the `forward` function of the model. 488 | If model is an encoder-decoder model the kwargs should include `encoder_outputs`. 489 | 490 | Return: 491 | [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or 492 | `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a 493 | [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 494 | `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if 495 | `model.config.is_encoder_decoder=True`. 496 | 497 | Examples: 498 | 499 | ```python 500 | >>> from transformers import ( 501 | ... AutoTokenizer, 502 | ... AutoModelForCausalLM, 503 | ... LogitsProcessorList, 504 | ... MinLengthLogitsProcessor, 505 | ... StoppingCriteriaList, 506 | ... MaxLengthCriteria, 507 | ... ) 508 | 509 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") 510 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") 511 | >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2") 512 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token 513 | >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id 514 | >>> input_prompt = "It might be possible to" 515 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids 516 | >>> # instantiate logits processors 517 | >>> logits_processor = LogitsProcessorList( 518 | ... [ 519 | ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), 520 | ... ] 521 | ... ) 522 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) 523 | >>> outputs = model.assisted_decoding( 524 | ... input_ids, 525 | ... assistant_model=assistant_model, 526 | ... logits_processor=logits_processor, 527 | ... stopping_criteria=stopping_criteria, 528 | ... ) 529 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) 530 | ["It might be possible to get a better understanding of the nature of the problem, but it's not"] 531 | ```""" 532 | # Assistant: initialize assistant-related variables 533 | if num_assistant_tokens_schedule is None: # default to heuristic 534 | num_assistant_tokens_schedule = 'heuristic' 535 | 536 | if num_assistant_tokens is not None: 537 | assistant_model.max_assistant_tokens = num_assistant_tokens 538 | logger.warning("Setting initial assistant model max_assistant_tokens to %d" % num_assistant_tokens) 539 | else: 540 | if not hasattr(assistant_model, "max_assistant_tokens") or assistant_model.max_assistant_tokens is None: 541 | assistant_model.max_assistant_tokens = 5 # default to 5 542 | # this value, which will be updated if heuristic num_assistant_tokens_schedule is applied, persists across calls 543 | 544 | # init values 545 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 546 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() 547 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 548 | pad_token_id = pad_token_id if pad_token_id is not None else model.generation_config.pad_token_id 549 | eos_token_id = eos_token_id if eos_token_id is not None else model.generation_config.eos_token_id 550 | if eos_token_id is not None and pad_token_id is None: 551 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 552 | if isinstance(eos_token_id, int): 553 | eos_token_id = [eos_token_id] 554 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None 555 | output_scores = output_scores if output_scores is not None else model.generation_config.output_scores 556 | output_attentions = ( 557 | output_attentions if output_attentions is not None else model.generation_config.output_attentions 558 | ) 559 | output_hidden_states = ( 560 | output_hidden_states if output_hidden_states is not None else model.generation_config.output_hidden_states 561 | ) 562 | return_dict_in_generate = ( 563 | return_dict_in_generate 564 | if return_dict_in_generate is not None 565 | else model.generation_config.return_dict_in_generate 566 | ) 567 | 568 | # init attention / hidden states / scores tuples 569 | scores = () if (return_dict_in_generate and output_scores) else None 570 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 571 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 572 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 573 | 574 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 575 | if return_dict_in_generate and model.config.is_encoder_decoder: 576 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 577 | encoder_hidden_states = ( 578 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 579 | ) 580 | 581 | # keep track of which sequences are already finished 582 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 583 | 584 | # other auxiliary variables 585 | max_len = stopping_criteria[0].max_length 586 | assistant_kv_indexing = ( 587 | 1 588 | if "bloom" in assistant_model.__class__.__name__.lower() 589 | or ( 590 | assistant_model.config.architectures is not None 591 | and "bloom" in assistant_model.config.architectures[0].lower() 592 | ) 593 | else 0 594 | ) 595 | 596 | this_peer_finished = False # used by synced_gpus only 597 | num_mismatched_tokens = 0 598 | assist_rounds = 0 599 | 600 | while True: 601 | if synced_gpus: 602 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 603 | # The following logic allows an early break if all peers finished generating their sequence 604 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 605 | # send 0.0 if we finished, 1.0 otherwise 606 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 607 | # did all peers finish? the reduced sum will be 0.0 then 608 | if this_peer_finished_flag.item() == 0.0: 609 | break 610 | 611 | # Assistant: main logic start 612 | cur_len = input_ids.shape[-1] 613 | 614 | # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a 615 | # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we 616 | # need access to the assistant cache to secure strong speedups. 617 | candidate_input_ids = input_ids 618 | q_prob = [] 619 | assist_steps = 0 620 | 621 | cum_acc_prob = 1. # used for 'ada' schedule 622 | 623 | while True: 624 | # for _ in range(int(assistant_model.max_assistant_tokens)): 625 | if num_assistant_tokens_schedule != 'ada' and assist_steps >= assistant_model.max_assistant_tokens: 626 | break 627 | # 1.1. use the assistant model to obtain the next candidate logits 628 | if "assistant_past_key_values" in model_kwargs: 629 | prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] 630 | # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) 631 | new_token_len = candidate_input_ids.shape[1] - prev_seq_len 632 | assert new_token_len > 0, 'might have bug!' 633 | assist_inputs = candidate_input_ids[:, -new_token_len:] 634 | assist_attn = torch.ones_like(candidate_input_ids) 635 | # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 636 | if assistant_model.config.is_encoder_decoder: 637 | assistant_model_outputs = assistant_model( 638 | decoder_input_ids=assist_inputs, 639 | decoder_attention_mask=assist_attn, 640 | past_key_values=model_kwargs["assistant_past_key_values"], 641 | encoder_outputs=model_kwargs["assistant_encoder_outputs"], 642 | output_hidden_states = True if num_assistant_tokens_schedule == 'ada' else False, 643 | ) 644 | else: 645 | assistant_model_outputs = assistant_model( 646 | assist_inputs, 647 | attention_mask=assist_attn, 648 | past_key_values=model_kwargs["assistant_past_key_values"], 649 | output_hidden_states = True if num_assistant_tokens_schedule == 'ada' else False, 650 | ) 651 | else: 652 | if assistant_model.config.is_encoder_decoder: 653 | assistant_model_outputs = assistant_model( 654 | decoder_input_ids=candidate_input_ids, 655 | encoder_outputs=model_kwargs["assistant_encoder_outputs"], 656 | output_hidden_states = True if num_assistant_tokens_schedule == 'ada' else False, 657 | ) 658 | else: 659 | assistant_model_outputs = assistant_model(candidate_input_ids, 660 | output_hidden_states = True if num_assistant_tokens_schedule == 'ada' else False, 661 | ) 662 | 663 | 664 | model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values 665 | if len(logits_processor) > 0: 666 | assistant_model_outputs.logits[:, -1, :] = logits_processor( 667 | candidate_input_ids, assistant_model_outputs.logits[:, -1, :] 668 | ) 669 | if len(logits_warper) > 0: 670 | assistant_model_outputs.logits[:, -1, :] = logits_warper( 671 | candidate_input_ids, assistant_model_outputs.logits[:, -1, :] 672 | ) 673 | 674 | # 1.2. greedily select the next candidate token; or do speculative decoding. 675 | if do_sample: 676 | probs = assistant_model_outputs.logits[:, -1, :].softmax(dim=-1) # bs * vocab_size 677 | new_token = torch.multinomial(probs[0, :], num_samples=1) 678 | q_prob.append(probs) 679 | else: 680 | new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) 681 | 682 | candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) 683 | 684 | # 1.3. stop assistant generation on EOS 685 | if eos_token_id_tensor is not None: 686 | last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) 687 | last_assistant_token_is_eos = ( 688 | ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() 689 | ) 690 | if last_assistant_token_is_eos: 691 | break 692 | else: 693 | last_assistant_token_is_eos = False 694 | 695 | # 1.4. stop assistant generation when the max length is reached or when the stop head predcits stop 696 | assist_steps += 1 697 | 698 | # bound = (min, max) 699 | if num_assistant_tokens_schedule == 'ada': 700 | assert assist_acc_head is not None, 'assist_acc_head is None!' 701 | 702 | 703 | ### obtain current acceptance probability with assist_acc_head 704 | hidden_states = assistant_model_outputs.get("hidden_states") # hidden_states[-1] is the last hidden states, size: bs * seq_len * hidden_dim 705 | logits = assist_acc_head(hidden_states[-1][0, -1].float()) 706 | 707 | if stop_threshold is None: 708 | logger.warning("[Deprecated] Stop_threshold not set. using the acceptance of current token instead.") 709 | predicted = logits.argmax(dim = -1) 710 | stop_prediction = (predicted == 0) 711 | else: 712 | 713 | ## stop generation when the estimated P(exists one reject) = 1 - P(all proposed tokens are accepted) exceeds threshold. 714 | 715 | if assist_steps == 1: 716 | acc_prob = 1 # skip the first round as all tokens are verified and there are no proposed tokens. 717 | else: 718 | acc_prob = logits.softmax(dim = -1)[1].item() 719 | cum_acc_prob *= acc_prob 720 | rej_prob = 1 - cum_acc_prob 721 | 722 | stop_prediction = (rej_prob > stop_threshold) 723 | 724 | # bound = (min, max): forces the generated tokens to be inside [min, max] (both boundaries are included) 725 | if bound is not None: 726 | if assist_steps >= bound[1]: 727 | is_stop = True 728 | elif assist_steps < bound[0]: 729 | is_stop = False 730 | else: 731 | is_stop = stop_prediction 732 | else: 733 | is_stop = stop_prediction 734 | 735 | 736 | if is_stop: 737 | break 738 | 739 | candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] 740 | 741 | if candidate_length == 0: 742 | last_assistant_token_is_eos = False 743 | 744 | # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain 745 | # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, 746 | # we use this forward pass to also pick the subsequent logits in the original model. 747 | 748 | # 2.1. Run a forward pass on the candidate sequence 749 | if "past_key_values" in model_kwargs: 750 | model_attn = torch.ones_like(candidate_input_ids) 751 | model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] 752 | if model.config.is_encoder_decoder: 753 | outputs = model( 754 | decoder_input_ids=model_input_ids, 755 | decoder_attention_mask=model_attn, 756 | past_key_values=model_kwargs["past_key_values"], 757 | encoder_outputs=model_kwargs["encoder_outputs"], 758 | output_attentions=output_attentions, 759 | output_hidden_states=output_hidden_states, 760 | use_cache=True, 761 | ) 762 | else: 763 | outputs = model( 764 | model_input_ids, 765 | attention_mask=model_attn, 766 | past_key_values=model_kwargs["past_key_values"], 767 | output_attentions=output_attentions, 768 | output_hidden_states=output_hidden_states, 769 | use_cache=True, 770 | ) 771 | else: 772 | if model.config.is_encoder_decoder: 773 | outputs = model( 774 | decoder_input_ids=candidate_input_ids, 775 | encoder_outputs=model_kwargs["encoder_outputs"], 776 | output_attentions=output_attentions, 777 | output_hidden_states=output_hidden_states, 778 | use_cache=True, 779 | ) 780 | else: 781 | outputs = model( 782 | candidate_input_ids, 783 | output_attentions=output_attentions, 784 | output_hidden_states=output_hidden_states, 785 | use_cache=True, 786 | ) 787 | 788 | # 2.2. Process the new logits 789 | new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present 790 | if len(logits_processor) > 0: 791 | for i in range(candidate_length + 1): 792 | new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) 793 | if len(logits_warper) > 0: 794 | for i in range(candidate_length + 1): 795 | new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) 796 | 797 | # 3. Obtain the next tokens from the original model logits. 798 | if do_sample: 799 | # speculative decoding logit here. 800 | 801 | probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) # bs(1) * (candidate_length+1) * vocab_size 802 | p_prob, next_p = probs[:, :-1], probs[:, -1] 803 | 804 | if candidate_length == 0: 805 | n_matches = 0 806 | else: 807 | q_prob = torch.stack(q_prob, dim=1) # bs(1) * candidate_length * vocab_size 808 | 809 | candidate_index = candidate_input_ids[:, -candidate_length:, None] 810 | 811 | q_candidate = q_prob.gather(-1, candidate_index).squeeze(-1) 812 | p_candidate = p_prob.gather(-1, candidate_index).squeeze(-1) 813 | r_candidate = torch.rand_like(q_candidate, device = q_candidate.device) 814 | n_matches = ((r_candidate > (p_candidate/q_candidate)).cumsum(dim=-1) < 1).sum() 815 | 816 | else: 817 | 818 | # greedy decoding logic. 819 | selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1) 820 | 821 | # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep 822 | # the assistant forecasted tokens until the first mismatch, or until the max length is reached. 823 | if candidate_length == 0: 824 | n_matches = 0 825 | else: 826 | candidate_new_tokens = candidate_input_ids[:, -candidate_length:] 827 | n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() 828 | 829 | 830 | 831 | 832 | 833 | 834 | # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated 835 | # by the model after the last candidate match is also valid, as it is generated from a correct sequence. 836 | # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there 837 | # is no match. 838 | 839 | num_mismatched_tokens += (candidate_length - n_matches) 840 | 841 | 842 | if do_sample: 843 | # for speculative decoding 844 | candidate_new_tokens = candidate_input_ids[:, (candidate_input_ids.shape[1] - candidate_length):] 845 | valid_tokens = candidate_new_tokens[:, : n_matches] 846 | 847 | # for the last token 848 | ## case 1. rejected some token. resample from [p-q]+, index = n_matches 849 | if n_matches < candidate_length: 850 | next_p = p_prob[:, n_matches, :] - q_prob[:, n_matches, :] # bs(1) * vocab_size 851 | next_p.clamp_(min=0.) 852 | next_p = next_p / next_p.sum(dim = -1, keepdim=True) 853 | 854 | ## case 2. all tokens accepted: sample from next_p (defined before) 855 | 856 | new_added_token = torch.multinomial(next_p, num_samples=1) 857 | valid_tokens = torch.cat((valid_tokens, new_added_token), dim=-1) 858 | 859 | else: 860 | # for greedy decoding 861 | # 5.2. Get the valid continuation, after the matching tokens 862 | valid_tokens = selected_tokens[:, : n_matches + 1] 863 | 864 | # 5.1. Ensure we don't generate beyond max_len or an EOS token 865 | if last_assistant_token_is_eos and n_matches == candidate_length: 866 | n_matches -= 1 867 | n_matches = min(n_matches, max_len - cur_len - 1) 868 | valid_tokens = valid_tokens[:, : n_matches + 1] 869 | 870 | input_ids = torch.cat((input_ids, valid_tokens), dim=-1) 871 | if streamer is not None: 872 | streamer.put(valid_tokens.cpu()) 873 | new_cur_len = input_ids.shape[-1] 874 | 875 | 876 | # 5.3. Discard past key values relative to unused assistant tokens 877 | new_cache_size = new_cur_len - 1 878 | outputs.past_key_values = _crop_past_key_values(model, outputs.past_key_values, new_cache_size) 879 | if "assistant_past_key_values" in model_kwargs: 880 | model_kwargs["assistant_past_key_values"] = _crop_past_key_values( 881 | assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1 882 | ) # the assistant does not have the token after the last match, hence the -1 883 | 884 | # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, 885 | # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the 886 | # cost of forecasting incorrect assistant tokens. 887 | 888 | assist_rounds += 1 889 | if num_assistant_tokens_schedule == 'heuristic': 890 | if n_matches == int(assistant_model.max_assistant_tokens): 891 | assistant_model.max_assistant_tokens += 2.0 892 | else: 893 | assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) 894 | elif num_assistant_tokens_schedule == 'oracle': 895 | if assist_rounds < len(oracle_token_num_list): 896 | assistant_model.max_assistant_tokens = oracle_token_num_list[assist_rounds] 897 | else: 898 | logger.warning("warning. assist_rounds exceed len(oracle_token_num_list)") 899 | # print("oracle token num: %d" % assistant_model.max_assistant_tokens) 900 | 901 | 902 | # Assistant: main logic end 903 | 904 | if synced_gpus and this_peer_finished: 905 | continue # don't waste resources running the code we don't need 906 | 907 | # Store scores, attentions and hidden_states when required 908 | # Assistant: modified to append one tuple element per token, as in the other generation methods. 909 | if return_dict_in_generate: 910 | if output_scores: 911 | scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) 912 | 913 | if "past_key_values" not in model_kwargs: 914 | added_len = new_cur_len 915 | else: 916 | added_len = n_matches + 1 917 | 918 | if output_attentions: 919 | if model.config.is_encoder_decoder: 920 | cross_attentions = _split_model_outputs( 921 | cross_attentions, outputs.cross_attentions, cur_len, added_len 922 | ) 923 | decoder_attentions = _split_model_outputs( 924 | decoder_attentions, 925 | outputs.decoder_attentions, 926 | cur_len, 927 | added_len, 928 | is_decoder_attention=True, 929 | ) 930 | else: 931 | decoder_attentions = _split_model_outputs( 932 | decoder_attentions, 933 | outputs.attentions, 934 | cur_len, 935 | added_len, 936 | is_decoder_attention=True, 937 | ) 938 | if output_hidden_states: 939 | if model.config.is_encoder_decoder: 940 | decoder_hidden_states = _split_model_outputs( 941 | decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len 942 | ) 943 | else: 944 | decoder_hidden_states = _split_model_outputs( 945 | decoder_hidden_states, outputs.hidden_states, cur_len, added_len 946 | ) 947 | 948 | model_kwargs = model._update_model_kwargs_for_generation( 949 | outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder 950 | ) 951 | 952 | # if eos_token was found in one sentence, set sentence to finished 953 | if eos_token_id_tensor is not None: 954 | unfinished_sequences = unfinished_sequences.mul( 955 | input_ids[:, -1] 956 | .tile(eos_token_id_tensor.shape[0], 1) 957 | .ne(eos_token_id_tensor.unsqueeze(1)) 958 | .prod(dim=0) 959 | ) 960 | 961 | # stop when each sentence is finished 962 | if unfinished_sequences.max() == 0: 963 | this_peer_finished = True 964 | 965 | # stop if we exceed the maximum length 966 | if stopping_criteria(input_ids, scores): 967 | this_peer_finished = True 968 | 969 | if this_peer_finished and not synced_gpus: 970 | break 971 | 972 | if streamer is not None: 973 | streamer.end() 974 | 975 | if return_dict_in_generate: 976 | if model.config.is_encoder_decoder: 977 | return GreedySearchEncoderDecoderOutput( 978 | sequences=input_ids, 979 | scores=scores, 980 | encoder_attentions=encoder_attentions, 981 | encoder_hidden_states=encoder_hidden_states, 982 | decoder_attentions=decoder_attentions, 983 | cross_attentions=cross_attentions, 984 | decoder_hidden_states=decoder_hidden_states, 985 | ) 986 | else: 987 | return GreedySearchDecoderOnlyOutput( 988 | sequences=input_ids, 989 | scores=scores, 990 | attentions=decoder_attentions, 991 | hidden_states=decoder_hidden_states, 992 | ) 993 | else: 994 | return input_ids, num_mismatched_tokens.item(), assist_rounds 995 | 996 | --------------------------------------------------------------------------------