├── benchmarks ├── __init__.py ├── fineinfer │ ├── __init__.py │ ├── README.md │ ├── fi-gen.py │ ├── baseline-ht.py │ └── fi-ht.py ├── huggingface │ ├── __init__.py │ ├── README.md │ ├── hf-gen.py │ ├── hf-peft-gen.py │ └── hf-peft.py └── utils.py ├── fineinfer ├── __init__.py └── engine │ ├── __init__.py │ └── llm_engine.py ├── requirements.txt ├── LICENSE └── README.md /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fineinfer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fineinfer/engine/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/fineinfer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu121 2 | torch == 2.3.0 3 | torchaudio == 2.3.0 4 | torchvision == 0.18.0 5 | transformers == 4.41.0 6 | peft == 0.11.0 7 | datasets == 2.19.0 8 | fastapi == 0.111.0 9 | -------------------------------------------------------------------------------- /benchmarks/huggingface/README.md: -------------------------------------------------------------------------------- 1 | ``` 2 | conda create -n FineInfer python=3.12 3 | conda activate FineInfer 4 | pip install -r requirements.txt 5 | ``` 6 | 7 | HuggingFace 8 | ``` 9 | CUDA_VISIBLE_DEVICES=0 python hf-gen.py -m meta-llama/Meta-Llama-3-8B --batch_size 1 10 | CUDA_VISIBLE_DEVICES=0 python hf-peft-gen.py -m meta-llama/Meta-Llama-3-8B --batch_size 1 11 | CUDA_VISIBLE_DEVICES=0 python hf-peft.py -m meta-llama/Meta-Llama-3-8B --batch_size 1 12 | ``` 13 | -------------------------------------------------------------------------------- /benchmarks/fineinfer/README.md: -------------------------------------------------------------------------------- 1 | ``` 2 | conda create -n FineInfer python=3.12 3 | conda activate FineInfer 4 | pip install -r requirements.txt 5 | ``` 6 | 7 | FineInfer-inference 8 | ``` 9 | CUDA_VISIBLE_DEVICES=0 python fi-gen.py -m meta-llama/Meta-Llama-3-8B --batch_size 1 10 | ``` 11 | 12 | FineInfer-heterogeneous 13 | ``` 14 | CUDA_VISIBLE_DEVICES=0 python baseline-ht.py -m meta-llama/Meta-Llama-3-8B --batch_size 1 15 | CUDA_VISIBLE_DEVICES=0 python fi-ht.py -m meta-llama/Meta-Llama-3-8B --batch_size 1 16 | ``` 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 FineInfer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | FineInfer 3 |

4 | 5 |

6 | | Paper | 7 |

8 | 9 | FineInfer is a research prototype for fine-tuning and serving large language models. 10 | 11 | FineInfer supports concurrent parameter-efficient fine-tuning and inference through the following features: 12 | * Deferred continuous batching 13 | * Hybrid system architecture 14 | * Heterogeneous batching 15 | 16 | ## Get Started 17 | [Installation and examples](https://github.com/llm-db/FineInfer/tree/main/benchmarks/fineinfer) 18 | 19 | The current version removes some previous features and functionalities. If you need them, please download [previous versions](https://github.com/llm-db/FineInfer/releases). 20 | 21 | ## Citation 22 | ``` 23 | @inproceedings{FineInfer, 24 | author = {He, Yongjun and Lu, Yao and Alonso, Gustavo}, 25 | title = {Deferred Continuous Batching in Resource-Efficient Large Language Model Serving}, 26 | year = {2024}, 27 | booktitle = {Proceedings of the 4th Workshop on Machine Learning and Systems}, 28 | pages = {98–106}, 29 | series = {EuroMLSys '24} 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import time 4 | 5 | KB = 1 << 10 6 | MB = 1 << 20 7 | GB = 1 << 30 8 | TB = 1e12 9 | 10 | 11 | def get_ht_workloads( 12 | total_latency: float, 13 | req_gap_bound: float, 14 | ): 15 | random.seed(2024) 16 | 17 | workloads = [0] 18 | while True: 19 | req_gap = random.uniform(0, req_gap_bound) 20 | if len(workloads): 21 | # avoid the total latency is larger than the setup 22 | if workloads[-1] + req_gap > total_latency - req_gap_bound: 23 | break 24 | workloads.append(workloads[-1] + req_gap) 25 | 26 | return workloads 27 | 28 | def get_quant_config(model_config, quant_bits: int, quant_group_size: int): 29 | qaunt_config = { 30 | 'weight_quantization': { 31 | 'quantized_initialization' : { 32 | 'num_bits': quant_bits, 33 | 'group_size': quant_group_size, 34 | "group_dim": 1, 35 | "symmetric": False 36 | } 37 | } 38 | } 39 | 40 | return qaunt_config 41 | 42 | def model_bytes(config): 43 | h = config.hidden_size 44 | return 2 * (config.num_hidden_layers * ( 45 | # config-attention 46 | h * (3 * h + 1) + h * (h + 1) + 47 | # mlp 48 | h * (4 * h + 1) + h * 4 * (h + 1) + 49 | # layer norm 50 | h * 4) + 51 | # embedding 52 | config.vocab_size * (h + 1)) 53 | 54 | def cache_bytes(config, batch_size, seq_len): 55 | return 2 * batch_size * seq_len * config.num_hidden_layers * config.hidden_size * 2 56 | 57 | def write_gen_benchmark_log(model_size, cache_size, gpu_peak_mem, 58 | prefill_latency, prefill_throughput, 59 | decode_latency, decode_throughput, 60 | total_latency, total_throughput): 61 | 62 | log_str = (f"model size: {model_size/GB:.3f} GB\t" 63 | f"cache size: {cache_size/GB:.3f} GB\t" 64 | f"peak gpu mem: {gpu_peak_mem / GB:.3f} GB\n" 65 | f"prefill latency: {prefill_latency:.3f} s\t" 66 | f"prefill throughput: {prefill_throughput:.3f} token/s\n" 67 | f"decode latency: {decode_latency:.3f} s\t" 68 | f"decode throughput: {decode_throughput:.3f} token/s\n" 69 | f"total latency: {total_latency:.3f} s\t" 70 | f"total throughput: {total_throughput:.3f} token/s") 71 | 72 | return log_str 73 | 74 | def write_peft_benchmark_log(model_size, activation_size, gpu_peak_mem, 75 | forward_latency, forward_throughput, 76 | backward_latency, backward_throughput, 77 | total_latency, total_throughput): 78 | 79 | log_str = (f"model size: {model_size/GB:.3f} GB\t" 80 | f"activation size: {activation_size/GB:.3f} GB\t" 81 | f"peak gpu mem: {gpu_peak_mem / GB:.3f} GB\n" 82 | f"forward latency: {forward_latency:.3f} s\t" 83 | f"forward throughput: {forward_throughput:.3f} sample/s\n" 84 | f"backward latency: {backward_latency:.3f} s\t" 85 | f"backward throughput: {backward_throughput:.3f} sample/s\n" 86 | f"total latency: {total_latency:.3f} s\t" 87 | f"total throughput: {total_throughput:.3f} sample/s") 88 | 89 | return log_str 90 | 91 | 92 | def write_ht_benchmark_log(model_size, activation_size, gpu_peak_mem, 93 | gen_trials, gen_exec_total_latency, gen_exec_throughput, 94 | peft_trials, peft_total_latency, peft_throughput, total_latency): 95 | 96 | log_str = (f"model size: {model_size/GB:.3f} GB\t" 97 | f"activation size: {activation_size/GB:.3f} GB\t" 98 | f"peak gpu mem: {gpu_peak_mem / GB:.3f} GB\n" 99 | f"gen trials: {gen_trials}\t" 100 | f"gen exec total latency: {gen_exec_total_latency:.3f} s\t" 101 | f"gen exec throughput: {gen_exec_throughput:.3f} token/s\n" 102 | f"peft trials: {peft_trials}\t" 103 | f"peft total latency: {peft_total_latency:.3f} s\t" 104 | f"peft throughput: {peft_throughput:.3f} sample/s\n" 105 | f"total latency: {total_latency:.3f} s") 106 | 107 | return log_str 108 | 109 | # add timing hooks 110 | def add_model_hooks(model: torch.nn.Module): 111 | 112 | def start_time_hook(module, input): 113 | if hasattr(module, 'stage') and module.stage == "decode": 114 | return 115 | elif hasattr(module, 'stage') and module.stage == 'prefill': 116 | torch.cuda.synchronize() 117 | module.__start_time__ = time.time() 118 | 119 | def end_time_hook(module, input, output): 120 | if hasattr(module, 'stage') and module.stage == "decode": 121 | return 122 | elif hasattr(module, 'stage') and module.stage == 'prefill': 123 | torch.cuda.synchronize() 124 | module.__duration__ = time.time() - module.__start_time__ 125 | module.stage = "decode" 126 | 127 | if not hasattr(model, '__start_time_hook_handle'): 128 | model.__start_time_hook_handle__ = model.register_forward_pre_hook( 129 | start_time_hook, ) 130 | 131 | if not hasattr(model, '__end_time_hook_handle__'): 132 | model.__end_time_hook_handle__ = model.register_forward_hook( 133 | end_time_hook, ) 134 | 135 | # remove timing hooks 136 | def remove_model_hooks(module): 137 | if hasattr(module, "__start_time_hook_handle__"): 138 | module.__start_time_hook_handle__.remove() 139 | del module.__start_time_hook_handle__ 140 | if hasattr(module, "__end_time_hook_handle__"): 141 | module.__end_time_hook_handle__.remove() 142 | del module.__end_time_hook_handle__ 143 | if hasattr(module, "stage"): 144 | del module.stage 145 | if hasattr(module, "__duration__"): 146 | del module.__duration__ 147 | 148 | -------------------------------------------------------------------------------- /benchmarks/huggingface/hf-gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run Llama3 with huggingface 3 | 4 | Reference: 5 | https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/run_model.py 6 | """ 7 | 8 | import argparse 9 | import gc 10 | import os 11 | import time 12 | 13 | import torch 14 | from transformers import ( 15 | AutoConfig, 16 | AutoModelForCausalLM, 17 | AutoTokenizer, 18 | ) 19 | 20 | import sys 21 | sys.path.append("..") 22 | import utils 23 | 24 | 25 | def get_hf_model( 26 | model_name, 27 | pin_memory, 28 | quant_bits, 29 | quant_group_size, 30 | cache_dir, 31 | ): 32 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 33 | pin_memory = bool(args.pin_memory) 34 | dtype = torch.float16 35 | 36 | if quant_bits == 4: 37 | raise NotImplementedError() 38 | 39 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, torch_dtype=dtype) 40 | model.to(torch.cuda.current_device()) 41 | model = model.eval() 42 | 43 | return model 44 | 45 | 46 | def run_generation( 47 | model_name, 48 | trials, 49 | batch_size, 50 | prompt_len, 51 | gen_len, 52 | local_rank, 53 | pin_memory, 54 | quant_bits, 55 | quant_group_size, 56 | cache_dir 57 | ): 58 | # Load tokenizer 59 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 60 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', cache_dir=cache_dir) 61 | tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token 62 | 63 | print("load model") 64 | with torch.no_grad(): 65 | model = get_hf_model( 66 | model_name, 67 | pin_memory, 68 | quant_bits, 69 | quant_group_size, 70 | cache_dir, 71 | ) 72 | 73 | utils.add_model_hooks(model) 74 | 75 | prompts = ["Paris is the capital city of"] * batch_size 76 | input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", 77 | padding="max_length", max_length=prompt_len) 78 | input_tokens.to(torch.cuda.current_device()) 79 | 80 | # Run generation 81 | print(f"benchmark, prompt_len = {prompt_len}, gen_len = {gen_len}, input_ids.shape = {input_tokens.input_ids.shape}") 82 | 83 | prefill_timings = [] 84 | total_timings = [] 85 | for _ in range(trials): 86 | start = time.time() 87 | with torch.no_grad(): 88 | model.stage = "prefill" 89 | output_ids = model.generate(**input_tokens, max_new_tokens=gen_len, do_sample=False) 90 | prefill_timings.append(model.__duration__) 91 | end = time.time() 92 | total_timings.append(end - start) 93 | 94 | if local_rank != 0: 95 | return 96 | 97 | utils.remove_model_hooks(model) 98 | # Check lengths 99 | input_lens = [len(x) for x in input_tokens.input_ids] 100 | output_lens = [len(x) for x in output_ids] 101 | assert all(x == prompt_len for x in input_lens) 102 | assert all(x == prompt_len + gen_len for x in output_lens) 103 | 104 | # Log output 105 | print(f"Summary:") 106 | print(f"total_timings = {total_timings}") 107 | print(f"prefill_timings = {prefill_timings}") 108 | total_latency = total_timings[-1] 109 | prefill_latency = prefill_timings[-1] 110 | 111 | prefill_throughput = batch_size * prompt_len / prefill_latency 112 | decode_latency = total_latency - prefill_latency 113 | decode_throughput = batch_size * (gen_len - 1) / max(decode_latency, 1e-10) 114 | num_generated_tokens = batch_size * gen_len 115 | total_throughput = num_generated_tokens / total_latency 116 | gpu_peak_mem = torch.cuda.max_memory_allocated(torch.device("cuda")) 117 | 118 | model_size = utils.model_bytes(config) 119 | cache_size = utils.cache_bytes(config, batch_size, prompt_len + gen_len) 120 | log_str = utils.write_gen_benchmark_log( 121 | model_size, 122 | cache_size, 123 | gpu_peak_mem, 124 | prefill_latency, 125 | prefill_throughput, 126 | decode_latency, 127 | decode_throughput, 128 | total_latency, 129 | total_throughput, 130 | ) 131 | print(log_str) 132 | 133 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 134 | show_str = "Outputs:\n" + 30 * "-" + "\n" 135 | for i in [0, (len(outputs) - 1) // 2, len(outputs) - 1]: 136 | show_str += f"{i}: {outputs[i]}\n" 137 | show_str += 30 * "-" + "\n" 138 | print(show_str) 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument("--model_name", "-m", type=str, default="meta-llama/Meta-Llama-3-8B", help="model name or path") 144 | parser.add_argument("--trials", type=int, default=3, help="Number of token generation iterations") 145 | parser.add_argument("--batch_size", type=int, default=1) 146 | parser.add_argument("--prompt_len", type=int, default=512, help="prompt length") 147 | parser.add_argument("--gen_len", type=int, default=32, help="number of tokens to generate") 148 | parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank for distributed inference") 149 | parser.add_argument("--pin_memory", type=int, default=0, help="whether to pinned CPU memory for ZeRO offloading") 150 | parser.add_argument("--quant_bits", type=int, default=16, help="model weight quantization bits; either 4 or 8") 151 | parser.add_argument("--quant_group_size", type=int, default=64, help="model weight quantization group size") 152 | parser.add_argument("--cache_dir", type=str, default="/scratch/yonghe", help="cache dir for model name") 153 | args = parser.parse_args() 154 | 155 | gc.collect() 156 | 157 | run_generation( 158 | args.model_name, 159 | args.trials, 160 | args.batch_size, 161 | args.prompt_len, 162 | args.gen_len, 163 | args.local_rank, 164 | args.pin_memory, 165 | args.quant_bits, 166 | args.quant_group_size, 167 | args.cache_dir, 168 | ) 169 | 170 | -------------------------------------------------------------------------------- /benchmarks/huggingface/hf-peft-gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run Llama3 with huggingface 3 | 4 | Reference: 5 | https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/run_model.py 6 | """ 7 | 8 | import argparse 9 | import gc 10 | import os 11 | import time 12 | 13 | import torch 14 | from transformers import ( 15 | AutoConfig, 16 | AutoModelForCausalLM, 17 | AutoTokenizer, 18 | ) 19 | from peft import LoraConfig, get_peft_model 20 | 21 | import sys 22 | sys.path.append("..") 23 | import utils 24 | 25 | 26 | def get_hf_model( 27 | model_name, 28 | pin_memory, 29 | quant_bits, 30 | quant_group_size, 31 | cache_dir, 32 | ): 33 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 34 | pin_memory = bool(args.pin_memory) 35 | dtype = torch.float16 36 | 37 | if quant_bits == 4: 38 | raise NotImplementedError() 39 | 40 | base_model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, torch_dtype=dtype) 41 | lora_config = LoraConfig( 42 | r=64, 43 | lora_alpha=32, 44 | target_modules=["q_proj", "v_proj"], 45 | lora_dropout=0.01, 46 | bias="none", 47 | task_type="CAUSAL_LM", 48 | ) 49 | model = get_peft_model(base_model, lora_config) 50 | model.to(torch.cuda.current_device()) 51 | model = model.eval() 52 | 53 | return model 54 | 55 | 56 | def run_generation( 57 | model_name, 58 | trials, 59 | batch_size, 60 | prompt_len, 61 | gen_len, 62 | local_rank, 63 | pin_memory, 64 | quant_bits, 65 | quant_group_size, 66 | cache_dir 67 | ): 68 | # Load tokenizer 69 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 70 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', cache_dir=cache_dir) 71 | tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token 72 | 73 | print("load model") 74 | with torch.no_grad(): 75 | model = get_hf_model( 76 | model_name, 77 | pin_memory, 78 | quant_bits, 79 | quant_group_size, 80 | cache_dir, 81 | ) 82 | 83 | utils.add_model_hooks(model.base_model.model) 84 | 85 | prompts = ["Paris is the capital city of"] * batch_size 86 | input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", 87 | padding="max_length", max_length=prompt_len) 88 | input_tokens.to(torch.cuda.current_device()) 89 | 90 | # Run generation 91 | print(f"benchmark, prompt_len = {prompt_len}, gen_len = {gen_len}, input_ids.shape = {input_tokens.input_ids.shape}") 92 | 93 | prefill_timings = [] 94 | total_timings = [] 95 | for _ in range(trials): 96 | start = time.time() 97 | with torch.no_grad(): 98 | model.base_model.model.stage = "prefill" 99 | output_ids = model.generate(**input_tokens, max_new_tokens=gen_len, do_sample=False) 100 | prefill_timings.append(model.base_model.model.__duration__) 101 | end = time.time() 102 | total_timings.append(end - start) 103 | 104 | if local_rank != 0: 105 | return 106 | 107 | utils.remove_model_hooks(model.base_model.model) 108 | # Check lengths 109 | input_lens = [len(x) for x in input_tokens.input_ids] 110 | output_lens = [len(x) for x in output_ids] 111 | assert all(x == prompt_len for x in input_lens) 112 | assert all(x == prompt_len + gen_len for x in output_lens) 113 | 114 | # Log output 115 | print(f"Summary:") 116 | print(f"total_timings = {total_timings}") 117 | print(f"prefill_timings = {prefill_timings}") 118 | total_latency = total_timings[-1] 119 | prefill_latency = prefill_timings[-1] 120 | 121 | prefill_throughput = batch_size * prompt_len / prefill_latency 122 | decode_latency = total_latency - prefill_latency 123 | decode_throughput = batch_size * (gen_len - 1) / max(decode_latency, 1e-10) 124 | num_generated_tokens = batch_size * gen_len 125 | total_throughput = num_generated_tokens / total_latency 126 | gpu_peak_mem = torch.cuda.max_memory_allocated(torch.device("cuda")) 127 | 128 | model_size = utils.model_bytes(config) 129 | cache_size = utils.cache_bytes(config, batch_size, prompt_len + gen_len) 130 | log_str = utils.write_gen_benchmark_log( 131 | model_size, 132 | cache_size, 133 | gpu_peak_mem, 134 | prefill_latency, 135 | prefill_throughput, 136 | decode_latency, 137 | decode_throughput, 138 | total_latency, 139 | total_throughput, 140 | ) 141 | print(log_str) 142 | 143 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 144 | show_str = "Outputs:\n" + 30 * "-" + "\n" 145 | for i in [0, (len(outputs) - 1) // 2, len(outputs) - 1]: 146 | show_str += f"{i}: {outputs[i]}\n" 147 | show_str += 30 * "-" + "\n" 148 | print(show_str) 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument("--model_name", "-m", type=str, default="meta-llama/Meta-Llama-3-8B", help="model name or path") 154 | parser.add_argument("--trials", type=int, default=3, help="Number of token generation iterations") 155 | parser.add_argument("--batch_size", type=int, default=1) 156 | parser.add_argument("--prompt_len", type=int, default=512, help="prompt length") 157 | parser.add_argument("--gen_len", type=int, default=32, help="number of tokens to generate") 158 | parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank for distributed inference") 159 | parser.add_argument("--pin_memory", type=int, default=0, help="whether to pinned CPU memory for ZeRO offloading") 160 | parser.add_argument("--quant_bits", type=int, default=16, help="model weight quantization bits; either 4 or 8") 161 | parser.add_argument("--quant_group_size", type=int, default=64, help="model weight quantization group size") 162 | parser.add_argument("--cache_dir", type=str, default="/scratch/yonghe", help="cache dir for model name") 163 | args = parser.parse_args() 164 | 165 | gc.collect() 166 | 167 | run_generation( 168 | args.model_name, 169 | args.trials, 170 | args.batch_size, 171 | args.prompt_len, 172 | args.gen_len, 173 | args.local_rank, 174 | args.pin_memory, 175 | args.quant_bits, 176 | args.quant_group_size, 177 | args.cache_dir, 178 | ) 179 | 180 | -------------------------------------------------------------------------------- /benchmarks/huggingface/hf-peft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run Llama3 with huggingface 3 | 4 | Reference: 5 | https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/run_model.py 6 | """ 7 | 8 | import argparse 9 | import gc 10 | import itertools 11 | import os 12 | import time 13 | 14 | import datasets 15 | import torch 16 | from transformers import ( 17 | AutoConfig, 18 | AutoModelForCausalLM, 19 | AutoTokenizer, 20 | DataCollatorForLanguageModeling, 21 | ) 22 | from peft import LoraConfig, get_peft_model 23 | 24 | import sys 25 | sys.path.append("..") 26 | import utils 27 | 28 | 29 | def get_hf_model( 30 | model_name, 31 | pin_memory, 32 | quant_bits, 33 | quant_group_size, 34 | cache_dir, 35 | ): 36 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 37 | pin_memory = bool(args.pin_memory) 38 | dtype = torch.float16 39 | 40 | if quant_bits == 4: 41 | raise NotImplementedError() 42 | 43 | base_model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, torch_dtype=dtype) 44 | lora_config = LoraConfig( 45 | r=64, 46 | lora_alpha=32, 47 | target_modules=["q_proj", "v_proj"], 48 | lora_dropout=0.01, 49 | bias="none", 50 | task_type="CAUSAL_LM", 51 | ) 52 | model = get_peft_model(base_model, lora_config) 53 | model.to(torch.cuda.current_device()) 54 | 55 | return model 56 | 57 | 58 | def run_peft( 59 | model_name, 60 | dataset_name, 61 | trials, 62 | batch_size, 63 | gradient_accumulation_steps, 64 | seq_len, 65 | local_rank, 66 | pin_memory, 67 | quant_bits, 68 | quant_group_size, 69 | cache_dir 70 | ): 71 | # Load tokenizer 72 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 73 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', cache_dir=cache_dir) 74 | tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token 75 | 76 | print("load model") 77 | model = get_hf_model( 78 | model_name, 79 | pin_memory, 80 | quant_bits, 81 | quant_group_size, 82 | cache_dir, 83 | ) 84 | 85 | def prepare_alpaca(sample_raw): 86 | template = { 87 | "description": "A shorter template to experiment with.", 88 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 89 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", 90 | "response_split": "### Response:" 91 | } 92 | if len(sample_raw["input"]): 93 | sample_text = template["prompt_input"].format( 94 | instruction=sample_raw["instruction"], input=sample_raw["input"] 95 | ) 96 | else: 97 | sample_text = template["prompt_no_input"].format( 98 | instruction=sample_raw["instruction"] 99 | ) 100 | if len(sample_raw["output"]): 101 | sample_text += sample_raw["output"] 102 | sample_tokens = tokenizer(sample_text, padding='max_length', truncation=True, max_length=seq_len) 103 | return sample_tokens 104 | 105 | dataset = datasets.load_dataset(dataset_name, cache_dir=cache_dir) 106 | dataset = dataset.map(lambda sample_raw: prepare_alpaca(sample_raw), remove_columns=dataset["train"].column_names) 107 | dataloader = torch.utils.data.DataLoader( 108 | dataset["train"], shuffle=True, collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False), 109 | batch_size=batch_size, pin_memory=pin_memory, 110 | ) 111 | dataloader_iter = itertools.cycle(iter(enumerate(dataloader))) 112 | 113 | optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) 114 | 115 | # Run peft 116 | # print(f"benchmark, seq_len = {seq_len}, input_ids.shape = {inputs.input_ids.shape}") 117 | 118 | forward_timings = [] 119 | total_timings = [] 120 | for _ in range(trials): 121 | start = time.time() 122 | 123 | step, inputs = next(dataloader_iter) 124 | inputs.to(torch.cuda.current_device()) 125 | outputs = model(**inputs) 126 | forward_timings.append(time.time() - start) 127 | 128 | loss = outputs.loss 129 | loss.backward() 130 | if step % gradient_accumulation_steps == 0: 131 | optimizer.step() 132 | optimizer.zero_grad() 133 | 134 | end = time.time() 135 | total_timings.append(end - start) 136 | 137 | if local_rank != 0: 138 | return 139 | 140 | # Check lengths 141 | input_lens = [len(x) for x in inputs.input_ids] 142 | output_lens = [len(x) for x in outputs.logits] 143 | assert all(x == seq_len for x in input_lens) 144 | assert all(x == seq_len for x in output_lens) 145 | 146 | # Log output 147 | print(f"Summary:") 148 | print(f"total_timings = {total_timings}") 149 | print(f"forward_timings = {forward_timings}") 150 | total_latency = total_timings[-1] 151 | forward_latency = forward_timings[-1] 152 | backward_latency = total_latency - forward_latency 153 | 154 | total_throughput = batch_size / total_latency 155 | forward_throughput = batch_size / forward_latency 156 | backward_throughput = batch_size / backward_latency 157 | gpu_peak_mem = torch.cuda.max_memory_allocated(torch.device("cuda")) 158 | 159 | model_size = utils.model_bytes(config) 160 | log_str = utils.write_peft_benchmark_log( 161 | model_size, 162 | 0, 163 | gpu_peak_mem, 164 | forward_latency, 165 | forward_throughput, 166 | backward_latency, 167 | backward_throughput, 168 | total_latency, 169 | total_throughput, 170 | ) 171 | print(log_str) 172 | 173 | 174 | if __name__ == "__main__": 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument("--model_name", "-m", type=str, default="meta-llama/Meta-Llama-3-8B", help="model name or path") 177 | parser.add_argument("--dataset_name", type=str, default="yahma/alpaca-cleaned", help="dataset name or path") 178 | parser.add_argument("--trials", type=int, default=5, help="Number of peft iterations") 179 | parser.add_argument("--batch_size", type=int, default=1) 180 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 181 | parser.add_argument("--seq_len", type=int, default=256, help="sequence length") 182 | parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank for distributed inference") 183 | parser.add_argument("--pin_memory", type=int, default=0, help="whether to pinned CPU memory for ZeRO offloading") 184 | parser.add_argument("--quant_bits", type=int, default=16, help="model weight quantization bits; either 4 or 8") 185 | parser.add_argument("--quant_group_size", type=int, default=64, help="model weight quantization group size") 186 | parser.add_argument("--cache_dir", type=str, default="/scratch/yonghe", help="cache dir for model name") 187 | args = parser.parse_args() 188 | 189 | gc.collect() 190 | 191 | run_peft( 192 | args.model_name, 193 | args.dataset_name, 194 | args.trials, 195 | args.batch_size, 196 | args.gradient_accumulation_steps, 197 | args.seq_len, 198 | args.local_rank, 199 | args.pin_memory, 200 | args.quant_bits, 201 | args.quant_group_size, 202 | args.cache_dir, 203 | ) 204 | 205 | -------------------------------------------------------------------------------- /benchmarks/fineinfer/fi-gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run Llama2 with FineInfer 3 | 4 | Reference: 5 | https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/run_model.py 6 | """ 7 | 8 | import argparse 9 | import copy 10 | import gc 11 | import inspect 12 | import os 13 | import time 14 | 15 | import torch 16 | import transformers 17 | from transformers import ( 18 | AutoConfig, 19 | AutoModelForCausalLM, 20 | AutoTokenizer, 21 | ) 22 | 23 | import sys 24 | sys.path.append("../..") 25 | from benchmarks import utils 26 | from fineinfer.engine import llm_engine 27 | 28 | 29 | def get_hf_model( 30 | model_name, 31 | pin_memory, 32 | quant_bits, 33 | quant_group_size, 34 | cache_dir, 35 | ): 36 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 37 | pin_memory = bool(args.pin_memory) 38 | dtype = torch.float16 39 | 40 | if quant_bits == 4: 41 | raise NotImplementedError() 42 | 43 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, torch_dtype=dtype) 44 | model.to(torch.cuda.current_device()) 45 | model = model.eval() 46 | 47 | return model 48 | 49 | 50 | def run_generation( 51 | model_name, 52 | trials, 53 | batch_size, 54 | prompt_len, 55 | gen_len, 56 | local_rank, 57 | pin_memory, 58 | quant_bits, 59 | quant_group_size, 60 | cache_dir 61 | ): 62 | # Load tokenizer 63 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 64 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', cache_dir=cache_dir) 65 | tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token 66 | 67 | print("load model") 68 | with torch.no_grad(): 69 | model = get_hf_model( 70 | model_name, 71 | pin_memory, 72 | quant_bits, 73 | quant_group_size, 74 | cache_dir, 75 | ) 76 | 77 | utils.add_model_hooks(model) 78 | 79 | prompts = ["Paris is the capital city of"] * batch_size 80 | input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", 81 | padding="max_length", max_length=prompt_len) 82 | input_tokens.to(torch.cuda.current_device()) 83 | 84 | # Run generation 85 | print(f"benchmark, prompt_len = {prompt_len}, gen_len = {gen_len}, input_ids.shape = {input_tokens.input_ids.shape}") 86 | 87 | # Prepare 88 | prepare_output = llm_engine.prepare_inputs_and_config(self=model, 89 | **input_tokens, max_new_tokens=gen_len, do_sample=False) 90 | 91 | input_ids = copy.deepcopy(prepare_output.input_ids) 92 | model_kwargs = copy.deepcopy(prepare_output.model_kwargs) 93 | batch_meta = llm_engine.BatchMeta( 94 | prompt_lens = torch.full(size=(batch_size,), fill_value=prompt_len, 95 | dtype=torch.long, device=torch.cuda.current_device()), 96 | gen_lens = torch.full(size=(batch_size,), fill_value=gen_len, 97 | dtype=torch.long, device=torch.cuda.current_device()), 98 | cur_lens = torch.full(size=(batch_size,), fill_value=prompt_len, 99 | dtype=torch.long, device=torch.cuda.current_device()), 100 | ) 101 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=torch.cuda.current_device()) 102 | this_peer_finished = False 103 | 104 | output_ids = [] 105 | prefill_timings = [] 106 | total_timings = [] 107 | 108 | start = time.time() 109 | with torch.no_grad(): 110 | model.stage = "prefill" 111 | model_kwargs = model._get_initial_cache_position(input_ids, model_kwargs) 112 | 113 | while True: 114 | if input_ids.shape[0] + len(output_ids) < trials * batch_size \ 115 | and torch.min(batch_meta.cur_lens) == prompt_len + 5: 116 | new_unfinished_sequences = torch.ones(prepare_output.input_ids.shape[0], 117 | dtype=torch.long, device=input_ids.device) 118 | new_input_ids = copy.deepcopy(prepare_output.input_ids) 119 | new_model_kwargs = copy.deepcopy(prepare_output.model_kwargs) 120 | 121 | prefill_timings.append(model.__duration__) 122 | model.stage = "prefill" 123 | new_model_kwargs = model._get_initial_cache_position(new_input_ids, new_model_kwargs) 124 | 125 | new_unfinished_sequences, new_input_ids, new_model_kwargs = llm_engine.generate_step( 126 | self=model, 127 | unfinished_sequences=new_unfinished_sequences, 128 | input_ids=new_input_ids, 129 | logits_processor=prepare_output.logits_processor, 130 | stopping_criteria=prepare_output.stopping_criteria, 131 | generation_config=prepare_output.generation_config, 132 | **new_model_kwargs 133 | ) 134 | 135 | unfinished_sequences, input_ids, model_kwargs = llm_engine.add_new_request( 136 | unfinished_sequences = unfinished_sequences, 137 | input_ids = input_ids, 138 | model_kwargs = model_kwargs, 139 | new_unfinished_sequences = new_unfinished_sequences, 140 | new_input_ids = new_input_ids, 141 | new_model_kwargs = new_model_kwargs, 142 | ) 143 | 144 | batch_meta.prompt_lens = torch.cat(tensors=(batch_meta.prompt_lens, \ 145 | torch.full(size=(batch_size,), fill_value=prompt_len, \ 146 | dtype=torch.long, device=torch.cuda.current_device())), dim=0) 147 | batch_meta.gen_lens = torch.cat(tensors=(batch_meta.gen_lens, \ 148 | torch.full(size=(batch_size,), fill_value=gen_len, \ 149 | device=torch.cuda.current_device(), dtype=torch.long)), dim=0) 150 | batch_meta.cur_lens = torch.cat(tensors=(batch_meta.cur_lens, \ 151 | torch.full(size=(batch_size,), fill_value=prompt_len + 1, \ 152 | device=torch.cuda.current_device(), dtype=torch.long)), dim=0) 153 | 154 | unfinished_sequences, input_ids, model_kwargs = llm_engine.generate_step( 155 | self=model, 156 | unfinished_sequences=unfinished_sequences, 157 | input_ids=input_ids, 158 | logits_processor=prepare_output.logits_processor, 159 | stopping_criteria=prepare_output.stopping_criteria, 160 | generation_config=prepare_output.generation_config, 161 | **model_kwargs 162 | ) 163 | 164 | batch_meta.cur_lens += 1 165 | 166 | this_peer_finished = unfinished_sequences.max() == 0 167 | 168 | if not model._has_unfinished_sequences(this_peer_finished, prepare_output.synced_gpus, input_ids.device): 169 | unfinished_sequences, input_ids, model_kwargs, batch_meta, output_ids = llm_engine.remove_old_request( 170 | unfinished_sequences=unfinished_sequences, 171 | input_ids=input_ids, 172 | model_kwargs=model_kwargs, 173 | batch_meta=batch_meta, 174 | output_ids=output_ids, 175 | ) 176 | 177 | if len(output_ids) >= trials * batch_size: 178 | break 179 | 180 | end = time.time() 181 | total_timings.append(end - start) 182 | prefill_timings.append(model.__duration__) 183 | 184 | if local_rank != 0: 185 | return 186 | 187 | utils.remove_model_hooks(model) 188 | # Check lengths 189 | input_lens = [len(x) for x in input_tokens.input_ids] 190 | output_lens = [len(x) for x in output_ids] 191 | assert all(x == prompt_len for x in input_lens) 192 | assert all(x == prompt_len + gen_len for x in output_lens) 193 | 194 | # Log output 195 | print(f"Summary:") 196 | print(f"total_timings = {total_timings}") 197 | print(f"prefill_timings = {prefill_timings}") 198 | total_latency = total_timings[-1] 199 | prefill_latency = prefill_timings[-1] 200 | 201 | prefill_throughput = batch_size * prompt_len / prefill_latency 202 | decode_latency = total_latency - sum(prefill_timings) 203 | decode_throughput = trials * batch_size * (gen_len - 1) / max(decode_latency, 1e-10) 204 | num_generated_tokens = trials * batch_size * gen_len 205 | total_throughput = num_generated_tokens / total_latency 206 | gpu_peak_mem = torch.cuda.max_memory_allocated(torch.cuda.current_device()) 207 | 208 | model_size = utils.model_bytes(config) 209 | cache_size = utils.cache_bytes(config, batch_size, prompt_len + gen_len) 210 | log_str = utils.write_gen_benchmark_log( 211 | model_size, 212 | cache_size, 213 | gpu_peak_mem, 214 | prefill_latency, 215 | prefill_throughput, 216 | decode_latency, 217 | decode_throughput, 218 | total_latency, 219 | total_throughput, 220 | ) 221 | print(log_str) 222 | 223 | outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in output_ids] 224 | show_str = "Outputs:\n" + 30 * "-" + "\n" 225 | for i in [0, (len(outputs) - 1) // 2, len(outputs) - 1]: 226 | show_str += f"{i}: {outputs[i]}\n" 227 | show_str += 30 * "-" + "\n" 228 | print(show_str) 229 | 230 | 231 | if __name__ == "__main__": 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument("--model_name", "-m", type=str, default="meta-llama/Meta-Llama-3-8B", help="model name or path") 234 | parser.add_argument("--trials", type=int, default=3, help="Number of token generation iterations") 235 | parser.add_argument("--batch_size", type=int, default=1) 236 | parser.add_argument("--prompt_len", type=int, default=512, help="prompt length") 237 | parser.add_argument("--gen_len", type=int, default=32, help="number of tokens to generate") 238 | parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank for distributed inference") 239 | parser.add_argument("--pin_memory", type=int, default=0, help="whether to pinned CPU memory for ZeRO offloading") 240 | parser.add_argument("--quant_bits", type=int, default=16, help="model weight quantization bits; either 4 or 8") 241 | parser.add_argument("--quant_group_size", type=int, default=64, help="model weight quantization group size") 242 | parser.add_argument("--cache_dir", type=str, default="/scratch/yonghe", help="cache dir for model name") 243 | args = parser.parse_args() 244 | 245 | gc.collect() 246 | 247 | run_generation( 248 | args.model_name, 249 | args.trials, 250 | args.batch_size, 251 | args.prompt_len, 252 | args.gen_len, 253 | args.local_rank, 254 | args.pin_memory, 255 | args.quant_bits, 256 | args.quant_group_size, 257 | args.cache_dir, 258 | ) 259 | -------------------------------------------------------------------------------- /benchmarks/fineinfer/baseline-ht.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run Llama2 with huggingface 3 | Reference: 4 | https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/run_model.py 5 | """ 6 | 7 | import argparse 8 | import copy 9 | import gc 10 | import itertools 11 | import os 12 | import time 13 | 14 | import datasets 15 | import torch 16 | from transformers import ( 17 | AutoConfig, 18 | AutoModelForCausalLM, 19 | AutoTokenizer, 20 | DataCollatorForLanguageModeling, 21 | ) 22 | from peft import LoraConfig, get_peft_model 23 | 24 | import sys 25 | sys.path.append("../..") 26 | from benchmarks import utils 27 | from fineinfer.engine import llm_engine 28 | 29 | 30 | def get_hf_model( 31 | model_name, 32 | adapter_names, 33 | pin_memory, 34 | quant_bits, 35 | quant_group_size, 36 | cache_dir, 37 | ): 38 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 39 | pin_memory = bool(args.pin_memory) 40 | dtype = torch.float16 41 | 42 | if quant_bits == 4: 43 | raise NotImplementedError() 44 | 45 | base_model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, torch_dtype=dtype) 46 | lora_config = LoraConfig( 47 | r=64, 48 | lora_alpha=32, 49 | target_modules=["q_proj", "v_proj"], 50 | lora_dropout=0.01, 51 | bias="none", 52 | task_type="CAUSAL_LM", 53 | ) 54 | model = get_peft_model(base_model, lora_config, adapter_name=adapter_names[0]) 55 | for idx, name in enumerate(adapter_names): 56 | if idx: 57 | model.add_adapter(name, lora_config) 58 | model.to(torch.cuda.current_device()) 59 | 60 | return model 61 | 62 | 63 | def run_ht( 64 | model_name, 65 | adapter_size, 66 | dataset_name, 67 | batch_size, 68 | prompt_len, 69 | gen_len, 70 | gradient_accumulation_steps, 71 | seq_len, 72 | local_rank, 73 | pin_memory, 74 | quant_bits, 75 | quant_group_size, 76 | cache_dir 77 | ): 78 | # Load tokenizer 79 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 80 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', cache_dir=cache_dir) 81 | tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token 82 | 83 | print("load model") 84 | adapter_names = ["adapter_" + str(i) for i in range(adapter_size)] 85 | model = get_hf_model( 86 | model_name, 87 | adapter_names, 88 | pin_memory, 89 | quant_bits, 90 | quant_group_size, 91 | cache_dir, 92 | ) 93 | 94 | prompts = ["Paris is the capital city of"] * batch_size 95 | gen_inputs = tokenizer.batch_encode_plus(prompts, return_tensors="pt", 96 | padding="max_length", max_length=prompt_len) 97 | gen_inputs.to(torch.cuda.current_device()) 98 | prepare_output = llm_engine.prepare_inputs_and_config(self=model, 99 | **gen_inputs, max_new_tokens=gen_len, do_sample=False) 100 | 101 | def prepare_alpaca(sample_raw): 102 | template = { 103 | "description": "A shorter template to experiment with.", 104 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 105 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", 106 | "response_split": "### Response:" 107 | } 108 | if len(sample_raw["input"]): 109 | sample_text = template["prompt_input"].format( 110 | instruction=sample_raw["instruction"], input=sample_raw["input"] 111 | ) 112 | else: 113 | sample_text = template["prompt_no_input"].format( 114 | instruction=sample_raw["instruction"] 115 | ) 116 | if len(sample_raw["output"]): 117 | sample_text += sample_raw["output"] 118 | sample_tokens = tokenizer(sample_text, padding='max_length', truncation=True, max_length=seq_len) 119 | return sample_tokens 120 | 121 | dataset = datasets.load_dataset(dataset_name, cache_dir=cache_dir) 122 | dataset = dataset.map(lambda sample_raw: prepare_alpaca(sample_raw), remove_columns=dataset["train"].column_names) 123 | dataloader = torch.utils.data.DataLoader( 124 | dataset["train"], shuffle=True, collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False), 125 | batch_size=batch_size, pin_memory=pin_memory, 126 | ) 127 | dataloader_iter = itertools.cycle(iter(enumerate(dataloader))) 128 | 129 | optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) 130 | 131 | # Run heterogeneous workload 132 | print(f"benchmark, seq_len = {seq_len}, prompt_len = {prompt_len}, gen_len = {gen_len}") 133 | 134 | total_latency = 60.0 135 | req_gap_bound = 2.0 136 | ht_workloads = utils.get_ht_workloads(total_latency, req_gap_bound) 137 | gen_trials = len(ht_workloads) 138 | cursor = 0 139 | 140 | gen_outputs = [] 141 | gen_timings = [] 142 | peft_timings = [] 143 | start = time.time() 144 | while time.time() - start < total_latency: 145 | if cursor < gen_trials and ht_workloads[cursor] <= time.time() - start: 146 | cursor += 1 147 | with torch.no_grad(): 148 | input_ids = copy.deepcopy(prepare_output.input_ids) 149 | model_kwargs = copy.deepcopy(prepare_output.model_kwargs) 150 | model_kwargs = model._get_initial_cache_position(input_ids, model_kwargs) 151 | batch_meta = llm_engine.BatchMeta( 152 | prompt_lens = torch.full(size=(batch_size,), fill_value=prompt_len, 153 | dtype=torch.long, device=torch.cuda.current_device()), 154 | gen_lens = torch.full(size=(batch_size,), fill_value=gen_len, 155 | dtype=torch.long, device=torch.cuda.current_device()), 156 | cur_lens = torch.full(size=(batch_size,), fill_value=prompt_len, 157 | dtype=torch.long, device=torch.cuda.current_device()), 158 | ) 159 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=torch.cuda.current_device()) 160 | 161 | while True: 162 | unfinished_sequences, input_ids, model_kwargs = llm_engine.generate_step( 163 | self=model, 164 | unfinished_sequences=unfinished_sequences, 165 | input_ids=input_ids, 166 | logits_processor=prepare_output.logits_processor, 167 | stopping_criteria=prepare_output.stopping_criteria, 168 | generation_config=prepare_output.generation_config, 169 | **model_kwargs 170 | ) 171 | batch_meta.cur_lens += 1 172 | 173 | this_peer_finished = unfinished_sequences.max() == 0 174 | 175 | if not model._has_unfinished_sequences(this_peer_finished, prepare_output.synced_gpus, input_ids.device): 176 | gen_timings.append(time.time() - start - ht_workloads[int(len(gen_outputs) / batch_size)]) 177 | unfinished_sequences, input_ids, model_kwargs, batch_meta, output_ids = llm_engine.remove_old_request( 178 | unfinished_sequences=unfinished_sequences, 179 | input_ids=input_ids, 180 | model_kwargs=model_kwargs, 181 | batch_meta=batch_meta, 182 | output_ids=gen_outputs, 183 | ) 184 | 185 | if batch_meta.cur_lens.shape[0] == 0: 186 | break 187 | 188 | if cursor < gen_trials and ht_workloads[cursor] <= time.time() - start: 189 | cursor += 1 190 | new_unfinished_sequences = torch.ones(prepare_output.input_ids.shape[0], 191 | dtype=torch.long, device=input_ids.device) 192 | new_input_ids = copy.deepcopy(prepare_output.input_ids) 193 | new_model_kwargs = copy.deepcopy(prepare_output.model_kwargs) 194 | new_model_kwargs = model._get_initial_cache_position(new_input_ids, new_model_kwargs) 195 | 196 | new_unfinished_sequences, new_input_ids, new_model_kwargs = llm_engine.generate_step( 197 | self=model, 198 | unfinished_sequences=new_unfinished_sequences, 199 | input_ids=new_input_ids, 200 | logits_processor=prepare_output.logits_processor, 201 | stopping_criteria=prepare_output.stopping_criteria, 202 | generation_config=prepare_output.generation_config, 203 | **new_model_kwargs 204 | ) 205 | 206 | unfinished_sequences, input_ids, model_kwargs = llm_engine.add_new_request( 207 | unfinished_sequences = unfinished_sequences, 208 | input_ids = input_ids, 209 | model_kwargs = model_kwargs, 210 | new_unfinished_sequences = new_unfinished_sequences, 211 | new_input_ids = new_input_ids, 212 | new_model_kwargs = new_model_kwargs, 213 | ) 214 | 215 | batch_meta.prompt_lens = torch.cat(tensors=(batch_meta.prompt_lens, \ 216 | torch.full(size=(batch_size,), fill_value=prompt_len, \ 217 | dtype=torch.long, device=torch.cuda.current_device())), dim=0) 218 | batch_meta.gen_lens = torch.cat(tensors=(batch_meta.gen_lens, \ 219 | torch.full(size=(batch_size,), fill_value=gen_len, \ 220 | device=torch.cuda.current_device(), dtype=torch.long)), dim=0) 221 | batch_meta.cur_lens = torch.cat(tensors=(batch_meta.cur_lens, \ 222 | torch.full(size=(batch_size,), fill_value=prompt_len + 1, \ 223 | device=torch.cuda.current_device(), dtype=torch.long)), dim=0) 224 | 225 | peft_start_time = time.time() 226 | step, peft_inputs = next(dataloader_iter) 227 | peft_inputs.to(torch.cuda.current_device()) 228 | model.set_adapter(adapter_names[0]) 229 | peft_outputs = model(**peft_inputs) 230 | 231 | loss = peft_outputs.loss 232 | loss.backward() 233 | if step % gradient_accumulation_steps == 0: 234 | optimizer.step() 235 | optimizer.zero_grad() 236 | peft_timings.append(time.time() - peft_start_time) 237 | 238 | total_latency = time.time() - start 239 | 240 | if local_rank != 0: 241 | return 242 | 243 | # Check lengths 244 | gen_input_lens = [len(x) for x in gen_inputs.input_ids] 245 | gen_output_lens = [len(x) for x in gen_outputs] 246 | assert all(x == prompt_len for x in gen_input_lens) 247 | assert all(x == prompt_len + gen_len for x in gen_output_lens) 248 | peft_input_lens = [len(x) for x in peft_inputs.input_ids] 249 | peft_output_lens = [len(x) for x in peft_outputs.logits] 250 | assert all(x == seq_len for x in peft_input_lens) 251 | assert all(x == seq_len for x in peft_output_lens) 252 | 253 | # Log output 254 | print(f"Summary:") 255 | print(f"gen_timings = {gen_timings[-3:]}") 256 | print(f"peft_timings = {peft_timings[-3:]}") 257 | 258 | gen_exec_total_latency = total_latency - sum(peft_timings) 259 | gen_exec_throughput = gen_trials * batch_size * gen_len / gen_exec_total_latency 260 | 261 | peft_trials = len(peft_timings) 262 | peft_total_latency = sum(peft_timings) 263 | peft_throughput = peft_trials * batch_size / peft_total_latency 264 | 265 | gpu_peak_mem = torch.cuda.max_memory_allocated(torch.device("cuda")) 266 | model_size = utils.model_bytes(config) 267 | 268 | log_str = utils.write_ht_benchmark_log( 269 | model_size, 270 | 0, 271 | gpu_peak_mem, 272 | gen_trials, 273 | gen_exec_total_latency, 274 | gen_exec_throughput, 275 | peft_trials, 276 | peft_total_latency, 277 | peft_throughput, 278 | total_latency, 279 | ) 280 | print(log_str) 281 | 282 | 283 | if __name__ == "__main__": 284 | parser = argparse.ArgumentParser() 285 | parser.add_argument("--model_name", "-m", type=str, default="meta-llama/Meta-Llama-3-8B", help="model name or path") 286 | parser.add_argument("--adapter_size", type=int, default=2, help="lora adapters swapping") 287 | parser.add_argument("--dataset_name", type=str, default="yahma/alpaca-cleaned", help="dataset name or path") 288 | parser.add_argument("--batch_size", type=int, default=1) 289 | parser.add_argument("--prompt_len", type=int, default=512, help="prompt length") 290 | parser.add_argument("--gen_len", type=int, default=32, help="number of tokens to generate") 291 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 292 | parser.add_argument("--seq_len", type=int, default=256, help="sequence length") 293 | parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank for distributed inference") 294 | parser.add_argument("--pin_memory", type=int, default=0, help="whether to pinned CPU memory for ZeRO offloading") 295 | parser.add_argument("--quant_bits", type=int, default=16, help="model weight quantization bits; either 4 or 8") 296 | parser.add_argument("--quant_group_size", type=int, default=64, help="model weight quantization group size") 297 | parser.add_argument("--cache_dir", type=str, default="/scratch/yonghe", help="cache dir for model name") 298 | args = parser.parse_args() 299 | 300 | gc.collect() 301 | 302 | run_ht( 303 | args.model_name, 304 | args.adapter_size, 305 | args.dataset_name, 306 | args.batch_size, 307 | args.prompt_len, 308 | args.gen_len, 309 | args.gradient_accumulation_steps, 310 | args.seq_len, 311 | args.local_rank, 312 | args.pin_memory, 313 | args.quant_bits, 314 | args.quant_group_size, 315 | args.cache_dir, 316 | ) 317 | -------------------------------------------------------------------------------- /benchmarks/fineinfer/fi-ht.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run Llama2 with huggingface 3 | 4 | Reference: 5 | https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/run_model.py 6 | """ 7 | 8 | import argparse 9 | import copy 10 | import gc 11 | import itertools 12 | import os 13 | import time 14 | 15 | import datasets 16 | import torch 17 | from transformers import ( 18 | AutoConfig, 19 | AutoModelForCausalLM, 20 | AutoTokenizer, 21 | DataCollatorForLanguageModeling, 22 | ) 23 | from peft import LoraConfig, get_peft_model 24 | 25 | import sys 26 | sys.path.append("../..") 27 | from benchmarks import utils 28 | from fineinfer.engine import llm_engine 29 | 30 | 31 | def get_fg_model( 32 | model_name, 33 | adapter_names, 34 | pin_memory, 35 | quant_bits, 36 | quant_group_size, 37 | cache_dir, 38 | ): 39 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 40 | pin_memory = bool(args.pin_memory) 41 | dtype = torch.float16 42 | 43 | if quant_bits == 4: 44 | raise NotImplementedError() 45 | 46 | base_model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, torch_dtype=dtype) 47 | lora_config = LoraConfig( 48 | r=64, 49 | lora_alpha=32, 50 | target_modules=["q_proj", "v_proj"], 51 | lora_dropout=0.01, 52 | bias="none", 53 | task_type="CAUSAL_LM", 54 | ) 55 | model = get_peft_model(base_model, lora_config, adapter_name=adapter_names[0]) 56 | for idx, name in enumerate(adapter_names): 57 | if idx: 58 | model.add_adapter(name, lora_config) 59 | model.to(torch.cuda.current_device()) 60 | 61 | return model 62 | 63 | 64 | def run_ht( 65 | model_name, 66 | adapter_size, 67 | dataset_name, 68 | batch_size, 69 | prompt_len, 70 | gen_len, 71 | gradient_accumulation_steps, 72 | seq_len, 73 | local_rank, 74 | pin_memory, 75 | quant_bits, 76 | quant_group_size, 77 | cache_dir 78 | ): 79 | # Load tokenizer 80 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 81 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', cache_dir=cache_dir) 82 | tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token 83 | 84 | print("load model") 85 | adapter_names = ["adapter_" + str(i) for i in range(adapter_size)] 86 | model = get_fg_model( 87 | model_name, 88 | adapter_names, 89 | pin_memory, 90 | quant_bits, 91 | quant_group_size, 92 | cache_dir, 93 | ) 94 | 95 | prompts = ["Paris is the capital city of"] * batch_size 96 | gen_inputs = tokenizer.batch_encode_plus(prompts, return_tensors="pt", 97 | padding="max_length", max_length=prompt_len) 98 | gen_inputs.to(torch.cuda.current_device()) 99 | prepare_output = llm_engine.prepare_inputs_and_config(self=model, 100 | **gen_inputs, max_new_tokens=gen_len, do_sample=False) 101 | 102 | def prepare_alpaca(sample_raw): 103 | template = { 104 | "description": "A shorter template to experiment with.", 105 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 106 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", 107 | "response_split": "### Response:" 108 | } 109 | if len(sample_raw["input"]): 110 | sample_text = template["prompt_input"].format( 111 | instruction=sample_raw["instruction"], input=sample_raw["input"] 112 | ) 113 | else: 114 | sample_text = template["prompt_no_input"].format( 115 | instruction=sample_raw["instruction"] 116 | ) 117 | if len(sample_raw["output"]): 118 | sample_text += sample_raw["output"] 119 | sample_tokens = tokenizer(sample_text, padding='max_length', truncation=True, max_length=seq_len) 120 | return sample_tokens 121 | 122 | dataset = datasets.load_dataset(dataset_name, cache_dir=cache_dir) 123 | dataset = dataset.map(lambda sample_raw: prepare_alpaca(sample_raw), remove_columns=dataset["train"].column_names) 124 | dataloader = torch.utils.data.DataLoader( 125 | dataset["train"], shuffle=True, collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False), 126 | batch_size=batch_size, pin_memory=pin_memory, 127 | ) 128 | dataloader_iter = itertools.cycle(iter(enumerate(dataloader))) 129 | 130 | optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) 131 | 132 | # Run heterogeneous workload 133 | print(f"benchmark, seq_len = {seq_len}, prompt_len = {prompt_len}, gen_len = {gen_len}") 134 | 135 | total_latency = 60.0 136 | req_gap_bound = 2.0 137 | ht_workloads = utils.get_ht_workloads(total_latency, req_gap_bound) 138 | gen_trials = len(ht_workloads) 139 | cursor = 0 140 | 141 | deferral_bound = 1.0 142 | 143 | gen_outputs = [] 144 | gen_timings = [] 145 | peft_timings = [] 146 | start = time.time() 147 | while time.time() - start < total_latency: 148 | if cursor < gen_trials and ht_workloads[cursor] + deferral_bound <= time.time() - start: 149 | cursor += 1 150 | with torch.no_grad(): 151 | input_ids = copy.deepcopy(prepare_output.input_ids) 152 | model_kwargs = copy.deepcopy(prepare_output.model_kwargs) 153 | model_kwargs = model._get_initial_cache_position(input_ids, model_kwargs) 154 | batch_meta = llm_engine.BatchMeta( 155 | prompt_lens = torch.full(size=(batch_size,), fill_value=prompt_len, 156 | dtype=torch.long, device=torch.cuda.current_device()), 157 | gen_lens = torch.full(size=(batch_size,), fill_value=gen_len, 158 | dtype=torch.long, device=torch.cuda.current_device()), 159 | cur_lens = torch.full(size=(batch_size,), fill_value=prompt_len, 160 | dtype=torch.long, device=torch.cuda.current_device()), 161 | ) 162 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=torch.cuda.current_device()) 163 | 164 | while True: 165 | unfinished_sequences, input_ids, model_kwargs = llm_engine.generate_step( 166 | self=model, 167 | unfinished_sequences=unfinished_sequences, 168 | input_ids=input_ids, 169 | logits_processor=prepare_output.logits_processor, 170 | stopping_criteria=prepare_output.stopping_criteria, 171 | generation_config=prepare_output.generation_config, 172 | **model_kwargs 173 | ) 174 | batch_meta.cur_lens += 1 175 | 176 | this_peer_finished = unfinished_sequences.max() == 0 177 | 178 | if not model._has_unfinished_sequences(this_peer_finished, prepare_output.synced_gpus, input_ids.device): 179 | gen_timings.append(time.time() - start - ht_workloads[int(len(gen_outputs) / batch_size)]) 180 | unfinished_sequences, input_ids, model_kwargs, batch_meta, output_ids = llm_engine.remove_old_request( 181 | unfinished_sequences=unfinished_sequences, 182 | input_ids=input_ids, 183 | model_kwargs=model_kwargs, 184 | batch_meta=batch_meta, 185 | output_ids=gen_outputs, 186 | ) 187 | 188 | if batch_meta.cur_lens.shape[0] == 0: 189 | break 190 | 191 | if cursor < gen_trials and ht_workloads[cursor] <= time.time() - start: 192 | cursor += 1 193 | new_unfinished_sequences = torch.ones(prepare_output.input_ids.shape[0], 194 | dtype=torch.long, device=input_ids.device) 195 | new_input_ids = copy.deepcopy(prepare_output.input_ids) 196 | new_model_kwargs = copy.deepcopy(prepare_output.model_kwargs) 197 | new_model_kwargs = model._get_initial_cache_position(new_input_ids, new_model_kwargs) 198 | 199 | new_unfinished_sequences, new_input_ids, new_model_kwargs = llm_engine.generate_step( 200 | self=model, 201 | unfinished_sequences=new_unfinished_sequences, 202 | input_ids=new_input_ids, 203 | logits_processor=prepare_output.logits_processor, 204 | stopping_criteria=prepare_output.stopping_criteria, 205 | generation_config=prepare_output.generation_config, 206 | **new_model_kwargs 207 | ) 208 | 209 | unfinished_sequences, input_ids, model_kwargs = llm_engine.add_new_request( 210 | unfinished_sequences = unfinished_sequences, 211 | input_ids = input_ids, 212 | model_kwargs = model_kwargs, 213 | new_unfinished_sequences = new_unfinished_sequences, 214 | new_input_ids = new_input_ids, 215 | new_model_kwargs = new_model_kwargs, 216 | ) 217 | 218 | batch_meta.prompt_lens = torch.cat(tensors=(batch_meta.prompt_lens, \ 219 | torch.full(size=(batch_size,), fill_value=prompt_len, \ 220 | dtype=torch.long, device=torch.cuda.current_device())), dim=0) 221 | batch_meta.gen_lens = torch.cat(tensors=(batch_meta.gen_lens, \ 222 | torch.full(size=(batch_size,), fill_value=gen_len, \ 223 | device=torch.cuda.current_device(), dtype=torch.long)), dim=0) 224 | batch_meta.cur_lens = torch.cat(tensors=(batch_meta.cur_lens, \ 225 | torch.full(size=(batch_size,), fill_value=prompt_len + 1, \ 226 | device=torch.cuda.current_device(), dtype=torch.long)), dim=0) 227 | 228 | peft_start_time = time.time() 229 | step, peft_inputs = next(dataloader_iter) 230 | peft_inputs.to(torch.cuda.current_device()) 231 | model.set_adapter(adapter_names[0]) 232 | peft_outputs = model(**peft_inputs) 233 | 234 | loss = peft_outputs.loss 235 | loss.backward() 236 | if step % gradient_accumulation_steps == 0: 237 | optimizer.step() 238 | optimizer.zero_grad() 239 | peft_timings.append(time.time() - peft_start_time) 240 | 241 | total_latency = time.time() - start 242 | 243 | if local_rank != 0: 244 | return 245 | 246 | # Check lengths 247 | gen_input_lens = [len(x) for x in gen_inputs.input_ids] 248 | gen_output_lens = [len(x) for x in gen_outputs] 249 | assert all(x == prompt_len for x in gen_input_lens) 250 | assert all(x == prompt_len + gen_len for x in gen_output_lens) 251 | peft_input_lens = [len(x) for x in peft_inputs.input_ids] 252 | peft_output_lens = [len(x) for x in peft_outputs.logits] 253 | assert all(x == seq_len for x in peft_input_lens) 254 | assert all(x == seq_len for x in peft_output_lens) 255 | 256 | # Log output 257 | print(f"Summary:") 258 | print(f"gen_timings = {gen_timings[-3:]}") 259 | print(f"peft_timings = {peft_timings[-3:]}") 260 | 261 | gen_total_latency = total_latency - sum(peft_timings) 262 | gen_exec_throughput = gen_trials * batch_size * gen_len / gen_total_latency 263 | 264 | peft_trials = len(peft_timings) 265 | peft_total_latency = sum(peft_timings) 266 | peft_throughput = peft_trials * batch_size / peft_total_latency 267 | 268 | gpu_peak_mem = torch.cuda.max_memory_allocated(torch.device("cuda")) 269 | model_size = utils.model_bytes(config) 270 | 271 | log_str = utils.write_ht_benchmark_log( 272 | model_size, 273 | 0, 274 | gpu_peak_mem, 275 | gen_trials, 276 | gen_total_latency, 277 | gen_exec_throughput, 278 | peft_trials, 279 | peft_total_latency, 280 | peft_throughput, 281 | total_latency, 282 | ) 283 | print(log_str) 284 | 285 | 286 | if __name__ == "__main__": 287 | parser = argparse.ArgumentParser() 288 | parser.add_argument("--model_name", "-m", type=str, default="meta-llama/Meta-Llama-3-8B", help="model name or path") 289 | parser.add_argument("--adapter_size", type=int, default=2, help="lora adapters swapping") 290 | parser.add_argument("--dataset_name", type=str, default="yahma/alpaca-cleaned", help="dataset name or path") 291 | parser.add_argument("--batch_size", type=int, default=1) 292 | parser.add_argument("--prompt_len", type=int, default=512, help="prompt length") 293 | parser.add_argument("--gen_len", type=int, default=32, help="number of tokens to generate") 294 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 295 | parser.add_argument("--seq_len", type=int, default=256, help="sequence length") 296 | parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank for distributed inference") 297 | parser.add_argument("--pin_memory", type=int, default=0, help="whether to pinned CPU memory for ZeRO offloading") 298 | parser.add_argument("--quant_bits", type=int, default=16, help="model weight quantization bits; either 4 or 8") 299 | parser.add_argument("--quant_group_size", type=int, default=64, help="model weight quantization group size") 300 | parser.add_argument("--cache_dir", type=str, default="/scratch/yonghe", help="cache dir for model name") 301 | args = parser.parse_args() 302 | 303 | gc.collect() 304 | 305 | run_ht( 306 | args.model_name, 307 | args.adapter_size, 308 | args.dataset_name, 309 | args.batch_size, 310 | args.prompt_len, 311 | args.gen_len, 312 | args.gradient_accumulation_steps, 313 | args.seq_len, 314 | args.local_rank, 315 | args.pin_memory, 316 | args.quant_bits, 317 | args.quant_group_size, 318 | args.cache_dir, 319 | ) 320 | 321 | 322 | -------------------------------------------------------------------------------- /fineinfer/engine/llm_engine.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from dataclasses import dataclass 3 | import inspect 4 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import torch 7 | 8 | from transformers import PreTrainedModel 9 | from transformers.cache_utils import StaticCache 10 | from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging 11 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled 12 | from transformers.generation.configuration_utils import GenerationConfig 13 | from transformers.generation.logits_process import ( 14 | LogitsProcessorList 15 | ) 16 | from transformers.generation.stopping_criteria import ( 17 | StoppingCriteriaList 18 | ) 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | NEED_SETUP_CACHE_CLASSES_MAPPING = { 24 | "static": StaticCache, 25 | } 26 | 27 | 28 | @dataclass 29 | class PrepareOutput(ModelOutput): 30 | input_ids: Optional[torch.LongTensor] = None 31 | logits_processor: Optional[LogitsProcessorList] = None 32 | logits_warper: Optional[LogitsProcessorList] = None 33 | stopping_criteria: Optional[StoppingCriteriaList] = None 34 | generation_config: Optional[GenerationConfig] = None 35 | synced_gpus: Optional[bool] = None 36 | streamer: Optional["BaseStreamer"] = None 37 | model_kwargs: Optional[Dict[str, Any]] = None 38 | 39 | @dataclass 40 | class BatchMeta(ModelOutput): 41 | prompt_lens: Optional[torch.LongTensor] = None 42 | gen_lens: Optional[torch.LongTensor] = None 43 | cur_lens: Optional[torch.LongTensor] = None 44 | 45 | @torch.no_grad() 46 | def prepare_inputs_and_config( 47 | self: PreTrainedModel, 48 | inputs: Optional[torch.Tensor] = None, 49 | generation_config: Optional[GenerationConfig] = None, 50 | logits_processor: Optional[LogitsProcessorList] = None, 51 | stopping_criteria: Optional[StoppingCriteriaList] = None, 52 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 53 | synced_gpus: Optional[bool] = None, 54 | assistant_model: Optional["PreTrainedModel"] = None, 55 | streamer: Optional["BaseStreamer"] = None, 56 | negative_prompt_ids: Optional[torch.Tensor] = None, 57 | negative_prompt_attention_mask: Optional[torch.Tensor] = None, 58 | **kwargs, 59 | ) -> PrepareOutput: 60 | 61 | # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call 62 | self._validate_model_class() 63 | tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria 64 | generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) 65 | self._validate_model_kwargs(model_kwargs.copy()) 66 | 67 | # 2. Set generation parameters if not already defined 68 | if synced_gpus is None: 69 | if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: 70 | synced_gpus = True 71 | else: 72 | synced_gpus = False 73 | 74 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 75 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 76 | 77 | accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) 78 | requires_attention_mask = "encoder_outputs" not in model_kwargs 79 | kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None 80 | 81 | # 3. Define model inputs 82 | inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( 83 | inputs, generation_config.bos_token_id, model_kwargs 84 | ) 85 | batch_size = inputs_tensor.shape[0] 86 | 87 | device = inputs_tensor.device 88 | self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) 89 | 90 | # decoder-only models must use left-padding for batched generation. 91 | if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): 92 | # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` 93 | # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. 94 | if ( 95 | generation_config.pad_token_id is not None 96 | and batch_size > 1 97 | and len(inputs_tensor.shape) == 2 98 | and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 99 | ): 100 | logger.warning( 101 | "A decoder-only architecture is being used, but right-padding was detected! For correct " 102 | "generation results, please set `padding_side='left'` when initializing the tokenizer." 103 | ) 104 | 105 | # 4. Define other model kwargs 106 | # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are 107 | # generating the first new token or not, and we only want to use the embeddings for the first new token) 108 | if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": 109 | model_kwargs["use_cache"] = True 110 | else: 111 | model_kwargs["use_cache"] = generation_config.use_cache 112 | 113 | if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: 114 | model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( 115 | inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id 116 | ) 117 | 118 | if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: 119 | # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` 120 | model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( 121 | inputs_tensor, model_kwargs, model_input_name, generation_config 122 | ) 123 | 124 | # 5. Prepare `input_ids` which will be used for auto-regressive generation 125 | if self.config.is_encoder_decoder: 126 | input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( 127 | batch_size=batch_size, 128 | model_input_name=model_input_name, 129 | model_kwargs=model_kwargs, 130 | decoder_start_token_id=generation_config.decoder_start_token_id, 131 | device=inputs_tensor.device, 132 | ) 133 | else: 134 | input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") 135 | 136 | # 6. Prepare `max_length` depending on other stopping criteria. 137 | input_ids_length = input_ids.shape[-1] 138 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None 139 | has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None 140 | generation_config = self._prepare_generated_length( 141 | generation_config=generation_config, 142 | has_default_max_length=has_default_max_length, 143 | has_default_min_length=has_default_min_length, 144 | model_input_name=model_input_name, 145 | inputs_tensor=inputs_tensor, 146 | input_ids_length=input_ids_length, 147 | ) 148 | 149 | if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: 150 | raise ValueError( 151 | "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " 152 | "Cache object) is unsupported. Please use only one of the two." 153 | ) 154 | elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: 155 | if not self._supports_cache_class: 156 | raise ValueError( 157 | "This model does not support the `cache_implementation` argument. Please check the following " 158 | "issue: https://github.com/huggingface/transformers/issues/28981." 159 | ) 160 | if generation_config.cache_implementation == "static": 161 | if not self._supports_static_cache: 162 | raise ValueError( 163 | "This model does not support `cache_implementation='static'`. Please check the following " 164 | "issue: https://github.com/huggingface/transformers/issues/28981" 165 | ) 166 | model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) 167 | 168 | self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) 169 | 170 | # 7. determine generation mode 171 | generation_mode = generation_config.get_generation_mode(assistant_model) 172 | 173 | if streamer is not None and (generation_config.num_beams > 1): 174 | raise ValueError( 175 | "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." 176 | ) 177 | 178 | if self.device.type != input_ids.device.type: 179 | warnings.warn( 180 | "You are calling .generate() with the `input_ids` being on a device type different" 181 | f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" 182 | f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." 183 | " Please make sure that you have put `input_ids` to the" 184 | f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" 185 | " running `.generate()`.", 186 | UserWarning, 187 | ) 188 | 189 | # 8. prepare distribution pre_processing samplers 190 | prepared_logits_processor = self._get_logits_processor( 191 | generation_config=generation_config, 192 | input_ids_seq_length=input_ids_length, 193 | encoder_input_ids=inputs_tensor, 194 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 195 | logits_processor=logits_processor, 196 | device=inputs_tensor.device, 197 | model_kwargs=model_kwargs, 198 | negative_prompt_ids=negative_prompt_ids, 199 | negative_prompt_attention_mask=negative_prompt_attention_mask, 200 | ) 201 | 202 | # 9. prepare stopping criteria 203 | prepared_stopping_criteria = self._get_stopping_criteria( 204 | generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs 205 | ) 206 | 207 | # 10. go into different generation modes 208 | # 11. prepare logits warper 209 | prepared_logits_warper = ( 210 | self._get_logits_warper(generation_config) if generation_config.do_sample else None 211 | ) 212 | 213 | # 12. expand input_ids with `num_return_sequences` additional sequences per batch 214 | input_ids, model_kwargs = self._expand_inputs_for_generation( 215 | input_ids=input_ids, 216 | expand_size=generation_config.num_return_sequences, 217 | is_encoder_decoder=self.config.is_encoder_decoder, 218 | **model_kwargs, 219 | ) 220 | 221 | return PrepareOutput( 222 | input_ids = input_ids, 223 | logits_processor = prepared_logits_processor, 224 | logits_warper = prepared_logits_warper, 225 | stopping_criteria = prepared_stopping_criteria, 226 | generation_config = generation_config, 227 | synced_gpus = synced_gpus, 228 | streamer = streamer, 229 | model_kwargs = model_kwargs 230 | ) 231 | 232 | @torch.no_grad() 233 | def generate_step( 234 | self: PreTrainedModel, 235 | unfinished_sequences: torch.LongTensor, 236 | input_ids: torch.LongTensor, 237 | logits_processor: Optional[LogitsProcessorList], 238 | stopping_criteria: Optional[StoppingCriteriaList], 239 | generation_config: GenerationConfig, 240 | synced_gpus: bool = False, 241 | streamer: Optional["BaseStreamer"] = None, 242 | logits_warper: Optional[LogitsProcessorList] = None, 243 | **model_kwargs, 244 | ) -> Tuple[torch.LongTensor, torch.LongTensor, Dict[str, Any]]: 245 | #init values 246 | pad_token_id = generation_config.pad_token_id 247 | output_attentions = generation_config.output_attentions 248 | output_hidden_states = generation_config.output_hidden_states 249 | output_scores = generation_config.output_scores 250 | output_logits = generation_config.output_logits 251 | return_dict_in_generate = generation_config.return_dict_in_generate 252 | has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) 253 | do_sample = generation_config.do_sample 254 | if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): 255 | raise ValueError( 256 | "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " 257 | f"{logits_warper})." 258 | ) 259 | 260 | # init attention / hidden states / scores tuples 261 | scores = () if (return_dict_in_generate and output_scores) else None 262 | raw_logits = () if (return_dict_in_generate and output_logits) else None 263 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 264 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 265 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 266 | 267 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 268 | if return_dict_in_generate and self.config.is_encoder_decoder: 269 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 270 | encoder_hidden_states = ( 271 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 272 | ) 273 | 274 | # prepare model inputs 275 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 276 | 277 | # forward pass to get next token 278 | outputs = self( 279 | **model_inputs, 280 | return_dict=True, 281 | output_attentions=output_attentions, 282 | output_hidden_states=output_hidden_states, 283 | ) 284 | 285 | next_token_logits = outputs.logits[:, -1, :] 286 | 287 | # pre-process distribution 288 | next_token_scores = logits_processor(input_ids, next_token_logits) 289 | if do_sample: 290 | next_token_scores = logits_warper(input_ids, next_token_scores) 291 | 292 | # token selection 293 | if do_sample: 294 | probs = nn.functional.softmax(next_token_scores, dim=-1) 295 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 296 | else: 297 | next_tokens = torch.argmax(next_token_scores, dim=-1) 298 | 299 | # finished sentences should have their next token be a padding token 300 | if has_eos_stopping_criteria: 301 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 302 | 303 | # update generated ids, model inputs, and length for next step 304 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 305 | model_kwargs = self._update_model_kwargs_for_generation( 306 | outputs, 307 | model_kwargs, 308 | is_encoder_decoder=self.config.is_encoder_decoder, 309 | ) 310 | 311 | unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) 312 | 313 | return unfinished_sequences, input_ids, model_kwargs 314 | 315 | @torch.no_grad() 316 | def add_new_request( 317 | unfinished_sequences: torch.LongTensor, 318 | input_ids: torch.LongTensor, 319 | model_kwargs: Dict[str, Any], 320 | new_unfinished_sequences: torch.LongTensor, 321 | new_input_ids: torch.LongTensor, 322 | new_model_kwargs: Dict[str, Any], 323 | ) -> Tuple[torch.LongTensor, torch.LongTensor, Dict[str, Any]]: 324 | device = torch.cuda.current_device() 325 | 326 | unfinished_sequences = torch.cat((unfinished_sequences, new_unfinished_sequences), dim=0) 327 | 328 | cur_input_ids = torch.zeros(new_input_ids.shape[0], input_ids.shape[1], dtype=torch.long, device=device) 329 | cur_input_ids[:, -new_input_ids.shape[1]:] = new_input_ids 330 | input_ids = torch.cat((input_ids, cur_input_ids), dim=0) 331 | 332 | cur_attention_mask = torch.zeros(new_input_ids.shape[0], input_ids.shape[1], dtype=torch.long, device=device) 333 | cur_attention_mask[:, -new_input_ids.shape[1]:] = new_model_kwargs['attention_mask'] 334 | model_kwargs['attention_mask'] = torch.cat((model_kwargs['attention_mask'], cur_attention_mask), dim=0) 335 | 336 | if model_kwargs['use_cache']: 337 | model_kwargs['past_key_values'] = list(model_kwargs['past_key_values']) 338 | for layer_idx, past_key_value in enumerate(model_kwargs['past_key_values']): 339 | past_key, past_value = past_key_value 340 | new_past_key, new_past_value = new_model_kwargs['past_key_values'][layer_idx] 341 | 342 | cur_past_key = torch.zeros_like(past_key, dtype=past_key.dtype, device=device)[:new_input_ids.shape[0]] 343 | cur_past_value = torch.zeros_like(past_value, dtype=past_value.dtype, device=device)[:new_input_ids.shape[0]] 344 | cur_past_key[:, :, -new_past_key.shape[2]:, :] = new_past_key 345 | cur_past_value[:, :, -new_past_value.shape[2]:, :] = new_past_value 346 | 347 | new_past_key = torch.cat((past_key, cur_past_key), dim=0) 348 | new_past_value = torch.cat((past_value, cur_past_value), dim=0) 349 | model_kwargs['past_key_values'][layer_idx] = (new_past_key, new_past_value) 350 | model_kwargs['past_key_values'] = tuple(model_kwargs['past_key_values']) 351 | 352 | return unfinished_sequences, input_ids, model_kwargs 353 | 354 | @torch.no_grad() 355 | def remove_old_request( 356 | unfinished_sequences: torch.LongTensor, 357 | input_ids: torch.LongTensor, 358 | model_kwargs: Dict[str, Any], 359 | batch_meta: BatchMeta, 360 | output_ids: List[torch.LongTensor], 361 | ) : 362 | device = torch.cuda.current_device() 363 | masks = torch.ones(unfinished_sequences.shape, dtype=torch.bool, device=device) 364 | for _id in range(len(batch_meta.cur_lens)): 365 | if batch_meta.cur_lens[_id] == (batch_meta.prompt_lens[_id] + batch_meta.gen_lens[_id]): 366 | masks[_id] = False 367 | 368 | unfinished_sequences = unfinished_sequences[masks] 369 | batch_meta.prompt_lens = batch_meta.prompt_lens[masks] 370 | batch_meta.gen_lens = batch_meta.gen_lens[masks] 371 | batch_meta.cur_lens = batch_meta.cur_lens[masks] 372 | 373 | for _id, _mask in enumerate(masks): 374 | if _mask == False: 375 | output_ids.append(input_ids[_id]) 376 | input_ids = input_ids[masks] 377 | 378 | if len(input_ids): 379 | max_sequence_length = torch.max(batch_meta.cur_lens) 380 | input_ids = input_ids[:, -max_sequence_length:] 381 | model_kwargs['attention_mask'] = model_kwargs['attention_mask'][masks][:, -max_sequence_length:] 382 | if model_kwargs['use_cache']: 383 | model_kwargs['past_key_values'] = list(model_kwargs['past_key_values']) 384 | for layer_idx, past_key_value in enumerate(model_kwargs['past_key_values']): 385 | past_key, past_value = past_key_value 386 | model_kwargs['past_key_values'][layer_idx] = (past_key[masks][:, :, 1-max_sequence_length:, :], \ 387 | past_value[masks][:, :, 1-max_sequence_length:, :]) 388 | 389 | model_kwargs['past_key_values'] = tuple(model_kwargs['past_key_values']) 390 | 391 | return unfinished_sequences, input_ids, model_kwargs, batch_meta, output_ids 392 | --------------------------------------------------------------------------------