'
75 | print("sanity check passed")
76 |
77 |
78 | def infer(prompt, tokenizer, model, max_length = 32, do_sample=False):
79 |
80 | device = "cuda"
81 | inputs = tokenizer(prompt, return_tensors="pt").to(device)
82 | max_length += len(inputs['input_ids'][0])
83 | model_output = model.generate(**inputs, do_sample=do_sample, max_length = max_length)[0]
84 |
85 | ret = tokenizer.decode(model_output, skip_special_tokens=True)
86 |
87 |
88 | # Is it possible that this is not reversible? after decoding, the tokens changed? Yes!
89 |
90 | #re_token = tokenizer.encode(ret, return_tensors='pt').to(device)
91 | #if re_token.shape != model_output.shape or (re_token - model_output).sum().item() != 0:
92 | #print("mismatch!")
93 |
94 | prefix_token_id = model_output[:len(inputs['input_ids'][0])]
95 | gen_token_id = model_output[len(inputs['input_ids'][0]) :]
96 | ret = ret[len(prompt):] # remove prefix
97 |
98 | return prefix_token_id.cpu().tolist() , gen_token_id.cpu().tolist(), ret
99 |
100 |
101 |
102 | def parse_args():
103 | parser = argparse.ArgumentParser(description='data generator')
104 |
105 | parser.add_argument('--dataset_name', type=str)
106 | parser.add_argument('--model_name', type=str, choices=["7b", "13b", "70b"])
107 | parser.add_argument('--mode', type=str, choices=['hf'])
108 | parser.add_argument('--do_sample', action='store_true')
109 | parser.add_argument('--n_begin', type=int, default=0)
110 | parser.add_argument('--n_end', type=int, default=-1)
111 | parser.add_argument('--max_length', type=int, default=512)
112 | parser.add_argument('--output_file', type=str, default=None)
113 |
114 |
115 |
116 | args = parser.parse_args()
117 |
118 | return args
119 |
120 |
121 |
122 | def main(args):
123 | #sanity_check(dataset[0], tokenizer)
124 |
125 | print(f"we are using do sample = {args.do_sample}")
126 |
127 | tokenizer, model = get_model(args.model_name)
128 | dataset = get_dataset(args.dataset_name)
129 |
130 | if args.n_end == -1:
131 | args.n_end = len(dataset)
132 | args.n_end = min(args.n_end, len(dataset))
133 |
134 | res_dict = []
135 | for i in tqdm(range(args.n_begin, args.n_end)):
136 | sample = dataset[i]
137 | prompt = get_prompt(sample, args.dataset_name)
138 | prefix_token, gen_token, s = infer(prompt, tokenizer, model, max_length=args.max_length, do_sample=args.do_sample)
139 | res_dict.append(
140 | {
141 | 'prompt': prompt,
142 | 'continuation': s,
143 | 'prefix': str(prefix_token) if prefix_token is not None else "",
144 | 'tokens': str(gen_token) if gen_token is not None else ""
145 | }
146 | )
147 |
148 | if args.output_file is None:
149 | args.output_file = f'dataset{args.n_begin}to{args.n_end}_{args.mode}{args.model_name}.json'
150 | with open(args.output_file, 'w') as f:
151 | f.write(json.dumps(res_dict, indent=2))
152 |
153 |
154 | if __name__ == "__main__":
155 | args = parse_args()
156 | main(args)
157 |
--------------------------------------------------------------------------------
/specdec_pp/evaluate.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
2 | import time
3 | from datetime import datetime
4 | import torch
5 | from hf_generation import my_generate
6 | import argparse
7 | import json
8 | import os
9 | import numpy as np
10 | from ast import literal_eval as eval
11 |
12 | device = "cuda"
13 |
14 | def set_up(args):
15 | if args.do_sample:
16 | print("do_sample for SpeculativeDecoding")
17 |
18 | np.random.seed(args.random_seed)
19 | torch.manual_seed(args.random_seed)
20 |
21 | checkpoint = args.model_name
22 | assistant_checkpoint = args.assistant_name
23 |
24 | tokenizer = AutoTokenizer.from_pretrained(checkpoint)
25 |
26 | assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, torch_dtype=torch.bfloat16, device_map='cuda:0')
27 |
28 | if args.num_assistant_tokens_schedule == 'ada':
29 | from wrap_model import AcceptancePredictionHead
30 | print("Loading from acc_head checkpoint:", args.assist_acc_head_dir)
31 |
32 | assist_acc_head = AcceptancePredictionHead.from_pretrained(args.assist_acc_head_dir).to('cuda:0')
33 |
34 | else:
35 | assist_acc_head = None
36 |
37 | model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')
38 |
39 | # print(model.hf_device_map)
40 | # print(assistant_model.hf_device_map)
41 | return model, assistant_model, tokenizer, assist_acc_head
42 |
43 | def assist(model, assistant_model, tokenizer, assist_acc_head, inputs, max_length, num_assistant_tokens=None):
44 | # outputs = model.generate(**inputs, generation_config=generation_config, assistant_model=assistant_model, max_length=max_length)
45 | before=time.time()
46 | assistant_model.max_assistant_tokens = None
47 | outputs, mismatched_tokens, LM_call = my_generate(model=model, **inputs, assistant_model=assistant_model, \
48 | max_length=max_length, num_assistant_tokens_schedule=args.num_assistant_tokens_schedule, \
49 | num_assistant_tokens=num_assistant_tokens, do_sample=args.do_sample, \
50 | assist_acc_head=assist_acc_head, \
51 | stop_threshold=args.stop_threshold, bound=args.bound)
52 |
53 | after = time.time()
54 | assisted_time = (after - before)
55 |
56 | print("assisted time: {:.2f}".format(assisted_time))
57 | print("mismatched_tokens: {:.2f}".format(mismatched_tokens))
58 | print("LM_call: {:.2f}".format(LM_call))
59 |
60 | return outputs, mismatched_tokens, LM_call, assisted_time
61 |
62 | def target(model, tokenizer, inputs, max_length):
63 | before=time.time()
64 | outputs = model.generate(**inputs, max_length=max_length, do_sample=args.do_sample)
65 | after = time.time()
66 | target_time = (after - before)
67 | print("target time {:.2f}".format(target_time))
68 | return outputs, target_time
69 |
70 | def draft(assistant_model, tokenizer, inputs, max_length):
71 | before=time.time()
72 | outputs = assistant_model.generate(**inputs, max_length=max_length, do_sample=args.do_sample)
73 | after = time.time()
74 | draft_time = (after - before)
75 | print("draft time {:.2f}".format(draft_time))
76 | return outputs, draft_time
77 |
78 |
79 | def run(model, assistant_model, tokenizer, assist_acc_head, args, item):
80 | len_prefix = len(eval(item['prefix']))
81 | inputs = {'input_ids': torch.LongTensor([eval(item['prefix'])]).to(device)}
82 | max_length = args.max_length
83 | print("max_length:", max_length)
84 |
85 |
86 | if args.num_assistant_tokens_schedule in ['constant', 'heuristic', 'ada']:
87 |
88 | if args.num_assistant_tokens_schedule == 'ada':
89 | num_assistant_tokens = None
90 | else:
91 | num_assistant_tokens = args.num_assistant_tokens
92 | print("num_assistant_tokens:", num_assistant_tokens)
93 |
94 | res_a, num_mismatched_tokens, num_LM_call, assisted_time = assist(model, assistant_model, tokenizer, assist_acc_head, inputs, max_length, num_assistant_tokens=num_assistant_tokens)
95 | elif args.num_assistant_tokens_schedule == 'none':
96 | res_a = [[-1]]
97 | num_mismatched_tokens = -1
98 | num_LM_call = -1
99 | assisted_time = -1
100 | else:
101 | raise ValueError(f"{args.num_assistant_tokens_schedule} not supported")
102 |
103 |
104 | if args.num_assistant_tokens_schedule == 'none':
105 | res_b, target_time = target(model, tokenizer, inputs, max_length)
106 | generated_length_target = len(res_b[0]) - len_prefix
107 |
108 | res_c, draft_time = draft(assistant_model, tokenizer, inputs, max_length)
109 | generated_length_draft = len(res_c[0]) - len_prefix
110 | else:
111 | target_time = -1
112 | generated_length_target = -1
113 | draft_time = -1
114 | generated_length_draft = -1
115 |
116 |
117 |
118 | generated_length = len(res_a[0]) - len_prefix
119 | print("generated_length: {:.2f}".format(generated_length))
120 |
121 | return assisted_time, target_time, draft_time, num_mismatched_tokens, num_LM_call, generated_length, generated_length_target, generated_length_draft
122 |
123 | def parse_args():
124 | parser = argparse.ArgumentParser(description='benchmark performance')
125 |
126 | parser.add_argument('--model_name', type=str, default=None)
127 | parser.add_argument('--assistant_name', type=str, default=None)
128 | parser.add_argument('--max_length', type=int, default=512)
129 | parser.add_argument('--do_sample', action='store_true')
130 |
131 | parser.add_argument('--num_assistant_tokens', type=int, default=5)
132 | parser.add_argument('--num_assistant_tokens_schedule', type=str, default="constant", choices=['constant', 'heuristic', 'ada', 'none'])
133 | parser.add_argument('--assist_acc_head_dir', type=str, default=None)
134 | parser.add_argument('--data_path', type=str, default='data/alpaca_data/test.json')
135 | parser.add_argument('--save_path', type=str, default='./test_results')
136 | parser.add_argument('--random_seed', type=int, default=47)
137 | parser.add_argument('--stop_threshold', type=float, default=None)
138 | parser.add_argument('--bound', nargs='+', type=int, default=None)
139 |
140 | parser.add_argument('--n_begin', type=int, default=0)
141 | parser.add_argument('--n_end', type=int, default=None)
142 |
143 | args = parser.parse_args()
144 | print(args)
145 |
146 | return args
147 |
148 |
149 | if __name__ == "__main__":
150 | args = parse_args()
151 | data = json.load(open(args.data_path,'r'))
152 | if args.n_end is None:
153 | args.n_end = len(data)
154 | args.n_end = min(len(data), args.n_end)
155 |
156 |
157 | os.makedirs(args.save_path, exist_ok=True)
158 |
159 | model, assistant_model, tokenizer, assist_acc_head = set_up(args)
160 |
161 | results = []
162 |
163 | for i, item in enumerate(data[args.n_begin:args.n_end]):
164 | print("---------------------------------")
165 | print(f"data {i + args.n_begin}")
166 | before=time.time()
167 |
168 | assisted_time, target_time, draft_time, num_mismatched_tokens, num_LM_call, generated_length, generated_length_target, generated_length_draft = run(model, assistant_model, tokenizer, assist_acc_head, args, item)
169 | item.update({
170 | 'id': i+args.n_begin,
171 | 'spec_time': assisted_time,
172 | 'target_time': target_time,
173 | 'draft_time': draft_time,
174 | 'num_mismatched_tokens': num_mismatched_tokens,
175 | 'num_LM_call': num_LM_call,
176 | 'generated_length': generated_length,
177 | 'generated_length_target': generated_length_target,
178 | 'generated_length_draft': generated_length_draft,
179 | })
180 | results.append(item)
181 |
182 | after=time.time()
183 | print("total time: {:.2f}".format(after-before))
184 | save_file = f"{args.save_path}/results_{args.n_begin}to{args.n_end}.json"
185 | with open(save_file, 'w') as f:
186 | f.write(json.dumps(results, indent=2))
187 |
--------------------------------------------------------------------------------
/specdec_pp/train.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py
2 |
3 | import copy
4 | import logging
5 | from dataclasses import dataclass, field
6 | from typing import Dict, Optional, Sequence, List, TYPE_CHECKING, Any, Callable, Tuple, Union
7 |
8 | import torch
9 | import transformers
10 | from torch.nn import CrossEntropyLoss
11 | from torch.utils.data import Dataset
12 | from transformers import Trainer
13 | import json
14 | import numpy
15 | import scipy.special
16 | from ast import literal_eval as eval
17 |
18 | from wrap_model import WrapModel, AcceptancePredictionHead
19 | from transformers import EvalPrediction
20 |
21 | IGNORE_INDEX = -100
22 | DEFAULT_PAD_TOKEN = "[PAD]"
23 |
24 | def compute_metrics(eval_pred: "EvalPrediction") -> Dict:
25 | num_class = 2
26 | logits= eval_pred[0]
27 | soft_labels = eval_pred[1]
28 |
29 | logits = logits.reshape(-1, num_class)
30 | soft_labels = soft_labels.reshape(-1)
31 |
32 | not_ignore = (soft_labels - IGNORE_INDEX) > 0.1
33 |
34 | target_prob = soft_labels[not_ignore]
35 | logits = logits[not_ignore]
36 | predicted_log_prob = scipy.special.log_softmax(logits, axis=-1)
37 |
38 | # KL divergence:
39 | CrossEnt = target_prob * ( - predicted_log_prob[:,1]) + (1-target_prob) * ( - predicted_log_prob[:,0])
40 | Ent = target_prob * numpy.log(target_prob) + (1-target_prob) * numpy.log(1-target_prob)
41 | Ent[numpy.isnan(Ent)] = 0. # hack for binary entropy
42 | KL_binary = CrossEnt - Ent
43 | KL_binary = numpy.mean(KL_binary)
44 |
45 | return {'KL': KL_binary}
46 |
47 |
48 | class MyTrainer(Trainer):
49 |
50 | def compute_loss(self, model, inputs, return_outputs=False):
51 | soft_labels = inputs.pop('soft_labels')
52 | mask = (soft_labels - IGNORE_INDEX).abs() > 0.1
53 |
54 | soft_labels_1 = soft_labels
55 | soft_labels_0 = soft_labels_1.clone()
56 | soft_labels_0[mask] = 1 - soft_labels_1[mask]
57 |
58 | label_0 = torch.ones_like(soft_labels, dtype=torch.long).to(soft_labels.device) * IGNORE_INDEX
59 | label_0[mask] = 0
60 | label_1 = torch.ones_like(soft_labels, dtype=torch.long).to(soft_labels.device) * IGNORE_INDEX
61 | label_1[mask] = 1
62 |
63 | outputs = model.model(**inputs, output_hidden_states = True, return_dict=True)
64 | hidden_states = outputs.get("hidden_states")
65 | orignal_logits = model.assist_acc_head(hidden_states[-1])
66 | orignal_logits = orignal_logits.float()
67 |
68 | num_class = 2
69 |
70 | weight = torch.tensor([self.args.weight_mismatch, 1]).to(orignal_logits.device)
71 | loss_fct = CrossEntropyLoss(weight=weight, reduction='none')
72 |
73 | logits = orignal_logits.view(-1, num_class)
74 | label_0 = label_0.view(-1)
75 | label_1 = label_1.view(-1)
76 | soft_labels_0 = soft_labels_0.view(-1)
77 | soft_labels_1 = soft_labels_1.view(-1)
78 | mask = mask.view(-1)
79 |
80 | loss_0 = loss_fct(logits, label_0) # (bs * seq_len), num_class
81 | loss_1 = loss_fct(logits, label_1) # (bs * seq_len), num_class
82 |
83 | # reduce with soft labels, coresponding to BCELoss
84 | loss = (loss_0 * soft_labels_0 + loss_1 * soft_labels_1).sum() / (self.args.weight_mismatch * soft_labels_0[mask].sum() + soft_labels_1[mask].sum() )
85 |
86 | if model.training:
87 | # KL divergence:
88 | target_prob = soft_labels_1[mask]
89 | predicted_logits = logits[mask, :]
90 | predicted_log_prob = torch.log_softmax(predicted_logits, dim=-1)
91 |
92 | #KL_binary = target_prob * (target_prob.log() - predicted_log_prob[:,1]) + (1-target_prob) * ( (1-target_prob).log() - predicted_log_prob[:,0])
93 |
94 | CrossEnt = target_prob * ( - predicted_log_prob[:,1]) + (1-target_prob) * ( - predicted_log_prob[:,0])
95 | Ent = target_prob * target_prob.log() + (1-target_prob) * (1-target_prob).log()
96 | Ent[Ent.isnan()] = 0. # hack for binary entropy
97 | KL_binary = CrossEnt - Ent
98 | KL_binary = KL_binary.mean().item()
99 |
100 | self.log({'KL': KL_binary})
101 |
102 |
103 | if return_outputs:
104 | outputs = (loss, orignal_logits)
105 | return (loss, outputs)
106 | else:
107 | return loss
108 |
109 | @dataclass
110 | class TrainingArguments(transformers.TrainingArguments):
111 | bf16: bool = True
112 | model_name_or_path: Optional[str] = field(default=None)
113 | data_path: str = field(default=None)
114 | eval_data_path: str = field(default=None)
115 | remove_unused_columns: bool = False
116 | evaluate_only: bool = False
117 | label_names: Optional[List[str]] = field(
118 | default_factory=lambda: ['soft_labels'], metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
119 | )
120 |
121 | weight_mismatch: Optional[float] = field(default = 1.) # 6 for balancing classes
122 | resnet_num_layers: Optional[int] = field(default = 1)
123 | mixing_ratio: Optional[float] = field(default = 0.15)
124 |
125 |
126 | def smart_tokenizer_and_embedding_resize(
127 | special_tokens_dict: Dict,
128 | tokenizer: transformers.PreTrainedTokenizer,
129 | model: transformers.PreTrainedModel,
130 | ):
131 | """Resize tokenizer and embedding.
132 |
133 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
134 | """
135 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
136 | model.resize_token_embeddings(len(tokenizer))
137 |
138 | if num_new_tokens > 0:
139 | input_embeddings = model.get_input_embeddings().weight.data
140 | output_embeddings = model.get_output_embeddings().weight.data
141 |
142 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
143 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
144 |
145 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
146 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
147 |
148 |
149 |
150 | class SupervisedDataset(Dataset):
151 | def __init__(self, data_path: str, r: float = 0.15):
152 | super(SupervisedDataset, self).__init__()
153 | logging.warning(f"Loading data... from {data_path}")
154 | data = json.load(open(data_path,'r'))
155 | self.input_ids = []
156 | self.soft_labels = []
157 | for item in data:
158 | item['prefix'] = eval(item['prefix'])
159 | item['tokens'] = eval(item['tokens'])
160 | item['draft'] = eval(item['draft'])
161 |
162 | # item['tokens'] are generated autoregressively from target model
163 | # item['draft'] are stochatic next-token predicted by the draft model
164 |
165 | item['p_acc'] = eval(item['p_acc'])
166 |
167 | prefix = torch.LongTensor(item['prefix'])
168 | Xs = torch.LongTensor(item['tokens'])
169 | # Ys = torch.LongTensor(item['draft'])
170 |
171 | # take r from Xs and (1-r) from Ys.
172 | mask = (torch.rand(*Xs.shape) < r)
173 | Zs = torch.LongTensor(item['draft'])
174 | Zs[mask] = Xs[mask]
175 |
176 | self.input_ids.append(torch.cat([prefix, Zs]))
177 |
178 | label_prefix = torch.tensor([IGNORE_INDEX] * len(item['prefix']))
179 | p_acc = torch.tensor(item['p_acc'])
180 |
181 | # don't calculate loss on Xs.
182 | p_acc[mask] = IGNORE_INDEX
183 |
184 | self.soft_labels.append(torch.cat([label_prefix, p_acc]))
185 |
186 | def __len__(self):
187 | return len(self.input_ids)
188 |
189 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
190 | return dict(input_ids=self.input_ids[i], soft_labels=self.soft_labels[i])
191 |
192 |
193 | @dataclass
194 | class DataCollatorForSupervisedDataset(object):
195 | """Collate examples for supervised fine-tuning."""
196 |
197 | tokenizer: transformers.PreTrainedTokenizer
198 |
199 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
200 | input_ids, soft_labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "soft_labels"))
201 | input_ids = torch.nn.utils.rnn.pad_sequence(
202 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
203 | )
204 | soft_labels = torch.nn.utils.rnn.pad_sequence(soft_labels, batch_first=True, padding_value=IGNORE_INDEX)
205 | return dict(
206 | input_ids=input_ids,
207 | soft_labels=soft_labels,
208 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
209 | )
210 |
211 |
212 |
213 |
214 | if __name__ == "__main__":
215 | parser = transformers.HfArgumentParser((TrainingArguments))
216 | training_args = parser.parse_args_into_dataclasses()[0]
217 |
218 | tokenizer = transformers.AutoTokenizer.from_pretrained(training_args.model_name_or_path)
219 | model = transformers.AutoModelForCausalLM.from_pretrained(training_args.model_name_or_path)
220 | special_tokens_dict = dict()
221 | if tokenizer.pad_token is None:
222 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
223 |
224 | smart_tokenizer_and_embedding_resize(
225 | special_tokens_dict=special_tokens_dict,
226 | tokenizer=tokenizer,
227 | model=model,
228 | )
229 |
230 | train_dataset = SupervisedDataset(training_args.data_path, r=training_args.mixing_ratio)
231 | if training_args.eval_data_path is not None:
232 | eval_dataset = SupervisedDataset(training_args.eval_data_path, r=training_args.mixing_ratio)
233 | print("num eval example:", len(eval_dataset))
234 | else:
235 | eval_dataset = None
236 | data_collator = DataCollatorForSupervisedDataset(tokenizer)
237 |
238 | acc_head_config = {'hidden_size': model.config.hidden_size, 'num_layers': training_args.resnet_num_layers}
239 | assist_acc_head = AcceptancePredictionHead(acc_head_config)
240 | wrapped = WrapModel(model, assist_acc_head)
241 | wrapped.model.requires_grad_(False)
242 | print('num training example:', len(train_dataset))
243 | trainer = MyTrainer(model=wrapped, tokenizer=tokenizer, args=training_args, train_dataset = train_dataset, eval_dataset = eval_dataset, data_collator=data_collator, compute_metrics = compute_metrics)
244 | if training_args.evaluate_only:
245 | print("eval only. Loading from checkpoint:", training_args.output_dir)
246 | wrapped.assist_acc_head = AcceptancePredictionHead.from_pretrained(training_args.output_dir)
247 | trainer.evaluate()
248 | else:
249 | trainer.train()
250 | trainer.save_state()
251 | wrapped.assist_acc_head.save_pretrained(training_args.output_dir, config=acc_head_config)
252 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths
2 |
3 | Kaixuan Huang, Xudong Guo,
4 | Mengdi Wang
5 |
6 | Princeton University
7 |
8 |
9 |
10 |
11 | COLM 2025 & ICML 2024 ES-FoMo workshop
12 |
13 |
14 |
15 |
16 |
17 | arXiv
18 |
19 |
20 |
21 | -----
22 | We propose SpecDec++, an enhanced version of speculative decoding that adaptively determines the candidate length with the help of a trained acceptance prediction head. Our method can boost the performance of speculative decoding and can be combined with other tricks like fused kernel, quantization, and advanced KV cache management.
23 |
24 | 
25 |
26 | *Tested with llama-2-chat 7B & 70B model pair (bfloat16) on 2 NVIDIA A100-80G GPUs.
27 |
28 | ----
29 |
30 | ## Quick Links
31 |
32 | - [Quick Links](#quick-links)
33 | - [Overview of Speculative Decoding](#overview-of-speculative-decoding)
34 | - [Case I: There exists rejected tokens.](#case-i-there-exists-rejected-tokens)
35 | - [Case II: All tokens are accepted.](#case-ii-all-tokens-are-accepted)
36 | - [Problem: Determination of the candidate length $K$.](#problem-determination-of-the-candidate-length-k)
37 | - [Our approach](#our-approach)
38 | - [Performance](#performance)
39 | - [Using `SpecDec++`](#using-specdec)
40 | - [Checkpoint Release \& Sampling Code](#checkpoint-release--sampling-code)
41 | - [Training and Evaluation](#training-and-evaluation)
42 | - [Dataset Preparation](#dataset-preparation)
43 | - [Training the Acceptance Prediction Heads.](#training-the-acceptance-prediction-heads)
44 | - [Benchmarking Performances.](#benchmarking-performances)
45 | - [To benchmark the performance of SpecDec ++, modify and run the following command.](#to-benchmark-the-performance-of-specdec--modify-and-run-the-following-command)
46 | - [To benchmark the performance of SpecDec, modify and run the following command.](#to-benchmark-the-performance-of-specdec-modify-and-run-the-following-command)
47 | - [To benchmark the performance without speculative decoding, modify and run the following command.](#to-benchmark-the-performance-without-speculative-decoding-modify-and-run-the-following-command)
48 | - [Sample results](#sample-results)
49 | - [Bugs or Questions](#bugs-or-questions)
50 | - [Citation](#citation)
51 |
52 |
53 |
54 | ----
55 |
56 | ## Overview of Speculative Decoding
57 |
58 | In speculative decoding, the draft model first generates $K$ tokens. The target model computes their log probabilities *in parallel* and then sequentially determines whether each token is accepted or not.
59 |
60 | ### Case I: There exists rejected tokens.
61 |
62 | Following the first rejected token, the algorithm discards the remaining tokens and corrects the rejected token with a fresh sample from a modified distribution.
63 |
64 |
65 |
66 |
67 |
68 | ### Case II: All tokens are accepted.
69 |
70 | If all tokens are accepted, a new token is sampled from the next-token probability given by the target model and appended to the sequence of accepted tokens, and then the process moves forward.
71 |
72 |
73 |
74 |
75 |
76 | ## Problem: Determination of the candidate length $K$.
77 |
78 | `SpecDec++` aims to find a *theoretically justifiable* approach towards the following problem: what is a proper candidate length that generates as many accepted tokens and wastes as few discarded tokens as possible?
79 |
80 |
81 |
82 |
83 |
84 |
85 | ### Our approach
86 |
87 |
88 | We formalize the dynamic choice of candidate length in speculative decoding as a Markov Decision
89 | Process (MDP). We theoretically show that when the probability that at least one token gets rejected
90 | exceeds a threshold, the optimal action is to stop the speculation and submit it for verification:
91 |
92 |
93 |
94 |
95 | We augment the draft model with a trained acceptance prediction head to predict the conditional acceptance probability of the candidate tokens. `SpecDec++` will stop the current speculation round when the predicted probability that at least one token gets rejected exceeds a threshold.
96 |
97 |
98 | 
99 |
100 | ### Performance
101 |
102 | `SpecDec++` has better Pareto frontiers than `SpecDec` on both the in-distribution dataset Alpaca and the two out-of-distribution datasets HumanEval and GSM8K. Please check our paper for more details.
103 |
104 | 
105 |
106 | -----
107 |
108 | ## Using `SpecDec++`
109 |
110 | **Step 0 (Optional)**: To start with, prepare a conda environment with pytorch installed. If not, you can use the following command.
111 |
112 | ```
113 | conda create -n specdecpp python=3.11
114 | conda activate specdecpp
115 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
116 | ```
117 |
118 | **Step 1**: Clone the repository and install the required packages.
119 |
120 | ```
121 | git clone git@github.com:Kaffaljidhmah2/SpecDec_pp.git
122 | cd SpecDec_pp
123 | pip install -r requirements.txt
124 | ```
125 |
126 |
127 | ### Checkpoint Release & Sampling Code
128 |
129 | The checkpoint of our best acceptance prediction head for llama-2-chat 7B & 70B model pair is available at [huggingface hub](https://huggingface.co/hacky/acchead-llama2-chat-7bx70b).
130 |
131 | Please take a look at [specdec_pp/sample.py](specdec_pp/sample.py) for how to use SpecDec++.
132 |
133 | ----
134 |
135 |
136 | ## Training and Evaluation
137 |
138 | ### Dataset Preparation
139 |
140 | Follow the instructions in [data/readme.md](./data/readme.md) for dataset preparation. After running the code, you should be able to get the Alpaca dataset (`data/alpaca_data/train.json`, `data/alpaca_data/dev.json`, `data/alpaca_data/test.json`), HumanEval dataset (`data/humaneval_data/test.json`), and GSM8K test dataset (`data/gsm8k_test_data/test.json`) for llama-2-chat models.
141 |
142 | ### Training the Acceptance Prediction Heads.
143 |
144 |
145 | Please modify the following code for training. Here `layer` indicates the number of layers of the ResNet prediction head, `weight` is the loss weight for the mismatched tokens for the BCE loss (the weight for the matched tokens is `1`). The mixing ratio can be set via `--mixing_ratio` (default is 0.15).
146 |
147 | ```bash
148 | layer=3
149 | weight=6
150 | draft_model=meta-llama/Llama-2-7b-chat-hf
151 |
152 | WANDB_PROJECT=specdecpp python3 specdec_pp/train.py \
153 | --data_path data/alpaca_data/train.json \
154 | --eval_data_path data/alpaca_data/dev.json \
155 | --output_dir exp-weight${weight}-layer${layer} \
156 | --model_name_or_path ${draft_model} \
157 | --bf16 True \
158 | --per_device_train_batch_size 4 \
159 | --num_train_epochs 3 \
160 | --gradient_accumulation_steps 8 \
161 | --logging_steps 5 \
162 | --evaluation_strategy epoch \
163 | --per_device_eval_batch_size 4 \
164 | --weight_mismatch ${weight} \
165 | --save_strategy no \
166 | --warmup_ratio 0.03 \
167 | --lr_scheduler_type cosine \
168 | --resnet_num_layers ${layer} \
169 | --mixing_ratio 0.15
170 | ```
171 |
172 | ### Benchmarking Performances.
173 |
174 | #### To benchmark the performance of SpecDec ++, modify and run the following command.
175 |
176 | Note: `--num_assistant_tokens_schedule ada` indicates the proposed SpecDec++ method, where the ckeckpoint of the acceptance prediction head should be specified via `--assist_acc_head_dir`. `--stop_threshold` indicates the threshold value (between 0 and 1) used to stop the current speculation round. A larger `stop_threshold` indicates longer speculation rounds. `--bound MIN MAX` indicates the minimum number and the maximum number of candidate tokens for one speculation round.
177 |
178 | ```bash
179 | layer=3
180 | weight=6
181 | thres=0.3
182 |
183 | ckpt=exp-weight${weight}-layer${layer}
184 |
185 | target_model=meta-llama/Llama-2-70b-chat-hf
186 | draft_model=meta-llama/Llama-2-7b-chat-hf
187 | data=data/alpaca_data/test.json
188 | SAVEPATH=test-results-alpaca/weight${weight}-layer${layer}-thres${thres}-bound2_20/
189 |
190 | python3 specdec_pp/evaluate.py \
191 | --model_name ${target_model} \
192 | --assistant_name ${draft_model} \
193 | --num_assistant_tokens_schedule ada \
194 | --data_path ${data} \
195 | --assist_acc_head_dir $ckpt\
196 | --do_sample \
197 | --random_seed 42 \
198 | --save_path ${SAVEPATH} \
199 | --stop_threshold ${thres} \
200 | --bound 2 20
201 | ```
202 |
203 | The result will be stored under the folder `${SAVEPATH}`.
204 |
205 |
206 |
207 | #### To benchmark the performance of SpecDec, modify and run the following command.
208 |
209 | Note: `--num_assistant_tokens_schedule constant` indicates the baseline SpecDec method. `--num_assistant_tokens` means the constant number of candidate tokens generated per speculation round.
210 |
211 | ```bash
212 | target_model=meta-llama/Llama-2-70b-chat-hf
213 | draft_model=meta-llama/Llama-2-7b-chat-hf
214 | K=4
215 | data=data/alpaca_data/test.json
216 | SAVEPATH=test-results-alpaca/baseline-${K}/
217 |
218 | python3 specdec_pp/evaluate.py \
219 | --model_name ${target_model} \
220 | --assistant_name ${draft_model} \
221 | --num_assistant_tokens_schedule constant \
222 | --num_assistant_tokens ${K} \
223 | --data_path ${data} \
224 | --do_sample \
225 | --random_seed 42 \
226 | --save_path ${SAVEPATH} \
227 | ```
228 |
229 | #### To benchmark the performance without speculative decoding, modify and run the following command.
230 |
231 | Note: `--num_assistant_tokens_schedule none` indicates the baseline SpecDec method.
232 |
233 | ```bash
234 | target_model=meta-llama/Llama-2-70b-chat-hf
235 | draft_model=meta-llama/Llama-2-7b-chat-hf
236 | data=data/alpaca_data/test.json
237 | SAVEPATH=test-results-alpaca/standalone/
238 |
239 | python3 specdec_pp/evaluate.py \
240 | --model_name ${target_model} \
241 | --assistant_name ${draft_model} \
242 | --num_assistant_tokens_schedule none \
243 | --data_path ${data} \
244 | --do_sample \
245 | --random_seed 42 \
246 | --save_path ${SAVEPATH} \
247 | ```
248 |
249 |
250 | #### Sample results
251 |
252 | ```
253 | [
254 | {
255 | ## key-value pairs for prompt, continuation, prefix, tokens, draft, p_acc, and id
256 |
257 | ## for SpecDec & SpecDec++
258 | "spec_time": 15.580421447753906,
259 | "num_mismatched_tokens": 20,
260 | "num_LM_call": 67,
261 | "generated_length": 180,
262 | ## for standalone target model / draft model
263 | "target_time": 25.6504251956939,
264 | "draft_time": 2.795105218887329,
265 | "generated_length_target": 203,
266 | "generated_length_draft": 134
267 | }
268 | ]
269 | ```
270 |
271 | ------
272 |
273 |
274 | ### Bugs or Questions
275 |
276 | Feel free to send an email to `kaixuanh@princeton.edu` or create a GitHub Issue/Pull request.
277 |
278 |
279 | ### Citation
280 |
281 | If you find this useful in your research, please consider citing our paper.
282 |
283 | ```bibtex
284 | @article{huang2024specdec++,
285 | title={SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths},
286 | author={Huang, Kaixuan and Guo, Xudong and Wang, Mengdi},
287 | journal={arXiv preprint arXiv:2405.19715},
288 | year={2024}
289 | }
290 | ```
291 |
--------------------------------------------------------------------------------
/specdec_pp/hf_generation.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # modified from https://raw.githubusercontent.com/huggingface/transformers/v4.34.1/src/transformers/generation/utils.py
18 |
19 | import copy
20 | import inspect
21 | import warnings
22 | from dataclasses import dataclass
23 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
24 |
25 | import torch
26 | import torch.distributed as dist
27 | from torch import nn
28 |
29 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
30 | from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
31 | from transformers.models.auto import (
32 | MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
33 | MODEL_FOR_CAUSAL_LM_MAPPING,
34 | MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
35 | MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
36 | MODEL_FOR_VISION_2_SEQ_MAPPING,
37 | )
38 | from transformers.utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
39 | from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint
40 | from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
41 | from transformers.generation.configuration_utils import GenerationConfig
42 | from transformers.generation.logits_process import (
43 | EncoderNoRepeatNGramLogitsProcessor,
44 | EncoderRepetitionPenaltyLogitsProcessor,
45 | EpsilonLogitsWarper,
46 | EtaLogitsWarper,
47 | ExponentialDecayLengthPenalty,
48 | ForcedBOSTokenLogitsProcessor,
49 | ForcedEOSTokenLogitsProcessor,
50 | ForceTokensLogitsProcessor,
51 | HammingDiversityLogitsProcessor,
52 | InfNanRemoveLogitsProcessor,
53 | LogitNormalization,
54 | LogitsProcessorList,
55 | MinLengthLogitsProcessor,
56 | MinNewTokensLengthLogitsProcessor,
57 | NoBadWordsLogitsProcessor,
58 | NoRepeatNGramLogitsProcessor,
59 | PrefixConstrainedLogitsProcessor,
60 | RepetitionPenaltyLogitsProcessor,
61 | SequenceBiasLogitsProcessor,
62 | SuppressTokensAtBeginLogitsProcessor,
63 | SuppressTokensLogitsProcessor,
64 | TemperatureLogitsWarper,
65 | TopKLogitsWarper,
66 | TopPLogitsWarper,
67 | TypicalLogitsWarper,
68 | UnbatchedClassifierFreeGuidanceLogitsProcessor,
69 | )
70 | from transformers.generation.stopping_criteria import (
71 | MaxLengthCriteria,
72 | MaxTimeCriteria,
73 | StoppingCriteria,
74 | StoppingCriteriaList,
75 | validate_stopping_criteria,
76 | )
77 |
78 |
79 | if TYPE_CHECKING:
80 | from transformers.modeling_utils import PreTrainedModel
81 | from transformers.streamers import BaseStreamer
82 |
83 | logger = logging.get_logger(__name__)
84 |
85 | if is_accelerate_available():
86 | from accelerate.hooks import AlignDevicesHook, add_hook_to_module
87 |
88 | from transformers.generation.utils import GenerationMixin, GenerateOutput, GenerationMode, _crop_past_key_values, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput
89 |
90 |
91 | @torch.no_grad()
92 | def my_generate(
93 | model: "PreTrainedModel",
94 | inputs: Optional[torch.Tensor] = None,
95 | generation_config: Optional[GenerationConfig] = None,
96 | logits_processor: Optional[LogitsProcessorList] = None,
97 | stopping_criteria: Optional[StoppingCriteriaList] = None,
98 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
99 | synced_gpus: Optional[bool] = None,
100 | assistant_model: Optional["PreTrainedModel"] = None,
101 | streamer: Optional["BaseStreamer"] = None,
102 | negative_prompt_ids: Optional[torch.Tensor] = None,
103 | negative_prompt_attention_mask: Optional[torch.Tensor] = None,
104 | num_assistant_tokens_schedule: Optional[str] = 'heuristic',
105 | num_assistant_tokens: Optional[int] = None,
106 | oracle_token_num_list: Optional[List[int]] = None,
107 | assist_acc_head: Optional[nn.Module] = None,
108 | stop_threshold: Optional[float] = None,
109 | bound: Optional[List[int]] = None,
110 | **kwargs,
111 | ) -> Union[GenerateOutput, torch.LongTensor]:
112 | r"""
113 |
114 | Generates sequences of token ids for models with a language modeling head.
115 |
116 |
117 |
118 | Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
119 | model's default generation configuration. You can override any `generation_config` by passing the corresponding
120 | parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
121 |
122 | For an overview of generation strategies and code examples, check out the [following
123 | guide](../generation_strategies).
124 |
125 |
126 |
127 | Parameters:
128 | inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
129 | The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
130 | method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
131 | should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
132 | `input_ids`, `input_values`, `input_features`, or `pixel_values`.
133 | generation_config (`~generation.GenerationConfig`, *optional*):
134 | The generation configuration to be used as base parametrization for the generation call. `**kwargs`
135 | passed to generate matching the attributes of `generation_config` will override them. If
136 | `generation_config` is not provided, the default will be used, which had the following loading
137 | priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
138 | configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
139 | default values, whose documentation should be checked to parameterize generation.
140 | logits_processor (`LogitsProcessorList`, *optional*):
141 | Custom logits processors that complement the default logits processors built from arguments and
142 | generation config. If a logit processor is passed that is already created with the arguments or a
143 | generation config an error is thrown. This feature is intended for advanced users.
144 | stopping_criteria (`StoppingCriteriaList`, *optional*):
145 | Custom stopping criteria that complement the default stopping criteria built from arguments and a
146 | generation config. If a stopping criteria is passed that is already created with the arguments or a
147 | generation config an error is thrown. This feature is intended for advanced users.
148 | prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
149 | If provided, this function constraints the beam search to allowed tokens only at each step. If not
150 | provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
151 | `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
152 | on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
153 | for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
154 | Retrieval](https://arxiv.org/abs/2010.00904).
155 | synced_gpus (`bool`, *optional*):
156 | Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
157 | `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
158 | generating before other GPUs. Otherwise it'll be set to `False`.
159 | assistant_model (`PreTrainedModel`, *optional*):
160 | An assistant model that can be used to accelerate generation. The assistant model must have the exact
161 | same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
162 | is much faster than running generation with the model you're calling generate from. As such, the
163 | assistant model should be much smaller.
164 | streamer (`BaseStreamer`, *optional*):
165 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed
166 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
167 | negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
168 | The negative prompt needed for some processors such as CFG. The batch size must match the input batch
169 | size. This is an experimental feature, subject to breaking API changes in future versions.
170 | negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171 | Attention_mask for `negative_prompt_ids`.
172 | kwargs (`Dict[str, Any]`, *optional*):
173 | Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
174 | forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
175 | specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
176 |
177 | Return:
178 | [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
179 | or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
180 |
181 | If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
182 | [`~utils.ModelOutput`] types are:
183 |
184 | - [`~generation.GreedySearchDecoderOnlyOutput`],
185 | - [`~generation.SampleDecoderOnlyOutput`],
186 | - [`~generation.BeamSearchDecoderOnlyOutput`],
187 | - [`~generation.BeamSampleDecoderOnlyOutput`]
188 |
189 | If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
190 | [`~utils.ModelOutput`] types are:
191 |
192 | - [`~generation.GreedySearchEncoderDecoderOutput`],
193 | - [`~generation.SampleEncoderDecoderOutput`],
194 | - [`~generation.BeamSearchEncoderDecoderOutput`],
195 | - [`~generation.BeamSampleEncoderDecoderOutput`]
196 | """
197 |
198 | if synced_gpus is None:
199 | if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
200 | synced_gpus = True
201 | else:
202 | synced_gpus = False
203 |
204 | # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
205 | model._validate_model_class()
206 |
207 | # priority: `generation_config` argument > `model.generation_config` (the default generation config)
208 | if generation_config is None:
209 | # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
210 | # two conditions must be met
211 | # 1) the generation config must have been created from the model config (`_from_model_config` field);
212 | # 2) the generation config must have seen no modification since its creation (the hash is the same).
213 | if model.generation_config._from_model_config and model.generation_config._original_object_hash == hash(
214 | model.generation_config
215 | ):
216 | new_generation_config = GenerationConfig.from_model_config(model.config)
217 | if new_generation_config != model.generation_config:
218 | warnings.warn(
219 | "You have modified the pretrained model configuration to control generation. This is a"
220 | " deprecated strategy to control generation and will be removed soon, in a future version."
221 | " Please use and modify the model generation configuration (see"
222 | " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
223 | )
224 | model.generation_config = new_generation_config
225 | generation_config = model.generation_config
226 |
227 | generation_config = copy.deepcopy(generation_config)
228 | model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
229 | generation_config.validate()
230 | model._validate_model_kwargs(model_kwargs.copy())
231 |
232 | # 2. Set generation parameters if not already defined
233 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
234 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
235 |
236 | if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
237 | if model_kwargs.get("attention_mask", None) is None:
238 | logger.warning(
239 | "The attention mask and the pad token id were not set. As a consequence, you may observe "
240 | "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
241 | )
242 | eos_token_id = generation_config.eos_token_id
243 | if isinstance(eos_token_id, list):
244 | eos_token_id = eos_token_id[0]
245 | logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
246 | generation_config.pad_token_id = eos_token_id
247 |
248 | # 3. Define model inputs
249 | # inputs_tensor has to be defined
250 | # model_input_name is defined if model-specific keyword input is passed
251 | # otherwise model_input_name is None
252 | # all model-specific keyword inputs are removed from `model_kwargs`
253 | inputs_tensor, model_input_name, model_kwargs = model._prepare_model_inputs(
254 | inputs, generation_config.bos_token_id, model_kwargs
255 | )
256 | batch_size = inputs_tensor.shape[0]
257 |
258 | # 4. Define other model kwargs
259 | model_kwargs["output_attentions"] = generation_config.output_attentions
260 | model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
261 | # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
262 | # generating the first new token or not, and we only want to use the embeddings for the first new token)
263 | if not model.config.is_encoder_decoder and model_input_name == "inputs_embeds":
264 | model_kwargs["use_cache"] = True
265 | else:
266 | model_kwargs["use_cache"] = generation_config.use_cache
267 |
268 | accepts_attention_mask = "attention_mask" in set(inspect.signature(model.forward).parameters.keys())
269 | requires_attention_mask = "encoder_outputs" not in model_kwargs
270 |
271 | if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
272 | model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation(
273 | inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
274 | )
275 |
276 | # decoder-only models should use left-padding for generation
277 | if not model.config.is_encoder_decoder:
278 | # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
279 | # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
280 | if (
281 | generation_config.pad_token_id is not None
282 | and len(inputs_tensor.shape) == 2
283 | and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
284 | ):
285 | logger.warning(
286 | "A decoder-only architecture is being used, but right-padding was detected! For correct "
287 | "generation results, please set `padding_side='left'` when initializing the tokenizer."
288 | )
289 |
290 | if model.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
291 | # if model is encoder decoder encoder_outputs are created
292 | # and added to `model_kwargs`
293 | model_kwargs = model._prepare_encoder_decoder_kwargs_for_generation(
294 | inputs_tensor, model_kwargs, model_input_name
295 | )
296 |
297 | # 5. Prepare `input_ids` which will be used for auto-regressive generation
298 | if model.config.is_encoder_decoder:
299 | input_ids, model_kwargs = model._prepare_decoder_input_ids_for_generation(
300 | batch_size=batch_size,
301 | model_input_name=model_input_name,
302 | model_kwargs=model_kwargs,
303 | decoder_start_token_id=generation_config.decoder_start_token_id,
304 | bos_token_id=generation_config.bos_token_id,
305 | device=inputs_tensor.device,
306 | )
307 | else:
308 | input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
309 |
310 | if streamer is not None:
311 | streamer.put(input_ids.cpu())
312 |
313 | # 6. Prepare `max_length` depending on other stopping criteria.
314 | input_ids_length = input_ids.shape[-1]
315 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
316 | if generation_config.max_new_tokens is not None:
317 | if not has_default_max_length and generation_config.max_length is not None:
318 | logger.warning(
319 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
320 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
321 | "Please refer to the documentation for more information. "
322 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
323 | )
324 | generation_config.max_length = generation_config.max_new_tokens + input_ids_length
325 | model._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
326 |
327 | # 7. determine generation mode
328 | generation_mode = model._get_generation_mode(generation_config, assistant_model)
329 |
330 | if streamer is not None and (generation_config.num_beams > 1):
331 | raise ValueError(
332 | "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
333 | )
334 |
335 | if model.device.type != input_ids.device.type:
336 | warnings.warn(
337 | "You are calling .generate() with the `input_ids` being on a device type different"
338 | f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
339 | f" is on {model.device.type}. You may experience unexpected behaviors or slower generation."
340 | " Please make sure that you have put `input_ids` to the"
341 | f" correct device by calling for example input_ids = input_ids.to('{model.device.type}') before"
342 | " running `.generate()`.",
343 | UserWarning,
344 | )
345 |
346 | # 8. prepare distribution pre_processing samplers
347 | logits_processor = model._get_logits_processor(
348 | generation_config=generation_config,
349 | input_ids_seq_length=input_ids_length,
350 | encoder_input_ids=inputs_tensor,
351 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
352 | logits_processor=logits_processor,
353 | model_kwargs=model_kwargs,
354 | negative_prompt_ids=negative_prompt_ids,
355 | negative_prompt_attention_mask=negative_prompt_attention_mask,
356 | )
357 |
358 | # 9. prepare stopping criteria
359 | stopping_criteria = model._get_stopping_criteria(
360 | generation_config=generation_config, stopping_criteria=stopping_criteria
361 | )
362 | # 10. go into different generation modes
363 | if generation_mode == GenerationMode.ASSISTED_GENERATION:
364 | if generation_config.num_return_sequences > 1:
365 | raise ValueError(
366 | "num_return_sequences has to be 1 when doing assisted generate, "
367 | f"but is {generation_config.num_return_sequences}."
368 | )
369 | if batch_size > 1:
370 | raise ValueError("assisted generate is only supported for batch_size = 1")
371 | if not model_kwargs["use_cache"]:
372 | raise ValueError("assisted generate requires `use_cache=True`")
373 |
374 | # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
375 | if assistant_model.config.is_encoder_decoder:
376 | assistant_model_kwargs = copy.deepcopy(model_kwargs)
377 | inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
378 | inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
379 | )
380 | assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
381 | inputs_tensor, assistant_model_kwargs, model_input_name
382 | )
383 | model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
384 |
385 | # 12. run assisted generate
386 | return my_assisted_decoding(
387 | model,
388 | input_ids,
389 | assistant_model=assistant_model,
390 | do_sample=generation_config.do_sample,
391 | logits_processor=logits_processor,
392 | logits_warper=model._get_logits_warper(generation_config) if generation_config.do_sample else None,
393 | stopping_criteria=stopping_criteria,
394 | pad_token_id=generation_config.pad_token_id,
395 | eos_token_id=generation_config.eos_token_id,
396 | output_scores=generation_config.output_scores,
397 | return_dict_in_generate=generation_config.return_dict_in_generate,
398 | synced_gpus=synced_gpus,
399 | streamer=streamer,
400 | num_assistant_tokens_schedule=num_assistant_tokens_schedule,
401 | num_assistant_tokens=num_assistant_tokens,
402 | oracle_token_num_list=oracle_token_num_list,
403 | assist_acc_head=assist_acc_head,
404 | stop_threshold=stop_threshold,
405 | bound=bound,
406 | **model_kwargs,
407 | )
408 |
409 |
410 | def my_assisted_decoding(
411 | model: "PreTrainedModel",
412 | input_ids: torch.LongTensor,
413 | assistant_model: "PreTrainedModel",
414 | do_sample: bool = False,
415 | logits_processor: Optional[LogitsProcessorList] = None,
416 | logits_warper: Optional[LogitsProcessorList] = None,
417 | stopping_criteria: Optional[StoppingCriteriaList] = None,
418 | pad_token_id: Optional[int] = None,
419 | eos_token_id: Optional[Union[int, List[int]]] = None,
420 | output_attentions: Optional[bool] = None,
421 | output_hidden_states: Optional[bool] = None,
422 | output_scores: Optional[bool] = None,
423 | return_dict_in_generate: Optional[bool] = None,
424 | synced_gpus: bool = False,
425 | streamer: Optional["BaseStreamer"] = None,
426 | num_assistant_tokens_schedule: Optional[str] = 'heuristic',
427 | num_assistant_tokens: Optional[int] = None,
428 | oracle_token_num_list: Optional[List[int]] = None,
429 | assist_acc_head: Optional[nn.Module] = None,
430 | stop_threshold: Optional[float] = None,
431 | bound: Optional[List[int]] = None,
432 | **model_kwargs,
433 | ):
434 | r"""
435 | Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
436 | **sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text,
437 | speech-to-text, and vision-to-text models.
438 |
439 |
440 |
441 | In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use
442 | generate() instead. For an overview of generation strategies and code examples, check the [following
443 | guide](../generation_strategies).
444 |
445 |
446 |
447 | Parameters:
448 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
449 | The sequence used as a prompt for the generation.
450 | assistant_model (`PreTrainedModel`, *optional*):
451 | An assistant model that can be used to accelerate generation. The assistant model must have the exact
452 | same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
453 | is much faster than running generation with the model you're calling generate from. As such, the
454 | assistant model should be much smaller.
455 | do_sample (`bool`, *optional*, defaults to `False`):
456 | Whether or not to use sampling ; use greedy decoding otherwise.
457 | logits_processor (`LogitsProcessorList`, *optional*):
458 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
459 | used to modify the prediction scores of the language modeling head applied at each generation step.
460 | logits_warper (`LogitsProcessorList`, *optional*):
461 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
462 | to warp the prediction score distribution of the language modeling head applied before multinomial
463 | sampling at each generation step.
464 | stopping_criteria (`StoppingCriteriaList`, *optional*):
465 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
466 | used to tell if the generation loop should stop.
467 | pad_token_id (`int`, *optional*):
468 | The id of the *padding* token.
469 | eos_token_id (`Union[int, List[int]]`, *optional*):
470 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
471 | output_attentions (`bool`, *optional*, defaults to `False`):
472 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
473 | returned tensors for more details.
474 | output_hidden_states (`bool`, *optional*, defaults to `False`):
475 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
476 | for more details.
477 | output_scores (`bool`, *optional*, defaults to `False`):
478 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
479 | return_dict_in_generate (`bool`, *optional*, defaults to `False`):
480 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
481 | synced_gpus (`bool`, *optional*, defaults to `False`):
482 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
483 | streamer (`BaseStreamer`, *optional*):
484 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed
485 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
486 | model_kwargs:
487 | Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
488 | If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
489 |
490 | Return:
491 | [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
492 | `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
493 | [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
494 | `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
495 | `model.config.is_encoder_decoder=True`.
496 |
497 | Examples:
498 |
499 | ```python
500 | >>> from transformers import (
501 | ... AutoTokenizer,
502 | ... AutoModelForCausalLM,
503 | ... LogitsProcessorList,
504 | ... MinLengthLogitsProcessor,
505 | ... StoppingCriteriaList,
506 | ... MaxLengthCriteria,
507 | ... )
508 |
509 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
510 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
511 | >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
512 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
513 | >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
514 | >>> input_prompt = "It might be possible to"
515 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
516 | >>> # instantiate logits processors
517 | >>> logits_processor = LogitsProcessorList(
518 | ... [
519 | ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
520 | ... ]
521 | ... )
522 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
523 | >>> outputs = model.assisted_decoding(
524 | ... input_ids,
525 | ... assistant_model=assistant_model,
526 | ... logits_processor=logits_processor,
527 | ... stopping_criteria=stopping_criteria,
528 | ... )
529 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
530 | ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
531 | ```"""
532 | # Assistant: initialize assistant-related variables
533 | if num_assistant_tokens_schedule is None: # default to heuristic
534 | num_assistant_tokens_schedule = 'heuristic'
535 |
536 | if num_assistant_tokens is not None:
537 | assistant_model.max_assistant_tokens = num_assistant_tokens
538 | logger.warning("Setting initial assistant model max_assistant_tokens to %d" % num_assistant_tokens)
539 | else:
540 | if not hasattr(assistant_model, "max_assistant_tokens") or assistant_model.max_assistant_tokens is None:
541 | assistant_model.max_assistant_tokens = 5 # default to 5
542 | # this value, which will be updated if heuristic num_assistant_tokens_schedule is applied, persists across calls
543 |
544 | # init values
545 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
546 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
547 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
548 | pad_token_id = pad_token_id if pad_token_id is not None else model.generation_config.pad_token_id
549 | eos_token_id = eos_token_id if eos_token_id is not None else model.generation_config.eos_token_id
550 | if eos_token_id is not None and pad_token_id is None:
551 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
552 | if isinstance(eos_token_id, int):
553 | eos_token_id = [eos_token_id]
554 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
555 | output_scores = output_scores if output_scores is not None else model.generation_config.output_scores
556 | output_attentions = (
557 | output_attentions if output_attentions is not None else model.generation_config.output_attentions
558 | )
559 | output_hidden_states = (
560 | output_hidden_states if output_hidden_states is not None else model.generation_config.output_hidden_states
561 | )
562 | return_dict_in_generate = (
563 | return_dict_in_generate
564 | if return_dict_in_generate is not None
565 | else model.generation_config.return_dict_in_generate
566 | )
567 |
568 | # init attention / hidden states / scores tuples
569 | scores = () if (return_dict_in_generate and output_scores) else None
570 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
571 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None
572 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
573 |
574 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
575 | if return_dict_in_generate and model.config.is_encoder_decoder:
576 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
577 | encoder_hidden_states = (
578 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
579 | )
580 |
581 | # keep track of which sequences are already finished
582 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
583 |
584 | # other auxiliary variables
585 | max_len = stopping_criteria[0].max_length
586 | assistant_kv_indexing = (
587 | 1
588 | if "bloom" in assistant_model.__class__.__name__.lower()
589 | or (
590 | assistant_model.config.architectures is not None
591 | and "bloom" in assistant_model.config.architectures[0].lower()
592 | )
593 | else 0
594 | )
595 |
596 | this_peer_finished = False # used by synced_gpus only
597 | num_mismatched_tokens = 0
598 | assist_rounds = 0
599 |
600 | while True:
601 | if synced_gpus:
602 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
603 | # The following logic allows an early break if all peers finished generating their sequence
604 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
605 | # send 0.0 if we finished, 1.0 otherwise
606 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
607 | # did all peers finish? the reduced sum will be 0.0 then
608 | if this_peer_finished_flag.item() == 0.0:
609 | break
610 |
611 | # Assistant: main logic start
612 | cur_len = input_ids.shape[-1]
613 |
614 | # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
615 | # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
616 | # need access to the assistant cache to secure strong speedups.
617 | candidate_input_ids = input_ids
618 | q_prob = []
619 | assist_steps = 0
620 |
621 | cum_acc_prob = 1. # used for 'ada' schedule
622 |
623 | while True:
624 | # for _ in range(int(assistant_model.max_assistant_tokens)):
625 | if num_assistant_tokens_schedule != 'ada' and assist_steps >= assistant_model.max_assistant_tokens:
626 | break
627 | # 1.1. use the assistant model to obtain the next candidate logits
628 | if "assistant_past_key_values" in model_kwargs:
629 | prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
630 | # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
631 | new_token_len = candidate_input_ids.shape[1] - prev_seq_len
632 | assert new_token_len > 0, 'might have bug!'
633 | assist_inputs = candidate_input_ids[:, -new_token_len:]
634 | assist_attn = torch.ones_like(candidate_input_ids)
635 | # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
636 | if assistant_model.config.is_encoder_decoder:
637 | assistant_model_outputs = assistant_model(
638 | decoder_input_ids=assist_inputs,
639 | decoder_attention_mask=assist_attn,
640 | past_key_values=model_kwargs["assistant_past_key_values"],
641 | encoder_outputs=model_kwargs["assistant_encoder_outputs"],
642 | output_hidden_states = True if num_assistant_tokens_schedule == 'ada' else False,
643 | )
644 | else:
645 | assistant_model_outputs = assistant_model(
646 | assist_inputs,
647 | attention_mask=assist_attn,
648 | past_key_values=model_kwargs["assistant_past_key_values"],
649 | output_hidden_states = True if num_assistant_tokens_schedule == 'ada' else False,
650 | )
651 | else:
652 | if assistant_model.config.is_encoder_decoder:
653 | assistant_model_outputs = assistant_model(
654 | decoder_input_ids=candidate_input_ids,
655 | encoder_outputs=model_kwargs["assistant_encoder_outputs"],
656 | output_hidden_states = True if num_assistant_tokens_schedule == 'ada' else False,
657 | )
658 | else:
659 | assistant_model_outputs = assistant_model(candidate_input_ids,
660 | output_hidden_states = True if num_assistant_tokens_schedule == 'ada' else False,
661 | )
662 |
663 |
664 | model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
665 | if len(logits_processor) > 0:
666 | assistant_model_outputs.logits[:, -1, :] = logits_processor(
667 | candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
668 | )
669 | if len(logits_warper) > 0:
670 | assistant_model_outputs.logits[:, -1, :] = logits_warper(
671 | candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
672 | )
673 |
674 | # 1.2. greedily select the next candidate token; or do speculative decoding.
675 | if do_sample:
676 | probs = assistant_model_outputs.logits[:, -1, :].softmax(dim=-1) # bs * vocab_size
677 | new_token = torch.multinomial(probs[0, :], num_samples=1)
678 | q_prob.append(probs)
679 | else:
680 | new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
681 |
682 | candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
683 |
684 | # 1.3. stop assistant generation on EOS
685 | if eos_token_id_tensor is not None:
686 | last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
687 | last_assistant_token_is_eos = (
688 | ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
689 | )
690 | if last_assistant_token_is_eos:
691 | break
692 | else:
693 | last_assistant_token_is_eos = False
694 |
695 | # 1.4. stop assistant generation when the max length is reached or when the stop head predcits stop
696 | assist_steps += 1
697 |
698 | # bound = (min, max)
699 | if num_assistant_tokens_schedule == 'ada':
700 | assert assist_acc_head is not None, 'assist_acc_head is None!'
701 |
702 |
703 | ### obtain current acceptance probability with assist_acc_head
704 | hidden_states = assistant_model_outputs.get("hidden_states") # hidden_states[-1] is the last hidden states, size: bs * seq_len * hidden_dim
705 | logits = assist_acc_head(hidden_states[-1][0, -1].float())
706 |
707 | if stop_threshold is None:
708 | logger.warning("[Deprecated] Stop_threshold not set. using the acceptance of current token instead.")
709 | predicted = logits.argmax(dim = -1)
710 | stop_prediction = (predicted == 0)
711 | else:
712 |
713 | ## stop generation when the estimated P(exists one reject) = 1 - P(all proposed tokens are accepted) exceeds threshold.
714 |
715 | if assist_steps == 1:
716 | acc_prob = 1 # skip the first round as all tokens are verified and there are no proposed tokens.
717 | else:
718 | acc_prob = logits.softmax(dim = -1)[1].item()
719 | cum_acc_prob *= acc_prob
720 | rej_prob = 1 - cum_acc_prob
721 |
722 | stop_prediction = (rej_prob > stop_threshold)
723 |
724 | # bound = (min, max): forces the generated tokens to be inside [min, max] (both boundaries are included)
725 | if bound is not None:
726 | if assist_steps >= bound[1]:
727 | is_stop = True
728 | elif assist_steps < bound[0]:
729 | is_stop = False
730 | else:
731 | is_stop = stop_prediction
732 | else:
733 | is_stop = stop_prediction
734 |
735 |
736 | if is_stop:
737 | break
738 |
739 | candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
740 |
741 | if candidate_length == 0:
742 | last_assistant_token_is_eos = False
743 |
744 | # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
745 | # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
746 | # we use this forward pass to also pick the subsequent logits in the original model.
747 |
748 | # 2.1. Run a forward pass on the candidate sequence
749 | if "past_key_values" in model_kwargs:
750 | model_attn = torch.ones_like(candidate_input_ids)
751 | model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
752 | if model.config.is_encoder_decoder:
753 | outputs = model(
754 | decoder_input_ids=model_input_ids,
755 | decoder_attention_mask=model_attn,
756 | past_key_values=model_kwargs["past_key_values"],
757 | encoder_outputs=model_kwargs["encoder_outputs"],
758 | output_attentions=output_attentions,
759 | output_hidden_states=output_hidden_states,
760 | use_cache=True,
761 | )
762 | else:
763 | outputs = model(
764 | model_input_ids,
765 | attention_mask=model_attn,
766 | past_key_values=model_kwargs["past_key_values"],
767 | output_attentions=output_attentions,
768 | output_hidden_states=output_hidden_states,
769 | use_cache=True,
770 | )
771 | else:
772 | if model.config.is_encoder_decoder:
773 | outputs = model(
774 | decoder_input_ids=candidate_input_ids,
775 | encoder_outputs=model_kwargs["encoder_outputs"],
776 | output_attentions=output_attentions,
777 | output_hidden_states=output_hidden_states,
778 | use_cache=True,
779 | )
780 | else:
781 | outputs = model(
782 | candidate_input_ids,
783 | output_attentions=output_attentions,
784 | output_hidden_states=output_hidden_states,
785 | use_cache=True,
786 | )
787 |
788 | # 2.2. Process the new logits
789 | new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
790 | if len(logits_processor) > 0:
791 | for i in range(candidate_length + 1):
792 | new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
793 | if len(logits_warper) > 0:
794 | for i in range(candidate_length + 1):
795 | new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
796 |
797 | # 3. Obtain the next tokens from the original model logits.
798 | if do_sample:
799 | # speculative decoding logit here.
800 |
801 | probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) # bs(1) * (candidate_length+1) * vocab_size
802 | p_prob, next_p = probs[:, :-1], probs[:, -1]
803 |
804 | if candidate_length == 0:
805 | n_matches = 0
806 | else:
807 | q_prob = torch.stack(q_prob, dim=1) # bs(1) * candidate_length * vocab_size
808 |
809 | candidate_index = candidate_input_ids[:, -candidate_length:, None]
810 |
811 | q_candidate = q_prob.gather(-1, candidate_index).squeeze(-1)
812 | p_candidate = p_prob.gather(-1, candidate_index).squeeze(-1)
813 | r_candidate = torch.rand_like(q_candidate, device = q_candidate.device)
814 | n_matches = ((r_candidate > (p_candidate/q_candidate)).cumsum(dim=-1) < 1).sum()
815 |
816 | else:
817 |
818 | # greedy decoding logic.
819 | selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
820 |
821 | # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
822 | # the assistant forecasted tokens until the first mismatch, or until the max length is reached.
823 | if candidate_length == 0:
824 | n_matches = 0
825 | else:
826 | candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
827 | n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
828 |
829 |
830 |
831 |
832 |
833 |
834 | # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
835 | # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
836 | # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
837 | # is no match.
838 |
839 | num_mismatched_tokens += (candidate_length - n_matches)
840 |
841 |
842 | if do_sample:
843 | # for speculative decoding
844 | candidate_new_tokens = candidate_input_ids[:, (candidate_input_ids.shape[1] - candidate_length):]
845 | valid_tokens = candidate_new_tokens[:, : n_matches]
846 |
847 | # for the last token
848 | ## case 1. rejected some token. resample from [p-q]+, index = n_matches
849 | if n_matches < candidate_length:
850 | next_p = p_prob[:, n_matches, :] - q_prob[:, n_matches, :] # bs(1) * vocab_size
851 | next_p.clamp_(min=0.)
852 | next_p = next_p / next_p.sum(dim = -1, keepdim=True)
853 |
854 | ## case 2. all tokens accepted: sample from next_p (defined before)
855 |
856 | new_added_token = torch.multinomial(next_p, num_samples=1)
857 | valid_tokens = torch.cat((valid_tokens, new_added_token), dim=-1)
858 |
859 | else:
860 | # for greedy decoding
861 | # 5.2. Get the valid continuation, after the matching tokens
862 | valid_tokens = selected_tokens[:, : n_matches + 1]
863 |
864 | # 5.1. Ensure we don't generate beyond max_len or an EOS token
865 | if last_assistant_token_is_eos and n_matches == candidate_length:
866 | n_matches -= 1
867 | n_matches = min(n_matches, max_len - cur_len - 1)
868 | valid_tokens = valid_tokens[:, : n_matches + 1]
869 |
870 | input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
871 | if streamer is not None:
872 | streamer.put(valid_tokens.cpu())
873 | new_cur_len = input_ids.shape[-1]
874 |
875 |
876 | # 5.3. Discard past key values relative to unused assistant tokens
877 | new_cache_size = new_cur_len - 1
878 | outputs.past_key_values = _crop_past_key_values(model, outputs.past_key_values, new_cache_size)
879 | if "assistant_past_key_values" in model_kwargs:
880 | model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
881 | assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
882 | ) # the assistant does not have the token after the last match, hence the -1
883 |
884 | # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
885 | # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
886 | # cost of forecasting incorrect assistant tokens.
887 |
888 | assist_rounds += 1
889 | if num_assistant_tokens_schedule == 'heuristic':
890 | if n_matches == int(assistant_model.max_assistant_tokens):
891 | assistant_model.max_assistant_tokens += 2.0
892 | else:
893 | assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
894 | elif num_assistant_tokens_schedule == 'oracle':
895 | if assist_rounds < len(oracle_token_num_list):
896 | assistant_model.max_assistant_tokens = oracle_token_num_list[assist_rounds]
897 | else:
898 | logger.warning("warning. assist_rounds exceed len(oracle_token_num_list)")
899 | # print("oracle token num: %d" % assistant_model.max_assistant_tokens)
900 |
901 |
902 | # Assistant: main logic end
903 |
904 | if synced_gpus and this_peer_finished:
905 | continue # don't waste resources running the code we don't need
906 |
907 | # Store scores, attentions and hidden_states when required
908 | # Assistant: modified to append one tuple element per token, as in the other generation methods.
909 | if return_dict_in_generate:
910 | if output_scores:
911 | scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
912 |
913 | if "past_key_values" not in model_kwargs:
914 | added_len = new_cur_len
915 | else:
916 | added_len = n_matches + 1
917 |
918 | if output_attentions:
919 | if model.config.is_encoder_decoder:
920 | cross_attentions = _split_model_outputs(
921 | cross_attentions, outputs.cross_attentions, cur_len, added_len
922 | )
923 | decoder_attentions = _split_model_outputs(
924 | decoder_attentions,
925 | outputs.decoder_attentions,
926 | cur_len,
927 | added_len,
928 | is_decoder_attention=True,
929 | )
930 | else:
931 | decoder_attentions = _split_model_outputs(
932 | decoder_attentions,
933 | outputs.attentions,
934 | cur_len,
935 | added_len,
936 | is_decoder_attention=True,
937 | )
938 | if output_hidden_states:
939 | if model.config.is_encoder_decoder:
940 | decoder_hidden_states = _split_model_outputs(
941 | decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
942 | )
943 | else:
944 | decoder_hidden_states = _split_model_outputs(
945 | decoder_hidden_states, outputs.hidden_states, cur_len, added_len
946 | )
947 |
948 | model_kwargs = model._update_model_kwargs_for_generation(
949 | outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
950 | )
951 |
952 | # if eos_token was found in one sentence, set sentence to finished
953 | if eos_token_id_tensor is not None:
954 | unfinished_sequences = unfinished_sequences.mul(
955 | input_ids[:, -1]
956 | .tile(eos_token_id_tensor.shape[0], 1)
957 | .ne(eos_token_id_tensor.unsqueeze(1))
958 | .prod(dim=0)
959 | )
960 |
961 | # stop when each sentence is finished
962 | if unfinished_sequences.max() == 0:
963 | this_peer_finished = True
964 |
965 | # stop if we exceed the maximum length
966 | if stopping_criteria(input_ids, scores):
967 | this_peer_finished = True
968 |
969 | if this_peer_finished and not synced_gpus:
970 | break
971 |
972 | if streamer is not None:
973 | streamer.end()
974 |
975 | if return_dict_in_generate:
976 | if model.config.is_encoder_decoder:
977 | return GreedySearchEncoderDecoderOutput(
978 | sequences=input_ids,
979 | scores=scores,
980 | encoder_attentions=encoder_attentions,
981 | encoder_hidden_states=encoder_hidden_states,
982 | decoder_attentions=decoder_attentions,
983 | cross_attentions=cross_attentions,
984 | decoder_hidden_states=decoder_hidden_states,
985 | )
986 | else:
987 | return GreedySearchDecoderOnlyOutput(
988 | sequences=input_ids,
989 | scores=scores,
990 | attentions=decoder_attentions,
991 | hidden_states=decoder_hidden_states,
992 | )
993 | else:
994 | return input_ids, num_mismatched_tokens.item(), assist_rounds
995 |
996 |
--------------------------------------------------------------------------------