├── .gitignore ├── CITATION.cff ├── README.md ├── README.zh.md ├── apply_weight_convert.py ├── cli.py ├── cli_llava.py ├── docs ├── LlamaForCausalLM.md ├── LlavaForConditionalGeneration.md ├── LlavaNextForConditionalGeneration.md ├── Qwen2ForCausalLM.md ├── benchamrk_kernels.md ├── benchmark.md ├── benchmark_models.md ├── benchmark_models_history.md └── performance_optimization.md ├── examples ├── benchmark.py ├── evaluator │ ├── __init__.py │ └── eval.py ├── example_chat.py ├── example_eval_acc.py └── example_llava.py ├── generate.py ├── images ├── acc_test.jpg ├── anwser.png ├── benchamrk_result │ ├── fused-attention-batch4-head32-d64-fwd-causal=False.png │ ├── fused-attention-batch8-head64-d64-fwd-causal=False.png │ ├── layer-norm-forward.csv │ ├── layer-norm-forward.png │ ├── matmul-performance-fp16.csv │ ├── matmul-performance-fp16.png │ ├── matmul-performance-fp8.csv │ ├── matmul-performance-fp8.png │ ├── mlp-silu-performance.csv │ ├── mlp-silu-performance.png │ ├── mlp-silu-performance_ret.png │ ├── result.png │ ├── results.html │ ├── rms-norm-forward.csv │ ├── rms-norm-forward.png │ ├── skip_rmsnorm_benchmark.png │ ├── softmax-performance.csv │ ├── softmax-performance.png │ ├── te_benchmark.png │ └── token_embedding_benchmark.png ├── cli_stream.png ├── flashattention_nopad_benchamrk.png ├── flashattentionv2_nopad_benchamrk.png ├── flashattentionv2_nopad_benchamrk2.png ├── flashdecoding_benchamrk.png ├── generate.gif ├── generate_stream.png ├── llava_output.gif ├── llava_output1.gif ├── llava_output2.gif ├── llava_output3.gif ├── llava_test │ ├── WechatIMG205.jpg │ ├── dog.jpeg │ ├── dog2.png │ ├── extreme_ironing.jpg │ ├── graduate.png │ ├── kaali.jpg │ ├── markdown.png │ ├── mask.png │ ├── movie.jpeg │ ├── painting.png │ ├── panda.jpg │ ├── pexels-christian-heitz-285904-842711.jpg │ ├── pexels-francesco-ungaro-1525041.jpg │ ├── pexels-sanaan-3052361.jpg │ ├── superJumbo.png │ ├── taitan.jpg │ └── website.png ├── output.gif └── qwen2.5-3b-output.gif ├── lite_llama ├── __init__.py ├── executor │ ├── __init__.py │ ├── cuda_graph.py │ ├── executor_struct.py │ ├── mem_manager.py │ ├── model_executor.py │ └── req_tokens_manager.py ├── generate.py ├── generate_stream.py ├── generete_with_probs.py ├── inference.py ├── kernels │ ├── __init__.py │ ├── activations.py │ ├── flashattention.py │ ├── flashattention2_nopad.py │ ├── flashattentionv2.py │ ├── flashdecoding.py │ ├── others │ │ ├── activation_layers.py │ │ ├── context_flashattention_nopad.py │ │ ├── fused_linear.py │ │ ├── layernorm.py │ │ ├── rmsnorm_layer.py │ │ ├── rmsnorm_v1.py │ │ ├── rope_orig.py │ │ └── rotary_emb_v1.py │ ├── rope_emb.py │ ├── skip_rmsnorm.py │ ├── softmax_split.py │ ├── swiglu.py │ ├── update_kv_buffer.py │ ├── update_kv_index.py │ └── utils.py ├── llava_generate_stream.py ├── models │ ├── RotaryEmbedding.py │ ├── llama.py │ ├── llava.py │ ├── model_config.py │ ├── qwen2.py │ ├── qwen3.py │ └── utils.py └── utils │ ├── common.py │ ├── config_convert.py │ ├── constants.py │ ├── file_interface.py │ ├── image_process.py │ ├── logger.py │ └── prompt_templates.py ├── requirement.txt └── tests ├── __init__.py ├── kernels ├── fused_mlp_silu.py ├── kernels_benchmark.py ├── kernels_test.py ├── softmax_native.py ├── softmax_split.py ├── test_attention.py ├── test_available_blocks.py ├── test_cuda_graph.py ├── test_flashattentionv2.py ├── test_flashdecoding.py ├── test_flashdecoding_stage1.py ├── test_flashdecoding_stage2.py ├── test_mask.py ├── test_mem_manager.py ├── test_merge_input_ids_with_image_features.py └── test_rope_forward.py ├── models ├── test_LlamaConfig.py ├── test_LlamaForCausalLM.py ├── test_LlamaModel.py ├── test_LlavaConfig.py ├── test_LlavaForConditionalGeneration.py ├── test_LlavaLlama.py ├── test_Qwen2ForCausalLM.py ├── test_get_model_name.py ├── test_gpt2.py ├── test_qwen2.py └── test_transformers.py ├── others ├── test_embedding_merge.py ├── test_image_process.py ├── test_image_token.py ├── test_load_weight.py └── test_standard_mha.py ├── test_torch_matmul.py └── test_torch_rope.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .pytest_cache 3 | lite_llama/kernels/__pycache__ 4 | lite_llama/models/__pycache__ 5 | lite_llama/executor/__pycache__ 6 | 7 | lite_llama/__pycache__ 8 | 9 | images/tmp 10 | test/tmp 11 | test/debug 12 | my_weight 13 | .idea 14 | logs 15 | lite_llama/logs -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, you can cite it as shown below." 3 | title: "Lite Llama" 4 | abstract: "A light llama-like llm inference framework based on the triton kernel." 5 | date-released: 2023-04-23 6 | authors: 7 | - name: "The Litellama AI team" 8 | url: "https://github.com/harleyszhang/lite_llama.git" -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | from lite_llama.utils.prompt_templates import get_prompter 4 | from lite_llama.generate_stream import GenerateStreamText # 导入 GenerateText 类 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils") 8 | 9 | checkpoints_dir = "/home/honggao/lite_llama/my_weight/Qwen3-1.7B" 10 | 11 | def main( 12 | temperature: float = 0.6, 13 | top_p: float = 0.9, 14 | max_seq_len: int = 2048, 15 | max_gpu_num_blocks=40960, 16 | max_gen_len: Optional[int] = 1024, 17 | load_model: bool = True, 18 | compiled_model: bool = False, 19 | triton_weight: bool = True, 20 | ): 21 | device = "cuda" if torch.cuda.is_available() else "cpu" 22 | if max_seq_len <= 1024: 23 | short_prompt = True 24 | else: 25 | short_prompt = False 26 | model_prompter = get_prompter("qwen2", checkpoints_dir, short_prompt) 27 | 28 | # 初始化 LLM 文本生成器 29 | generator = GenerateStreamText( 30 | checkpoints_dir=checkpoints_dir, 31 | tokenizer_path=checkpoints_dir, 32 | max_gpu_num_blocks=max_gpu_num_blocks, 33 | max_seq_len=max_seq_len, 34 | compiled_model=compiled_model, 35 | device=device, 36 | ) 37 | 38 | while True: 39 | prompt = input("请输入您的提示(输入 'exit' 退出):\n") # 提示用户输入 40 | # NOTE: strip() 是字符串方法,用于移除字符串开头和结尾的指定字符(默认为空格或换行符)。 41 | if prompt.strip().lower() == "exit": 42 | print("程序已退出。") 43 | break 44 | 45 | print("\n生成结果: ", end="", flush=True) 46 | 47 | model_prompter.insert_prompt(prompt) 48 | prompts = [model_prompter.model_input] 49 | 50 | # 调用生成函数,开始流式生成 51 | stream = generator.text_completion_stream( 52 | prompts, 53 | temperature=temperature, 54 | top_p=top_p, 55 | max_gen_len=max_gen_len, 56 | ) 57 | 58 | completion = "" # 初始化生成结果 59 | # NOTE: 创建了一个 generator 后,可以通过 for 循环来迭代它 60 | for batch_completions in stream: 61 | new_text = batch_completions[0]["generation"][len(completion) :] 62 | completion = batch_completions[0]["generation"] 63 | print(new_text, end="", flush=True) 64 | print("\n\n==================================\n") 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /cli_llava.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from rich.console import Console 5 | from rich.prompt import Prompt 6 | 7 | import sys, os 8 | import warnings 9 | warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils") 10 | 11 | from lite_llama.llava_generate_stream import LlavaGeneratorStream 12 | from lite_llama.utils.image_process import vis_images 13 | from lite_llama.utils.prompt_templates import get_prompter, get_image_token 14 | 15 | # 模型检查点目录,请根据实际情况修改 16 | checkpoints_dir = "/path/Qwen/llava-v1.5-7b" 17 | 18 | def main( 19 | temperature: float = 0.6, 20 | top_p: float = 0.9, 21 | max_seq_len: int = 2048, 22 | max_gpu_num_blocks=None, 23 | max_gen_len: Optional[int] = 512, 24 | compiled_model: bool = False, 25 | ): 26 | """ 27 | 主函数,处理用户输入并生成响应。 28 | 29 | Args: 30 | temperature (float, optional): 生成文本的温度。默认值为 0.6。 31 | top_p (float, optional): 生成文本的top-p值。默认值为 0.9。 32 | max_seq_len (int, optional): 最大序列长度。默认值为 2048。 33 | max_gpu_num_blocks: 用户自行设置的最大可用 blocks(tokens), 如果设置该值, kv cache 内存管理器的最大可用内存-tokens 由该值决定。 34 | max_gen_len (Optional[int], optional): 生成文本的最大长度。默认值为 512。 35 | load_model (bool, optional): 是否加载模型。默认值为True。 36 | compiled_model (bool, optional): 是否使用编译模型。默认值为True。 37 | triton_weight (bool, optional): 是否使用Triton权重。默认值为True。 38 | """ 39 | console = Console() 40 | device = "cuda" if torch.cuda.is_available() else "cpu" 41 | if max_seq_len <= 1024: 42 | short_prompt = True 43 | else: 44 | short_prompt = False 45 | 46 | model_prompter = get_prompter("llama", checkpoints_dir, short_prompt) 47 | 48 | # 初始化多模态模型文本生成器 49 | try: 50 | generator = LlavaGeneratorStream( 51 | checkpoints_dir=checkpoints_dir, 52 | tokenizer_path=checkpoints_dir, 53 | max_gpu_num_blocks=max_gpu_num_blocks, 54 | max_seq_len=max_seq_len, 55 | compiled_model=compiled_model, 56 | device=device, 57 | ) 58 | except Exception as e: 59 | console.print(f"[red]模型加载失败: {e}[/red]") 60 | sys.exit(1) 61 | 62 | while True: 63 | console.print( 64 | "[bold green]请输入图片路径或URL (输入 'exit' 退出):[/bold green]" 65 | ) # 获取用户输入的图片路径或URL 66 | while True: # 循环判断输入图像路径是否成功, 成功则跳出循环 67 | image_input = Prompt.ask("图片") 68 | if os.path.isfile(image_input): 69 | break 70 | elif image_input.strip().lower() == "exit": 71 | break 72 | else: 73 | print(f"错误:'{image_input}' 不是有效的文件路径!") 74 | image_input = Prompt.ask("图片") 75 | 76 | image_input = image_input.strip() 77 | if image_input.lower() == "exit": 78 | break 79 | 80 | image_items = [image_input] # 准备image_items列表 81 | image_num = len(image_items) # 计算输入图片数量 82 | vis_images(image_items) # 在终端中显示图片 83 | 84 | # console.print("\n[bold blue]请输入提示词(输入 'exit' 退出):[/bold blue]") # 获取用户的提示词 85 | input_prompt = Prompt.ask("[bold green]提示词[/bold green]").strip() 86 | if input_prompt.lower() == "exit": 87 | break 88 | 89 | image_token = get_image_token() 90 | model_prompter.insert_prompt(image_token * image_num + input_prompt) 91 | 92 | # prompts = "USER: \nWhat's the content of the image? ASSISTANT:" 93 | prompts = [model_prompter.model_input] # 准备提示词,替换标记 94 | 95 | # 调用生成器生成文本 96 | try: 97 | stream = generator.text_completion_stream( 98 | prompts, 99 | image_items, 100 | temperature=temperature, 101 | top_p=top_p, 102 | max_gen_len=max_gen_len, 103 | ) 104 | except Exception as e: 105 | console.print(f"[red]文本生成失败: {e}[/red]") 106 | continue 107 | 108 | completion = "" # 初始化生成结果 109 | console.print("ASSISTANT: ", end="") 110 | 111 | for batch_completions in stream: 112 | next_text = batch_completions[0]["generation"][len(completion) :] 113 | completion = batch_completions[0]["generation"] 114 | print(f"\033[91m{next_text}\033[0m", end="", flush=True) # 红色文本 115 | 116 | console.print("\n[bold green]==================================[/bold green]\n") 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /docs/benchmark.md: -------------------------------------------------------------------------------- 1 | ## benchmark Performance Test 2 | 3 | ### Llama-3.2-1B Model Performance Comparison Test 4 | 5 | Virtaicloud environment for the `B1.small` equivalent to `1/4` of `3090`. Running the performance test against `python benchmark.py`, lite_llama runs at up to `4x` times the speed of transformers. `batch_size = 16` for the prompter, and `max_gen_len = 1900` for the benchmark performance test results: 6 | 7 | ```bash 8 | lite_llama inference time: 67.8760 s 9 | Transformers inference time: 131.8708 s 10 | lite_llama throughput: 411.04 tokens/s 11 | Transformers throughput: 104.70 tokens/s 12 | lite_llama per token latency: 2.432831 ms/token 13 | Transformers per token latency: 9.551007 ms/token 14 | ``` 15 | 16 | ### Llama-3.2-3B Model Performance Comparison Test 17 | 18 | Virtaicloud environment for the `B1.big` equivalent to `3090`. Running the performance test against `python benchmark.py`, lite_llama runs up to `4x` times faster than transformers. Benchmark performance results with `max_gen_len = 1900`: 19 | 20 | ```bash 21 | lite_llama inference time: 31.3463 s 22 | Transformers inference time: 69.1433 s 23 | lite_llama throughput: 730.45 tokens/s 24 | Transformers throughput: 183.95 tokens/s 25 | lite_llama per token latency: 1.369015 ms/token 26 | Transformers per token latency: 5.436221 ms/token 27 | ``` 28 | 29 | For more performance test results refer to the documentation [benchmark_models](./docs/benchmark_models.md)(More model performance test results to be updated) 30 | -------------------------------------------------------------------------------- /docs/performance_optimization.md: -------------------------------------------------------------------------------- 1 | ## Performance Test 2 | 3 | Input prompts: 4 | 5 | ```bash 6 | prompts: List[str] = [ 7 | # For these prompts, the expected answer is the natural continuation of the prompt 8 | "I believe the meaning of life is", 9 | "Simply put, the theory of relativity states that ", 10 | """A brief message congratulating the team on the launch: 11 | 12 | Hi everyone, 13 | 14 | I just """, 15 | # Few shot prompt (providing a few examples before asking model to complete more); 16 | "Roosevelt was the first president of the United States, he has", 17 | ] 18 | ``` 19 | 20 | 1. After optimizing the decode phase with cuda graph, the time for a single decode phase is `8.2402` ms, compared to `17.2241` ms before using cuda graph, which is a performance improvement of 2x times, which is almost the same as the performance improvement after applying cuda graph to vllm. 21 | 22 | ```bash 23 | INFO: After apply cuda graph, Decode inference time: 8.2402 ms 24 | INFO: Before apply cuda graph, Decode inference time: 17.2241 ms 25 | ``` 26 | 27 | 2. On the basis of the previous, flashattention has been used to take off the original standard attention. 28 | 29 | > flashattention1 is more helpful in training the model, and its speedup effect is limited when the prompt words are very short. The decode phase of inference should be flash-decoding. 30 | 31 | ```bash 32 | INFO: input tokens shape is torch.Size([8, 115]) 33 | # Before using flashattention 34 | INFO:lite_llama.generate:Batch inference time: 3152.0476 ms 35 | INFO:lite_llama.generate:Tokens per second: 97.71 tokens/s 36 | # After using flashattention 37 | INFO:lite_llama.generate:Batch inference time: 2681.3823 ms 38 | INFO:lite_llama.generate:Tokens per second: 114.87 tokens/s 39 | ``` 40 | 41 | 3. Continue optimization by upgrading `flashattention` to `flashattention2` to reduce some computation. 42 | 43 | ```bash 44 | INFO:lite_llama.generate:Batch inference time: 2103.0737 ms 45 | INFO:lite_llama.generate:Tokens per second: 146.45 tokens/s 46 | ``` 47 | 48 | 4. Further optimized by using `flashdecoding` in the decoding phase to improve the parallelism of attention computation during decoding, thereby fully leveraging the GPU's computational power. 49 | 50 | ```bash 51 | INFO:lite_llama.generate:Decode stage Batch inference time: 1641.4178 ms 52 | INFO:lite_llama.generate:Decode stage tokens per second : 187.64 tokens/s 53 | ``` 54 | 55 | 5. Further optimization includes efficient dynamic management of the KV cache (similar to TokenAttention), addressing issues of memory waste and inefficient allocation in KV cache usage. 56 | 57 | ```bash 58 | INFO:lite_llama.generate:Decode stage Batch inference time: 1413.9111 ms 59 | INFO:lite_llama.generate:Decode stage tokens per second : 217.84 tokens/s 60 | ``` 61 | 62 | 6. A simple optimization is to replace the `repeat_kv` function with `GQA_KV_heads_index`. 63 | 64 | 7. A common and straightforward optimization is the fusion of the key and value linear layers. 65 | 66 | 8. A commonly used optimization is operator fusion: fusing the residual connection's skip operation with the `rmsnorm` operator to form a new `skip_rmsnorm` operator. 67 | 68 | 9. Refactored and optimized the `MHA` module, improving the `context_attention` and token_attention kernels to support `Nopad attention` as well as dynamic allocation and management of the `kv cache`. 69 | 70 | - token_attention now supports directly passing kv_cache indices and the actual sequence length seq_len, reducing `concat` and `view` operations within the `MHA` module and enabling `Nopad` token_attention. 71 | - During each prefill/decode step, the number of kv_cache indices is dynamically allocated based on the actual prompt length, instead of pre-allocating a continuous kv_cache space for `(max(prompt_len) + max_gen_len) * batch_size` tokens before inference. 72 | -------------------------------------------------------------------------------- /examples/evaluator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/examples/evaluator/__init__.py -------------------------------------------------------------------------------- /examples/example_chat.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import torch 3 | 4 | import sys, os, time 5 | 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 7 | from lite_llama.generate import GenerateText 8 | from lite_llama.generate_stream import GenerateStreamText 9 | import warnings 10 | 11 | warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils") 12 | 13 | checkpoints_dir = ( 14 | "/path/lite_llama/my_weight/Qwen2.5-3B" # 改成自己的存放模型路径 15 | ) 16 | 17 | 18 | def cli_generate_stream( 19 | temperature: float = 0.6, 20 | top_p: float = 0.9, 21 | max_seq_len: int = 512, 22 | max_gpu_num_blocks=None, 23 | max_gen_len: Optional[int] = 128, 24 | ): 25 | """ 26 | 程序的入口点,用于使用预训练模型生成文本。 27 | 28 | 参数: 29 | temperature (float): 控制生成随机性的温度值。 30 | top_p (float): 控制生成多样性的 top-p 采样参数。 31 | max_seq_len (int): 输入提示的最大序列长度。 32 | max_batch_size (int): 生成序列的最大批量大小。 33 | max_gen_len (int): 生成序列的最大长度。 34 | 35 | """ 36 | device = "cuda" if torch.cuda.is_available() else "cpu" 37 | 38 | generator = GenerateStreamText( 39 | checkpoints_dir=checkpoints_dir, 40 | tokenizer_path=checkpoints_dir, 41 | max_gpu_num_blocks=max_gpu_num_blocks, 42 | max_seq_len=max_seq_len, 43 | compiled_model=True, 44 | device=device, 45 | ) 46 | 47 | prompts: List[str] = [ 48 | "I believe the meaning of life is", 49 | "Simply put, the theory of relativity states that ", 50 | """A brief message congratulating the team on the launch: 51 | 52 | Hi everyone, 53 | 54 | I just """, 55 | "Roosevelt was the first president of the United States, he has", 56 | "Here are some tips and resources to help you get started:", 57 | ] 58 | 59 | for idx, prompt in enumerate(prompts): 60 | print(f"Prompt {idx}: {prompt}") 61 | print("Generated output:", end="", flush=True) 62 | 63 | stream = generator.text_completion_stream( 64 | [prompt], 65 | temperature=temperature, 66 | top_p=top_p, 67 | max_gen_len=max_gen_len, 68 | ) 69 | 70 | # 初始化生成结果 71 | completion = "" 72 | for batch_completions in stream: 73 | new_text = batch_completions[0]["generation"][len(completion) :] 74 | completion = batch_completions[0]["generation"] 75 | print(new_text, end="", flush=True) 76 | print("\n\n==================================\n") 77 | 78 | 79 | def cli_generate( 80 | temperature: float = 0.6, 81 | top_p: float = 0.9, 82 | max_seq_len: int = 512, 83 | max_gen_len: Optional[int] = 64, 84 | ): 85 | """ 86 | Entry point of the program for generating text using a pretrained model. 87 | 88 | Args: 89 | ckpt_dir (str): The directory containing checkpoint files for the pretrained model. 90 | tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding. 91 | temperature (float, optional): The temperature value for controlling randomness in generation. 92 | Defaults to 0.6. 93 | top_p (float, optional): The top-p sampling parameter for controlling diversity in generation. 94 | Defaults to 0.9. 95 | max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512. 96 | max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8. 97 | max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be 98 | set to the model's max sequence length. Defaults to None. 99 | """ 100 | device = "cuda" if torch.cuda.is_available() else "cpu" 101 | 102 | generator = GenerateText( 103 | checkpoints_dir=checkpoints_dir, 104 | tokenizer_path=checkpoints_dir, 105 | max_seq_len=max_seq_len, 106 | compiled_model=True, 107 | device=device, 108 | ) 109 | 110 | prompts: List[str] = [ 111 | # For these prompts, the expected answer is the natural continuation of the prompt 112 | "I believe the meaning of life is", 113 | "Simply put, the theory of relativity states that ", 114 | """A brief message congratulating the team on the launch: 115 | 116 | Hi everyone, 117 | 118 | I just """, 119 | "Roosevelt was the first president of the United States, he has", 120 | ] 121 | 122 | results = generator.text_completion( 123 | prompts, 124 | temperature=temperature, 125 | top_p=top_p, 126 | max_gen_len=max_gen_len, 127 | ) 128 | 129 | for prompt, result in zip(prompts, results): 130 | print(prompt) 131 | print(f"> {result['generation']}") 132 | print("\n==================================\n") 133 | 134 | 135 | def main(stream_flag=False): 136 | cli_generate_stream() if stream_flag else cli_generate() 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /examples/example_eval_acc.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils") 4 | import torch 5 | 6 | from .evaluator.eval import * 7 | 8 | import sys, os 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))) 10 | from lite_llama.inference import Inference 11 | 12 | class EvaluatorAccuracy(object): 13 | def __init__(self, test_data_path, custom_checkpoints_dir, data_batch=10): 14 | self.custom_checkpoints_dir = custom_checkpoints_dir 15 | self.test_data_path = test_data_path 16 | self.data_batch = data_batch 17 | 18 | # init inference 19 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 20 | 21 | self.model_inference = Inference( 22 | temperature=0.7, 23 | top_p=0.8, 24 | max_seq_len=2048, 25 | max_gen_len=1900, 26 | lite_llama_ckpt_dir=self.custom_checkpoints_dir, 27 | device=self.device, 28 | ) 29 | 30 | def process( 31 | self, 32 | ): 33 | if "hotpot" in self.test_data_path.lower(): 34 | data_obj = HotpotQA(self.test_data_path, self.data_batch) 35 | 36 | elif "hellaswag" in self.test_data_path.lower(): 37 | data_obj = HellaSwag(self.test_data_path, self.data_batch) 38 | 39 | try: 40 | assert data_obj is not None, "data_obj has not been created" 41 | except NameError: 42 | raise AssertionError("Dataset may not be supported") 43 | 44 | ground_truth, prompts, options = data_obj.parse_data() 45 | 46 | predictions = self.model_inference.process(prompts) 47 | 48 | if data_obj.data_type == "mcq": 49 | data_obj.evaluate(predictions, ground_truth, options) 50 | else: 51 | data_obj.evaluate(predictions, ground_truth) 52 | 53 | 54 | if __name__ == "__main__": 55 | ea = EvaluatorAccuracy( 56 | "/path_to/hotpot_dev_distractor_v1.json", "/path_to/Llama-3.2-3B-Instruct" 57 | ) 58 | ea.process() 59 | -------------------------------------------------------------------------------- /examples/example_llava.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | import sys, os 5 | 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 7 | from lite_llama.llava_generate_stream import ( 8 | LlavaGeneratorStream, 9 | ) # 导入 GenerateText 类 10 | 11 | checkpoints_dir = "/gemini/code/lite_llama/my_weight/llava-1.5-7b-hf" 12 | 13 | 14 | def main( 15 | temperature: float = 0.6, 16 | top_p: float = 0.9, 17 | max_seq_len: int = 2048, 18 | max_gpu_num_blocks=None, 19 | max_gen_len: Optional[int] = 64, 20 | load_model: bool = True, 21 | compiled_model: bool = True, 22 | triton_weight: bool = True, 23 | ): 24 | device = "cuda" if torch.cuda.is_available() else "cpu" 25 | 26 | generator = LlavaGeneratorStream( 27 | checkpoints_dir=checkpoints_dir, 28 | tokenizer_path=checkpoints_dir, 29 | max_gpu_num_blocks=max_gpu_num_blocks, 30 | max_seq_len=max_seq_len, 31 | compiled_model=compiled_model, 32 | device=device, 33 | ) 34 | 35 | # 调用生成函数,开始流式生成 36 | prompts = ["USER: \nWhat's the content of the image? ASSISTANT:"] 37 | image_items = ["https://www.ilankelman.org/stopsigns/australia.jpg"] 38 | 39 | stream = generator.text_completion_stream( 40 | prompts, 41 | image_items, 42 | temperature=temperature, 43 | top_p=top_p, 44 | max_gen_len=max_gen_len, 45 | ) 46 | 47 | completion = "" # 初始化生成结果 48 | # NOTE: 创建了一个 generator 后,可以通过 for 循环来迭代它 49 | for batch_completions in stream: 50 | new_text = batch_completions[0]["generation"][len(completion) :] 51 | completion = batch_completions[0]["generation"] 52 | print(new_text, end=" ", flush=True) 53 | print("\n\n==================================\n") 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils") 4 | from lite_llama.utils.common import get_gpu_memory, detect_device, count_tokens, get_model_type 5 | from lite_llama.utils.prompt_templates import get_prompter 6 | from lite_llama.generate_stream import GenerateStreamText # 导入 GenerateText 类 7 | import warnings 8 | 9 | import sys, os, time 10 | from pathlib import Path 11 | 12 | # support running without installing as a package 13 | wd = Path(__file__).parent.parent.resolve() 14 | sys.path.append(str(wd)) 15 | import psutil 16 | 17 | process = psutil.Process(os.getpid()) 18 | 19 | 20 | def report_resource_usage(ram_before, vram_before, gpu_type) -> None: 21 | end_time = time.time() 22 | ram_after = process.memory_info().rss 23 | vram_after = get_gpu_memory(gpu_type) 24 | 25 | ram_used = (ram_after - ram_before) / (1024**3) # Bytes to GB 26 | 27 | if vram_before is not None and vram_after is not None: 28 | vram_used = vram_after - vram_before 29 | vram_text = f"{vram_used:.2f} GB" 30 | else: 31 | vram_text = "Unavailable" 32 | 33 | print(f"CPU RAM Used: {ram_used:.2f} GB") 34 | print(f"GPU VRAM Used: {vram_text}") 35 | 36 | 37 | def main( 38 | prompt: str = "Hello, my name is", 39 | *, 40 | temperature: float = 0.6, 41 | top_p: float = 0.9, 42 | max_seq_len: int = 2048, 43 | max_gpu_num_blocks=40960, 44 | max_gen_len: Optional[int] = 1024, 45 | load_model: bool = True, 46 | compiled_model: bool = False, 47 | triton_weight: bool = True, 48 | gpu_type: str = "nvidia", 49 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/"), 50 | quantize: Optional[str] = None, 51 | ): 52 | device = "cuda" if torch.cuda.is_available() else "cpu" 53 | assert checkpoint_path.is_dir(), checkpoint_path 54 | checkpoint_path = str(checkpoint_path) 55 | 56 | if max_seq_len <= 1024: 57 | short_prompt = True 58 | else: 59 | short_prompt = False 60 | model_prompter = get_prompter( 61 | get_model_type(checkpoint_path), checkpoint_path, short_prompt 62 | ) 63 | # Start resource tracking 64 | ram_before = process.memory_info().rss 65 | 66 | gpu_type = detect_device() 67 | vram_before = get_gpu_memory(gpu_type) 68 | # Init LLM generator 69 | start = time.perf_counter() 70 | 71 | generator = GenerateStreamText( 72 | checkpoints_dir=checkpoint_path, 73 | tokenizer_path=checkpoint_path, 74 | max_gpu_num_blocks=max_gpu_num_blocks, 75 | max_seq_len=max_seq_len, 76 | load_model=load_model, 77 | compiled_model=compiled_model, 78 | triton_weight=triton_weight, 79 | device=device, 80 | ) 81 | 82 | model_prompter.insert_prompt(prompt) 83 | prompts = [model_prompter.model_input] 84 | # Call the generation function and start the stream generation 85 | stream = generator.text_completion_stream( 86 | prompts, 87 | temperature=temperature, 88 | top_p=top_p, 89 | max_gen_len=max_gen_len, 90 | ) 91 | end = time.perf_counter() 92 | 93 | completion = "" # Initialize to generate the result 94 | # NOTE: After creating a generator, it can be iterated through a for loop 95 | text_msg = "" 96 | for batch_completions in stream: 97 | new_text = batch_completions[0]["generation"][len(completion) :] 98 | completion = batch_completions[0]["generation"] 99 | print(new_text, end="", flush=True) 100 | text_msg += new_text 101 | 102 | print("\n\n==================================\n") 103 | print( 104 | f"Time for inference: {(end - start):.2f} sec, {count_tokens(text_msg, generator.tokenizer) / (end - start):.2f} tokens/sec" 105 | ) 106 | 107 | # Report resource usage 108 | report_resource_usage(ram_before, vram_before, gpu_type) 109 | 110 | 111 | if __name__ == "__main__": 112 | from jsonargparse import CLI 113 | 114 | torch.set_float32_matmul_precision("high") 115 | CLI(main) 116 | -------------------------------------------------------------------------------- /images/acc_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/acc_test.jpg -------------------------------------------------------------------------------- /images/anwser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/anwser.png -------------------------------------------------------------------------------- /images/benchamrk_result/fused-attention-batch4-head32-d64-fwd-causal=False.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/fused-attention-batch4-head32-d64-fwd-causal=False.png -------------------------------------------------------------------------------- /images/benchamrk_result/fused-attention-batch8-head64-d64-fwd-causal=False.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/fused-attention-batch8-head64-d64-fwd-causal=False.png -------------------------------------------------------------------------------- /images/benchamrk_result/layer-norm-forward.csv: -------------------------------------------------------------------------------- 1 | N,Triton,Torch 2 | 1024.000000,655.360017,630.153853 3 | 1536.000000,702.171410,599.414644 4 | 2048.000000,712.347810,585.142849 5 | 2560.000000,731.428561,568.888869 6 | 3072.000000,712.347810,558.545450 7 | 3584.000000,699.317085,551.384622 8 | 4096.000000,668.734716,555.389814 9 | 4608.000000,624.813540,546.133354 10 | 5120.000000,585.142842,546.133307 11 | 5632.000000,570.329131,546.133317 12 | 6144.000000,537.180338,543.116035 13 | 6656.000000,543.346957,543.346957 14 | 7168.000000,516.612607,543.545008 15 | 7680.000000,501.551014,543.716805 16 | 8192.000000,486.352478,541.619847 17 | 8704.000000,470.486466,544.000001 18 | 9216.000000,462.244495,546.133354 19 | 9728.000000,455.111093,544.223786 20 | 10240.000000,444.010856,544.318950 21 | 10752.000000,437.740464,546.133312 22 | 11264.000000,436.377722,544.483403 23 | 11776.000000,432.146787,546.133321 24 | 12288.000000,430.214454,541.619836 25 | 12800.000000,429.350110,546.133329 26 | 13312.000000,428.555331,546.133332 27 | 13824.000000,428.651187,546.133335 28 | 14336.000000,426.349425,547.436768 29 | 14848.000000,427.280587,546.133340 30 | 15360.000000,427.408686,546.133343 31 | 15872.000000,426.810091,546.133345 32 | -------------------------------------------------------------------------------- /images/benchamrk_result/layer-norm-forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/layer-norm-forward.png -------------------------------------------------------------------------------- /images/benchamrk_result/matmul-performance-fp16.csv: -------------------------------------------------------------------------------- 1 | M,N,K,cuBLAS,Triton 2 | 256.000000,256.000000,256.000000,4.681143,1.927529 3 | 384.000000,384.000000,384.000000,11.059200,4.808348 4 | 512.000000,512.000000,512.000000,21.845333,9.362286 5 | 640.000000,640.000000,640.000000,20.480001,15.515151 6 | 768.000000,768.000000,768.000000,40.215272,22.685539 7 | 896.000000,896.000000,896.000000,34.266537,31.220622 8 | 1024.000000,1024.000000,1024.000000,44.620254,41.120628 9 | 1152.000000,1152.000000,1152.000000,56.339321,52.385683 10 | 1280.000000,1280.000000,1280.000000,60.235293,34.711863 11 | 1408.000000,1408.000000,1408.000000,42.261832,46.998068 12 | 1536.000000,1536.000000,1536.000000,63.195428,56.173715 13 | 1664.000000,1664.000000,1664.000000,50.555685,45.220663 14 | 1792.000000,1792.000000,1792.000000,66.114259,52.520673 15 | 1920.000000,1920.000000,1920.000000,66.782607,60.104346 16 | 2048.000000,2048.000000,2048.000000,66.313105,52.265470 17 | 2176.000000,2176.000000,2176.000000,66.196213,58.669529 18 | 2304.000000,2304.000000,2304.000000,73.728002,66.725901 19 | 2432.000000,2432.000000,2432.000000,57.102567,60.679187 20 | 2560.000000,2560.000000,2560.000000,75.851852,66.873469 21 | 2688.000000,2688.000000,2688.000000,70.116556,61.982117 22 | 2816.000000,2816.000000,2816.000000,75.982940,67.829250 23 | 2944.000000,2944.000000,2944.000000,76.670818,60.261223 24 | 3072.000000,3072.000000,3072.000000,73.824123,61.413346 25 | 3200.000000,3200.000000,3200.000000,73.903001,13.155190 26 | 3328.000000,3328.000000,3328.000000,69.624080,13.349026 27 | 3456.000000,3456.000000,3456.000000,16.150154,14.968729 28 | 3584.000000,3584.000000,3584.000000,63.164317,15.901564 29 | 3712.000000,3712.000000,3712.000000,17.018287,16.897386 30 | 3840.000000,3840.000000,3840.000000,17.979515,18.730120 31 | 3968.000000,3968.000000,3968.000000,66.498058,19.870369 32 | 4096.000000,4096.000000,4096.000000,21.360344,15.941295 33 | -------------------------------------------------------------------------------- /images/benchamrk_result/matmul-performance-fp16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/matmul-performance-fp16.png -------------------------------------------------------------------------------- /images/benchamrk_result/matmul-performance-fp8.csv: -------------------------------------------------------------------------------- 1 | M,N,K,Triton 2 | 256.000000,256.000000,256.000000,1.820444 3 | 384.000000,384.000000,384.000000,4.608000 4 | 512.000000,512.000000,512.000000,8.738134 5 | 640.000000,640.000000,640.000000,14.222222 6 | 768.000000,768.000000,768.000000,20.575256 7 | 896.000000,896.000000,896.000000,29.269332 8 | 1024.000000,1024.000000,1024.000000,38.836148 9 | 1152.000000,1152.000000,1152.000000,49.766401 10 | 1280.000000,1280.000000,1280.000000,36.247787 11 | 1408.000000,1408.000000,1408.000000,44.323380 12 | 1536.000000,1536.000000,1536.000000,52.428802 13 | 1664.000000,1664.000000,1664.000000,43.056994 14 | 1792.000000,1792.000000,1792.000000,49.731964 15 | 1920.000000,1920.000000,1920.000000,57.123968 16 | 2048.000000,2048.000000,2048.000000,50.231184 17 | 2176.000000,2176.000000,2176.000000,56.527100 18 | 2304.000000,2304.000000,2304.000000,63.195429 19 | 2432.000000,2432.000000,2432.000000,55.853806 20 | 2560.000000,2560.000000,2560.000000,63.258689 21 | 2688.000000,2688.000000,2688.000000,57.387374 22 | 2816.000000,2816.000000,2816.000000,64.375217 23 | 2944.000000,2944.000000,2944.000000,11.482957 24 | 3072.000000,3072.000000,3072.000000,11.473780 25 | 3200.000000,3200.000000,3200.000000,11.823388 26 | 3328.000000,3328.000000,3328.000000,15.349957 27 | 3456.000000,3456.000000,3456.000000,14.168993 28 | 3584.000000,3584.000000,3584.000000,15.659246 29 | 3712.000000,3712.000000,3712.000000,16.594243 30 | 3840.000000,3840.000000,3840.000000,66.163325 31 | 3968.000000,3968.000000,3968.000000,64.975472 32 | 4096.000000,4096.000000,4096.000000,20.494386 33 | -------------------------------------------------------------------------------- /images/benchamrk_result/matmul-performance-fp8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/matmul-performance-fp8.png -------------------------------------------------------------------------------- /images/benchamrk_result/mlp-silu-performance.csv: -------------------------------------------------------------------------------- 1 | N,Torch_mlp_silu,Torch_fused_mlp,Triton_mlp_silu,Triton_torch_mlp_silu 2 | 32.000000,0.342901,1.977925,0.398311,1.774257 3 | 288.000000,0.982247,1.246223,1.090688,1.375112 4 | 544.000000,1.537887,1.345078,1.465214,1.300158 5 | 800.000000,1.588540,1.546321,1.293116,1.419473 6 | 1056.000000,1.841555,1.812126,1.473941,1.510595 7 | 1312.000000,1.830029,1.910771,1.389752,1.574153 8 | 1568.000000,1.783519,1.860142,1.384919,1.524789 9 | 1824.000000,1.611358,1.574159,1.247850,1.486012 10 | -------------------------------------------------------------------------------- /images/benchamrk_result/mlp-silu-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/mlp-silu-performance.png -------------------------------------------------------------------------------- /images/benchamrk_result/mlp-silu-performance_ret.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/mlp-silu-performance_ret.png -------------------------------------------------------------------------------- /images/benchamrk_result/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/result.png -------------------------------------------------------------------------------- /images/benchamrk_result/results.html: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/benchamrk_result/rms-norm-forward.csv: -------------------------------------------------------------------------------- 1 | N,Triton,Triton_rmsnorm_fwd,Torch_Py 2 | 512.000000,585.142849,585.142849,79.533982 3 | 1024.000000,655.360017,682.666643,128.000000 4 | 1536.000000,702.171410,722.823517,142.057809 5 | 2048.000000,728.177767,762.046512,151.004613 6 | 2560.000000,718.596477,772.830192,155.741448 7 | 3072.000000,692.281669,792.774204,159.341904 8 | 3584.000000,690.891575,796.444416,159.732591 9 | 4096.000000,668.734716,799.219525,165.494954 10 | 4608.000000,715.805836,801.391287,163.476718 11 | 5120.000000,694.237267,803.137269,168.907218 12 | 5632.000000,677.533865,811.477150,170.344042 13 | 6144.000000,664.216187,812.429770,171.279977 14 | 6656.000000,653.349676,812.946606,173.019165 15 | 7168.000000,644.314599,813.390089,161.539508 16 | 7680.000000,636.683938,819.199961,29.194582 17 | 8192.000000,627.138740,819.200021,135.685297 18 | 8704.000000,773.688877,819.200003,32.386976 19 | 9216.000000,764.020726,819.199988,28.033460 20 | 9728.000000,759.258522,823.534398,29.386951 21 | 10240.000000,751.559629,823.316575,30.217633 22 | 10752.000000,741.517223,823.365820,31.513465 23 | 11264.000000,732.617855,822.940661,32.461096 24 | 11776.000000,727.474923,826.385949,33.681801 25 | 12288.000000,714.938186,826.084057,35.011665 26 | 12800.000000,711.111086,829.149834,36.428318 27 | 13312.000000,702.943876,825.550434,36.894509 28 | 13824.000000,695.547157,828.404507,37.789851 29 | 14336.000000,688.816809,831.072445,38.557069 30 | 14848.000000,680.710584,827.763059,39.757009 31 | 15360.000000,673.343873,827.474771,40.682005 32 | 15872.000000,668.294715,829.908492,41.583755 33 | 16384.000000,661.979817,829.569645,42.673612 34 | 16896.000000,819.199976,824.195135,43.490349 35 | 17408.000000,814.372108,826.492569,44.635897 36 | 17920.000000,805.534704,826.282436,45.489451 37 | 18432.000000,801.391287,826.264868,46.472109 38 | 18944.000000,793.465972,828.153021,48.218900 39 | 19456.000000,790.091366,830.122659,48.349267 40 | 19968.000000,784.982774,827.689120,44.382580 41 | 20480.000000,780.190482,827.474771,48.995217 42 | 20992.000000,768.585795,827.270923,31.812086 43 | 21504.000000,762.891353,827.076932,30.868833 44 | 22016.000000,754.295495,826.891993,31.347868 45 | 22528.000000,746.269135,828.616102,34.959314 46 | 23040.000000,744.727294,826.547102,32.496475 47 | 23552.000000,740.337953,828.202196,32.925470 48 | 24064.000000,736.183522,826.231764,33.529915 49 | 24576.000000,720.835903,826.084057,34.390065 50 | 25088.000000,724.563141,827.645366,58.314519 51 | 25600.000000,716.083929,829.149834,36.183747 52 | 26112.000000,716.624338,830.600417,36.793660 53 | 26624.000000,706.441136,828.762623,37.811469 54 | 27136.000000,701.982221,828.580194,39.366759 55 | 27648.000000,694.455238,829.958749,38.413337 56 | 28160.000000,692.104445,826.715570,38.942091 57 | 28672.000000,688.816809,829.569590,39.355895 58 | 29184.000000,681.172885,829.385465,39.726389 59 | 29696.000000,675.388769,829.207682,40.389197 60 | 30208.000000,106.167593,829.035999,41.498068 61 | 30720.000000,553.513508,831.763084,41.624255 62 | 31232.000000,581.060485,830.086399,41.109949 63 | 31744.000000,95.327326,829.231011,41.764987 64 | 32256.000000,104.177634,829.736306,42.035917 65 | -------------------------------------------------------------------------------- /images/benchamrk_result/rms-norm-forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/rms-norm-forward.png -------------------------------------------------------------------------------- /images/benchamrk_result/skip_rmsnorm_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/skip_rmsnorm_benchmark.png -------------------------------------------------------------------------------- /images/benchamrk_result/softmax-performance.csv: -------------------------------------------------------------------------------- 1 | N,Torch_softmax,Triton_softmax,Triton_online_v2_softmax 2 | 4096.000000,73.142856,85.333330,46.545454 3 | 12288.000000,139.636363,192.000000,127.999995 4 | 20480.000000,170.666671,256.000006,196.923079 5 | 28672.000000,170.666668,298.666656,255.999996 6 | 36864.000000,184.320005,230.400006,307.200008 7 | 45056.000000,187.733338,268.190478,331.294112 8 | 53248.000000,184.888882,302.545452,391.529405 9 | 61440.000000,182.857144,349.090906,426.666652 10 | 69632.000000,197.818180,119.232872,414.476194 11 | 77824.000000,198.530619,131.459454,442.181815 12 | 86016.000000,195.490908,139.636369,430.080011 13 | 94208.000000,193.049184,149.063296,436.148147 14 | 102400.000000,182.857138,154.216869,412.903231 15 | 110592.000000,170.666670,164.571430,406.588228 16 | 118784.000000,185.600005,181.073174,436.705875 17 | 126976.000000,186.626485,188.952382,453.485702 18 | 135168.000000,185.670326,92.835163,444.631599 19 | 143360.000000,188.631573,97.399575,459.487195 20 | 151552.000000,189.440005,102.956519,462.048788 21 | 159744.000000,190.171430,105.372030,453.818178 22 | 167936.000000,177.898300,107.102044,466.488882 23 | 176128.000000,193.122803,116.776395,468.425518 24 | 184320.000000,193.613451,121.263154,470.204097 25 | 192512.000000,194.064519,124.683938,471.843146 26 | 200704.000000,194.480631,129.319592,464.592592 27 | 208896.000000,193.422230,131.216079,466.285708 28 | 217088.000000,180.906658,135.680003,467.862057 29 | 225280.000000,179.363055,142.944162,478.301479 30 | 233472.000000,195.865772,146.653265,486.400012 31 | 241664.000000,197.437907,151.798994,487.225813 32 | 249856.000000,196.427669,155.383083,488.000001 33 | 258048.000000,195.490903,159.683172,488.727269 34 | 266240.000000,193.488372,86.666664,489.411756 35 | 274432.000000,182.468080,85.546135,490.057130 36 | 282624.000000,198.471908,92.968418,497.577449 37 | 290816.000000,197.565213,95.162304,497.972583 38 | 299008.000000,198.808505,97.333330,505.081059 39 | 307200.000000,198.963731,99.481865,505.263180 40 | 315392.000000,199.111117,101.608251,492.800012 41 | 323584.000000,184.694069,102.399997,470.325581 42 | 331776.000000,185.142855,105.661149,499.662657 43 | 339968.000000,200.452831,108.132317,505.528626 44 | 348160.000000,200.553001,110.177212,512.000002 45 | 356352.000000,200.648646,112.484852,511.999998 46 | 364544.000000,199.859645,114.492461,506.311103 47 | 372736.000000,185.625505,114.758619,506.434771 48 | 380928.000000,181.740467,117.010598,511.999987 49 | 389120.000000,200.991740,120.396042,511.999984 50 | 397312.000000,201.502213,122.024566,512.000019 51 | 405504.000000,202.751990,124.540536,517.224507 52 | 413696.000000,202.000001,126.435208,517.120013 53 | 421888.000000,201.282453,126.162676,512.000008 54 | 430080.000000,186.666660,122.879996,521.941755 55 | 438272.000000,202.154984,132.323513,521.752385 56 | 446464.000000,202.202894,134.153848,521.570094 57 | 454656.000000,202.249120,134.992873,521.394493 58 | 462848.000000,200.888882,138.246122,521.225220 59 | 471040.000000,200.955638,137.249415,521.061938 60 | 479232.000000,187.200005,140.290401,495.330246 61 | 487424.000000,202.418610,143.698114,525.241366 62 | 495616.000000,201.798039,146.285716,525.016933 63 | 503808.000000,203.148385,148.528303,529.210099 64 | 512000.000000,201.892746,150.234737,528.925632 65 | 520192.000000,187.930643,150.170898,520.191975 66 | -------------------------------------------------------------------------------- /images/benchamrk_result/softmax-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/softmax-performance.png -------------------------------------------------------------------------------- /images/benchamrk_result/te_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/te_benchmark.png -------------------------------------------------------------------------------- /images/benchamrk_result/token_embedding_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/benchamrk_result/token_embedding_benchmark.png -------------------------------------------------------------------------------- /images/cli_stream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/cli_stream.png -------------------------------------------------------------------------------- /images/flashattention_nopad_benchamrk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/flashattention_nopad_benchamrk.png -------------------------------------------------------------------------------- /images/flashattentionv2_nopad_benchamrk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/flashattentionv2_nopad_benchamrk.png -------------------------------------------------------------------------------- /images/flashattentionv2_nopad_benchamrk2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/flashattentionv2_nopad_benchamrk2.png -------------------------------------------------------------------------------- /images/flashdecoding_benchamrk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/flashdecoding_benchamrk.png -------------------------------------------------------------------------------- /images/generate.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/generate.gif -------------------------------------------------------------------------------- /images/generate_stream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/generate_stream.png -------------------------------------------------------------------------------- /images/llava_output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_output.gif -------------------------------------------------------------------------------- /images/llava_output1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_output1.gif -------------------------------------------------------------------------------- /images/llava_output2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_output2.gif -------------------------------------------------------------------------------- /images/llava_output3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_output3.gif -------------------------------------------------------------------------------- /images/llava_test/WechatIMG205.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/WechatIMG205.jpg -------------------------------------------------------------------------------- /images/llava_test/dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/dog.jpeg -------------------------------------------------------------------------------- /images/llava_test/dog2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/dog2.png -------------------------------------------------------------------------------- /images/llava_test/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/extreme_ironing.jpg -------------------------------------------------------------------------------- /images/llava_test/graduate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/graduate.png -------------------------------------------------------------------------------- /images/llava_test/kaali.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/kaali.jpg -------------------------------------------------------------------------------- /images/llava_test/markdown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/markdown.png -------------------------------------------------------------------------------- /images/llava_test/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/mask.png -------------------------------------------------------------------------------- /images/llava_test/movie.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/movie.jpeg -------------------------------------------------------------------------------- /images/llava_test/painting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/painting.png -------------------------------------------------------------------------------- /images/llava_test/panda.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/panda.jpg -------------------------------------------------------------------------------- /images/llava_test/pexels-christian-heitz-285904-842711.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/pexels-christian-heitz-285904-842711.jpg -------------------------------------------------------------------------------- /images/llava_test/pexels-francesco-ungaro-1525041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/pexels-francesco-ungaro-1525041.jpg -------------------------------------------------------------------------------- /images/llava_test/pexels-sanaan-3052361.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/pexels-sanaan-3052361.jpg -------------------------------------------------------------------------------- /images/llava_test/superJumbo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/superJumbo.png -------------------------------------------------------------------------------- /images/llava_test/taitan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/taitan.jpg -------------------------------------------------------------------------------- /images/llava_test/website.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/llava_test/website.png -------------------------------------------------------------------------------- /images/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/output.gif -------------------------------------------------------------------------------- /images/qwen2.5-3b-output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/images/qwen2.5-3b-output.gif -------------------------------------------------------------------------------- /lite_llama/__init__.py: -------------------------------------------------------------------------------- 1 | from lite_llama.generate import GenerateText 2 | from lite_llama.generate_stream import GenerateStreamText 3 | from lite_llama.llava_generate_stream import LlavaGeneratorStream 4 | -------------------------------------------------------------------------------- /lite_llama/executor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/lite_llama/executor/__init__.py -------------------------------------------------------------------------------- /lite_llama/executor/cuda_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | from typing import Dict 4 | from .executor_struct import AttentionInfo 5 | from .mem_manager import KVCacheMemoryManager 6 | from ..models.utils import weak_ref_tensor 7 | 8 | _BATCH_SIZE_ALIGNMENT = 8 9 | _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ 10 | _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) 11 | ] 12 | 13 | 14 | class CUDAGraphRunner: 15 | def __init__(self, model): 16 | self.model = model 17 | self._cuda_graph = None 18 | self._graph_inputs: Dict[str, torch.Tensor] = {} 19 | self._graph_output = None 20 | 21 | def capture( 22 | self, 23 | input_ids: torch.Tensor, 24 | position_ids: torch.Tensor, 25 | atten_info: AttentionInfo, 26 | ): 27 | assert self._cuda_graph is None, "Already compiled the model" 28 | # 用于捕获的占位符输入 29 | self._graph_inputs = [input_ids, position_ids, atten_info] 30 | 31 | # Warm up 32 | graph_capture_stream = torch.cuda.Stream() 33 | graph_capture_stream.wait_stream(torch.cuda.current_stream()) 34 | with torch.cuda.stream(graph_capture_stream): 35 | _ = self.model.forward( 36 | input_ids=input_ids, 37 | position_ids=position_ids, 38 | atten_info=atten_info, 39 | ) 40 | torch.cuda.current_stream().wait_stream(graph_capture_stream) 41 | 42 | # Capture the graph 43 | self._cuda_graph = torch.cuda.CUDAGraph() 44 | with torch.cuda.graph(self._cuda_graph): 45 | self._graph_output = self.model.forward( 46 | input_ids=input_ids, 47 | position_ids=position_ids, 48 | atten_info=atten_info, 49 | ) 50 | 51 | # Save the input and output buffers. 52 | self._graph_inputs = { 53 | "input_ids": input_ids, 54 | "position_ids": position_ids, 55 | "kv_buffer": atten_info.kv_buffer, 56 | "cur_select_index": atten_info.cur_select_index, 57 | "b_req_tokens_table": atten_info.b_req_tokens_table, 58 | "b_req_idx": atten_info.b_req_idx, 59 | } 60 | 61 | def forward( 62 | self, 63 | input_ids: torch.Tensor, 64 | position_ids: torch.Tensor, 65 | atten_info: AttentionInfo, 66 | ): 67 | del ( 68 | atten_info.kv_buffer 69 | ) # kv_buffer are fixed tensors, so we don't need to copy them. 70 | del atten_info.b_req_tokens_table 71 | # 更新输入缓冲区 72 | self._graph_inputs["input_ids"].copy_(input_ids) # 据填充 graph 的输入内存 73 | self._graph_inputs["position_ids"].copy_(position_ids) 74 | 75 | self._graph_inputs["cur_select_index"].copy_(atten_info.cur_select_index) 76 | self._graph_inputs["b_req_idx"].copy_(atten_info.b_req_idx) 77 | 78 | self._cuda_graph.replay() 79 | 80 | return self._graph_output 81 | 82 | def __call__(self, *args, **kwargs): 83 | return self.forward(*args, **kwargs) 84 | 85 | 86 | class ModelRunner: 87 | def __init__( 88 | self, 89 | model, 90 | model_config, 91 | max_gpu_num_blocks: int, 92 | kv_mem_manager: KVCacheMemoryManager, 93 | req_tokens_manager, 94 | seq_len: int = 1, 95 | start_pos=8, 96 | ): 97 | self.model = model 98 | self.model_config = model_config 99 | self.max_gpu_num_blocks = max_gpu_num_blocks 100 | self.kv_mem_manager = kv_mem_manager 101 | self.req_tokens_manager = req_tokens_manager 102 | 103 | self.vocab_size = self.model_config.vocab_size 104 | self.graph_max_batch_size = self.model_config.max_batch_size 105 | self.max_seq_len = model_config.max_seq_len 106 | 107 | # 随机参数定义 108 | self.seq_len = seq_len 109 | self.start_pos = start_pos 110 | 111 | self.graph_runners = {} 112 | 113 | def build_atten_info(self, batch_size, atten_info, device="cuda"): 114 | """针对 decode 阶段, 构建 attention 输入信息结构体""" 115 | atten_info.kv_buffer = self.kv_mem_manager.gpu_kv_buffer # torch.Tensor 116 | atten_info.b_req_tokens_table = ( 117 | self.req_tokens.manager.b_req_tokens_table 118 | ) # torch.Tensor 119 | 120 | atten_info.b_req_idx = torch.arange(batch_size, device=device) # torch.Tensor 121 | atten_info.b_seq_len = torch.ones( 122 | batch_size, dtype=torch.int32, device="cuda" 123 | ) # torch.Tensor 124 | (atten_info.cur_select_index,) = self.kv_mem_manager.alloc_kvcache_index( 125 | batch_size 126 | ) # torch.Tensor 127 | 128 | return atten_info 129 | 130 | def capture_decode_graph( 131 | self, 132 | ): 133 | """ 134 | 针对 decode 阶段捕获 CUDA 图 135 | """ 136 | # 获取要捕获的批量大小列表,确保批量大小不超过最大批量大小 137 | batch_size_capture_list = [ 138 | bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= self.graph_max_batch_size 139 | ] 140 | atten_info = AttentionInfo 141 | print("cuda graph support batch list", batch_size_capture_list) 142 | 143 | # NOTE: Capturing the largest batch size first may help reduce the memory usage of CUDA graph. 144 | for batch_size in reversed(batch_size_capture_list): 145 | # 构造输入 tokens id 张量 146 | input_ids = torch.randint(0, self.vocab_size, (batch_size, 1)).cuda() 147 | position_ids = ( 148 | torch.arange( 149 | self.start_pos, self.start_pos + 1, device=input_ids.device 150 | ) 151 | .unsqueeze(0) # shape: [1, seq_len] 152 | .expand(batch_size, -1) # shape: [batch_size, seq_len], 不分配额外内存 153 | ) 154 | atten_info = self.build_atten_info(batch_size, atten_info) 155 | print( 156 | "apply cuda grpah atten_info.decode_index shape ", 157 | atten_info.decode_index.shape, 158 | ) 159 | 160 | graph_intput = (input_ids, position_ids, atten_info) 161 | graph_runner = CUDAGraphRunner(self.model) 162 | 163 | # graph 图捕捉输入 164 | graph_runner.capture(*graph_intput) 165 | self.graph_runners[batch_size] = graph_runner 166 | 167 | self.kv_mem_manager.free_all() 168 | 169 | def decode( 170 | self, 171 | input_ids: torch.Tensor, 172 | position_ids: torch.Tensor, 173 | atten_info: AttentionInfo, 174 | ): 175 | batch_size = input_ids.shape[0] 176 | if batch_size in self.graph_runners: 177 | model_executable = self.graph_runners[batch_size] 178 | else: 179 | print( 180 | "Warning: CUDA graph not captured for this batch size, falling back to original model." 181 | ) 182 | model_executable = self.model 183 | 184 | logits = model_executable(input_ids, position_ids, atten_info) 185 | return logits 186 | -------------------------------------------------------------------------------- /lite_llama/executor/executor_struct.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | @dataclass 5 | class ModelRunnerConfig: 6 | block_size = 1 7 | checkpoints_dir = "/gemini/code/Llama-3.2-1B-Instruct" 8 | max_batch_size = 16 9 | gpu_memory_utilization = 0.9 10 | 11 | 12 | @dataclass 13 | class AttentionInfo: 14 | # kv_cache = None # prefill 阶段的 context kv cache 15 | kv_buffer = list[torch.tensor([])] 16 | cur_select_index = torch.empty((0,), dtype=torch.int32) 17 | b_req_tokens_table = None 18 | b_start_loc = None 19 | b_req_idx = None 20 | -------------------------------------------------------------------------------- /lite_llama/executor/req_tokens_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class ReqTokensManager: 8 | """管理请求序列的 kv 内存 tokens 的类。 9 | 10 | TokenTable 将一系列 kv tokens 映射到一组token 表中, 每个 token 表代表请求序列分配的 kv cache 内存空间。 11 | """ 12 | 13 | def __init__(self, max_request_num, max_seq_len, mem_manager=None, device="cuda"): 14 | self.max_can_use_req_size = max_request_num 15 | self.can_use_req_size = max_request_num 16 | self.max_seq_len = max_seq_len 17 | self.req_state = torch.zeros( 18 | (max_request_num), dtype=torch.int32, device=device 19 | ) 20 | # 一个二维张量,形状为 [num_requests, max_seq_len],用于存储每个请求的 Token 索引。 21 | # 每行表示一个请求,每列表示该请求在特定序列位置上的 Token 索引。 22 | self.b_req_tokens_table = torch.zeros( 23 | (max_request_num, max_seq_len), dtype=torch.int32, device=device 24 | ) 25 | # self.mem_manager = mem_manager 26 | 27 | # 分配批次请求需要的内存空间 28 | def alloc_req(self, request_num): 29 | if request_num > self.can_use_req_size: 30 | logger.error( 31 | f"Insufficient requested capacity, remaining {self.can_use_req_size}" 32 | ) 33 | return None 34 | 35 | logical_select_index = torch.nonzero(self.req_state == 0).reshape(-1)[ 36 | :request_num 37 | ] 38 | self.req_state[logical_select_index] = 1 39 | self.can_use_req_size -= len(logical_select_index) 40 | return logical_select_index 41 | 42 | # 仅释放批次请求的索引 43 | def free_reqs(self, free_req_index, free_token_index): 44 | self.can_use_req_size += len(free_req_index) 45 | self.req_state[free_token_index] = 0 # 对应批次请求的索引重新置为 0 46 | if self.can_use_req_size == len(self.req_state): 47 | logger.debug(f"freed all request size {self.can_use_req_size}") 48 | # self.mem_manager.free(free_token_index) 49 | 50 | # 仅释放指定请求的索引 51 | def free_req(self, free_req_index): 52 | if free_req_index < 0 or free_req_index >= self.req_state.size(0): 53 | logger.error(f"Invalid free_req_index: {free_req_index}") 54 | return 55 | self.can_use_req_size += 1 56 | self.req_state[free_req_index] = 0 57 | return 58 | 59 | # 释放所有请求的内存,将所有请求状态 req_state 重置为未分配(都归 0)。 60 | def free_all(self): 61 | self.can_use_req_size = self.max_can_use_req_size 62 | self.req_state[:] = 0 63 | 64 | 65 | import unittest 66 | import torch 67 | 68 | 69 | class TestReqTokensManager(unittest.TestCase): 70 | def setUp(self): 71 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 72 | self.mem_manager_mock = unittest.mock.MagicMock() 73 | self.table = ReqTokensManager( 74 | max_request_num=10, 75 | max_seq_len=5, 76 | mem_manager=self.mem_manager_mock, 77 | device=self.device, 78 | ) 79 | 80 | def test_alloc_req(self): 81 | indices = self.table.alloc_req(3) 82 | self.assertEqual(len(indices), 3) 83 | self.assertTrue((self.table.req_state[indices] == 1).all()) 84 | 85 | def test_alloc_req_exceed_capacity(self): 86 | indices = self.table.alloc_req(11) 87 | self.assertIsNone(indices) 88 | 89 | def test_free_reqs(self): 90 | indices = self.table.alloc_req(3) 91 | self.table.free_reqs(indices, indices) 92 | self.assertTrue((self.table.req_state[indices] == 0).all()) 93 | 94 | def test_free_all(self): 95 | self.table.alloc_req(5) 96 | self.table.free_all() 97 | self.assertTrue((self.table.req_state == 0).all()) 98 | self.assertEqual(self.table.can_use_req_size, self.table.max_can_use_req_size) 99 | 100 | def test_invalid_free_req(self): 101 | self.table.free_req(-1) # Should not raise an error 102 | self.table.free_req(100) # Should not raise an error 103 | 104 | 105 | if __name__ == "__main__": 106 | unittest.main() 107 | -------------------------------------------------------------------------------- /lite_llama/inference.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | 4 | import time 5 | 6 | from .utils.prompt_templates import get_prompter 7 | from .generate import GenerateText 8 | 9 | 10 | class Inference(object): 11 | def __init__( 12 | self, 13 | temperature: float, 14 | top_p: float, 15 | max_seq_len: int, 16 | max_gen_len: Optional[int], 17 | lite_llama_ckpt_dir: str, 18 | device: str = "cuda", 19 | ): 20 | self.temperature = temperature 21 | self.top_p = top_p 22 | self.max_seq_len = max_seq_len 23 | self.max_gen_len = max_gen_len 24 | self.lite_llama_ckpt_dir = lite_llama_ckpt_dir 25 | self.device = device 26 | 27 | def load_generator(self, max_gpu_num_blocks=None) -> GenerateText: 28 | """ 29 | Initializes the lite-llama generator 30 | """ 31 | generator = GenerateText( 32 | checkpoints_dir=self.lite_llama_ckpt_dir, 33 | tokenizer_path=self.lite_llama_ckpt_dir, 34 | max_seq_len=self.max_seq_len, 35 | max_gpu_num_blocks=max_gpu_num_blocks, 36 | compiled_model=True, 37 | device=self.device, 38 | ) 39 | return generator 40 | 41 | def count_tokens(self, texts: list[str], tokenizer) -> int: 42 | # Optimized segmentation statistics 43 | total_tokens = 0 44 | for t in texts: 45 | ids = tokenizer(t, add_special_tokens=False)["input_ids"] 46 | total_tokens += len(ids) 47 | return total_tokens 48 | 49 | def inference(self, generator: GenerateText, prompts: list[str]): 50 | """ 51 | Inference is performed using lite-llama's GenerateText instance and returns 52 | the result with the time taken and the number of tokens output 53 | """ 54 | 55 | # Warm-up step: use a short dummy input to allow the model to 56 | # perform a simple inference to load caches/compile optimizations, etc. 57 | warm_up_prompt = ["Hello World"] * 4 58 | _ = generator.text_completion( 59 | warm_up_prompt, 60 | temperature=self.temperature, 61 | top_p=self.top_p, 62 | max_gen_len=5, 63 | ) 64 | 65 | start_time = time.time() 66 | results = generator.text_completion( 67 | prompts, 68 | temperature=self.temperature, 69 | top_p=self.top_p, 70 | max_gen_len=self.max_gen_len, 71 | ) 72 | end_time = time.time() 73 | 74 | total_tokens = self.count_tokens(results, generator.tokenizer) 75 | 76 | return results, end_time - start_time, total_tokens 77 | 78 | def process(self, prompts): 79 | if "qwen2" in self.lite_llama_ckpt_dir.lower(): 80 | model_type = "qwen2" 81 | elif "llama" in self.lite_llama_ckpt_dir.lower(): 82 | model_type = "llama" 83 | elif "llava" in self.lite_llama_ckpt_dir.lower(): 84 | model_type = "llava" 85 | else: 86 | print("Error! Unsupported model type!") 87 | 88 | model_prompter = get_prompter(model_type, self.lite_llama_ckpt_dir) 89 | update_prompts = [] 90 | for prompt in prompts: 91 | model_prompter.insert_prompt(prompt) 92 | update_prompts.append(model_prompter.model_input) 93 | 94 | # 1. lite-llama inference 95 | lite_llama_generator = self.load_generator(max_gpu_num_blocks=40960) 96 | lite_llama_results, lite_llama_time, lite_llama_tokens = self.inference( 97 | lite_llama_generator, update_prompts 98 | ) 99 | del lite_llama_generator 100 | torch.cuda.empty_cache() # Release the memory used by lite_llama_generator after use. 101 | 102 | return lite_llama_results 103 | -------------------------------------------------------------------------------- /lite_llama/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import gelu, relu, leaky_relu, tanh 2 | 3 | from .flashattention import flash_attention_v1 4 | from .flashattention2_nopad import flash_attention2_no_pad 5 | from .flashattentionv2 import flash_attention_v2 6 | from .flashdecoding import flash_decoding 7 | 8 | from .skip_rmsnorm import skip_rmsnorm 9 | from .swiglu import swiglu_forward 10 | from .rope_emb import rope_emb_forward 11 | from .softmax_split import softmax_split 12 | from .update_kv_buffer import update_kv_buffer 13 | from .update_kv_index import update_kv_index 14 | 15 | # from .others.activation_layers import ACT2FN 16 | # from .others.rmsnorm_v1 import rmsnorm 17 | # from .others.fused_linear import (fused_linear) 18 | # from .others.rope_orig import (precompute_freqs_cis, rope) 19 | # from .others.layernorm import layernorm 20 | # from .others.rotary_emb_v1 import rotary_emb_fwd 21 | # from .others.context_flashattention_nopad import context_attention_fwd_no_prompt_cache 22 | # from .others.rmsnorm_layer import rmsnorm_fwd 23 | -------------------------------------------------------------------------------- /lite_llama/kernels/activations.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import math 4 | 5 | sqrt2 = math.sqrt(2.0) 6 | 7 | 8 | # 激活函数都是逐元素操作算子,所以无需指定维度参数 9 | @triton.jit 10 | def relu(x): 11 | """ReLU(Rectified Linear Unit, 修正线性单元), only support inference. 12 | max(0, x) 13 | """ 14 | return tl.maximum(0, x) 15 | 16 | 17 | # Leaky ReLU 18 | @triton.jit 19 | def leaky_relu(x): 20 | """ 21 | LeakyReLU_ activation 22 | 23 | .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html 24 | """ 25 | scale = 1e-2 26 | scale = scale.to(x.dtype) 27 | return tl.where(x >= 0, x, scale * x) 28 | 29 | 30 | @triton.jit 31 | def tanh(x): 32 | """ 33 | Tanh(双曲正切)函数也是一种 Sigmoid 型函数,可以看作放大并平移的 Sigmoid 函数, only support inference. 34 | 2 / (1+e^{-2x}) -1 35 | """ 36 | return 2 / (1 + tl.exp(-2 * x)) - 1 37 | 38 | 39 | @triton.jit 40 | def gelu(x): 41 | """Gaussian Error Linear Unit (GELU), only support inference.""" 42 | return x * 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2)) 43 | 44 | 45 | @triton.jit 46 | def silu(x): 47 | return x * tl.sigmoid(x) 48 | -------------------------------------------------------------------------------- /lite_llama/kernels/others/fused_linear.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/fpgaminer/GPTQ-triton/blob/main/src/gptq_triton/quant_linear.py 2 | 3 | import math 4 | 5 | import torch 6 | import triton 7 | import triton.language as tl 8 | 9 | torch.backends.cuda.matmul.allow_tf32 = True 10 | torch.backends.cudnn.allow_tf32 = True 11 | 12 | 13 | # tl.math.tanh doesn't exist in CPU version of triton 14 | @triton.jit 15 | def tanh(x): 16 | return 2 * tl.sigmoid(2 * x) - 1 17 | 18 | 19 | @triton.jit 20 | def gelu_new(x): 21 | pi = tl.constexpr(tl.float32(math.pi)) 22 | a = tl.math.sqrt(2.0 / pi) 23 | b = x + 0.044715 * x * x * x 24 | return 0.5 * x * (1.0 + tanh(a * b)) 25 | 26 | 27 | @triton.jit 28 | def silu(x): 29 | return x * tl.sigmoid(x) 30 | 31 | 32 | @triton.jit 33 | def _fused_linear_kernel_fwd( 34 | x_ptr, # 输入数据矩阵首元素指针 35 | w_ptr, # 权重矩阵首元素指针 36 | z_ptr, # 输出结果地址 37 | M, 38 | N, 39 | K, # Matrix dimensions 40 | b_ptr=None, 41 | r_ptr=None, 42 | apply_silu=False, # gelu 激活和 dropout 43 | seed=1337, 44 | BLOCK_SIZE_M: tl.constexpr = 128, # 块大小 45 | BLOCK_SIZE_N: tl.constexpr = 128, 46 | BLOCK_SIZE_K: tl.constexpr = 64, 47 | ): 48 | # 当前 kernel 在 M/N 方向的程序 id 49 | pid_m = tl.program_id( 50 | 0 51 | ) # 二维内核允许在行(M)和列(N)两个方向上并行计算,极大地提高了计算效率。 52 | pid_n = tl.program_id(1) 53 | 54 | # 计算行列索引偏移,offs_m: 当前块负责的行索引,形状为 (BLOCK_SIZE_M, 1)。 55 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None] 56 | offs_n = ( 57 | pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :] 58 | ) # 形状为 (1, BLOCK_SIZE_N)。 59 | 60 | # 子块的矩阵乘法 61 | z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 62 | for k in range(0, K, BLOCK_SIZE_K): 63 | x_k = tl.arange(0, BLOCK_SIZE_K)[None, :] + k 64 | # (BLOCK_SIZE_M, BLOCK_SIZE_K) 65 | x = tl.load(x_ptr + offs_m * K + x_k, mask=(offs_m < M) & (x_k < K), other=0.0) 66 | x = x.to(tl.float16) 67 | 68 | w_k = tl.arange(0, BLOCK_SIZE_K)[:, None] + k 69 | # (BLOCK_SIZE_K, BLOCK_SIZE_N) 70 | w = tl.load(w_ptr + w_k * N + offs_n, mask=(w_k < K) & (offs_n < N), other=0.0) 71 | w = w.to(tl.float16) 72 | 73 | # (BLOCK_SIZE_M, BLOCK_SIZE_N) 74 | z = tl.dot(x, w, acc=z) 75 | 76 | if b_ptr is not None: 77 | b = tl.load(b_ptr + offs_n, mask=(offs_n < N), other=0.0) 78 | z += b.to(tl.float32) 79 | # (1, BLOCK_SIZE_N) 80 | 81 | z_offset = offs_m * N + offs_n 82 | z_mask = (offs_m < M) & (offs_n < N) 83 | 84 | if apply_silu: 85 | z = silu(z) 86 | 87 | if r_ptr is not None: 88 | r = tl.load(r_ptr + z_offset, mask=z_mask) 89 | z += r.to(tl.float32) 90 | 91 | tl.store(z_ptr + z_offset, z, mask=z_mask) 92 | 93 | 94 | @torch.no_grad() 95 | def fused_linear( 96 | x, 97 | weight, 98 | bias=None, 99 | residual=None, # 残差输入项 100 | add_silu=False, 101 | ): 102 | """ 103 | x: (*, K) 104 | weight: (K, N) 105 | bias: (N,) 106 | f = silu(x @ w + b) + residual 107 | """ 108 | # 将 x 形状去除最后一个维度,保存为 out_shape_0 109 | out_shape_0 = x.shape[:-1] 110 | # 将 x 的所有维度压缩为二维张量, [B, L, K] -> [M, K], K 是隐藏层的维度。 111 | x = x.view((-1, x.shape[-1])) 112 | M, K = x.shape 113 | N = weight.shape[1] 114 | 115 | # Allocates output. 116 | z = torch.empty((M, N), device=x.device, dtype=x.dtype) 117 | 118 | # assert x.shape[1] == weight.shape[0] 119 | assert x.is_contiguous() 120 | assert weight.is_contiguous() 121 | 122 | if bias is not None: 123 | assert bias.is_contiguous() 124 | assert weight.shape[1] == bias.shape[0] 125 | if residual is not None: 126 | residual = residual.view(z.shape) 127 | assert residual.is_contiguous() 128 | 129 | BLOCK_SIZE_M = 64 130 | BLOCK_SIZE_N = 64 131 | BLOCK_SIZE_K = 32 132 | 133 | # 2D launch kernel where each block gets its own program. 134 | grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N), 1) 135 | _fused_linear_kernel_fwd[grid]( 136 | x, 137 | weight, 138 | z, 139 | M, 140 | N, 141 | K, 142 | apply_silu=add_silu, 143 | b_ptr=bias, 144 | r_ptr=residual, 145 | BLOCK_SIZE_M=BLOCK_SIZE_M, 146 | BLOCK_SIZE_N=BLOCK_SIZE_N, 147 | BLOCK_SIZE_K=BLOCK_SIZE_K, 148 | ) 149 | return z.view((*out_shape_0, N)) 150 | -------------------------------------------------------------------------------- /lite_llama/kernels/others/layernorm.py: -------------------------------------------------------------------------------- 1 | import triton, torch 2 | import triton.language as tl 3 | 4 | 5 | @triton.jit 6 | def _layernorm_kernel_fwd( 7 | x_ptr, 8 | weight_ptr, 9 | bias_ptr, 10 | z_ptr, 11 | H, 12 | eps=1e-5, 13 | BLOCK_SIZE: tl.constexpr = 16, 14 | ): 15 | row_idx = tl.program_id(0) 16 | x_row_ptr = x_ptr + row_idx * H # 一行 H 个元素,H 表示嵌入层大小 17 | z_row_ptr = z_ptr + row_idx * H 18 | 19 | # 1, compute mean 20 | _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 21 | for i in range(0, H, BLOCK_SIZE): 22 | col_offsets = i + tl.arange(0, BLOCK_SIZE) 23 | x = tl.load(x_row_ptr + col_offsets, mask=col_offsets < H) 24 | _sum += x.to(tl.float32) 25 | 26 | mean = tl.sum(_sum, axis=0) / H 27 | 28 | # 2, compute variance 29 | x_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 30 | for i in range(0, H, BLOCK_SIZE): 31 | col_offsets = i + tl.arange(0, BLOCK_SIZE) 32 | x = tl.load(x_row_ptr + col_offsets, mask=col_offsets < H).to(tl.float32) 33 | x = tl.where(col_offsets < H, x - mean, 0.0) 34 | x_var += x * x 35 | 36 | x_var = tl.sum(x_var, axis=0) / H 37 | rtsd = tl.sqrt(x_var + eps) 38 | 39 | # 3, compute ln(x_i) 40 | for i in range(0, H, BLOCK_SIZE): 41 | col_offsets = i + tl.arange(0, BLOCK_SIZE) 42 | mask = col_offsets < H 43 | x = tl.load(x_row_ptr + col_offsets, mask=mask) 44 | w = tl.load(weight_ptr + col_offsets, mask=mask) 45 | b = tl.load(bias_ptr + col_offsets) 46 | 47 | x_hat = (x - mean) / rtsd 48 | z = x_hat * w + b 49 | tl.store(z_row_ptr + col_offsets, z, mask=mask) 50 | 51 | 52 | @torch.no_grad() 53 | def layernorm(x, weight, bias, eps=1e-5): 54 | # 只针对 nlp 领域的 layernorm,省去了 normalized_shape 参数 55 | assert x.is_contiguous() 56 | assert weight.is_contiguous() 57 | assert bias.is_contiguous() 58 | 59 | assert x.shape[-1] == weight.shape[0] == bias.shape[0] 60 | out_shape = x.shape 61 | x = x.view(-1, x.shape[-1]) # if: [B, L, H] then -> [B*L, H] 62 | BL, H = x.shape 63 | z = torch.empty(x.shape, device=x.device, dtype=x.dtype) 64 | 65 | # Less than 64KB per feature: enqueue fused kernel 66 | MAX_FUSED_SIZE = 4096 // x.element_size() 67 | BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) 68 | num_warps = min(max(BLOCK_SIZE // 256, 1), 8) 69 | 70 | _layernorm_kernel_fwd[BL,]( 71 | x, weight, bias, z, H, eps, BLOCK_SIZE, num_warps=num_warps 72 | ) 73 | return z.view(out_shape) 74 | -------------------------------------------------------------------------------- /lite_llama/kernels/others/rmsnorm_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/rms_norm.py 3 | """ 4 | 5 | import operator 6 | 7 | import torch 8 | import triton 9 | import triton.language as tl 10 | 11 | from ..utils import ( 12 | calculate_settings, 13 | compare_version, 14 | ) 15 | 16 | if compare_version("triton", operator.ge, "3.0.0"): 17 | try: 18 | # typical import path with dispatch available 19 | from triton.language.extra.libdevice import rsqrt 20 | except ModuleNotFoundError: 21 | # for working with NGC containers 22 | from triton.language.extra.cuda.libdevice import rsqrt 23 | else: 24 | from triton.language.math import rsqrt 25 | 26 | 27 | @triton.jit 28 | def _rms_norm_forward_kernel( 29 | Y_ptr, 30 | Y_row_stride, 31 | X_ptr, 32 | X_row_stride, 33 | W_ptr, 34 | W_row_stride, 35 | n_cols, 36 | eps, 37 | offset, 38 | BLOCK_SIZE: tl.constexpr, 39 | ): 40 | """ 41 | y_i = (x_i / (RMS)) * (offset_wi + wi), RMS = sqrt(sum(x_i^2) / N) 42 | 43 | Reference: 44 | 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html 45 | 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 46 | 3. https://arxiv.org/pdf/1910.07467 47 | """ 48 | 49 | row_idx = tl.program_id(0) 50 | col_offsets = tl.arange(0, BLOCK_SIZE) 51 | mask = col_offsets < n_cols 52 | 53 | Y_ptr += row_idx * Y_row_stride 54 | X_ptr += row_idx * X_row_stride 55 | 56 | X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) 57 | X_row_dtype = X_row.dtype 58 | W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) 59 | X_row = X_row.to(tl.float32) # On Llama, only rstd is computed on fp32 60 | 61 | mean_square = tl.sum(X_row * X_row, axis=0) / n_cols 62 | rstd = rsqrt(mean_square + eps) 63 | 64 | X_row = X_row * rstd 65 | X_row = X_row.to(X_row_dtype) 66 | Y_row = X_row * (W_row + offset) 67 | 68 | tl.store(Y_ptr + col_offsets, Y_row.to(X_row_dtype), mask=mask) 69 | 70 | 71 | @torch.no_grad() 72 | def rmsnorm_fwd(X, W, eps=1e-5, offset=0.0): 73 | shape = X.shape 74 | X = X.view(-1, shape[-1]) 75 | n_rows, n_cols = X.shape 76 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) 77 | 78 | Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) 79 | 80 | # Check constraints. 81 | assert X.shape[1] == W.shape[0], ( 82 | "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" 83 | ) 84 | 85 | _rms_norm_forward_kernel[(n_rows,)]( 86 | Y, 87 | Y.stride(0), 88 | X, 89 | X.stride(0), 90 | W, 91 | W.stride(0), 92 | n_cols, 93 | eps, 94 | offset, 95 | BLOCK_SIZE=BLOCK_SIZE, 96 | num_warps=num_warps, 97 | ) 98 | return Y.view(*shape) 99 | 100 | 101 | def test_rms_layernorm( 102 | dim=1024, 103 | eps=1e-5, 104 | dtype=torch.float16, 105 | bsz=21, 106 | random_state=3407, 107 | seqlen=3341, 108 | ): 109 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 110 | 111 | layernorm = LlamaRMSNorm((dim,), eps=eps).to("cuda") 112 | torch.cuda.manual_seed(random_state) 113 | torch.manual_seed(random_state) 114 | torch.nn.init.uniform_(layernorm.weight) 115 | X = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda") 116 | Y = layernorm(X) 117 | Y2 = rmsnorm_fwd(X, layernorm.weight, eps) 118 | 119 | assert torch.amax(Y - Y2).item() <= 0.05 120 | print("max delta:", torch.max(torch.abs(Y - Y2))) 121 | 122 | 123 | def testing_suite_layernorm(): 124 | for dim in [512, 1024, 2048]: 125 | for dtype in [torch.float16, torch.bfloat16]: 126 | with torch.autocast(device_type="cuda", dtype=dtype): 127 | for seqlen in [3341, 2048, 349]: 128 | for random_state in [3407, 42]: 129 | test_rms_layernorm( 130 | dim=dim, 131 | eps=1e-5, 132 | dtype=dtype, 133 | bsz=21, 134 | random_state=random_state, 135 | seqlen=seqlen, 136 | ) 137 | 138 | 139 | if __name__ == "__main__": 140 | testing_suite_layernorm() 141 | -------------------------------------------------------------------------------- /lite_llama/kernels/others/rmsnorm_v1.py: -------------------------------------------------------------------------------- 1 | # modified from https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html 2 | import triton, torch, os 3 | import triton.language as tl 4 | from ..utils import calculate_settings 5 | 6 | 7 | @triton.jit 8 | def _rmsnorm_kernel_fwd( 9 | x_ptr, # shape is [M, K] 10 | w_ptr, # gamma 参数地址 11 | z_ptr, # 输出结果首元素指针 12 | K, # 权重 W 大小, 也是输入 X 的第二维度大小 13 | eps, # epsilon to avoid division by zero 14 | BLOCK_SIZE: tl.constexpr = 8, 15 | ): 16 | """z = (x / (rms)) * w""" 17 | row_idx = tl.program_id(0) 18 | x_row_ptr = x_ptr + row_idx * K # 一行有 K 个元素,K 是最后一维 19 | z_row_ptr = z_ptr + row_idx * K 20 | 21 | # Compute variance 22 | _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 23 | for col_index in range(0, K, BLOCK_SIZE): 24 | col_offsets = col_index + tl.arange(0, BLOCK_SIZE) 25 | x_ptrs = x_row_ptr + col_offsets 26 | 27 | x = tl.load(x_ptrs, mask=col_offsets < K, other=0.0).to(tl.float32) 28 | _var += x * x 29 | 30 | var = tl.sum(_var, axis=0) / K 31 | rsqrt = 1 / tl.sqrt(var + eps) 32 | 33 | # Normalize and apply rmsnorm 34 | for col_index in range(0, K, BLOCK_SIZE): 35 | col_offsets = col_index + tl.arange(0, BLOCK_SIZE) 36 | mask = col_offsets < K 37 | 38 | x = tl.load(x_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) 39 | w = tl.load(w_ptr + col_offsets, mask=mask).to(tl.float32) 40 | 41 | normed = x * rsqrt 42 | normed = normed.to(w.dtype) # Exact copy from HF 43 | z = normed * w 44 | tl.store(z_row_ptr + col_offsets, z.to(z.dtype), mask=mask) 45 | 46 | 47 | @torch.no_grad() 48 | def rmsnorm(x, weight, eps=1e-5): 49 | z = torch.empty_like(x) # z 是三维的, [B, L, K] 50 | out_shape = x.shape 51 | x = x.view( 52 | (-1, x.shape[-1]) 53 | ) # 将 x 的所有维度压缩为二维张量, [B, L, K] -> [M, K], K 是隐藏层的维度。 54 | M, K = x.shape 55 | 56 | # Less than 64KB per feature: enqueue fused kernel 57 | # MAX_FUSED_SIZE = 65536 // x.element_size() # 用于返回张量中单个元素的大小(以字节为单位)。 58 | BLOCK_SIZE, num_warps = calculate_settings(K) 59 | if K > BLOCK_SIZE: 60 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 61 | # heuristics for number of warps 62 | num_warps = min(max(BLOCK_SIZE // 256, 1), 8) 63 | _rmsnorm_kernel_fwd[M,]( 64 | x, 65 | weight, 66 | z, 67 | K, 68 | eps=eps, 69 | BLOCK_SIZE=BLOCK_SIZE, 70 | num_warps=num_warps, 71 | ) 72 | return z.view(out_shape) 73 | 74 | 75 | def test_rms_layernorm( 76 | dim=1024, 77 | eps=1e-5, 78 | dtype=torch.float16, 79 | bsz=21, 80 | random_state=3407, 81 | seqlen=3341, 82 | ): 83 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 84 | 85 | layernorm = LlamaRMSNorm((dim,), eps=eps).to("cuda") 86 | torch.cuda.manual_seed(random_state) 87 | torch.manual_seed(random_state) 88 | torch.nn.init.uniform_(layernorm.weight) 89 | X = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda") 90 | Y = layernorm(X) 91 | Y2 = rmsnorm(X, layernorm.weight, eps) 92 | 93 | assert torch.amax(Y - Y2).item() <= 0.05 94 | print("max delta:", torch.max(torch.abs(Y - Y2))) 95 | 96 | 97 | def testing_suite_layernorm(): 98 | for dim in [512, 1024, 2048]: 99 | for dtype in [torch.float16, torch.bfloat16]: 100 | with torch.autocast(device_type="cuda", dtype=dtype): 101 | for seqlen in [3341, 2048, 349]: 102 | for random_state in [3407, 42]: 103 | test_rms_layernorm( 104 | dim=dim, 105 | eps=1e-5, 106 | dtype=dtype, 107 | bsz=21, 108 | random_state=random_state, 109 | seqlen=seqlen, 110 | ) 111 | 112 | 113 | if __name__ == "__main__": 114 | testing_suite_layernorm() 115 | -------------------------------------------------------------------------------- /lite_llama/kernels/others/rope_orig.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/nicksypark/rope-triton/blob/main/rope_triton/rope_triton.py 2 | 3 | import torch 4 | import triton 5 | import triton.language as tl 6 | from typing import Tuple, Union 7 | 8 | 9 | @triton.jit 10 | def rope_kernel_fw( 11 | input_ptr, 12 | in_seq_len_stride, 13 | in_batch_stride, 14 | output_ptr, 15 | cos_ptr, 16 | sin_ptr, 17 | cos_stride, 18 | sin_stride, 19 | seq_len, 20 | head_dim, 21 | BLOCK_SIZE: tl.constexpr, 22 | BATCH_NUM: tl.constexpr, 23 | ): 24 | pid_seq = tl.program_id(axis=0) 25 | pid_head = tl.program_id(axis=1) 26 | 27 | head_dim_offset = tl.arange(0, BLOCK_SIZE) # [0:head_dim/2] 28 | head_dim_mid = head_dim // 2 29 | 30 | mask = head_dim_offset < head_dim_mid 31 | 32 | cos_offset = (pid_seq % seq_len) * cos_stride + head_dim_offset 33 | sin_offset = (pid_seq % seq_len) * sin_stride + head_dim_offset 34 | 35 | cos = tl.load(cos_ptr + cos_offset, mask=mask, other=0.0) 36 | sin = tl.load(sin_ptr + sin_offset, mask=mask, other=0.0) 37 | 38 | for batch_idx in tl.static_range(0, BATCH_NUM): 39 | x1_offset = ( 40 | pid_seq * in_seq_len_stride 41 | + batch_idx * in_batch_stride 42 | + pid_head * head_dim 43 | + head_dim_offset 44 | ) 45 | x2_offset = ( 46 | pid_seq * in_seq_len_stride 47 | + batch_idx * in_batch_stride 48 | + pid_head * head_dim 49 | + head_dim_mid 50 | + head_dim_offset 51 | ) 52 | 53 | x1 = tl.load(input_ptr + x1_offset, mask=mask, other=0.0) 54 | x2 = tl.load(input_ptr + x2_offset, mask=mask, other=0.0) 55 | 56 | y1 = x1 * cos - x2 * sin 57 | y2 = x1 * sin + x2 * cos 58 | 59 | tl.store(output_ptr + x1_offset, y1, mask=mask) 60 | tl.store(output_ptr + x2_offset, y2, mask=mask) 61 | return 62 | 63 | 64 | @torch.no_grad() 65 | def rope( 66 | t: torch.Tensor, 67 | freqs: torch.Tensor, 68 | tensor_format: str = "sbhd", 69 | cu_seqlens: Union[torch.Tensor, None] = None, 70 | ) -> torch.Tensor: 71 | if tensor_format == "bshd": 72 | t = t.transpose(0, 1) 73 | elif tensor_format != "sbhd": 74 | raise ValueError(f"Unsupported tensor_format: {tensor_format}.") 75 | 76 | seq_len, batch_num, head_num, head_dim = t.shape 77 | assert t.device.type == "cuda", "Input tensor t must be on CUDA device" 78 | assert freqs.device.type == "cuda", "Input tensor freqs must be on CUDA device" 79 | 80 | output = torch.empty_like(t, device="cuda") 81 | 82 | BLOCK_SIZE = triton.next_power_of_2(head_dim // 2) 83 | 84 | grid = (seq_len, head_num) 85 | 86 | freqs = freqs[:seq_len] 87 | cos = torch.cos(freqs).to(t.dtype) 88 | sin = torch.sin(freqs).to(t.dtype) 89 | 90 | rope_kernel_fw[grid]( 91 | t, 92 | t.stride(0), 93 | t.stride(1), 94 | output, 95 | cos, 96 | sin, 97 | cos.stride(0), 98 | sin.stride(0), 99 | seq_len, 100 | head_dim, 101 | BLOCK_SIZE, 102 | batch_num, 103 | ) 104 | 105 | if tensor_format == "bshd": 106 | return output.transpose(0, 1) 107 | 108 | return output.to("cuda") 109 | 110 | 111 | def compute_theta( 112 | dim: int, base: float = 10000.0, device: torch.device = torch.device("cuda") 113 | ) -> torch.Tensor: 114 | """ 115 | 计算旋转位置编码中的 Theta 角度值。 116 | 117 | 参数: 118 | - d (int): 嵌入向量的维度(必须为偶数)。 119 | - base (float): 基础频率参数, 默认为10000.0。 120 | - device (torch.device): 计算设备, 默认为CPU。 121 | 122 | 返回: 123 | - torch.Tensor: 包含Theta值的1D张量, 形状为 [d/2]。 124 | """ 125 | if dim % 2 != 0: 126 | print("嵌入维度 dim 必须为偶数") 127 | i = torch.arange(1, (dim // 2) + 1, dtype=torch.float32, device=device) 128 | theta_i = base ** (-2 * (i - 1) / dim) 129 | 130 | return theta_i 131 | 132 | 133 | def precompute_freqs_cis( 134 | dim: int, 135 | seq_len: int, 136 | base: float = 10000.0, 137 | device: torch.device = torch.device("cuda"), 138 | ): 139 | theta = compute_theta(dim, base, device) # theta 角度值序列,向量, 大小为 dim // 2 140 | m = torch.arange(seq_len, device=device) # token 位置值序列,向量,大小为 seq_len 141 | m_theta = torch.outer( 142 | m, theta 143 | ) # 所有 token 位置的所有 Theta 值范围, 矩阵,尺寸为 [seq_len, dim // 2] 144 | freqs_cis = torch.polar( 145 | torch.ones_like(m_theta), m_theta 146 | ) # e^{i*m*\theta},本质上是旋转矩阵 147 | 148 | return freqs_cis 149 | 150 | 151 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 152 | """同一组的 kv cache 复制多份""" 153 | batch_size, seq_len, num_kv_heads, head_dim = x.shape 154 | if n_rep == 1: 155 | return x 156 | return ( 157 | # (B, Seq_Len, num_kv_heads, 1, Head_Dim) 158 | x[:, :, :, None, :] 159 | # (B, Seq_Len, num_kv_heads, N_Rep, Head_Dim) 160 | .expand(batch_size, seq_len, num_kv_heads, n_rep, head_dim) 161 | # (B, Seq_Len, num_kv_heads * N_Rep, Head_Dim) 162 | .reshape(batch_size, seq_len, num_kv_heads * n_rep, head_dim) 163 | ) 164 | -------------------------------------------------------------------------------- /lite_llama/kernels/others/rotary_emb_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | 6 | @triton.jit 7 | def _rotary_kernel( 8 | Q, 9 | K, 10 | Cos, 11 | Sin, 12 | stride_qbs, 13 | stride_qh, 14 | stride_qd, 15 | stride_kbs, 16 | stride_kh, 17 | stride_kd, 18 | stride_cosbs, 19 | stride_cosd, 20 | stride_sinbs, 21 | stride_sind, 22 | max_total_len, 23 | HEAD_Q, 24 | HEAD_K, # N_CTX 代表要计算的上下文长度 25 | BLOCK_HEAD: tl.constexpr, 26 | BLOCK_SEQ: tl.constexpr, 27 | BLOCK_DMODEL: tl.constexpr, 28 | ): 29 | cur_head_index = tl.program_id(0) 30 | cur_seq_index = tl.program_id(1) 31 | 32 | cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) 33 | cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) 34 | 35 | dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) 36 | dim_range1 = tl.arange(BLOCK_DMODEL // 2, BLOCK_DMODEL) 37 | 38 | off_q0 = ( 39 | cur_seq_range[:, None, None] * stride_qbs 40 | + cur_head_range[None, :, None] * stride_qh 41 | + dim_range0[None, None, :] * stride_qd 42 | ) 43 | off_q1 = ( 44 | cur_seq_range[:, None, None] * stride_qbs 45 | + cur_head_range[None, :, None] * stride_qh 46 | + dim_range1[None, None, :] * stride_qd 47 | ) 48 | 49 | off_dimcos_sin = ( 50 | cur_seq_range[:, None, None] * stride_cosbs 51 | + dim_range0[None, None, :] * stride_cosd 52 | ) 53 | 54 | q0 = tl.load( 55 | Q + off_q0, 56 | mask=(cur_seq_range[:, None, None] < max_total_len) 57 | & (cur_head_range[None, :, None] < HEAD_Q), 58 | other=0.0, 59 | ) 60 | q1 = tl.load( 61 | Q + off_q1, 62 | mask=(cur_seq_range[:, None, None] < max_total_len) 63 | & (cur_head_range[None, :, None] < HEAD_Q), 64 | other=0.0, 65 | ) 66 | 67 | cos = tl.load( 68 | Cos + off_dimcos_sin, 69 | mask=cur_seq_range[:, None, None] < max_total_len, 70 | other=0.0, 71 | ) 72 | sin = tl.load( 73 | Sin + off_dimcos_sin, 74 | mask=cur_seq_range[:, None, None] < max_total_len, 75 | other=0.0, 76 | ) 77 | 78 | out0 = q0 * cos - q1 * sin 79 | out1 = q0 * sin + q1 * cos 80 | 81 | tl.store( 82 | Q + off_q0, 83 | out0, 84 | mask=(cur_seq_range[:, None, None] < max_total_len) 85 | & (cur_head_range[None, :, None] < HEAD_Q), 86 | ) 87 | tl.store( 88 | Q + off_q1, 89 | out1, 90 | mask=(cur_seq_range[:, None, None] < max_total_len) 91 | & (cur_head_range[None, :, None] < HEAD_Q), 92 | ) 93 | 94 | off_k0 = ( 95 | cur_seq_range[:, None, None] * stride_kbs 96 | + cur_head_range[None, :, None] * stride_kh 97 | + dim_range0[None, None, :] * stride_kd 98 | ) 99 | off_k1 = ( 100 | cur_seq_range[:, None, None] * stride_kbs 101 | + cur_head_range[None, :, None] * stride_kh 102 | + dim_range1[None, None, :] * stride_kd 103 | ) 104 | 105 | off_dimcos_sin = ( 106 | cur_seq_range[:, None, None] * stride_cosbs 107 | + dim_range0[None, None, :] * stride_cosd 108 | ) 109 | 110 | k0 = tl.load( 111 | K + off_k0, 112 | mask=(cur_seq_range[:, None, None] < max_total_len) 113 | & (cur_head_range[None, :, None] < HEAD_K), 114 | other=0.0, 115 | ) 116 | k1 = tl.load( 117 | K + off_k1, 118 | mask=(cur_seq_range[:, None, None] < max_total_len) 119 | & (cur_head_range[None, :, None] < HEAD_K), 120 | other=0.0, 121 | ) 122 | cos = tl.load( 123 | Cos + off_dimcos_sin, 124 | mask=cur_seq_range[:, None, None] < max_total_len, 125 | other=0.0, 126 | ) 127 | sin = tl.load( 128 | Sin + off_dimcos_sin, 129 | mask=cur_seq_range[:, None, None] < max_total_len, 130 | other=0.0, 131 | ) 132 | 133 | out_k0 = k0 * cos - k1 * sin 134 | out_k1 = k0 * sin + k1 * cos 135 | 136 | tl.store( 137 | K + off_k0, 138 | out_k0, 139 | mask=(cur_seq_range[:, None, None] < max_total_len) 140 | & (cur_head_range[None, :, None] < HEAD_K), 141 | ) 142 | tl.store( 143 | K + off_k1, 144 | out_k1, 145 | mask=(cur_seq_range[:, None, None] < max_total_len) 146 | & (cur_head_range[None, :, None] < HEAD_K), 147 | ) 148 | return 149 | 150 | 151 | @torch.no_grad() 152 | def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): 153 | total_len = q.shape[0] 154 | head_num_q, head_num_k = q.shape[1], k.shape[1] 155 | head_dim = int(q.shape[2] * partial_rotary_factor) 156 | assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], ( 157 | f"q shape {q.shape} cos shape {cos.shape}" 158 | ) 159 | assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], ( 160 | f"k shape {k.shape} cos shape {cos.shape}" 161 | ) 162 | 163 | BLOCK_SEQ = 64 164 | BLOCK_HEAD = 4 165 | if head_dim >= 128: 166 | num_warps = 8 167 | else: 168 | num_warps = 4 169 | 170 | grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) 171 | _rotary_kernel[grid]( 172 | q, 173 | k, 174 | cos, 175 | sin, 176 | q.stride(0), 177 | q.stride(1), 178 | q.stride(2), 179 | k.stride(0), 180 | k.stride(1), 181 | k.stride(2), 182 | cos.stride(0), 183 | cos.stride(1), 184 | sin.stride(0), 185 | sin.stride(1), 186 | total_len, 187 | head_num_q, 188 | head_num_k, 189 | BLOCK_HEAD=BLOCK_HEAD, 190 | BLOCK_SEQ=BLOCK_SEQ, 191 | BLOCK_DMODEL=head_dim, 192 | num_warps=num_warps, 193 | num_stages=1, 194 | ) 195 | return q, k 196 | 197 | 198 | def torch_rotary_emb(x, cos, sin): 199 | seq_len, h, d = x.shape 200 | # cos, sin 的形状为 (seq_len, d//2) 201 | half_dim = cos.shape[-1] 202 | x0 = x[:, :, :half_dim] 203 | x1 = x[:, :, half_dim : 2 * half_dim] 204 | 205 | cos = cos.view(seq_len, 1, half_dim) 206 | sin = sin.view(seq_len, 1, half_dim) 207 | 208 | o0 = x0 * cos - x1 * sin 209 | o1 = x0 * sin + x1 * cos 210 | 211 | if 2 * half_dim < d: 212 | out = torch.cat([o0, o1, x[:, :, 2 * half_dim :]], dim=-1) 213 | else: 214 | out = torch.cat([o0, o1], dim=-1) 215 | 216 | return out 217 | 218 | 219 | if __name__ == "__main__": 220 | torch.manual_seed(0) 221 | batch_tokens = 24800 222 | x_shape = (batch_tokens, 32, 64) # (seq_len, num_heads, head_dim) 223 | dtype = torch.float16 224 | q = torch.randn(x_shape, dtype=dtype, device="cuda") 225 | k = torch.clone(q) 226 | 227 | # 生成 cos 和 sin,与 head_dim 对应,这里 head_dim=64,因此 cos, sin=(seq_len, head_dim//2)=(128,32) 228 | cos_shape = (batch_tokens, 32) 229 | y = torch.randn(cos_shape, dtype=dtype, device="cuda") 230 | cos = y.cos() 231 | sin = y.sin() 232 | 233 | output_torch = torch_rotary_emb(q, cos, sin) 234 | q_out, k_out = rotary_emb_fwd(q, k, cos, sin) 235 | print(output_torch) 236 | print(q_out) 237 | print( 238 | f"The maximum difference between torch and triton is {torch.max(torch.abs(output_torch - q_out))}" 239 | ) 240 | print("torch:", triton.testing.do_bench(lambda: torch_rotary_emb(q, cos, sin))) 241 | print("triton:", triton.testing.do_bench(lambda: torch_rotary_emb(q, cos, sin))) 242 | -------------------------------------------------------------------------------- /lite_llama/kernels/rope_emb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | 6 | @triton.jit 7 | def _triton_rope_emb( 8 | q_ptr, 9 | q_row_stride, 10 | k_ptr, 11 | k_row_stride, 12 | cos, 13 | cos_b_stride, 14 | cos_s_stride, 15 | sin, 16 | sin_b_stride, 17 | sin_s_stride, 18 | sl, 19 | bs: tl.constexpr, 20 | n_qh: tl.constexpr, 21 | n_kh: tl.constexpr, 22 | hd: tl.constexpr, 23 | pad_n_qh: tl.constexpr, 24 | pad_n_kh: tl.constexpr, 25 | pad_hd: tl.constexpr, 26 | BLOCK_SIZE: tl.constexpr, 27 | ): 28 | pid = tl.program_id(0) 29 | batch_id = pid // sl 30 | cos_row_idx = pid % sl 31 | 32 | # 定位到 q, k 行起点 33 | q_ptr += pid * q_row_stride 34 | k_ptr += pid * k_row_stride 35 | 36 | # 定位到 cos, sin 对应 batch_id 的 cos_row_idx 行 37 | cos_ptr = cos + batch_id * cos_b_stride + cos_row_idx * cos_s_stride 38 | sin_ptr = sin + batch_id * sin_b_stride + cos_row_idx * sin_s_stride 39 | 40 | cos_offsets = tl.arange(0, pad_hd // 2) 41 | cos_mask = cos_offsets < hd // 2 42 | cos_row = tl.load(cos_ptr + cos_offsets, mask=cos_mask, other=0) 43 | sin_row = tl.load(sin_ptr + cos_offsets, mask=cos_mask, other=0) 44 | 45 | # 计算 head 和 dim 偏移 46 | first_half_q_offsets = ( 47 | tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] 48 | ) 49 | first_half_k_offsets = ( 50 | tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] 51 | ) 52 | 53 | first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( 54 | tl.arange(0, pad_hd // 2)[None, :] < hd // 2 55 | ) 56 | first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( 57 | tl.arange(0, pad_hd // 2)[None, :] < hd // 2 58 | ) 59 | 60 | q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( 61 | sin_row.dtype 62 | ) 63 | k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( 64 | sin_row.dtype 65 | ) 66 | 67 | second_half_q_offsets = first_half_q_offsets + (hd // 2) 68 | second_half_k_offsets = first_half_k_offsets + (hd // 2) 69 | second_q_mask = first_q_mask 70 | second_k_mask = first_k_mask 71 | 72 | q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( 73 | sin_row.dtype 74 | ) 75 | k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( 76 | sin_row.dtype 77 | ) 78 | 79 | new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row 80 | tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) 81 | new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row 82 | tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) 83 | 84 | new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row 85 | tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) 86 | new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row 87 | tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) 88 | 89 | 90 | def rope_emb_forward(q, k, cos, sin, batch_size, seq_len): 91 | """ 92 | q: (batch_size * seq_len, n_q_heads, head_dim) 93 | k: (batch_size * seq_len, n_k_heads, head_dim) 94 | cos, sin: (batch_size, seq_len, head_dim) 95 | """ 96 | N, n_qh, HEAD_DIM = q.shape 97 | _, n_kh, _ = k.shape 98 | assert N == batch_size * seq_len 99 | 100 | pad_hd = triton.next_power_of_2(HEAD_DIM) 101 | pad_n_qh = triton.next_power_of_2(n_qh) 102 | pad_n_kh = triton.next_power_of_2(n_kh) 103 | BLOCK_SIZE = max(pad_n_qh, pad_n_kh) 104 | 105 | if HEAD_DIM >= 128: 106 | num_warps = 8 107 | else: 108 | num_warps = 4 109 | 110 | q = q.contiguous() 111 | k = k.contiguous() 112 | cos = cos.contiguous() 113 | sin = sin.contiguous() 114 | 115 | _triton_rope_emb[(N,)]( 116 | q, 117 | q.stride(0), 118 | k, 119 | k.stride(0), 120 | cos, 121 | cos.stride(0), 122 | cos.stride(1), 123 | sin, 124 | sin.stride(0), 125 | sin.stride(1), 126 | seq_len, 127 | batch_size, 128 | n_qh, 129 | n_kh, 130 | HEAD_DIM, 131 | pad_n_qh, 132 | pad_n_kh, 133 | pad_hd, 134 | BLOCK_SIZE=BLOCK_SIZE, 135 | num_warps=num_warps, 136 | num_stages=1, 137 | ) 138 | return q, k 139 | -------------------------------------------------------------------------------- /lite_llama/kernels/softmax_split.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/iclementine/optimize_softmax/blob/master/softmax_split.py 2 | 3 | import triton 4 | from triton import language as tl 5 | import torch 6 | 7 | 8 | @triton.jit 9 | def logsumexp_kernel( 10 | out_ptr, 11 | in_ptr, 12 | M, 13 | N, 14 | TILE_N: tl.constexpr, 15 | ): 16 | pid_n = tl.program_id(0) 17 | num_programs_n = tl.num_programs(0) 18 | pid_m = tl.program_id(1) 19 | 20 | n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) 21 | mask = n_offsets < N 22 | offset = pid_m * N + n_offsets 23 | inp = tl.load(in_ptr + offset, mask=mask, other=-float("inf")).to(tl.float32) 24 | m = tl.max(inp, 0) 25 | e = tl.exp(inp - m) 26 | z = tl.sum(e, 0) 27 | logz = m + tl.log(z) 28 | 29 | output_ptrs = out_ptr + pid_m * num_programs_n + pid_n 30 | tl.store(output_ptrs, logz) 31 | 32 | 33 | @triton.jit 34 | def combine_logsumexp_kernel(out_ptr, inp_ptr, M, N, TILE_N: tl.constexpr): 35 | pid_m = tl.program_id(0) 36 | n_offsets = tl.arange(0, TILE_N) 37 | mask = n_offsets < N 38 | logzs = tl.load(inp_ptr + pid_m * N + n_offsets, other=-float("inf"), mask=mask).to( 39 | out_ptr.dtype.element_ty 40 | ) 41 | m = tl.max(logzs, 0) 42 | e = tl.exp(logzs - m) 43 | z = tl.sum(e, 0) 44 | logz = m + tl.log(z) 45 | tl.store(out_ptr + pid_m, logz) 46 | 47 | 48 | @triton.jit 49 | def softmax_kernel(out_ptr, in_ptr, logz_ptr, M, N, TILE_N: tl.constexpr): 50 | pid_n = tl.program_id(0) 51 | pid_m = tl.program_id(1) 52 | n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) 53 | offset = pid_m * N + n_offsets 54 | mask = n_offsets < N 55 | inp = tl.load(in_ptr + offset, mask=mask, other=-float("inf")).to( 56 | out_ptr.dtype.element_ty 57 | ) 58 | logz = tl.load(logz_ptr + pid_m).to(tl.float32) 59 | out = tl.exp(inp - logz) 60 | tl.store(out_ptr + offset, out, mask=mask) 61 | 62 | 63 | def softmax_split(x): 64 | M, N = x.shape 65 | 66 | # num_sms = torch.cuda.get_device_properties(x.device).multi_processor_count 67 | 68 | TILE_N = min(4096, triton.next_power_of_2(N)) 69 | num_tiles_n = triton.cdiv(N, TILE_N) 70 | logz = torch.empty((M, num_tiles_n), dtype=x.dtype, device=x.device) 71 | grid = (num_tiles_n, M, 1) 72 | logsumexp_kernel[grid](logz, x, M, N, TILE_N) 73 | 74 | combined_logz = torch.empty((M,), dtype=x.dtype, device=x.device) 75 | TILE_N = triton.next_power_of_2(num_tiles_n) 76 | grid = (M, 1, 1) 77 | combine_logsumexp_kernel[grid](combined_logz, logz, M, num_tiles_n, TILE_N) 78 | 79 | out = torch.empty_like(x) 80 | TILE_N = min(4096, triton.next_power_of_2(N)) 81 | num_tiles_n = triton.cdiv(N, TILE_N) 82 | grid = (num_tiles_n, M, 1) 83 | softmax_kernel[grid](out, x, combined_logz, M, N, TILE_N) 84 | return out 85 | -------------------------------------------------------------------------------- /lite_llama/kernels/swiglu.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/swiglu.py 2 | 3 | import torch 4 | import triton 5 | import triton.language as tl 6 | import functools 7 | 8 | 9 | def is_hip() -> bool: 10 | return torch.version.hip is not None 11 | 12 | 13 | def ensure_contiguous(fn): 14 | @functools.wraps(fn) 15 | def wrapper(ctx, *args, **kwargs): 16 | def maybe_to_contiguous(x): 17 | return x.contiguous() if isinstance(x, torch.Tensor) else x 18 | 19 | args = [maybe_to_contiguous(arg) for arg in args] 20 | kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} 21 | return fn(ctx, *args, **kwargs) 22 | 23 | return wrapper 24 | 25 | 26 | def calculate_settings(n): 27 | # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 28 | 29 | MAX_FUSED_SIZE = 65536 30 | BLOCK_SIZE = triton.next_power_of_2(n) 31 | if BLOCK_SIZE > MAX_FUSED_SIZE: 32 | raise RuntimeError( 33 | f"Cannot launch Triton kernel since n = {n} exceeds " 34 | f"the recommended Triton blocksize = {MAX_FUSED_SIZE}." 35 | ) 36 | 37 | num_warps = 4 38 | if BLOCK_SIZE >= 32768: 39 | num_warps = 32 if not is_hip() else 16 40 | elif BLOCK_SIZE >= 8192: 41 | num_warps = 16 42 | elif BLOCK_SIZE >= 2048: 43 | num_warps = 8 44 | return BLOCK_SIZE, num_warps 45 | 46 | 47 | @triton.jit 48 | def silu(x): 49 | return x * tl.sigmoid(x) 50 | 51 | 52 | @triton.jit 53 | def _swiglu_forward_kernel( 54 | a_ptr, b_ptr, c_ptr, row_stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr 55 | ): 56 | program_id = tl.program_id(0).to(tl.int64) 57 | 58 | # locate start index 59 | a_ptr += program_id * row_stride 60 | b_ptr += program_id * row_stride 61 | c_ptr += program_id * row_stride 62 | 63 | col_offsets = tl.arange(0, BLOCK_SIZE) 64 | mask = col_offsets < n_cols 65 | 66 | # sigmoid requires type float32 67 | a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) 68 | b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) 69 | c_row = silu(a_row) * b_row 70 | tl.store(c_ptr + col_offsets, c_row, mask=mask) 71 | 72 | 73 | def swiglu_forward(a, b): 74 | ori_shape = a.shape # ori_shape is [batch_size, seq_len, hidden_size] 75 | 76 | n_cols = ori_shape[-1] 77 | a = a.view(-1, n_cols) 78 | b = b.view(-1, n_cols) 79 | c = torch.empty_like(a) 80 | n_rows = a.shape[0] 81 | 82 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) 83 | 84 | _swiglu_forward_kernel[(n_rows,)]( 85 | a, 86 | b, 87 | c, 88 | c.stride(-2), # c.stride(-2) = n_cols 89 | n_cols=n_cols, 90 | BLOCK_SIZE=BLOCK_SIZE, 91 | num_warps=num_warps, 92 | ) 93 | return c.view(*ori_shape) 94 | -------------------------------------------------------------------------------- /lite_llama/kernels/update_kv_buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import triton 4 | import triton.language as tl 5 | 6 | 7 | @triton.jit 8 | def _fwd_kernel_update_kv( 9 | KV_Values, 10 | Select_Index, 11 | KV_Buffer, 12 | stride_k_bs, 13 | stride_k_h, 14 | stride_k_d, 15 | stride_o_bs, 16 | stride_o_h, 17 | stride_o_d, 18 | head_num, 19 | BLOCK_DMODEL: tl.constexpr, 20 | BLOCK_HEAD: tl.constexpr, 21 | ): 22 | cur_index = tl.program_id(0) 23 | offs_h = tl.arange(0, BLOCK_HEAD) 24 | offs_d = tl.arange(0, BLOCK_DMODEL) 25 | 26 | dest_index = tl.load(Select_Index + cur_index) 27 | 28 | k_ptrs = ( 29 | KV_Values 30 | + cur_index * stride_k_bs 31 | + stride_k_h * offs_h[:, None] 32 | + stride_k_d * offs_d[None, :] 33 | ) 34 | o_ptrs = ( 35 | KV_Buffer 36 | + dest_index * stride_o_bs 37 | + stride_o_h * offs_h[:, None] 38 | + stride_o_d * offs_d[None, :] 39 | ) 40 | 41 | kv_value = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) 42 | tl.store(o_ptrs, kv_value, mask=offs_h[:, None] < head_num) 43 | return 44 | 45 | 46 | @torch.no_grad() 47 | def update_kv_buffer(KV_Values, Select_Index, KV_Buffer): 48 | """ 49 | 参数: 50 | - Select_Index: prefill 阶段 batch_size * seq_len, decode 阶段 batch_size。 51 | Select_Index[i] 表示 KV_Values 的第 i 行 应该被复制到 KV_Buffer 的第 Select_Index[i] 行。 52 | - KV_Values: 实际是 cache_kv, 尺寸为 [select_indexs, num_kv_heads * 2, head_dim]。 53 | - KV_Buffer: 尺寸为 [max_num_tokens, num_kv_heads * 2, head_dim] 54 | 输出: 55 | KV_Buffer 张量被填, KV_Buffer[Select_Index[i], :, :] = K[i, :, :]。 56 | """ 57 | seq_len = Select_Index.shape[0] # number_tokens 58 | head_num = KV_Values.shape[1] # num_kv_head * 2 59 | head_dim = KV_Values.shape[2] 60 | assert ( 61 | KV_Values.shape[1] == KV_Buffer.shape[1] 62 | and KV_Values.shape[2] == KV_Buffer.shape[2] 63 | ) 64 | BLOCK_HEAD = triton.next_power_of_2(head_num) 65 | grid = (seq_len,) 66 | num_warps = 1 67 | 68 | _fwd_kernel_update_kv[grid]( 69 | KV_Values, 70 | Select_Index, 71 | KV_Buffer, 72 | KV_Values.stride(0), 73 | KV_Values.stride(1), 74 | KV_Values.stride(2), 75 | KV_Buffer.stride(0), 76 | KV_Buffer.stride(1), 77 | KV_Buffer.stride(2), 78 | head_num, 79 | BLOCK_DMODEL=head_dim, 80 | BLOCK_HEAD=BLOCK_HEAD, 81 | num_warps=num_warps, 82 | num_stages=1, 83 | ) 84 | return 85 | 86 | 87 | def test1(): 88 | import time 89 | 90 | num_of_times = 1000 91 | 92 | B, Seq_Len, H, D = 32, 1024, 12, 128 93 | dest = torch.randn((B * Seq_Len, H, D), dtype=torch.float16).cuda() 94 | src = torch.randn((B * Seq_Len, H, D), dtype=torch.float16).cuda() 95 | dest_loc = torch.arange(0, B * Seq_Len, dtype=torch.int32, device="cuda") 96 | 97 | for _ in range(10): # Warm up 98 | update_kv_buffer(src, dest_loc, dest) 99 | torch.cuda.synchronize() 100 | 101 | t1 = time.time() 102 | for _ in range(num_of_times): 103 | update_kv_buffer(src, dest_loc, dest) 104 | torch.cuda.synchronize() 105 | t2 = time.time() 106 | 107 | for _ in range(num_of_times): 108 | dest[dest_loc] = src 109 | torch.cuda.synchronize() 110 | t3 = time.time() 111 | 112 | print("Triton Time cost ", t2 - t1) 113 | print("Torch Time cost ", t3 - t2) 114 | print("max ", torch.max(torch.abs(dest - src))) 115 | print("mean ", torch.mean(torch.abs(dest - src))) 116 | assert torch.allclose(src, dest, atol=1e-2, rtol=0) 117 | 118 | 119 | if __name__ == "__main__": 120 | test1() 121 | -------------------------------------------------------------------------------- /lite_llama/kernels/update_kv_index.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | 6 | @triton.jit 7 | def _fwd_kernel_update_kv_index( 8 | req_to_token_indexs, # 输出张量的指针,形状为 (num_requests, max_seq_len) 9 | b_req_idx, # decode_batch 批次中每个请求的 ID,形状为 (num_tokens,) 10 | b_seq_len, # decode_batch 中每个请求的序列长度,形状为 (num_tokens,) 11 | select_index, # decode_batch 中每个 tokens的 KV 索引,形状为 (num_tokens,) 12 | stride_req_to_token_b, # req_to_token_indexs 在第一个维度(请求)的步幅 13 | stride_req_to_token_s, # req_to_token_indexs 在第二个维度(序列长度)的步幅 14 | ): 15 | # 获取当前程序的 ID,即线程的索引 16 | cur_index = tl.program_id(0) 17 | 18 | # 从 b_req_idx 张量加载当前请求的 ID 19 | cur_req_idx = tl.load(b_req_idx + cur_index) 20 | 21 | # 从 select_index 张量加载当前令牌的 KV 索引 22 | cur_token_index = tl.load(select_index + cur_index) 23 | 24 | # 从 b_seq_len 张量加载当前请求的序列长度 25 | cur_seq_len = tl.load(b_seq_len + cur_index) 26 | 27 | # 计算目标位置的偏移量: 28 | # req_to_token_indexs[cur_req_idx][cur_seq_len - 1] 29 | dest_offset = ( 30 | req_to_token_indexs 31 | + cur_req_idx * stride_req_to_token_b 32 | + (cur_seq_len - 1) * stride_req_to_token_s 33 | ) 34 | 35 | # 将当前令牌索引存储到目标位置 36 | tl.store(dest_offset, cur_token_index) 37 | 38 | return 39 | 40 | 41 | @torch.no_grad() 42 | def update_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, select_index): 43 | """ 44 | 根据每个 token 的请求索引 ID 和当前序列长度, 把这个 token 在 KV 缓存里的索 (select_index) 存进输出张量 req_to_token_indexs 的正确位置 45 | 参数: 46 | req_to_token_indexs (torch.Tensor): 输出张量,用于存储 KV 索引。形状为 (num_requests, max_seq_len)。 47 | b_req_idx (torch.Tensor): 批次中每个请求的 ID, 形状为 (num_tokens,)。 48 | b_seq_len (torch.Tensor): 每个请求的序列长度,形状为 (num_tokens,)。 49 | select_index (torch.Tensor): 每个令牌的 KV 索引,形状为 (num_tokens,)。 50 | 51 | 该函数使用 Triton 内核来高效地执行复制操作。 52 | """ 53 | # 获取序列长度,即令牌数量 54 | seq_len = b_seq_len.shape[0] 55 | 56 | # 确保所有输入张量在第一个维度上的大小相同 57 | assert ( 58 | b_seq_len.shape[0] == select_index.shape[0] 59 | and b_req_idx.shape[0] == b_seq_len.shape[0] 60 | ), "所有输入张量在第一个维度上的大小必须相同。" 61 | 62 | # 定义 Triton 内核的网格大小(1D 网格) 63 | grid = (seq_len,) 64 | 65 | # 定义每个 block 使用的 warp 数量 66 | num_warps = 1 67 | 68 | # 启动 Triton 内核 69 | _fwd_kernel_update_kv_index[grid]( 70 | req_to_token_indexs, # 输出张量的指针 71 | b_req_idx, # 请求索引张量的指针 72 | b_seq_len, # 序列长度张量的指针 73 | select_index, # 令牌索引张量的指针 74 | req_to_token_indexs.stride(0), # req_to_token_indexs 在第一个维度上的步幅 75 | req_to_token_indexs.stride(1), # req_to_token_indexs 在第二个维度上的步幅 76 | num_warps=num_warps, # 使用的 warp 数量 77 | num_stages=1, # 使用的流水线阶段数量 78 | ) 79 | return 80 | -------------------------------------------------------------------------------- /lite_llama/kernels/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. 3 | See the original Unsloth repository at https://github.com/unslothai/unsloth. 4 | 5 | The following line 6 | https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 7 | is based on code from Unsloth, located at: 8 | https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 9 | 10 | Modifications made by Yanning Chen, 2024. 11 | """ 12 | 13 | import functools 14 | import importlib 15 | import operator 16 | from typing import Callable 17 | 18 | import torch 19 | import triton 20 | import triton.language as tl 21 | from packaging.version import Version 22 | 23 | 24 | def is_hip() -> bool: 25 | return torch.version.hip is not None 26 | 27 | 28 | def keep(conf): 29 | BLOCK_M = conf.kwargs["BLOCK_M"] 30 | BLOCK_N = conf.kwargs["BLOCK_N"] 31 | if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: 32 | return False 33 | return True 34 | 35 | 36 | def ensure_contiguous(fn): 37 | @functools.wraps(fn) 38 | def wrapper(ctx, *args, **kwargs): 39 | def maybe_to_contiguous(x): 40 | return x.contiguous() if isinstance(x, torch.Tensor) else x 41 | 42 | args = [maybe_to_contiguous(arg) for arg in args] 43 | kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} 44 | return fn(ctx, *args, **kwargs) 45 | 46 | return wrapper 47 | 48 | 49 | def calculate_settings(n): 50 | # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 51 | 52 | MAX_FUSED_SIZE = 65536 53 | BLOCK_SIZE = triton.next_power_of_2(n) 54 | if BLOCK_SIZE > MAX_FUSED_SIZE: 55 | raise RuntimeError( 56 | f"Cannot launch Triton kernel since n = {n} exceeds " 57 | f"the recommended Triton blocksize = {MAX_FUSED_SIZE}." 58 | ) 59 | 60 | num_warps = 4 61 | if BLOCK_SIZE >= 32768: 62 | num_warps = 32 if not is_hip() else 16 63 | elif BLOCK_SIZE >= 8192: 64 | num_warps = 16 65 | elif BLOCK_SIZE >= 2048: 66 | num_warps = 8 67 | return BLOCK_SIZE, num_warps 68 | 69 | 70 | def compare_version(package: str, operator: Callable, target: str): 71 | try: 72 | pkg = importlib.import_module(package) 73 | except ImportError: 74 | return False 75 | pkg_version = Version(pkg.__version__) 76 | return operator(pkg_version, Version(target)) 77 | 78 | 79 | torch_to_triton_dtype = { 80 | torch.float32: tl.float32, 81 | torch.float16: tl.float16, 82 | torch.bfloat16: tl.bfloat16, 83 | } 84 | -------------------------------------------------------------------------------- /lite_llama/models/llava.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformers import AutoModel, LlavaConfig 8 | from .llama import LlamaModel 9 | from .model_config import LlamaConfig 10 | from ..kernels import gelu 11 | from .utils import merge_input_ids_with_image_features 12 | 13 | 14 | class LlavaMultiModalProjector(nn.Module): 15 | def __init__( 16 | self, 17 | vision_hidden_size: int, 18 | text_hidden_size: int, 19 | projector_hidden_act: str = "gelu", 20 | ): 21 | super().__init__() 22 | 23 | self.linear_1 = nn.Linear(vision_hidden_size, text_hidden_size, bias=True) 24 | self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=True) 25 | 26 | def forward(self, image_features: torch.Tensor) -> torch.Tensor: 27 | hidden_states = self.linear_1(image_features) 28 | hidden_states = F.gelu(hidden_states) # GELU 激活函数 29 | hidden_states = self.linear_2(hidden_states) 30 | return hidden_states 31 | 32 | 33 | class LlavaLlama(nn.Module): 34 | def __init__(self, llava_config: LlavaConfig): 35 | super().__init__() 36 | self.device = "cuda" 37 | self.llava_config = llava_config 38 | text_config = ( 39 | self.llava_config.text_config 40 | ) # TODO: 将 text_config 转换成 LlamaConfig 类型 41 | self.llama_config = LlamaConfig.from_dict(text_config.to_dict()) 42 | 43 | self.select_layer = llava_config.vision_feature_layer 44 | self.select_feature = llava_config.vision_feature_select_strategy 45 | 46 | # 视觉处理模块(vision_tower)初始化 47 | self.vision_tower = AutoModel.from_config(llava_config.vision_config) 48 | 49 | # 多模态投影器(multi_modal_projector)初始化 50 | self.multi_modal_projector = LlavaMultiModalProjector( 51 | vision_hidden_size=llava_config.vision_config.hidden_size, 52 | text_hidden_size=llava_config.text_config.hidden_size, 53 | projector_hidden_act=llava_config.projector_hidden_act, 54 | ) 55 | 56 | # 语言模型初始化 57 | self.language_model = LlamaModel(self.llama_config) 58 | 59 | self.pad_token_id = ( 60 | self.llava_config.pad_token_id 61 | if self.llava_config.pad_token_id is not None 62 | else -1 63 | ) 64 | 65 | def _select_image_features( 66 | self, image_features: torch.Tensor, strategy: str 67 | ) -> torch.Tensor: 68 | """根据策略选择图像特征""" 69 | # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa 70 | if strategy == "default" or strategy == "patch": 71 | return image_features[:, 1:].contiguous() 72 | elif strategy == "full": 73 | return image_features 74 | 75 | raise ValueError(f"Unexpected select feature strategy: {strategy}") 76 | 77 | def vision_encode(self, image_tensor): 78 | x = image_tensor.half().to(device=self.device) 79 | 80 | # 1. 通过视觉处理模块提取图像特征 81 | x = self.vision_tower(x, output_hidden_states=True) 82 | x = x.hidden_states[self.select_layer] 83 | x = self._select_image_features(x, self.select_feature) 84 | 85 | # 2. 通过多模态投影器将图像特征转换为多模态嵌入 86 | image_features = self.multi_modal_projector(x) 87 | 88 | assert not torch.isnan(image_features).any(), ( 89 | f"After vision_tower image_features tensor contains NaN values!" 90 | ) 91 | return image_features 92 | 93 | def get_multi_modal_input_embeddings( 94 | self, 95 | input_ids: torch.Tensor, 96 | vision_embeddings=None, 97 | ) -> torch.Tensor: 98 | """获取输入嵌入,包括文本和视觉嵌入的合并。""" 99 | llm_inputs_embeds = self.language_model.get_input_embeddings( 100 | input_ids 101 | ) # torch.Size([1, 22]) --> torch.Size([1, 22, 4096]) 102 | 103 | # torch.Size([1, 576, 4096]) torch.Size([1, 22, 4096]) torch.Size([1, 22]) 104 | # print("self.llava_config.image_token_index is ", self.llava_config.image_token_index) 105 | if vision_embeddings is not None: 106 | inputs_embeds, position_ids = merge_input_ids_with_image_features( 107 | input_ids, 108 | llm_inputs_embeds, 109 | vision_embeddings, 110 | self.llava_config.pad_token_id, 111 | self.llava_config.image_token_index, 112 | ) 113 | 114 | assert not torch.isnan(inputs_embeds).any(), ( 115 | f"After merge inputs_embeds tensor contains NaN values!" 116 | ) 117 | 118 | return inputs_embeds, position_ids 119 | 120 | def forward( 121 | self, 122 | input_ids: torch.Tensor, 123 | position_ids: torch.Tensor, 124 | atten_info, 125 | image_tensor: Optional[torch.FloatTensor] = None, 126 | ): 127 | input_ids = input_ids.to(self.device) # 将 input_ids 移动到设备 128 | if position_ids is not None: # 如果提供了 position_ids,将其移动到设备 129 | position_ids = position_ids.to(self.device) 130 | 131 | if input_ids.shape[1] != 1: # 判断是不是首次 token 输出 132 | vision_embeddings = self.vision_encode( 133 | image_tensor 134 | ) # torch.Size([1, 3, 336, 336]) --> torch.Size([1, 576, 4096]) 135 | inputs_embeds, position_ids = self.get_multi_modal_input_embeddings( 136 | input_ids, vision_embeddings 137 | ) 138 | else: # 进入 decode 阶段, 无需再做视觉编码 139 | inputs_embeds = None 140 | 141 | hidden_states = self.language_model( 142 | input_ids=input_ids, 143 | position_ids=position_ids, 144 | atten_info=atten_info, 145 | inputs_embeds=inputs_embeds, 146 | ) 147 | 148 | return hidden_states 149 | -------------------------------------------------------------------------------- /lite_llama/utils/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time, os 3 | import subprocess 4 | from typing import List, Optional 5 | 6 | 7 | def read_json(json_path): 8 | with open(json_path, "r") as json_file: 9 | data = json.load(json_file) 10 | return data 11 | 12 | 13 | def read_jsonl(jsonl_path): 14 | with open(jsonl_path, "r", encoding="utf-8") as f: 15 | data = [json.loads(line) for line in f] 16 | return data 17 | 18 | 19 | def detect_device(): 20 | try: 21 | subprocess.check_output(["nvidia-smi"], stderr=subprocess.DEVNULL) 22 | return "nvidia" 23 | except: 24 | try: 25 | subprocess.check_output(["rocm-smi"], stderr=subprocess.DEVNULL) 26 | return "amd" 27 | except: 28 | return "cpu" 29 | 30 | 31 | def getTime(): 32 | return str(time.strftime("%m-%d %H:%M:%S", time.localtime())) 33 | 34 | 35 | def getProjectPath(): 36 | script_path = os.path.split(os.path.realpath(__file__))[0] 37 | return os.path.abspath(os.path.join(script_path, "..")) 38 | 39 | 40 | def get_gpu_memory(gpu_type="amd", device_id="0"): 41 | try: 42 | if gpu_type == "amd": 43 | result = subprocess.run( 44 | ["rocm-smi", "--showmeminfo", "vram", device_id], 45 | stdout=subprocess.PIPE, 46 | stderr=subprocess.PIPE, 47 | text=True, 48 | ) 49 | for line in result.stdout.splitlines(): 50 | if "VRAM Total Used Memory" in line: 51 | used = line.split(":")[-1].strip().split()[0] 52 | return float(used) / (10**9) # Convert MiB to GiB 53 | elif gpu_type == "nvidia": 54 | result = subprocess.run( 55 | [ 56 | "nvidia-smi", 57 | "--query-gpu=memory.used", 58 | "--format=csv,nounits,noheader", 59 | "-i", 60 | device_id, 61 | ], 62 | stdout=subprocess.PIPE, 63 | stderr=subprocess.PIPE, 64 | text=True, 65 | ) 66 | return float(result.stdout.strip()) / 1024 # Convert MiB to GiB 67 | elif gpu_type == "cpu": 68 | return None 69 | except Exception as e: 70 | from utils.logger import log 71 | 72 | log.warning(f"Unable to fetch GPU memory: {e}") 73 | return None 74 | 75 | 76 | def count_tokens(texts: List[str], tokenizer) -> int: 77 | total_tokens = 0 78 | for t in texts: 79 | ids = tokenizer(t, add_special_tokens=False)["input_ids"] 80 | total_tokens += len(ids) 81 | return total_tokens 82 | 83 | 84 | def get_model_type(checkpoint_path: str) -> str | None: 85 | from utils.logger import log 86 | 87 | model_type = ["llama", "falcon", "mpt", "qwen2", "llava"] 88 | 89 | config_content = read_json(os.path.join(checkpoint_path, "config.json")) 90 | for m in model_type: 91 | if m in config_content["model_type"].lower(): 92 | if m == "llava": 93 | return "llama" 94 | return m 95 | log.error(f"No model type found: {checkpoint_path}") 96 | return None 97 | -------------------------------------------------------------------------------- /lite_llama/utils/config_convert.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from transformers import LlavaConfig 3 | 4 | # sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 5 | from ..models.model_config import LlamaConfig 6 | 7 | 8 | def convert_transformers_to_custom_config( 9 | transformers_config: transformers.LlamaConfig, 10 | ) -> LlamaConfig: 11 | # 将 transformers 配置转换为字典 12 | config_dict = transformers_config.to_dict() 13 | print("transformers.LlamaConfig dict: ", config_dict) 14 | 15 | return LlamaConfig( 16 | _name_or_path=config_dict.get("_name_or_path"), 17 | architectures=config_dict.get("architectures", ["LlamaForCausalLM"]), 18 | max_position_embeddings=config_dict.get("max_position_embeddings", 4096), 19 | model_type=config_dict.get("model_type", "llama"), 20 | rms_norm_eps=config_dict.get("rms_norm_eps", 1e-5), 21 | torch_dtype=config_dict.get("torch_dtype", "float16"), 22 | vocab_size=config_dict.get("vocab_size", 32064), 23 | hidden_size=config_dict.get("hidden_size", 4096), 24 | intermediate_size=config_dict.get("intermediate_size", 11008), 25 | num_hidden_layers=config_dict.get("num_hidden_layers", 32), 26 | num_attention_heads=config_dict.get("num_attention_heads", 32), 27 | num_key_value_heads=config_dict.get("num_key_value_heads", None), 28 | ) 29 | custom_config = LlamaConfig(config_dict=config_dict) 30 | return custom_config 31 | 32 | 33 | if __name__ == "__main__": 34 | # 加载 transformers 的 LlamaConfig(请替换为实际模型名称) 35 | model_path = "/gemini/code/liuhaotian/llava-v1.5-7b" 36 | transformers_config = LlavaConfig.from_pretrained(model_path) 37 | 38 | # 转换为自定义配置 39 | custom_llama_config = convert_transformers_to_custom_config( 40 | transformers_config.text_config 41 | ) 42 | 43 | # 打印自定义配置 44 | # print(json.dumps(custom_llama_config, indent=4, ensure_ascii=False)) 45 | print(custom_llama_config) 46 | 47 | """ 48 | lamaConfig(architectures=None, attention_bias=False, attention_dropout=0.0, bos_token_id=1, eos_token_id=2, head_dim=128, hidden_act='silu', 49 | initializer_range=0.02, hidden_size=4096, intermediate_size=11008, max_position_embeddings=2048, mlp_bias=False, model_type='llama', 50 | num_heads=32, num_layers=32, num_kv_heads=32, pretraining_tp=1, rms_norm_eps=1e-06, rope_scaling=None, rope_theta=10000.0, 51 | tie_word_embeddings=False, torch_dtype=None, transformers_version='4.40.2', use_cache=True, vocab_size=32000, max_batch_size=4, 52 | max_seq_len=2048, device='cuda') 53 | """ 54 | -------------------------------------------------------------------------------- /lite_llama/utils/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = 32000 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | 15 | LLAVA_IGNORE_INDEX = -100 16 | LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = 32000 17 | LLAVA_DEFAULT_IMAGE_TOKEN = "" 18 | LLAVA_DEFAULT_IM_TOKEN_PLACE_HOLDER = "" 19 | LLAVA_DEFAULT_IMAGE_PATCH_TOKEN = "" 20 | LLAVA_DEFAULT_IM_START_TOKEN = "" 21 | LLAVA_DEFAULT_IM_END_TOKEN = "" 22 | -------------------------------------------------------------------------------- /lite_llama/utils/file_interface.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_model_name_from_path(model_path): 5 | model_path = model_path.strip("/") 6 | model_paths = model_path.split("/") 7 | if model_paths[-1].startswith("checkpoint-"): 8 | return model_paths[-2] + "_" + model_paths[-1] 9 | else: 10 | return model_paths[-1] 11 | -------------------------------------------------------------------------------- /lite_llama/utils/image_process.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/haotian-liu/LLaVA 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from PIL import Image 18 | from io import BytesIO 19 | import requests 20 | import os 21 | import base64 22 | 23 | 24 | def load_image_from_base64(image): 25 | return Image.open(BytesIO(base64.b64decode(image))) 26 | 27 | 28 | def load_image(image_file): 29 | if image_file.startswith("http://") or image_file.startswith("https://"): 30 | response = requests.get(image_file) 31 | image = Image.open(BytesIO(response.content)).convert("RGB") 32 | else: 33 | image = Image.open(image_file).convert("RGB") 34 | return image 35 | 36 | 37 | def load_images(image_files): 38 | out = [] 39 | for image_file in image_files: 40 | image = load_image(image_file) 41 | out.append(image) 42 | return out 43 | 44 | 45 | def vis_images(image_files): 46 | if len(image_files) == 1: 47 | image = image_files[0] 48 | os.system( 49 | f"termvisage --query-timeout 1 -H left --height 40 --oversize {image}" 50 | ) # --height 50:设置图片高度为 500 行。 51 | 52 | else: 53 | # Concat images 54 | system_inst = "convert " 55 | inst_template1 = " \\( {image} -background none -resize x{height} \\) " 56 | inst_template2 = ( 57 | " \\( {image} -background none -resize x{height} -splice 50x0 \\) " 58 | ) 59 | count = 0 60 | for image in image_files: 61 | with Image.open(image) as img: 62 | width, height = img.size # 查看尺寸 63 | print(f"{image} width and height is {width}, {height}") 64 | 65 | count += 1 66 | if count == 1: 67 | system_inst += inst_template1.format(image=image, height=height) 68 | else: 69 | system_inst += inst_template2.format(image=image, height=height) 70 | system_inst += " +append .vis.jpg" 71 | os.system(system_inst) 72 | 73 | os.system(f"termvisage --query-timeout 1 .vis.jpg -H left") 74 | 75 | 76 | def expand2square(pil_img, background_color): 77 | """ 78 | Copy from Llava codebase for image preprocessing. 79 | """ 80 | width, height = pil_img.size 81 | if width == height: 82 | return pil_img 83 | elif width > height: 84 | result = Image.new(pil_img.mode, (width, width), background_color) 85 | result.paste(pil_img, (0, (width - height) // 2)) 86 | return result 87 | else: 88 | result = Image.new(pil_img.mode, (height, height), background_color) 89 | result.paste(pil_img, ((height - width) // 2, 0)) 90 | return result 91 | 92 | 93 | def process_images(images, image_processor, model_cfg): 94 | """ 95 | Copy from Llava codebase for image preprocessing. 96 | """ 97 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 98 | new_images = [] 99 | if image_aspect_ratio == "pad": 100 | for image in images: 101 | image = expand2square( 102 | image, tuple(int(x * 255) for x in image_processor.image_mean) 103 | ) 104 | image = image_processor.preprocess(image, return_tensors="pt")[ 105 | "pixel_values" 106 | ][0] 107 | if "intern" in image_processor.__class__.__name__.lower(): 108 | # special case 109 | new_images.append(image.unsqueeze(0)) 110 | else: 111 | new_images.append(image) 112 | else: 113 | ret = image_processor(images, return_tensors="pt")["pixel_values"] 114 | if "intern" in image_processor.__class__.__name__.lower(): 115 | # special case 116 | ret = [x.unsqueeze(0) for x in ret] 117 | return ret 118 | if all(x.shape == new_images[0].shape for x in new_images): 119 | new_images = torch.stack(new_images, dim=0) 120 | 121 | return new_images 122 | -------------------------------------------------------------------------------- /lite_llama/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | import os 4 | import sys 5 | import time 6 | 7 | import logging 8 | from .common import getProjectPath 9 | 10 | __all__ = ["log", "logE", "logP", "logU"] 11 | 12 | # Set up the logger 13 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) 14 | 15 | # These are the sequences need to get colored ouput 16 | RESET_SEQ = "\033[0m" 17 | COLOR_SEQ = "\033[1;%dm" 18 | BOLD_SEQ = "\033[1m" 19 | 20 | COLORS = { 21 | "WARNING": YELLOW, 22 | "INFO": GREEN, 23 | "DEBUG": BLUE, 24 | "CRITICAL": YELLOW, 25 | "ERROR": RED, 26 | } 27 | 28 | LEVEL_SIM = { 29 | "WARNING": "[W]", 30 | "INFO": "[I]", 31 | "DEBUG": "[D]", 32 | "CRITICAL": "[C]", 33 | "ERROR": "[E]", 34 | } 35 | 36 | 37 | def formatter_message(message, use_color=True): 38 | if use_color: 39 | message = message.replace("$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ) 40 | else: 41 | message = message.replace("$RESET", "").replace("$BOLD", "") 42 | return message 43 | 44 | 45 | class ColoredFormatter(logging.Formatter): 46 | def __init__(self, msg, use_color=True): 47 | logging.Formatter.__init__(self, msg, datefmt="%m-%d %H:%M:%S") 48 | self.use_color = use_color 49 | 50 | def format(self, record): 51 | levelname = record.levelname 52 | if self.use_color and levelname in COLORS: 53 | simple_ln = LEVEL_SIM.get(levelname) 54 | levelname_color = ( 55 | COLOR_SEQ % (30 + COLORS[levelname]) + simple_ln + RESET_SEQ 56 | ) 57 | record.levelname = levelname_color 58 | return logging.Formatter.format(self, record) 59 | 60 | 61 | # Custom logger class with multiple destinations 62 | class ColoredLogger(logging.Logger): 63 | FORMAT = ( 64 | "%(asctime)s $RESET%(levelname)s %(filename)s$RESET:%(lineno)d %(message)s " 65 | ) 66 | COLOR_FORMAT = formatter_message(FORMAT, True) 67 | 68 | def __init__(self, name): 69 | logging.Logger.__init__(self, name, logging.ERROR) 70 | 71 | color_formatter = ColoredFormatter(self.COLOR_FORMAT) 72 | 73 | console = logging.StreamHandler() 74 | console.setFormatter(color_formatter) 75 | 76 | self.addHandler(console) 77 | return 78 | 79 | def loggerHandle(): 80 | logging.setLoggerClass(ColoredLogger) 81 | logger = logging.getLogger(__name__) 82 | logger.setLevel(logging.INFO) 83 | return logger 84 | 85 | 86 | def logfileHandle(log_name="../logs/common.log"): 87 | project_path = getProjectPath() 88 | log_file = os.path.join(project_path, log_name) 89 | if not os.path.exists(os.path.join(project_path, "../logs")): 90 | os.makedirs(os.path.join(project_path, "../logs")) 91 | if not os.path.exists(log_file): 92 | os.mknod(log_file) 93 | logfile = logging.getLogger() 94 | logfile.setLevel(logging.DEBUG) 95 | handler = logging.FileHandler(log_file, encoding="UTF-8") 96 | formatter = logging.Formatter( 97 | "%(asctime)s %(levelname)s %(filename)s:%(lineno)d %(message)s", 98 | datefmt="%m-%d %H:%M:%S", 99 | ) 100 | handler.setFormatter(formatter) 101 | logfile.addHandler(handler) 102 | return logfile 103 | 104 | 105 | log = loggerHandle() 106 | logE = logfileHandle("../logs/error.log") 107 | logP = logfileHandle("../logs/post.log") 108 | logU = logfileHandle("../logs/upload_data.log") 109 | 110 | if __name__ == "__main__": 111 | 112 | logging.setLoggerClass(ColoredLogger) 113 | logger = logging.getLogger(__name__) 114 | logger.setLevel(logging.DEBUG) 115 | 116 | logger.debug("\033[1;32mMessage Error\033[0m") 117 | logger.info("test") 118 | logger.warning("test") 119 | logger.error("test") 120 | time.sleep(10) 121 | logger.info("aaaaa") 122 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | tokenizers==0.20.3 2 | huggingface-hub==0.24.6 3 | transformers==4.46.3 4 | torch>=2.2.1 5 | triton>=2.2.0 6 | tqdm==4.65.0 7 | pytest==8.3.3 8 | pynvml==11.5.0 9 | protobuf==3.20.0 10 | numpy==1.26.4 11 | aiohttp==3.9.5 12 | termvisage==0.2.0 13 | rich==13.7.1 14 | termvisage==0.2.0 15 | accelerate==1.6.0 16 | sentence-transformers==4.0.2 17 | jsonargparse==4.38.0 -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harleyszhang/lite_llama/7e53e5e6701137e251ed242bf4395bcc672438ce/tests/__init__.py -------------------------------------------------------------------------------- /tests/kernels/softmax_native.py: -------------------------------------------------------------------------------- 1 | import triton, torch 2 | import triton.language as tl 3 | 4 | 5 | def naive_softmax(x: torch.Tensor) -> torch.Tensor: 6 | """Compute row-wise softmax of X using native pytorch 7 | 8 | We subtract the maximum element in order to avoid overflows. Softmax is invariant to 9 | this shift. 10 | # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements 11 | """ 12 | x_max = x.max(dim=1)[0] # read MN elements ; write M elements 13 | safe_x = x - x_max[:, None] # read MN + M elements ; write MN elements 14 | numerator = torch.exp(safe_x) # read MN elements ; write MN elements 15 | denominator = numerator.sum(dim=1) # read MN elements ; write M elements 16 | ret = numerator / denominator[:, None] # read MN + M elements ; write MN elements 17 | 18 | return ret 19 | 20 | 21 | def online_softmax(x: torch.Tensor) -> torch.tensor: 22 | """Iterative calculation and 2.5x faster than native softmax""" 23 | row_cont, col_count = x.shape 24 | assert x.ndim == 2, f"only accepts 2D tensor now" 25 | output = torch.zeros_like(x) 26 | 27 | for r in range(row_cont): 28 | row_max = x[r][0] 29 | normalizer = 0 30 | for c in range(1, col_count): 31 | pre_max = row_max 32 | cur = x[r][c] 33 | row_max = max(pre_max, cur) 34 | # if cur > pre_max: 35 | # print(f"Update row max now is {row_max}, row = {r}") 36 | normalizer = normalizer * torch.exp(pre_max - row_max) + torch.exp( 37 | cur - row_max 38 | ) 39 | output[r, :] = torch.exp(x[r, :] - row_max) / normalizer 40 | 41 | return output 42 | 43 | 44 | @triton.jit 45 | def _softmax_kernel_fwd( 46 | input_ptr, 47 | stride_input_row, 48 | output_ptr, 49 | stride_output_row, 50 | num_cols, 51 | BLOCK_SIZE: tl.constexpr, 52 | ): 53 | # 1, setup input ptrs 54 | row_id = tl.program_id(axis=0) 55 | row_start_ptr = input_ptr + row_id * stride_input_row 56 | col_offsets = tl.arange(0, BLOCK_SIZE) 57 | input_pointers = row_start_ptr + col_offsets 58 | 59 | row_data_mask = col_offsets < num_cols 60 | 61 | # 2, move to SRAM 62 | x = tl.load(input_pointers, mask=row_data_mask, other=0.0) 63 | 64 | # 3, softmax cal itself 65 | safe_row = x - tl.max(x, axis=0) 66 | numerator = tl.exp(safe_row) 67 | denominator = tl.sum(numerator, axis=0) 68 | softmax_out = numerator / denominator 69 | 70 | # 4, write back to HBM 71 | output_row_ptr = output_ptr + row_id * stride_input_row 72 | output_pointers = output_row_ptr + col_offsets 73 | tl.store(output_pointers, softmax_out, mask=row_data_mask) 74 | 75 | 76 | @torch.no_grad() 77 | def softmax_native_fwd(x: torch.Tensor) -> torch.Tensor: 78 | """Triton impl of Softmax, onlay support 2D tensor in fwd""" 79 | rows, cols = x.shape 80 | assert x.ndim == 2, f"only accepts 2D tensor now" 81 | BLOCK_SIZE = triton.next_power_of_2(cols) 82 | num_warps = 4 83 | if BLOCK_SIZE >= 32768: 84 | num_warps = 32 85 | elif BLOCK_SIZE >= 8192: 86 | num_warps = 16 87 | elif BLOCK_SIZE >= 2048: 88 | num_warps = 8 89 | 90 | grid = (rows, 1) 91 | 92 | # allocate output buffer 93 | softmax_out = torch.empty_like(x) 94 | 95 | _softmax_kernel_fwd[grid]( 96 | x, 97 | x.stride(0), # input row stride 98 | softmax_out, 99 | softmax_out.stride(0), 100 | cols, 101 | BLOCK_SIZE=BLOCK_SIZE, 102 | num_warps=num_warps, 103 | ) 104 | 105 | return softmax_out 106 | -------------------------------------------------------------------------------- /tests/kernels/softmax_split.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | 6 | @triton.jit 7 | def logsumexp_kernel( 8 | out_ptr, 9 | in_ptr, 10 | M, 11 | N, 12 | TILE_N: tl.constexpr, 13 | ): 14 | pid_n = tl.program_id(0) 15 | num_programs_n = tl.num_programs(0) 16 | pid_m = tl.program_id(1) 17 | 18 | n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) 19 | mask = n_offsets < N 20 | offset = pid_m * N + n_offsets 21 | inp = tl.load(in_ptr + offset, mask=mask, other=-float("inf")).to( 22 | out_ptr.dtype.element_ty 23 | ) 24 | m = tl.max(inp, 0) 25 | e = tl.exp(inp - m) 26 | z = tl.sum(e, 0) 27 | logz = m + tl.log(z) 28 | 29 | output_ptrs = out_ptr + pid_m * num_programs_n + pid_n 30 | tl.store(output_ptrs, logz) 31 | 32 | 33 | @triton.jit 34 | def combine_logsumexp_kernel(out_ptr, inp_ptr, M, N, TILE_N: tl.constexpr): 35 | pid_m = tl.program_id(0) 36 | n_offsets = tl.arange(0, TILE_N) 37 | mask = n_offsets < N 38 | logzs = tl.load(inp_ptr + pid_m * N + n_offsets, other=-float("inf"), mask=mask).to( 39 | out_ptr.dtype.element_ty 40 | ) 41 | m = tl.max(logzs, 0) 42 | e = tl.exp(logzs - m) 43 | z = tl.sum(e, 0) 44 | logz = m + tl.log(z) 45 | tl.store(out_ptr + pid_m, logz) 46 | 47 | 48 | @triton.jit 49 | def softmax_kernel(out_ptr, in_ptr, logz_ptr, M, N, TILE_N: tl.constexpr): 50 | pid_n = tl.program_id(0) 51 | pid_m = tl.program_id(1) 52 | n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) 53 | offset = pid_m * N + n_offsets 54 | mask = n_offsets < N 55 | inp = tl.load(in_ptr + offset, mask=mask, other=-float("inf")).to( 56 | out_ptr.dtype.element_ty 57 | ) 58 | logz = tl.load(logz_ptr + pid_m).to(out_ptr.dtype.element_ty) 59 | out = tl.exp(inp - logz) 60 | tl.store(out_ptr + offset, out, mask=mask) 61 | 62 | 63 | def softmax_split(x): 64 | M, N = x.shape 65 | 66 | num_sms = torch.cuda.get_device_properties(x.device).multi_processor_count 67 | 68 | TILE_N = min(4096, triton.next_power_of_2(N)) 69 | num_tiles_n = triton.cdiv(N, TILE_N) 70 | logz = torch.empty((M, num_tiles_n), dtype=x.dtype, device=x.device) 71 | grid = (num_tiles_n, M, 1) 72 | logsumexp_kernel[grid](logz, x, M, N, TILE_N) 73 | 74 | combined_logz = torch.empty((M,), dtype=x.dtype, device=x.device) 75 | TILE_N = triton.next_power_of_2(num_tiles_n) 76 | grid = (M, 1, 1) 77 | combine_logsumexp_kernel[grid](combined_logz, logz, M, num_tiles_n, TILE_N) 78 | 79 | out = torch.empty_like(x) 80 | TILE_N = min(4096, triton.next_power_of_2(N)) 81 | num_tiles_n = triton.cdiv(N, TILE_N) 82 | grid = (num_tiles_n, M, 1) 83 | softmax_kernel[grid](out, x, combined_logz, M, N, TILE_N) 84 | return out 85 | -------------------------------------------------------------------------------- /tests/kernels/test_attention.py: -------------------------------------------------------------------------------- 1 | import torch, os, sys 2 | 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 4 | from lite_llama.models.llama import * 5 | from lite_llama.tests.test_torch_rope import apply_rotary_emb 6 | 7 | 8 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 9 | """同一组的 kv cache 复制多份""" 10 | batch_size, seq_len, num_kv_heads, head_dim = x.shape 11 | if n_rep == 1: 12 | return x 13 | return ( 14 | # (B, Seq_Len, num_kv_heads, 1, Head_Dim) 15 | x[:, :, :, None, :] 16 | # (B, Seq_Len, num_kv_heads, N_Rep, Head_Dim) 17 | .expand(batch_size, seq_len, num_kv_heads, n_rep, head_dim) 18 | # (B, Seq_Len, num_kv_heads * N_Rep, Head_Dim) 19 | .reshape(batch_size, seq_len, num_kv_heads * n_rep, head_dim) 20 | ) 21 | 22 | 23 | class ModelArgs: 24 | def __init__(self): 25 | self.dim = 64 # 模型维度 26 | self.n_heads = 8 # 头数 27 | self.n_kv_heads = 8 # 将 n_kv_heads 设置为 n_heads 28 | self.max_batch_size = 2 29 | self.max_seq_len = 16 30 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | 32 | 33 | class FusedAttention(nn.Module): 34 | def __init__(self, args): 35 | super().__init__() 36 | self.args = args 37 | 38 | device = args.device 39 | 40 | # K V 头数相同,但和 Q 可能不同 41 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 42 | self.n_heads_q = args.n_heads 43 | self.n_rep = self.n_heads_q // self.n_kv_heads # kv 重复次数 44 | 45 | # 每个头的维度大小 46 | self.head_dim = args.dim // args.n_heads 47 | self.hidden_size = args.n_heads * self.head_dim 48 | 49 | # 定义线性层,并移动到设备 50 | self.wq = nn.Linear( 51 | args.dim, self.n_heads_q * self.head_dim, bias=False, dtype=torch.float16 52 | ).to(device) 53 | self.wk = nn.Linear( 54 | args.dim, self.n_kv_heads * self.head_dim, bias=False, dtype=torch.float16 55 | ).to(device) 56 | self.wv = nn.Linear( 57 | args.dim, self.n_kv_heads * self.head_dim, bias=False, dtype=torch.float16 58 | ).to(device) 59 | self.wo = nn.Linear( 60 | self.n_heads_q * self.head_dim, args.dim, bias=False, dtype=torch.float16 61 | ).to(device) 62 | 63 | # 提前按最大可分配空间分配好 kv cache 张量,并注册为 buffer 64 | self.register_buffer( 65 | "cache_k", 66 | torch.zeros( 67 | (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), 68 | dtype=torch.float16, 69 | device=device, 70 | ), 71 | persistent=False, 72 | ) 73 | self.register_buffer( 74 | "cache_v", 75 | torch.zeros( 76 | (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), 77 | dtype=torch.float16, 78 | device=device, 79 | ), 80 | persistent=False, 81 | ) 82 | 83 | def forward(self, x: torch.Tensor, start_pos: int): 84 | batch_size, seq_len, _ = ( 85 | x.shape 86 | ) # prefill: (B, Seq_Len, Dim); decode: (B, 1, Dim) 87 | 88 | x = x.to(torch.float16) # 确保输入为 float16 89 | 90 | # 1. 计算 Q K V 并且 reshape 91 | xq = self.wq(x).view(batch_size, seq_len, self.n_heads_q, self.head_dim) 92 | xk = self.wk(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim) 93 | xv = self.wv(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim) 94 | 95 | # 2. 计算 RoPE 位置编码 96 | freqs_cis = precompute_freqs_cis( 97 | dim=self.head_dim, seq_len=seq_len, device=x.device 98 | ) 99 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 100 | 101 | # 3. 更新缓存 102 | self.cache_k[:batch_size, start_pos : start_pos + seq_len, :, :] = xk 103 | self.cache_v[:batch_size, start_pos : start_pos + seq_len, :, :] = xv 104 | 105 | # 4. 获取累积的 K V 106 | keys = self.cache_k[ 107 | :batch_size, : start_pos + seq_len, :, : 108 | ] # (B, Seq_Len_KV, H_KV, D) 109 | values = self.cache_v[ 110 | :batch_size, : start_pos + seq_len, :, : 111 | ] # (B, Seq_Len_KV, H_KV, D) 112 | 113 | # 5. GQA 114 | keys = repeat_kv(keys, self.n_rep) # (B, Seq_Len_KV, H_Q, D) 115 | values = repeat_kv(values, self.n_rep) # (B, Seq_Len_KV, H_Q, D) 116 | 117 | # 6. 转置以适应注意力计算 118 | xq = xq.transpose(1, 2) # (B, H_Q, Seq_Len_Q, D) 119 | keys = keys.transpose(1, 2) # (B, H_Q, Seq_Len_KV, D) 120 | values = values.transpose(1, 2) # (B, H_Q, Seq_Len_KV, D) 121 | 122 | # 7. 计算注意力得分 123 | scores = torch.matmul(xq, keys.transpose(-2, -1)) / math.sqrt( 124 | self.head_dim 125 | ) # (B, H_Q, Seq_Len_Q, Seq_Len_KV) 126 | 127 | # 8. 应用因果掩码 128 | seq_len_q = xq.shape[2] 129 | # seq_len_kv = keys.shape[2] 130 | # causal_mask = torch.tril(torch.ones((seq_len_q, seq_len_kv), device=x.device, dtype=torch.bool)) 131 | # causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, Seq_Len_Q, Seq_Len_KV) 132 | # scores = scores.masked_fill(~causal_mask, float('-inf')) 133 | 134 | # 9. 计算注意力权重并应用 135 | attn_weights = F.softmax(scores, dim=-1) 136 | attn_output = torch.matmul(attn_weights, values) # (B, H_Q, Seq_Len_Q, D) 137 | 138 | # 10. 合并 heads 并输出 139 | attn_output = ( 140 | attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1) 141 | ) # (B, Seq_Len_Q, H_Q * D) 142 | output = self.wo(attn_output) 143 | 144 | return output 145 | 146 | 147 | def test_fused_attention(): 148 | # 模型参数 149 | args = ModelArgs() 150 | 151 | # 创建测试输入 152 | batch_size = 2 153 | seq_len = 10 154 | dim = args.dim 155 | 156 | # 使用 float16 数据类型并移动到设备 157 | x = torch.randn(batch_size, seq_len, dim, dtype=torch.float16, device=args.device) 158 | 159 | # 初始化自定义的 FusedAttention,并移动到设备 160 | fused_attention = FusedAttention(args).to(args.device) 161 | 162 | # 初始化 PyTorch 的 MultiheadAttention,并移动到设备 163 | mha = nn.MultiheadAttention( 164 | embed_dim=dim, num_heads=args.n_heads, batch_first=True, dtype=torch.float16 165 | ).to(args.device) 166 | 167 | # 同步权重 168 | with torch.no_grad(): 169 | # 将 FusedAttention 的权重复制到 MultiheadAttention 170 | mha.in_proj_weight.copy_( 171 | torch.cat( 172 | [ 173 | fused_attention.wq.weight, 174 | fused_attention.wk.weight, 175 | fused_attention.wv.weight, 176 | ], 177 | dim=0, 178 | ) 179 | ) 180 | 181 | # 设置输出投影权重 182 | mha.out_proj.weight.copy_(fused_attention.wo.weight) 183 | mha.out_proj.bias.zero_() # 假设没有偏置 184 | 185 | # 前向传播 186 | fused_output = fused_attention(x, start_pos=0) 187 | mha_output, _ = mha(x, x, x, need_weights=False) 188 | 189 | # 比较输出 190 | difference = torch.abs(fused_output - mha_output).mean().item() 191 | print( 192 | f"Average difference between FusedAttention and MultiheadAttention: {difference}" 193 | ) 194 | 195 | # 断言差异在可接受范围内 196 | assert difference < 1e-1, "FusedAttention output does not match MultiheadAttention" 197 | 198 | print("FusedAttention test passed!") 199 | 200 | 201 | if __name__ == "__main__": 202 | test_fused_attention() 203 | -------------------------------------------------------------------------------- /tests/kernels/test_available_blocks.py: -------------------------------------------------------------------------------- 1 | import torch, gc 2 | from typing import List, Tuple 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM 4 | import logging, json, os, sys 5 | 6 | # 获取 lite_llama 目录的绝对路径并添加到 sys.path 中 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 8 | from lite_llama.models.model_config import LlamaConfig 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def load_config_from_json(json_file_path: str, device: str = "cuda") -> LlamaConfig: 14 | with open(json_file_path, "r") as f: 15 | config_dict = json.load(f) 16 | config = LlamaConfig(config_dict, max_seq_len=2048, device=device) 17 | return config 18 | 19 | 20 | def _get_cache_block_size(model_config, block_size: int = 1) -> int: 21 | head_size = model_config.head_dim 22 | num_heads = model_config.num_kv_heads 23 | num_attention_layers = model_config.num_layers 24 | 25 | key_cache_block = block_size * num_heads * head_size 26 | value_cache_block = key_cache_block 27 | total = num_attention_layers * (key_cache_block + value_cache_block) 28 | dtype_size = 2 # torch.float16 29 | 30 | return dtype_size * total 31 | 32 | 33 | @torch.inference_mode() 34 | def determine_num_available_blocks( 35 | model_config, gpu_memory_utilization=0.9 36 | ) -> Tuple[int, int]: 37 | """ 38 | 评估模型的峰值内存使用情况,以确定在不发生内存溢出的情况下可以分配的 KV(键值)缓存块的数量。 39 | 40 | 该方法首先清理 CUDA 缓存,然后使用虚拟输入执行一次前向传播,以评估模型的内存使用情况。 41 | 接着,计算在剩余可用内存下,最多可以分配的 GPU 和 CPU 缓存块数量。 42 | 43 | 提示: 44 | 可以通过调整 `gpu_memory_utilization` 参数来限制 GPU 内存的使用。 45 | """ 46 | # 清理 CUDA 缓存,以确保获取准确的内存使用信息 47 | torch.cuda.empty_cache() 48 | 49 | # 使用虚拟输入执行一次前向传播,以评估模型的内存使用情况 50 | 51 | # 同步 CUDA 操作,确保内存信息准确 52 | torch.cuda.synchronize() 53 | # 获取当前 GPU 的空闲内存和总内存(单位:字节) 54 | free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() 55 | # 计算模型加载后的峰值内存使用量 56 | # Get the peak memory allocation recorded by torch 57 | peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] 58 | 59 | # 清理未使用的缓存,计算非Torch分配的内存 60 | torch.cuda.empty_cache() 61 | torch_allocated_bytes = torch.cuda.memory_stats()["allocated_bytes.all.current"] 62 | 63 | total_allocated_bytes = torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0] 64 | non_torch_allocations = total_allocated_bytes - torch_allocated_bytes 65 | 66 | if non_torch_allocations > 0: 67 | peak_memory += non_torch_allocations 68 | 69 | available_kv_cache_memory = total_gpu_memory * gpu_memory_utilization - peak_memory 70 | 71 | # 计算每个缓存块的大小 72 | cache_block_size = _get_cache_block_size(model_config) 73 | # 计算在剩余可用内存下,最多可以分配的 GPU 缓存块数量 74 | num_gpu_blocks = int( 75 | (total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size 76 | ) 77 | # 确保缓存块数量不为负数 78 | num_gpu_blocks = max(num_gpu_blocks, 0) 79 | 80 | logger.info( 81 | "Memory profiling results: total_gpu_memory=%.2fGiB \n" 82 | " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB \n" 83 | " memory_usage_post_profile=%.2fGib \n" 84 | " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB \n" 85 | " gpu_memory_utilization=%.2f", 86 | total_gpu_memory / (1024**3), 87 | (total_gpu_memory - free_memory_pre_profile) / (1024**3), 88 | (peak_memory - non_torch_allocations) / (1024**3), 89 | total_allocated_bytes / (1024**3), 90 | non_torch_allocations / (1024**3), 91 | available_kv_cache_memory / (1024**3), 92 | gpu_memory_utilization, 93 | ) 94 | 95 | # 进行垃圾回收,释放未使用的内存 96 | gc.collect() 97 | # 再次清理 CUDA 缓存 98 | torch.cuda.empty_cache() 99 | # 返回可分配的 GPU 和 CPU 缓存块数量(此处 CPU 块数量为 0) 100 | 101 | return num_gpu_blocks, 0 102 | 103 | 104 | def load_original_llama(model_name_or_path: str, device: str = "cuda"): 105 | # config = LlamaConfig.from_pretrained(model_name_or_path) 106 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 107 | model = AutoModelForCausalLM.from_pretrained( 108 | model_name_or_path, 109 | torch_dtype=torch.float16, 110 | device_map="cuda", 111 | ) 112 | model.to(device) 113 | return model, tokenizer 114 | 115 | 116 | if __name__ == "__main__": 117 | # 定义模型权重路径及配置参数 118 | device = "cuda" if torch.cuda.is_available() else "cpu" 119 | original_model_path = "/gemini/code/Llama-3.2-1B-Instruct" 120 | # 加载原始模型 121 | original_model, tokenizer = load_original_llama(original_model_path, device) 122 | # 定义模型配置参数 123 | json_file_path = ( 124 | "/gemini/code/Llama-3.2-1B-Instruct/my_weight/config.json" # JSON 文件的路径 125 | ) 126 | model_config = load_config_from_json(json_file_path, device) # 加载配置 127 | determine_num_available_blocks(model_config) 128 | -------------------------------------------------------------------------------- /tests/kernels/test_flashdecoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys, os, math 3 | import torch.nn.functional as F 4 | 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 6 | from lite_llama.kernels.flashattentionv2 import flash_decoding 7 | 8 | 9 | def standard_attention(Q, K, V, sm_scale, mask=None): 10 | """ 11 | 标准的 PyTorch 实现的自注意力机制。 12 | 13 | Args: 14 | Q (torch.Tensor): 查询张量,形状 (batch_size, num_heads, seq_length, head_dim) 15 | K (torch.Tensor): 键张量,形状 (batch_size, num_heads, seq_length, head_dim) 16 | V (torch.Tensor): 值张量,形状 (batch_size, num_heads, seq_length, head_dim) 17 | sm_scale (float): Softmax 缩放因子 18 | mask (torch.Tensor, optional): 遮罩张量,形状 (batch_size, num_heads, seq_length, seq_length) 19 | 20 | Returns: 21 | torch.Tensor: 注意力输出,形状与 Q 相同 22 | """ 23 | print( 24 | f"K V cache tensor have 0 numbers is ", 25 | torch.nonzero(K == 0).numel(), 26 | torch.nonzero(V == 0).numel(), 27 | ) 28 | # 计算 QK^T 29 | attn_scores = ( 30 | torch.matmul(Q, K.transpose(-2, -1)) * sm_scale 31 | ) # (batch_size, num_heads, seq_length, seq_length) 32 | 33 | if mask is not None: 34 | attn_scores = attn_scores.masked_fill(mask == 0, float("-inf")) 35 | 36 | attn_weights = F.softmax(attn_scores, dim=-1) 37 | 38 | # 计算注意力输出 39 | out = torch.matmul(attn_weights, V) # (batch_size, num_heads, seq_length, head_dim) 40 | 41 | return out 42 | 43 | 44 | def test_decode_stage(debug_out_text): 45 | # 设置测试参数 46 | batch_size = 4 47 | num_heads = 32 48 | # 使用 padattention,所以 batch 中的每个 seq 长度相同 49 | kv_cache_seq_length = 512 50 | generated_seq_length = 16 51 | head_dim = 64 52 | dtype = torch.float16 # 改为 float32 53 | 54 | # 生成固定的初始输入张量 55 | torch.manual_seed(0) 56 | 57 | # torch_q = torch.randn(batch_size, num_heads, initial_seq_length, head_dim, device='cuda', dtype = dtype) 58 | torch_k_cache = torch.randn( 59 | batch_size, num_heads, kv_cache_seq_length, head_dim, device="cuda", dtype=dtype 60 | ) 61 | torch_v_cache = torch.randn( 62 | batch_size, num_heads, kv_cache_seq_length, head_dim, device="cuda", dtype=dtype 63 | ) 64 | 65 | # triton_q = torch_q.transpose(1, 2).view(-1, num_heads, head_dim) 66 | triton_k_cache = torch_k_cache.transpose(1, 2).reshape(-1, num_heads, head_dim) 67 | triton_v_cache = torch_v_cache.transpose(1, 2).reshape(-1, num_heads, head_dim) 68 | print(f"triton_k_cache shape is ", triton_k_cache.shape) 69 | 70 | torch_new_token_q = torch.randn( 71 | batch_size, num_heads, 1, head_dim, device="cuda", dtype=dtype 72 | ) 73 | triton_new_token_q = torch_new_token_q.transpose(1, 2).reshape( 74 | -1, num_heads, head_dim 75 | ) 76 | print(f"triton_new_token_q shape is ", triton_new_token_q.shape) 77 | 78 | # 初始化线性层,用于生成 Q、K、V. 为了测试,这里使用随机的线性层参数 79 | q_linear = torch.nn.Linear(head_dim, num_heads * head_dim, bias=False).to( 80 | "cuda", dtype=dtype 81 | ) 82 | k_linear = torch.nn.Linear(head_dim, num_heads * head_dim, bias=False).to( 83 | "cuda", dtype=dtype 84 | ) 85 | v_linear = torch.nn.Linear(head_dim, num_heads * head_dim, bias=False).to( 86 | "cuda", dtype=dtype 87 | ) 88 | 89 | # 模拟生成过程中逐步增加序列长度 90 | for step in range(1, generated_seq_length + 1): 91 | # 扩展 Q, K, V 和 Out 92 | # q_extended = torch.cat([q_initial, new_token_q], dim=2) 93 | 94 | # 计算 Softmax 缩放因子 95 | sm_scale_extended = 1.0 / math.sqrt(head_dim) 96 | 97 | # 计算 Triton 内核输出 98 | 99 | triton_new_token_q = flash_decoding( 100 | triton_new_token_q, 101 | triton_k_cache, 102 | triton_v_cache, 103 | actual_seq_len=kv_cache_seq_length, 104 | ) 105 | 106 | # 使用标准 PyTorch 实现计算扩展后的注意力输出 107 | torch_new_token_q = standard_attention( 108 | torch_new_token_q, torch_k_cache, torch_v_cache, sm_scale_extended 109 | ) 110 | 111 | # 生成新的 token 112 | triton_k_cache = torch.cat([triton_k_cache, triton_new_token_q], dim=0) 113 | triton_v_cache = torch.cat([triton_v_cache, triton_new_token_q], dim=0) 114 | 115 | torch_k_cache = torch.cat([torch_k_cache, torch_new_token_q], dim=2) 116 | torch_v_cache = torch.cat([torch_v_cache, torch_new_token_q], dim=2) 117 | kv_cache_seq_length += 1 118 | 119 | torch_new_token_q_format = ( 120 | torch_new_token_q.transpose(1, 2).contiguous().view(-1, num_heads, head_dim) 121 | ) 122 | 123 | debug_out_text1 = debug_out_text.format(step=step, kernel_type="torch") 124 | debug_out_text2 = debug_out_text.format(step=step, kernel_type="triton") 125 | with open(debug_out_text1, "w") as f: 126 | f.write(str(torch_new_token_q_format)) 127 | 128 | with open(debug_out_text2, "w") as f: 129 | f.write(str(triton_new_token_q)) 130 | 131 | # 比较 Triton 内核输出与标准实现的输出 132 | if torch.allclose(triton_new_token_q, torch_new_token_q_format, atol=1e-1): 133 | max_difference = (triton_new_token_q - torch_new_token_q_format).abs().max() 134 | print( 135 | f"Decode Stage Step {step} Difference {max_difference} Test Passed: Triton output matches PyTorch standard implementation." 136 | ) 137 | else: 138 | max_diff = (triton_new_token_q - torch_new_token_q_format).abs().max() 139 | print( 140 | f"Decode Stage Step {step} Test Failed: Maximum difference {max_diff}" 141 | ) 142 | # 可选择打印更多信息进行调试 143 | break # 根据需要是否停止测试 144 | 145 | 146 | if __name__ == "__main__": 147 | debug_out_text = ( 148 | "/gemini/code/lite_llama/test/debug/{step}_{kernel_type}_decode_out_tensor.txt" 149 | ) 150 | print("\nRunning Decode Stage Test...") 151 | test_decode_stage(debug_out_text) 152 | -------------------------------------------------------------------------------- /tests/kernels/test_flashdecoding_stage1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | import torch 6 | import triton 7 | import triton.language as tl 8 | 9 | 10 | @triton.jit 11 | def _flash_decoding_stage1_kernel( 12 | Q, 13 | K, 14 | V, 15 | sm_scale, 16 | actual_seq_len, # 实际序列长度 17 | Mid_O, 18 | Mid_O_LogExpSum, 19 | q_bs_stride, 20 | q_heads_stride, 21 | q_dim_stride, # Q 的 strides 22 | k_bs_stride, 23 | k_heads_stride, 24 | k_dim_stride, # K 的 strides 25 | v_bs_stride, 26 | v_heads_stride, 27 | v_dim_stride, # V 的 strides 28 | mido_batch_stride, 29 | mido_heads_stride, 30 | mido_partitions_stride, 31 | mido_dim_stride, 32 | mido_les_batch_stride, 33 | mido_les_heads_stride, 34 | mido_les_partitions_stride, 35 | BLOCK_SEQ: tl.constexpr, 36 | BLOCK_N: tl.constexpr, 37 | BLOCK_DMODEL: tl.constexpr, 38 | ): 39 | """Flash Attention Stage1 Triton Kernel""" 40 | # 获取当前程序的 block 在各个维度上的索引 41 | batch_idx = tl.program_id(0) 42 | head_idx = tl.program_id(1) 43 | seq_block_idx = tl.program_id(2) 44 | 45 | # 计算当前批次的起始位置 46 | cur_batch_start_loc = batch_idx * actual_seq_len 47 | 48 | # 计算当前分区的起始和结束索引 49 | cur_batch_partition_start_index = seq_block_idx * BLOCK_SEQ 50 | cur_batch_partition_end_index = tl.minimum( 51 | actual_seq_len, cur_batch_partition_start_index + BLOCK_SEQ 52 | ) 53 | 54 | # 计算需要处理的块数 55 | num_blocks = ( 56 | cur_batch_partition_end_index - cur_batch_partition_start_index + BLOCK_N - 1 57 | ) // BLOCK_N 58 | 59 | # 初始化偏移向量 60 | offs_n = cur_batch_partition_start_index + tl.arange(0, BLOCK_N) # [BLOCK_N] 61 | offs_d = tl.arange(0, BLOCK_DMODEL) # [BLOCK_DMODEL] 62 | 63 | # 计算 Q 的偏移量 64 | q_offs = batch_idx * q_bs_stride + head_idx * q_heads_stride + offs_d * q_dim_stride 65 | 66 | # 计算 K 和 V 的偏移量 67 | k_offs = ( 68 | (cur_batch_start_loc + offs_n[:, None]) * k_bs_stride 69 | + head_idx * k_heads_stride 70 | + offs_d[None, :] * k_dim_stride 71 | ) 72 | 73 | v_offs = ( 74 | (cur_batch_start_loc + offs_n[:, None]) * v_bs_stride 75 | + head_idx * v_heads_stride 76 | + offs_d[None, :] * v_dim_stride 77 | ) 78 | 79 | # 获取指针 80 | q_ptrs = Q + q_offs 81 | k_ptrs = K + k_offs 82 | v_ptrs = V + v_offs 83 | 84 | # 加载 Q 向量 85 | q = tl.load(q_ptrs) # [BLOCK_DMODEL] 86 | 87 | # 初始化归一化项和累加器 88 | d_i = 0.0 # 标量 89 | m_i = -float("inf") # 标量 90 | acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [BLOCK_DMODEL] 91 | 92 | # 迭代处理每个块 93 | for start_n in range(num_blocks): 94 | offs_n_new = start_n * BLOCK_N + offs_n # [BLOCK_N] 95 | # 生成 K 的掩码 96 | k_mask = offs_n_new < cur_batch_partition_end_index # [BLOCK_N] 97 | 98 | # 加载 K 和 V 99 | k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # [BLOCK_N, BLOCK_DMODEL] 100 | v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # [BLOCK_N, BLOCK_DMODEL] 101 | 102 | # 计算 qk^T 103 | qk = tl.sum(q * k, axis=1) # [BLOCK_N] 104 | qk = qk * sm_scale 105 | qk = tl.where(k_mask, qk, float("-inf")) # [BLOCK_N] 106 | 107 | # 更新最大值项和 qk 项 108 | current_max = tl.max(qk) # 标量 109 | m_ij = tl.maximum(m_i, current_max) # 标量 110 | qk = qk - m_ij # [BLOCK_N] 111 | 112 | # 更新归一化项 113 | p = tl.exp(qk) # [BLOCK_N] 114 | alpha = tl.exp(m_i - m_ij) # 标量 115 | d_i = d_i * alpha + tl.sum(p) # 标量 116 | 117 | # 更新 attention 输出累加器 118 | acc = acc * alpha + tl.sum(p[:, None] * v, axis=0) # [BLOCK_DMODEL] 119 | # acc = acc * alpha + tl.dot(p, v) # [BLOCK_DMODEL] 120 | 121 | # 更新归一化器 122 | m_i = m_ij 123 | 124 | # 更新 K 和 V 的指针 125 | k_ptrs += BLOCK_N * k_bs_stride 126 | v_ptrs += BLOCK_N * v_bs_stride 127 | 128 | # 计算是否需要存储 129 | need_store = num_blocks > 0 # 标量布尔值 130 | 131 | # 计算存储的偏移量 132 | off_mid_o = ( 133 | batch_idx * mido_batch_stride 134 | + head_idx * mido_heads_stride 135 | + seq_block_idx * mido_partitions_stride 136 | + offs_d * mido_dim_stride 137 | ) 138 | 139 | off_mid_o_les = ( 140 | batch_idx * mido_les_batch_stride 141 | + head_idx * mido_les_heads_stride 142 | + seq_block_idx * mido_les_partitions_stride 143 | ) 144 | 145 | # 计算最终的 attention 输出和 log-sum-exp 146 | part_atten_out = acc / d_i # [BLOCK_DMODEL] 147 | logexpsum = m_i + tl.log(d_i) # 标量 148 | 149 | # 条件存储 150 | part_atten_out = tl.where(need_store, part_atten_out, 0.0) # [BLOCK_DMODEL] 151 | logexpsum = tl.where(need_store, logexpsum, float("-inf")) # 标量 152 | 153 | # 存储结果 154 | tl.store(Mid_O + off_mid_o, part_atten_out, mask=need_store) 155 | tl.store(Mid_O_LogExpSum + off_mid_o_les, logexpsum, mask=need_store) 156 | 157 | 158 | @torch.no_grad() 159 | def flash_decode_stage1( 160 | q, 161 | k, 162 | v, # Q: [batchs, num_heads, head_dim], K, V: [batchs * seq_len, num_heads, head_dim] 163 | actual_seq_len, # 实际的序列长度 164 | mid_o, 165 | mid_o_logexpsum, # Mid_O: [batchs, num_heads, cdiv(seq_len, PARTITION_SIZE), head_dim], Mid_O_LogExpSum: [batchs, num_heads, cdiv(seq_len, PARTITION_SIZE)] 166 | PARTITION_SIZE, 167 | ): 168 | BLOCK_N_SIZE = 32 169 | BLOCK_DMODEL = q.shape[-1] 170 | assert PARTITION_SIZE % BLOCK_N_SIZE == 0, ( 171 | "PARTITION_SIZE 必须是 BLOCK_N_SIZE 的倍数" 172 | ) 173 | 174 | batchs, num_heads, head_dim = q.shape 175 | sm_scale = 1.0 / (head_dim**0.5) 176 | grid = (batchs, num_heads, triton.cdiv(actual_seq_len, PARTITION_SIZE)) 177 | 178 | _flash_decoding_stage1_kernel[grid]( 179 | q, 180 | k, 181 | v, 182 | sm_scale, 183 | actual_seq_len, # 使用实际序列长度 184 | mid_o, 185 | mid_o_logexpsum, 186 | *q.stride(), 187 | *k.stride(), 188 | *v.stride(), 189 | *mid_o.stride(), 190 | *mid_o_logexpsum.stride(), 191 | BLOCK_SEQ=PARTITION_SIZE, 192 | BLOCK_N=BLOCK_N_SIZE, 193 | BLOCK_DMODEL=head_dim, 194 | num_warps=1, 195 | num_stages=2, 196 | ) 197 | 198 | 199 | import torch 200 | 201 | # 设置随机种子以确保可重复性 202 | torch.manual_seed(42) 203 | 204 | # 假设头维度为 64,批次为 2,头数为 4,序列长度为 128 205 | batchs, num_heads, head_dim, seq_len = 2, 4, 64, 128 206 | partition_size = 32 207 | 208 | # 随机初始化 Q, K, V 209 | q = torch.randn(batchs, num_heads, head_dim, device="cuda", dtype=torch.float32) 210 | k = torch.randn( 211 | batchs * seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32 212 | ) 213 | v = torch.randn( 214 | batchs * seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32 215 | ) 216 | 217 | # 初始化 mid_o 和 mid_o_logexpsum 218 | mid_o = torch.zeros( 219 | batchs, 220 | num_heads, 221 | (seq_len + partition_size - 1) // partition_size, 222 | head_dim, 223 | device="cuda", 224 | dtype=torch.float32, 225 | ) 226 | mid_o_logexpsum = torch.zeros( 227 | batchs, 228 | num_heads, 229 | (seq_len + partition_size - 1) // partition_size, 230 | device="cuda", 231 | dtype=torch.float32, 232 | ) 233 | 234 | # 调用修复后的函数 235 | flash_decode_stage1( 236 | q, 237 | k, 238 | v, 239 | actual_seq_len=seq_len, 240 | mid_o=mid_o, 241 | mid_o_logexpsum=mid_o_logexpsum, 242 | PARTITION_SIZE=partition_size, 243 | ) 244 | 245 | # 打印输出结果 246 | print("Mid_O:", mid_o) 247 | print("Mid_O_LogExpSum:", mid_o_logexpsum) 248 | -------------------------------------------------------------------------------- /tests/kernels/test_flashdecoding_stage2.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import triton 3 | import triton.language as tl 4 | from torch.cuda.amp import custom_fwd 5 | from typing import List, Optional, Union 6 | import torch.nn.functional as F 7 | 8 | 9 | @triton.jit 10 | def _flash_decoding_stage2_kernel( 11 | Mid_O, # [batch, head, seq_block_num, head_dim] 12 | Mid_O_LogExpSum, # [batch, head, seq_block_num] 13 | Ouput, # attention 输出首地址 14 | mido_batch_stride, 15 | mido_heads_stride, 16 | mido_partitions_stride, 17 | mido_dim_stride, 18 | mido_les_batch_stride, 19 | mido_les_heads_stride, 20 | mido_les_partitions_stride, 21 | o_bs_stride, 22 | o_heads_stride, 23 | o_dim_stride, 24 | actual_seq_len, # TODO 支持 PagedAttention 和连续批处理 25 | BLOCK_DMODEL: tl.constexpr, 26 | BLOCK_SEQ: tl.constexpr, # type: ignore 27 | ): 28 | """Reduction (online softmax)""" 29 | batch_idx = tl.program_id(0) 30 | head_idx = tl.program_id(1) 31 | 32 | # 初始化偏移 33 | offs_d = tl.arange(0, BLOCK_DMODEL) 34 | 35 | offs_part_v = ( 36 | batch_idx * mido_batch_stride 37 | + head_idx * mido_heads_stride 38 | + offs_d * mido_dim_stride 39 | ) 40 | 41 | offs_part_max = batch_idx * mido_les_batch_stride + head_idx * mido_les_heads_stride 42 | 43 | part_v_ptrs = Mid_O + offs_part_v 44 | part_max_ptrs = Mid_O_LogExpSum + offs_part_max 45 | 46 | # Reduce kv 分块相关变量值. num_partitions 是 kv 分块数量 47 | d_i = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) 48 | m_i = -float("inf") 49 | acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) 50 | num_partitions = (actual_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ 51 | 52 | for _ in range(0, num_partitions, 1): 53 | part_v = tl.load(part_v_ptrs) 54 | part_max = tl.load(part_max_ptrs) 55 | 56 | # -- 更新局部最大值 和 exp 分子项 p-- # 57 | m_ij = tl.maximum(part_max, m_i) 58 | p = tl.exp(part_v - m_ij) 59 | 60 | # -- 计算 alpha = exp(m{j-1} - m{j}) 值 -- # 61 | alpha = tl.exp(m_i - m_ij) 62 | 63 | # -- 更新归一化项和 attention 输出累加器 -- # 64 | d_i = d_i * alpha + p 65 | 66 | acc *= alpha 67 | acc += p * part_v 68 | 69 | # 更新 max 值和指针偏移 70 | m_i = m_ij 71 | part_v_ptrs += mido_partitions_stride 72 | part_max_ptrs += mido_les_partitions_stride 73 | 74 | # -- 更新 attention 输出累加器 -- # 75 | offs_out = ( 76 | batch_idx * o_bs_stride + head_idx * o_heads_stride + offs_d * o_dim_stride 77 | ) 78 | tl.store(Ouput + offs_out, acc / d_i) 79 | 80 | 81 | @torch.no_grad() 82 | def flash_decode_stage2( 83 | mid_o, 84 | mid_o_logexpsum, # 存储每个批次、每个头、每个分区的中间分数输出及 log(sum(exp(scores))) 85 | atten_output, # attention 输出首地址 86 | actual_seq_len, # kv cache 在 seq_len 维度的最大长度 87 | PARTITION_SIZE, 88 | ): 89 | HEAD_DIM = mid_o.shape[-1] 90 | 91 | batchs, num_heads = mid_o.shape[0], mid_o.shape[1] 92 | grid = (batchs, num_heads) 93 | 94 | _flash_decoding_stage2_kernel[grid]( 95 | mid_o, # [batch, head, seq_block_num, head_dim] 96 | mid_o_logexpsum, # [batch, head, seq_block_num] 97 | atten_output, # attention 输出首地址 98 | *mid_o.stride(), 99 | *mid_o_logexpsum.stride(), 100 | *atten_output.stride(), 101 | actual_seq_len, # TODO 支持 PagedAttention 和连续批处理 102 | BLOCK_DMODEL=HEAD_DIM, 103 | BLOCK_SEQ=PARTITION_SIZE, # type: ignore 104 | num_warps=4, 105 | num_stages=2, 106 | ) 107 | 108 | 109 | import torch 110 | 111 | 112 | # 定义 PyTorch 对照实现 113 | def pytorch_flash_decode_stage2(mid_o, mid_o_logexpsum, actual_seq_len, partition_size): 114 | batchs, num_heads, seq_block_num, head_dim = mid_o.shape 115 | atten_output_pt = torch.zeros( 116 | batchs, num_heads, head_dim, device="cuda", dtype=torch.float32 117 | ) 118 | 119 | for batch in range(batchs): 120 | for head in range(num_heads): 121 | d_i = torch.zeros(head_dim, device="cuda", dtype=torch.float32) 122 | m_i = torch.full( 123 | (head_dim,), -float("inf"), device="cuda", dtype=torch.float32 124 | ) # 初始化为 [head_dim] 125 | acc = torch.zeros(head_dim, device="cuda", dtype=torch.float32) 126 | for partition in range(seq_block_num): 127 | part_v = mid_o[batch, head, partition] # [head_dim] 128 | part_max = mid_o_logexpsum[batch, head, partition].item() # scalar 129 | 130 | # Broadcast part_max to [head_dim] for comparison 131 | part_max_tensor = torch.full( 132 | (head_dim,), part_max, device="cuda", dtype=torch.float32 133 | ) 134 | m_ij = torch.maximum(part_max_tensor, m_i) # [head_dim] 135 | p = torch.exp(part_v - m_ij) # [head_dim] 136 | 137 | alpha = torch.exp(m_i - m_ij) # [head_dim] 138 | 139 | d_i = d_i * alpha + p # [head_dim] 140 | acc = acc * alpha + p * part_v # [head_dim] 141 | 142 | m_i = m_ij # [head_dim] 143 | 144 | # Avoid division by zero by setting zero where d_i is zero 145 | mask = d_i > 0 146 | atten_output_pt[batch, head][mask] = acc[mask] / d_i[mask] 147 | atten_output_pt[batch, head][~mask] = 0.0 # Handle division by zero 148 | 149 | return atten_output_pt 150 | 151 | 152 | # 设置随机种子以确保可重复性 153 | torch.manual_seed(42) 154 | 155 | # 假设头维度为 64,批次为 2,头数为 4,分区数量为 4,实际序列长度为 128,分区大小为 32 156 | batchs, num_heads, seq_block_num, head_dim = ( 157 | 2, 158 | 4, 159 | 4, 160 | 64, 161 | ) # head_dim 必须等于 BLOCK_DMODEL_CONST 162 | actual_seq_len = 128 163 | partition_size = 32 164 | 165 | # 随机初始化 Mid_O 和 Mid_O_LogExpSum 166 | mid_o = torch.randn( 167 | batchs, num_heads, seq_block_num, head_dim, device="cuda", dtype=torch.float32 168 | ) 169 | mid_o_logexpsum = torch.randn( 170 | batchs, num_heads, seq_block_num, device="cuda", dtype=torch.float32 171 | ) 172 | 173 | # 初始化 atten_output 174 | atten_output = torch.zeros( 175 | batchs, num_heads, head_dim, device="cuda", dtype=torch.float32 176 | ) 177 | 178 | # 调用修复后的 Triton 函数 179 | flash_decode_stage2( 180 | mid_o, 181 | mid_o_logexpsum, 182 | atten_output, 183 | actual_seq_len=actual_seq_len, 184 | PARTITION_SIZE=partition_size, 185 | ) 186 | 187 | # 调用 PyTorch 实现 188 | pt_atten_output = pytorch_flash_decode_stage2( 189 | mid_o, mid_o_logexpsum, actual_seq_len, partition_size 190 | ) 191 | 192 | # 比较 Triton 和 PyTorch 的输出 193 | diff_atten_output = torch.abs(atten_output - pt_atten_output).max() 194 | print(f"Difference in Atten_Output: {diff_atten_output.item()}") 195 | 196 | # 断言差异在合理范围内 197 | assert diff_atten_output < 1e-3, "Atten_Output 的差异超出容忍范围" 198 | print("Triton 内核与 PyTorch 实现的数值对比通过。") 199 | -------------------------------------------------------------------------------- /tests/kernels/test_mask.py: -------------------------------------------------------------------------------- 1 | # 代码可直接运行,用于测试 masked-scores 的结果 2 | 3 | import torch, time 4 | 5 | 6 | def create_and_print_mask(): 7 | """用于测试 mask 内容和形状""" 8 | seq_len = 4 9 | start_pos = 0 10 | mask = torch.full((seq_len, seq_len), float("-inf")) 11 | print(mask) 12 | mask1 = torch.triu(mask, diagonal=1) # 创建上三角矩阵 13 | print(mask1) 14 | mask2 = torch.hstack([torch.zeros((seq_len, start_pos)), mask1]) 15 | print(mask2) 16 | print("mask shape is ", mask.shape) 17 | scores = torch.randn((seq_len, seq_len)) 18 | offs_m = torch.tensor([0, 1, 2, 3]) 19 | offs_k = torch.tensor([0, 1, 2, 3]) 20 | mask3 = offs_m[:, None] >= offs_k[None, :] 21 | print(mask3) 22 | mask4 = scores.masked_fill(mask3 == 0, float("-inf")) 23 | print(mask4) 24 | 25 | 26 | """ 27 | tensor([[-inf, -inf, -inf, -inf], 28 | [-inf, -inf, -inf, -inf], 29 | [-inf, -inf, -inf, -inf], 30 | [-inf, -inf, -inf, -inf]]) 31 | tensor([[0., -inf, -inf, -inf], 32 | [0., 0., -inf, -inf], 33 | [0., 0., 0., -inf], 34 | [0., 0., 0., 0.]]) 35 | tensor([[0., -inf, -inf, -inf], 36 | [0., 0., -inf, -inf], 37 | [0., 0., 0., -inf], 38 | [0., 0., 0., 0.]]) 39 | mask shape is torch.Size([4, 4]) 40 | tensor([[ True, False, False, False], 41 | [ True, True, False, False], 42 | [ True, True, True, False], 43 | [ True, True, True, True]]) 44 | tensor([[ 2.2425, -inf, -inf, -inf], 45 | [-0.4196, 1.4955, -inf, -inf], 46 | [ 1.1759, 1.9087, 0.2180, -inf], 47 | [-0.5477, 0.1412, 0.7192, 0.8276]]) 48 | """ 49 | 50 | 51 | def apply_prefill_mask1(scores, seq_len): 52 | """llama3 实现的创建并应用 mask 矩阵方法""" 53 | mask = torch.full((seq_len, seq_len), float("-inf")) 54 | mask = torch.triu(mask, diagonal=1) 55 | 56 | masked_scores = scores + mask 57 | 58 | return masked_scores 59 | 60 | 61 | def apply_prefill_mask2(scores, seq_len): 62 | """使用下三角矩阵方法创建并应用 mask""" 63 | mask = torch.tril(torch.ones([seq_len, seq_len])) 64 | masked_scores = scores.masked_fill(mask == 0, float("-inf")) 65 | return masked_scores 66 | 67 | 68 | def apply_prefill_mask3(scores, seq_len): 69 | """flashattention 内核中创建并应用的 mask""" 70 | offs_q = torch.arange( 71 | seq_len, 72 | ) 73 | offs_k = torch.arange( 74 | seq_len, 75 | ) 76 | mask = offs_q[:, None] >= offs_k[None, :] 77 | masked_scores = scores.masked_fill(mask == 0, float("-inf")) 78 | # masked_scores = torch.where(mask, scores, torch.full_like(scores, -1.0e8)) 79 | return masked_scores 80 | 81 | 82 | if __name__ == "__main__": 83 | # torch.manual_seed(42) 84 | seq_len = 512 85 | scores = torch.randn([seq_len, seq_len]) 86 | 87 | # 测量 apply_prefill_mask1 的运行时间 88 | start_time = time.time() 89 | masked_scores1 = apply_prefill_mask1(scores, seq_len) 90 | time1 = time.time() - start_time 91 | print(f"apply_prefill_mask1 运行时间: {time1:.6f} 秒") 92 | 93 | # 测量 apply_prefill_mask2 的运行时间 94 | start_time = time.time() 95 | masked_scores2 = apply_prefill_mask2(scores, seq_len) 96 | time2 = time.time() - start_time 97 | print(f"apply_prefill_mask2 运行时间: {time2:.6f} 秒") 98 | 99 | # 测量 apply_prefill_mask2 的运行时间 100 | start_time = time.time() 101 | masked_scores3 = apply_prefill_mask3(scores, seq_len) 102 | time3 = time.time() - start_time 103 | print(f"apply_prefill_mask3 运行时间: {time3:.6f} 秒") 104 | 105 | # 确保两个函数的结果一致 106 | assert torch.allclose(masked_scores1, masked_scores2, atol=1e-4) 107 | assert torch.allclose(masked_scores1, masked_scores3, atol=1e-4) 108 | -------------------------------------------------------------------------------- /tests/models/test_LlamaConfig.py: -------------------------------------------------------------------------------- 1 | import json, os, sys 2 | 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 4 | from lite_llama.models.model_config import LlamaConfig 5 | 6 | 7 | def load_config_from_json(json_file_path: str) -> LlamaConfig: 8 | with open(json_file_path, "r", encoding="utf-8") as f: 9 | config_dict = json.load(f) 10 | config = LlamaConfig(config_dict, max_seq_len=2048) 11 | return config 12 | 13 | 14 | if __name__ == "__main__": 15 | # 创建 LlamaConfig 实例,设置 max_batch_size=16 16 | config = LlamaConfig(max_batch_size=16) 17 | print("max_batch_size:", config.max_batch_size) 18 | 19 | # JSON 文件的路径 20 | json_file_path = "/gemini/code/Llama-3.2-1B-Instruct/config.json" 21 | 22 | # 加载配置 23 | config = load_config_from_json(json_file_path) 24 | 25 | # 访问配置参数 26 | print("模型类型:", config.model_type) 27 | print("隐藏层数 (n_layers):", config.n_layers) 28 | print("隐藏大小 (dim):", config.dim) 29 | print("词汇表大小:", config.vocab_size) 30 | print("旋转位置编码配置:", config.rope_scaling) 31 | print("最大支持序列长度:", config.max_seq_len) 32 | print("模型层数", config.n_layers) 33 | if config.rope_scaling is not None: 34 | print("rope 类型:", config.rope_scaling.get("rope_type")) 35 | else: 36 | print("rope_scaling is None") 37 | -------------------------------------------------------------------------------- /tests/models/test_LlamaForCausalLM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaTokenizerFast 3 | from transformers import pipeline 4 | 5 | 6 | def load_llama_model(model_path: str, device: str = "cuda"): 7 | """ 8 | Load the LLaMA model and tokenizer. 9 | 10 | Args: 11 | model_path (str): Path to the directory containing the model files. 12 | tokenizer_path (str): Path to the directory containing the tokenizer files. 13 | device (str): Device to load the model on ('cuda' or 'cpu'). 14 | 15 | Returns: 16 | model: The loaded LLaMA model. 17 | tokenizer: The loaded tokenizer. 18 | """ 19 | tokenizer = LlamaTokenizerFast.from_pretrained(model_path, legacy=False) 20 | model = LlamaForCausalLM.from_pretrained( 21 | model_path, 22 | torch_dtype=torch.float16, # Use float16 for faster inference if supported 23 | low_cpu_mem_usage=True, 24 | ) 25 | model.to(device) 26 | return model, tokenizer 27 | 28 | 29 | def generate_text( 30 | model, tokenizer, prompt: str, max_length: int = 50, device: str = "cuda" 31 | ): 32 | """ 33 | Generate text using the LLaMA model. 34 | 35 | Args: 36 | model: The loaded LLaMA model. 37 | tokenizer: The loaded tokenizer. 38 | prompt (str): The input text prompt. 39 | max_length (int): The maximum length of the generated text. 40 | device (str): Device to run the model on ('cuda' or 'cpu'). 41 | 42 | Returns: 43 | str: The generated text. 44 | """ 45 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 46 | with torch.no_grad(): 47 | outputs = model.generate( 48 | **inputs, 49 | max_length=max_length, 50 | do_sample=True, # Enable sampling to introduce randomness 51 | temperature=0.7, # Adjust temperature for creativity 52 | top_p=0.9, # Use top-p (nucleus) sampling 53 | repetition_penalty=1.2, # Penalize repetitions 54 | ) 55 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 56 | return generated_text 57 | 58 | 59 | def pipline_text(model_id): 60 | pipe = pipeline( 61 | "text-generation", 62 | model=model_id, 63 | torch_dtype=torch.bfloat16, 64 | device_map="auto", 65 | ) 66 | messages = [ 67 | { 68 | "role": "system", 69 | "content": "You are a pirate chatbot who always responds in pirate speak!", 70 | }, 71 | {"role": "user", "content": "Who are you?"}, 72 | ] 73 | outputs = pipe( 74 | messages, 75 | max_new_tokens=256, 76 | ) 77 | print(outputs[0]["generated_text"][-1]) 78 | 79 | 80 | if __name__ == "__main__": 81 | # Specify the paths to your model and tokenizer directories 82 | model_path = "/gemini/code/Llama-3.2-1B-Instruct/" 83 | 84 | # Load the model and tokenizer 85 | device = "cuda" if torch.cuda.is_available() else "cpu" 86 | model, tokenizer = load_llama_model(model_path, device) 87 | 88 | # Test the model with a sample prompt 89 | prompt = "I believe the meaning of life is," 90 | generated_text = generate_text( 91 | model, tokenizer, prompt, max_length=100, device=device 92 | ) 93 | 94 | print("Prompt:") 95 | print(prompt) 96 | print("\nGenerated Text:") 97 | print(generated_text) 98 | 99 | pipline_text(model_path) 100 | -------------------------------------------------------------------------------- /tests/models/test_LlavaConfig.py: -------------------------------------------------------------------------------- 1 | import json, os, sys 2 | 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) 4 | from lite_llama.models.model_config import LlavaConfig 5 | from lite_llama.models.llava import LlavaLlama 6 | 7 | 8 | def test_llava_config(): 9 | # 示例配置 JSON 字符串 10 | config_json = """ 11 | { 12 | "architectures": [ 13 | "LlavaForConditionalGeneration" 14 | ], 15 | "ignore_index": -100, 16 | "image_token_index": 32000, 17 | "model_type": "llava", 18 | "pad_token_id": 32001, 19 | "projector_hidden_act": "gelu", 20 | "text_config": { 21 | "_name_or_path": "lmsys/vicuna-7b-v1.5", 22 | "architectures": [ 23 | "LlamaForCausalLM" 24 | ], 25 | "max_position_embeddings": 4096, 26 | "model_type": "llama", 27 | "rms_norm_eps": 1e-05, 28 | "torch_dtype": "float16", 29 | "vocab_size": 32064 30 | }, 31 | "tie_word_embeddings": false, 32 | "torch_dtype": "float16", 33 | "transformers_version": "4.36.0.dev0", 34 | "vision_config": { 35 | "hidden_size": 1024, 36 | "image_size": 336, 37 | "intermediate_size": 4096, 38 | "model_type": "clip_vision_model", 39 | "num_attention_heads": 16, 40 | "num_hidden_layers": 24, 41 | "patch_size": 14, 42 | "projection_dim": 768, 43 | "vocab_size": 32000 44 | }, 45 | "vision_feature_layer": -2, 46 | "vision_feature_select_strategy": "default", 47 | "vocab_size": 32064 48 | } 49 | """ 50 | 51 | # 将 JSON 字符串解析为字典 52 | config_dict = json.loads(config_json) 53 | 54 | # 从字典创建 LlavaConfig 实例 55 | llava_config = LlavaConfig.from_dict(config_dict) 56 | 57 | # 打印配置以验证 58 | print(llava_config) 59 | 60 | 61 | def test_LlavaLlama_structure(): 62 | model_path = "/gemini/code/llm_weights/llava-hf/llava-1.5-7b-hf" 63 | from accelerate import init_empty_weights, load_checkpoint_and_dispatch 64 | from transformers import LlavaConfig 65 | 66 | # 使用 init_empty_weights 初始化空模型 67 | with init_empty_weights(): 68 | llava_config = LlavaConfig.from_pretrained(model_path) 69 | # print(llava_config) # 打印配置以验证 70 | 71 | model = LlavaLlama(llava_config) 72 | print(model) # 打印模型结构 73 | for name, param in list(model.named_parameters())[:]: # 打印模型参数 74 | print(name, param.shape) 75 | 76 | 77 | if __name__ == "__main__": 78 | test_llava_config() 79 | test_LlavaLlama_structure() 80 | -------------------------------------------------------------------------------- /tests/models/test_LlavaForConditionalGeneration.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import requests, torch 3 | from transformers import ( 4 | AutoProcessor, 5 | LlavaForConditionalGeneration, 6 | LlavaConfig, 7 | LlavaNextConfig, 8 | LlavaNextForConditionalGeneration, 9 | ) 10 | from accelerate import init_empty_weights, load_checkpoint_and_dispatch 11 | 12 | model_path = "/gemini/code/llm_weights/llava-hf/llava-1.5-7b-hf" 13 | 14 | # 使用 init_empty_weights 初始化空模型 15 | with init_empty_weights(): 16 | config = LlavaConfig.from_pretrained(model_path) 17 | model = LlavaForConditionalGeneration(config) 18 | 19 | # 使用 load_checkpoint_and_dispatch 分配权重 20 | model = load_checkpoint_and_dispatch(model, model_path, device_map="auto") 21 | 22 | # model = LlavaForConditionalGeneration.from_pretrained( 23 | # model_path, 24 | # torch_dtype=torch.float16, 25 | # low_cpu_mem_usage=True, 26 | # ).to("cuda") 27 | 28 | processor = AutoProcessor.from_pretrained(model_path) 29 | prompt = "USER: \nWhat's the content of the image? ASSISTANT:" 30 | url = "https://www.ilankelman.org/stopsigns/australia.jpg" 31 | image = Image.open(requests.get(url, stream=True).raw) 32 | 33 | inputs = processor(images=image, text=prompt, return_tensors="pt") 34 | 35 | # 显式移动每个张量到 CUDA 36 | inputs = {k: v.to("cuda") for k, v in inputs.items()} 37 | 38 | # Generate 39 | generate_ids = model.generate(**inputs, max_new_tokens=30) 40 | print( 41 | processor.batch_decode( 42 | generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False 43 | )[0] 44 | ) 45 | 46 | """ 47 | USER: 48 | What's the content of the image? 49 | ASSISTANT: The image shows a street scene with a red stop sign on the left side. In the background, there is a traditional Chinese-style archway with red 50 | """ 51 | 52 | print("模型结构", model) 53 | # 打印模型的简单摘要 54 | print(f"模型总参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M") 55 | 56 | # 打印模型参数信息 57 | for name, param in list(model.named_parameters()): 58 | print(name, param.shape) 59 | 60 | """ 61 | LlavaNextForConditionalGeneration( 62 | (vision_tower): CLIPVisionModel( 63 | (vision_model): CLIPVisionTransformer( 64 | (embeddings): CLIPVisionEmbeddings( 65 | (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False) 66 | (position_embedding): Embedding(577, 1024) 67 | ) 68 | (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) 69 | (encoder): CLIPEncoder( 70 | (layers): ModuleList( 71 | (0-23): 24 x CLIPEncoderLayer( 72 | (self_attn): CLIPAttention( 73 | (k_proj): Linear(in_features=1024, out_features=1024, bias=True) 74 | (v_proj): Linear(in_features=1024, out_features=1024, bias=True) 75 | (q_proj): Linear(in_features=1024, out_features=1024, bias=True) 76 | (out_proj): Linear(in_features=1024, out_features=1024, bias=True) 77 | ) 78 | (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) 79 | (mlp): CLIPMLP( 80 | (activation_fn): QuickGELUActivation() 81 | (fc1): Linear(in_features=1024, out_features=4096, bias=True) 82 | (fc2): Linear(in_features=4096, out_features=1024, bias=True) 83 | ) 84 | (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) 85 | ) 86 | ) 87 | ) 88 | (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) 89 | ) 90 | ) 91 | (multi_modal_projector): LlavaNextMultiModalProjector( 92 | (linear_1): Linear(in_features=1024, out_features=4096, bias=True) 93 | (act): GELUActivation() 94 | (linear_2): Linear(in_features=4096, out_features=4096, bias=True) 95 | ) 96 | (language_model): LlamaForCausalLM( 97 | (model): LlamaModel( 98 | (embed_tokens): Embedding(128320, 4096) 99 | (layers): ModuleList( 100 | (0-31): 32 x LlamaDecoderLayer( 101 | (self_attn): LlamaSdpaAttention( 102 | (q_proj): Linear(in_features=4096, out_features=4096, bias=False) 103 | (k_proj): Linear(in_features=4096, out_features=1024, bias=False) 104 | (v_proj): Linear(in_features=4096, out_features=1024, bias=False) 105 | (o_proj): Linear(in_features=4096, out_features=4096, bias=False) 106 | (rotary_emb): LlamaRotaryEmbedding() 107 | ) 108 | (mlp): LlamaMLP( 109 | (gate_proj): Linear(in_features=4096, out_features=14336, bias=False) 110 | (up_proj): Linear(in_features=4096, out_features=14336, bias=False) 111 | (down_proj): Linear(in_features=14336, out_features=4096, bias=False) 112 | (act_fn): SiLU() 113 | ) 114 | (input_layernorm): LlamaRMSNorm() 115 | (post_attention_layernorm): LlamaRMSNorm() 116 | ) 117 | ) 118 | (norm): LlamaRMSNorm() 119 | ) 120 | (lm_head): Linear(in_features=4096, out_features=128320, bias=False) 121 | ) 122 | ) 123 | """ 124 | -------------------------------------------------------------------------------- /tests/models/test_LlavaLlama.py: -------------------------------------------------------------------------------- 1 | from accelerate import init_empty_weights, load_checkpoint_and_dispatch 2 | from transformers import LlavaConfig 3 | import sys, os 4 | 5 | # 获取 lite_llama 目录的绝对路径并添加到 sys.path 中 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 7 | from lite_llama.models.llava import LlavaLlama 8 | 9 | hf_model_path = "/gemini/code/liuhaotian/llava-v1.5-7b" 10 | 11 | 12 | def test_LlavaLlama_structure(hf_model_path): 13 | # 使用 init_empty_weights 初始化空模型 14 | with init_empty_weights(): 15 | config = LlavaConfig.from_pretrained(hf_model_path) 16 | model = LlavaLlama(config) 17 | 18 | # 打印没有加载权重的 LlavaLlama 模型结构 19 | print(model) 20 | # 打印模型的简单摘要 21 | print(f"模型总参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M") 22 | 23 | # 可选择打印部分参数信息 24 | for name, param in list(model.named_parameters())[:]: # 打印模型参数 25 | print(name, param.shape) 26 | 27 | 28 | if __name__ == "__main__": 29 | test_LlavaLlama_structure(hf_model_path) 30 | -------------------------------------------------------------------------------- /tests/models/test_Qwen2ForCausalLM.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, Qwen2ForCausalLM 2 | 3 | model_name = "/gemini/code/llm_weights/Qwen/Qwen2.5-3B-Instruct" 4 | 5 | model = Qwen2ForCausalLM.from_pretrained( 6 | model_name, torch_dtype="auto", device_map="auto" 7 | ) 8 | tokenizer = AutoTokenizer.from_pretrained(model_name) 9 | print(model) 10 | print("my llama archetectue and shape") 11 | 12 | for name, param in model.named_parameters(): 13 | print(name, param.shape) 14 | 15 | prompt = "给出 c++ 多线程语法和编程示例代码." 16 | 17 | messages = [ 18 | { 19 | "role": "system", 20 | "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", 21 | }, 22 | {"role": "user", "content": prompt}, 23 | ] 24 | 25 | text = tokenizer.apply_chat_template( 26 | messages, tokenize=False, add_generation_prompt=True 27 | ) 28 | print("After call apply_chat_template, text is ", text) 29 | 30 | model_inputs = tokenizer([text], return_tensors="pt").to(model.device) 31 | 32 | generated_ids = model.generate(**model_inputs, max_new_tokens=512) 33 | generated_ids = [ 34 | output_ids[len(input_ids) :] 35 | for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 36 | ] 37 | 38 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 39 | print(response) 40 | 41 | """ 42 | Qwen2ForCausalLM( 43 | (model): Qwen2Model( 44 | (embed_tokens): Embedding(151936, 1536) 45 | (layers): ModuleList( 46 | (0-27): 28 x Qwen2DecoderLayer( 47 | (self_attn): Qwen2SdpaAttention( 48 | (q_proj): Linear(in_features=1536, out_features=1536, bias=True) 49 | (k_proj): Linear(in_features=1536, out_features=256, bias=True) 50 | (v_proj): Linear(in_features=1536, out_features=256, bias=True) 51 | (o_proj): Linear(in_features=1536, out_features=1536, bias=False) 52 | (rotary_emb): Qwen2RotaryEmbedding() 53 | ) 54 | (mlp): Qwen2MLP( 55 | (gate_proj): Linear(in_features=1536, out_features=8960, bias=False) 56 | (up_proj): Linear(in_features=1536, out_features=8960, bias=False) 57 | (down_proj): Linear(in_features=8960, out_features=1536, bias=False) 58 | (act_fn): SiLU() 59 | ) 60 | (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06) 61 | (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06) 62 | ) 63 | ) 64 | (norm): Qwen2RMSNorm((1536,), eps=1e-06) 65 | (rotary_emb): Qwen2RotaryEmbedding() 66 | ) 67 | (lm_head): Linear(in_features=1536, out_features=151936, bias=False) 68 | ) 69 | """ 70 | -------------------------------------------------------------------------------- /tests/models/test_get_model_name.py: -------------------------------------------------------------------------------- 1 | from transformers import LlavaConfig, AutoTokenizer 2 | 3 | 4 | def get_model_name_from_path(model_path): 5 | model_path = model_path.strip("/") 6 | model_paths = model_path.split("/") 7 | if model_paths[-1].startswith("checkpoint-"): 8 | return model_paths[-2] + "_" + model_paths[-1] 9 | else: 10 | return model_paths[-1] 11 | 12 | 13 | if __name__ == "__main__": 14 | model_path = "/gemini/code/lite_llama/my_weight/llava-1.5-7b-hf" 15 | print(get_model_name_from_path(model_path)) 16 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 17 | print(tokenizer) 18 | -------------------------------------------------------------------------------- /tests/models/test_gpt2.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import torch 3 | 4 | 5 | def generate_text( 6 | model, tokenizer, prompt, max_length=50, temperature=1.0, top_p=0.9, device="cuda" 7 | ): 8 | """ 9 | 使用 model.forward 实现逐步生成文本,并正确设置 attention_mask。 10 | 11 | Args: 12 | model: 已加载的因果语言模型。 13 | tokenizer: 对应的 tokenizer。 14 | prompt: 初始输入文本。 15 | max_length: 生成的最大 token 数量。 16 | temperature: 采样时的温度参数。 17 | top_p: 采样时的 top-p 参数。 18 | device: 设备类型。 19 | 20 | Returns: 21 | 生成的文本字符串。 22 | """ 23 | # 编码输入 prompt 24 | input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) # [1, S] 25 | 26 | # 初始化生成的 Token 列表 27 | generated_ids = input_ids 28 | 29 | # 初始化 past_key_values 为 None 30 | past_key_values = None 31 | 32 | for _ in range(max_length): 33 | # 调用模型的 forward 方法 34 | outputs = model( 35 | input_ids=input_ids, past_key_values=past_key_values, use_cache=True 36 | ) 37 | 38 | # 获取 logits,并仅关注最后一个 token 的 logits 39 | logits = outputs.logits # [1, 1, V] 40 | next_token_logits = logits[:, -1, :] / temperature # [1, V] 41 | 42 | # 应用 top-p 过滤 43 | sorted_logits, sorted_indices = torch.sort( 44 | next_token_logits, dim=-1, descending=True 45 | ) 46 | cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) 47 | 48 | # 创建 mask 49 | sorted_indices_to_remove = cumulative_probs > top_p 50 | # Shift the mask to include the first token exceeding p 51 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 52 | sorted_indices_to_remove[..., 0] = False 53 | 54 | # 应用 mask 55 | sorted_logits[sorted_indices_to_remove] = -float("Inf") 56 | # 应用 softmax 57 | probs = torch.softmax(sorted_logits, dim=-1) 58 | 59 | # 采样下一个 token 60 | next_token = torch.multinomial(probs, num_samples=1) # [1, 1] 61 | 62 | # 反向排序索引以获取原始 token ID 63 | next_token = sorted_indices.gather(-1, next_token) 64 | 65 | # 将生成的 token 添加到生成的 Token 列表中 66 | generated_ids = torch.cat([generated_ids, next_token], dim=-1) 67 | 68 | # 更新 input_ids 为新生成的 token 69 | input_ids = next_token 70 | 71 | # 更新 past_key_values 72 | past_key_values = outputs.past_key_values 73 | 74 | # 解码生成的 token 75 | generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) 76 | return generated_text 77 | 78 | 79 | if __name__ == "__main__": 80 | # 使用标准的 GPT-2 模型名称,确保模型和 tokenizer 匹配 81 | model_name = "/gemini/code/llm_weights/gpt2" # 修改为您的模型路径或名称 82 | tokenizer = AutoTokenizer.from_pretrained(model_name) 83 | model = AutoModelForCausalLM.from_pretrained(model_name) 84 | 85 | # 将模型移动到 GPU(如果可用)并设置为评估模式 86 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 87 | model.to(device) 88 | model.eval() 89 | 90 | # 定义 prompt 91 | prompt = "Once upon a time in a distant land," 92 | 93 | # 生成文本 94 | generated = generate_text( 95 | model, 96 | tokenizer, 97 | prompt, 98 | max_length=500, 99 | temperature=1.0, 100 | top_p=0.9, 101 | device=device, 102 | ) 103 | print(generated) 104 | -------------------------------------------------------------------------------- /tests/models/test_transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, TextIO, Optional 3 | import torch 4 | from transformers import AutoModel, AutoConfig, PreTrainedModel 5 | from accelerate import init_empty_weights 6 | 7 | MODEL_ID = "/home/honggao/llm_weights/Qwen3-235B-A22B" 8 | 9 | 10 | def print_empty_model(model_id): 11 | """ 12 | Accelerate 提供 init_empty_weights 上下文管理器,令所有 Parameter 和 Buffer 13 | 都放在 meta device,尺寸为 0,因此既 不下载权重 也 不占内存。 14 | """ 15 | cfg = AutoConfig.from_pretrained(model_id) # 只拉配置 16 | 17 | with init_empty_weights(): 18 | model = AutoModel.from_config(cfg) 19 | print(model) 20 | return model 21 | 22 | def print_transformers_model_summary( 23 | model: PreTrainedModel, 24 | *, 25 | use_torchinfo: bool = False, 26 | input_size: Optional[tuple] = None, 27 | file: Union[str, TextIO, None] = None, 28 | ) -> None: 29 | """ 30 | 打印 Hugging Face Transformers 模型结构 + 权重 shape。 31 | 32 | Args: 33 | model (PreTrainedModel): 已加载好的模型实例。 34 | use_torchinfo (bool): 是否调用 torchinfo.summary() 生成额外摘要。 35 | input_size (tuple): 当 use_torchinfo=True 时需提供 (seq_len, ) or (bs, seq_len, ...)。 36 | file: None -> 输出到 stdout; 37 | str -> 输出到指定路径文件; 38 | TextIO -> 已打开的文件句柄。 39 | """ 40 | import math 41 | 42 | def _human_readable(num: float, *, base: int = 1000, units=("", "K", "M", "G", "T", "P"), suffix=""): 43 | """Convert a large number to human‑readable form (e.g. 12.3G).""" 44 | if num == 0: 45 | return f"0{suffix}" 46 | exp = min(int(math.log(num, base)), len(units) - 1) 47 | value = num / (base ** exp) 48 | return f"{value:.2f}{units[exp]}{suffix}" 49 | 50 | def _dump(msg: str = ""): 51 | if fh: 52 | fh.write(msg + "\n") 53 | else: 54 | print(msg) 55 | 56 | # 0) 处理输出目标 57 | fh = open(file, "w") if isinstance(file, str) else file 58 | 59 | # 1) 模型 __repr__ 60 | _dump("=" * 60) 61 | _dump("Model architecture (__repr__):") 62 | _dump("=" * 60) 63 | _dump(str(model)) 64 | 65 | # 2) 权重 shape 66 | _dump("\n" + "=" * 60) 67 | _dump("Parameter shapes (name -> shape, #elements):") 68 | _dump("=" * 60) 69 | 70 | # Token count estimation for FLOPs (default = 1 token if unknown) 71 | tokens = 1 72 | if input_size is not None: 73 | # Accept (seq_len,), (bs, seq_len) or any shape where last dim is seq_len 74 | if len(input_size) == 1: 75 | tokens = input_size[0] 76 | else: 77 | tokens = input_size[0] * input_size[-1] 78 | 79 | total_params = 0 80 | total_flops = 0 81 | total_mem_bytes = 0 82 | for name, param in model.named_parameters(): 83 | numel = param.numel() 84 | total_params += numel 85 | 86 | # ---- Estimate per‑parameter FLOPs ---- 87 | if param.dim() == 2: # typical (out, in) weight matrix 88 | flops = 2 * param.shape[0] * param.shape[1] * tokens 89 | elif param.dim() == 1: # bias / norm weight 90 | flops = param.shape[0] * tokens 91 | else: 92 | flops = numel # fallback crude estimate 93 | total_flops += flops 94 | 95 | # ---- Memory access cost (parameter bytes only) ---- 96 | mem_bytes = numel * param.element_size() 97 | total_mem_bytes += mem_bytes 98 | 99 | # ---- Pretty print ---- 100 | flops_str = _human_readable(flops, suffix="F") 101 | mem_str = _human_readable(mem_bytes, base=1024, units=("B","KB","MB","GB","TB","PB")) 102 | _dump(f"{name:<60} {str(tuple(param.shape)):<20} {numel:,} | {flops_str:<8} | {mem_str}") 103 | 104 | _dump(f"\nTotal parameters: {total_params:,}") 105 | _dump(f"Estimated forward FLOPs: {_human_readable(total_flops, suffix='F')}") 106 | _dump(f"Parameter memory: {_human_readable(total_mem_bytes, base=1024, units=('B','KB','MB','GB','TB','PB'))}") 107 | 108 | # 3) 可选 torchinfo 摘要 109 | if use_torchinfo: 110 | try: 111 | from torchinfo import summary # pip install torchinfo 112 | assert input_size is not None, "`input_size` must be provided when use_torchinfo=True" 113 | info = summary( 114 | model, 115 | input_size=input_size, 116 | depth=3, 117 | col_names=("kernel_size", "output_size", "num_params", "mult_adds"), 118 | dtypes=[torch.long], # 对 NLP 模型输入通常是 int64 token id 119 | ) 120 | _dump("\n" + "=" * 60) 121 | _dump("torchinfo summary():") 122 | _dump("=" * 60) 123 | _dump(str(info)) 124 | except ImportError: 125 | _dump("torchinfo 未安装,跳过摘要。pip install torchinfo 获取更丰富视图。") 126 | 127 | if isinstance(file, str): # 自动关闭文件 128 | fh.close() 129 | 130 | from torchviz import make_dot # pip install torchviz graphviz 131 | def save_model_graph(model, input_example: torch.Tensor, file_name: str = "model_graph.svg") -> None: 132 | """ 133 | 利用 torchviz 生成前向图;input_example 必须能直接送入 model。 134 | """ 135 | model.eval() 136 | y = model(input_example) 137 | dot = make_dot(y, params=dict(model.named_parameters())) 138 | dot.format = file_name.split(".")[-1] # 自动根据后缀决定 svg/png 139 | dot.render(file_name, cleanup=True) 140 | print(f"✅ Graph saved to {file_name}") 141 | 142 | if __name__ == "__main__": 143 | # model = AutoModel.from_pretrained(MODEL_ID) 144 | model = print_empty_model(MODEL_ID) 145 | input_example = torch.randint(0, 1000, (2, 2048)) # 随机输入 146 | print_transformers_model_summary( 147 | model=model, 148 | use_torchinfo=True, 149 | input_size=(2, 2048), 150 | file="qwen3_8b_structure.txt" 151 | ) -------------------------------------------------------------------------------- /tests/others/test_image_process.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from PIL import Image 3 | 4 | # 获取 lite_llama 目录的绝对路径并添加到 sys.path 中 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 6 | from lite_llama.lite_llama.utils.image_process import vis_images 7 | 8 | 9 | def test_vis_images(image_files): 10 | print("=" * 50) 11 | print("Input Image:") 12 | vis_images(image_files) 13 | 14 | 15 | if __name__ == "__main__": 16 | image_files = [ 17 | "/gemini/code/lite_llama/images/pexels-christian-heitz-285904-842711.jpg", 18 | "/gemini/code/lite_llama/images/pexels-francesco-ungaro-1525041.jpg", 19 | ] 20 | test_vis_images(image_files) 21 | -------------------------------------------------------------------------------- /tests/others/test_load_weight.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | from tqdm.auto import tqdm 4 | import json, sys, os 5 | from pathlib import Path 6 | 7 | # 获取 lite_llama 目录的绝对路径并添加到 sys.path 中 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 9 | from lite_llama.models.qwen2 import Qwen2Model, Qwen2Config 10 | 11 | 12 | def load_config_from_json(json_file_path: str, device: str = "cuda") -> Qwen2Config: 13 | with open(json_file_path, "r") as f: 14 | config_dict = json.load(f) 15 | 16 | config = Qwen2Config(config_dict, max_seq_len=2048, device=device) 17 | return config 18 | 19 | 20 | def load_original_llama(model_name_or_path: str, device: str = "cuda"): 21 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 22 | model = AutoModelForCausalLM.from_pretrained( 23 | model_name_or_path, 24 | torch_dtype=torch.float16, 25 | device_map="cuda", 26 | ) 27 | model.to(device) 28 | hf_sd = model.state_dict() 29 | 30 | return model, tokenizer, hf_sd 31 | 32 | 33 | def load_custom_llam( 34 | model_name_or_path: str, model_config: Qwen2Config, device: str = "cuda" 35 | ): 36 | checkpoints = sorted(Path(model_name_or_path).glob("*.pth")) 37 | assert len(checkpoints) > 0, f"no checkpoint files found in {model_name_or_path}" 38 | ckpt_path = checkpoints[0] 39 | state_dict = torch.load(ckpt_path, map_location="cuda") 40 | 41 | # 根据设备选择合适的 dtype 42 | torch.set_default_dtype(torch.half) 43 | 44 | model = Qwen2Model(model_config).to(device) 45 | model.load_state_dict(state_dict, strict=True) 46 | new_sd = model.state_dict() 47 | 48 | return model, new_sd 49 | 50 | 51 | def compare_model_weights(hf_sd, new_sd, model_config, rtol=1e-5, atol=1e-8): 52 | """ 53 | 比较两个模型权重字典的各个参数是否相等。 54 | 55 | Args: 56 | hf_sd (dict): Hugging Face 模型的 state_dict。 57 | new_sd (dict): 自定义模型的 state_dict。 58 | rtol (float): 允许的相对误差。 59 | atol (float): 允许的绝对误差。 60 | 61 | Returns: 62 | bool: 如果权重完全匹配,则返回 True, 否则返回 False。 63 | """ 64 | 65 | all_match = True 66 | 67 | # 检查键是否一致 68 | hf_keys = set(hf_sd.keys()) 69 | new_keys = set(new_sd.keys()) 70 | 71 | if hf_keys != new_keys: 72 | print("键不一致!") 73 | print("Hugging Face 多出的键:", hf_keys - new_keys) 74 | print("自定义模型多出的键:", new_keys - hf_keys) 75 | # all_match = False 76 | 77 | # 映射嵌入层 # 映射归一化层 78 | mapping = { 79 | "model.norm.weight": "norm_weight", 80 | "model.embed_tokens.weight": "embed_tokens.weight", 81 | "lm_head.weight": "lm_head_weight", 82 | } 83 | 84 | # 映射层 85 | layers = { 86 | "model.layers.{i}.self_attn.q_proj.weight": "layers.{i}.self_attn.q_proj_weight", 87 | "model.layers.{i}.self_attn.q_proj.bias": "layers.{i}.self_attn.q_proj_bias", 88 | "model.layers.{i}.self_attn.k_proj.weight": "layers.{i}.self_attn.k_proj_weight", 89 | "model.layers.{i}.self_attn.k_proj.bias": "layers.{i}.self_attn.k_proj_bias", 90 | "model.layers.{i}.self_attn.v_proj.weight": "layers.{i}.self_attn.v_proj_weight", 91 | "model.layers.{i}.self_attn.v_proj.bias": "layers.{i}.self_attn.v_proj_bias", 92 | "model.layers.{i}.self_attn.o_proj.weight": "layers.{i}.self_attn.o_proj_weight", 93 | "model.layers.{i}.mlp.gate_proj.weight": "layers.{i}.mlp.gate_proj.weight", 94 | "model.layers.{i}.mlp.up_proj.weight": "layers.{i}.mlp.up_proj.weight", 95 | "model.layers.{i}.mlp.down_proj.weight": "layers.{i}.mlp.down_proj.weight", 96 | "model.layers.{i}.input_layernorm.weight": "layers.{i}.input_layernorm_weight", 97 | "model.layers.{i}.post_attention_layernorm.weight": "layers.{i}.post_attention_layernorm_weight", 98 | } 99 | 100 | # 根据 Transformer 层数量生成映射 101 | for i in range(model_config.num_layers): 102 | for hf_key, custom_key in layers.items(): 103 | mapped_key = hf_key.format(i=i) # hf 权重参数字典 key 104 | custom_mapped_key = custom_key.format(i=i) # 自定义模型权重参数字典 key 105 | mapping[mapped_key] = custom_mapped_key 106 | 107 | # 创建新的状态字典 108 | for hf_key, tensor in tqdm(hf_sd.items(), desc="Mapping weights"): 109 | custom_key = mapping.get(hf_key, None) 110 | hf_param = hf_sd[hf_key] 111 | new_param = new_sd[custom_key] 112 | 113 | if not torch.allclose(hf_param, new_param, rtol=rtol, atol=atol): 114 | print(f"hf 参数 {hf_key} 不匹配!") 115 | print(f"Hugging Face 权重: {hf_param}") 116 | print(f"自定义模型权重: {new_param}") 117 | all_match = False 118 | 119 | if all_match: 120 | print("所有权重完全匹配!") 121 | else: 122 | print("权重存在不匹配!") 123 | 124 | 125 | if __name__ == "__main__": 126 | device = "cuda" if torch.cuda.is_available() else "cpu" 127 | # 定义 Qwen2.5-3B 模型权重路径和配置参数 128 | original_model_path = "/gemini/pretrain/Qwen2.5-3B" 129 | my_model_path = "/gemini/code/Qwen2.5-3B-Instruct/" 130 | json_file_path = os.path.join(original_model_path, "config.json") # JSON 文件的路径 131 | model_config = load_config_from_json(json_file_path, device) # 加载配置 132 | 133 | # 加载原始 hf 模型权重 134 | original_model, tokenizer, hf_sd = load_original_llama(original_model_path, device) 135 | # 加载自定义模型权重 136 | custom_model = Qwen2Model(model_config) 137 | custom_model, new_sd = load_custom_llam(my_model_path, model_config, device) 138 | 139 | compare_model_weights(hf_sd, new_sd, model_config) 140 | 141 | for name, param in custom_model.named_parameters(): 142 | print(name, param.shape) 143 | -------------------------------------------------------------------------------- /tests/others/test_standard_mha.py: -------------------------------------------------------------------------------- 1 | # 代码可直接运行,用于测试标准 "MHA 层" 的结果 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | def __init__(self, embed_dim, num_heads): 10 | super(MultiHeadAttention, self).__init__() 11 | assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" 12 | 13 | self.embed_dim = embed_dim 14 | self.num_heads = num_heads 15 | self.head_dim = embed_dim // num_heads 16 | 17 | # 定义线性变换 18 | self.query = nn.Linear(embed_dim, embed_dim) 19 | self.key = nn.Linear(embed_dim, embed_dim) 20 | self.value = nn.Linear(embed_dim, embed_dim) 21 | 22 | self.out = nn.Linear(embed_dim, embed_dim) 23 | 24 | def forward(self, x, mask=None): 25 | batch_size, seq_length, embed_dim = x.size() 26 | 27 | # 线性变换并分成多头 28 | Q = ( 29 | self.query(x) 30 | .view(batch_size, seq_length, self.num_heads, self.head_dim) 31 | .transpose(1, 2) 32 | ) # (batch, heads, seq, head_dim) 33 | K = ( 34 | self.key(x) 35 | .view(batch_size, seq_length, self.num_heads, self.head_dim) 36 | .transpose(1, 2) 37 | ) 38 | V = ( 39 | self.value(x) 40 | .view(batch_size, seq_length, self.num_heads, self.head_dim) 41 | .transpose(1, 2) 42 | ) 43 | 44 | # 计算原始注意力分数, # (batch, heads, seq, seq) 45 | scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5) 46 | 47 | # 对 scores 应用 masked 48 | if mask is not None: 49 | masked_scores = scores.masked_fill(mask == 0, float("-inf")) 50 | 51 | # 归一化,将注意力权重分数转为概率分布 dim 维度值相加等于,对于2D张量即每行元素值相加等于 1 52 | attn_scores = F.softmax(masked_scores, dim=-1) # (batch, heads, seq, seq) 53 | # 加权求和 (batch, heads, seq, head_dim) 54 | context = torch.matmul(attn_scores, V) 55 | 56 | context = ( 57 | context.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim) 58 | ) 59 | out = self.out(context) # 最后的线性变换(batch, seq_length, embed_dim) 60 | 61 | print( 62 | f"mask 矩阵:\n {mask.squeeze()} \n" 63 | ) # 使用 torch.squeeze() 函数来移除张量中所有大小为 1 的维度 64 | print(f"原始的注意力分数矩阵:\n {scores.squeeze()} \n") 65 | print(f"应用 mask 后的注意力分数矩阵:\n {masked_scores.squeeze()} \n") 66 | print( 67 | f"使用 softmax 归一化后的掩码注意力分数矩阵:\n {attn_scores.squeeze()} \n" 68 | ) 69 | return out 70 | 71 | 72 | def generate_causal_mask(seq_length): 73 | """生成一个因果遮罩, 上三角为0, 下三角为1""" 74 | mask = ( 75 | torch.tril(torch.ones((seq_length, seq_length))).unsqueeze(0).unsqueeze(0) 76 | ) # (1, 1, seq, seq) 77 | return mask # 1表示可见,0表示遮蔽 78 | 79 | 80 | # 单元测试代码 81 | def test_multihead_attention( 82 | vocab_size=1000, batch_size=1, seq_length=4, embed_dim=6, num_heads=2 83 | ): 84 | embedding_layer = nn.Embedding( 85 | vocab_size, embed_dim 86 | ) # 将 input_ids 转为 embedding 向量 87 | mha_layer = MultiHeadAttention(embed_dim, num_heads) # 构建 MHA 模块 88 | 89 | torch.manual_seed(0) 90 | input_ids = torch.randint(vocab_size, [batch_size, seq_length]) # 构建输入数据 91 | mask = generate_causal_mask(seq_length) # 创建注意力 mask, 默认下三角矩阵(张量) 92 | 93 | h = embedding_layer(input_ids) 94 | output = mha_layer(h, mask) # MHA 前向传播 95 | assert output.shape == (batch_size, seq_length, embed_dim), "输出形状不正确" 96 | 97 | # 检查因果遮罩是否有效, 通过设置输入为单位矩阵,观察输出是否遵循因果遮罩 98 | x_identity = ( 99 | torch.eye(seq_length, embed_dim).unsqueeze(0).repeat(batch_size, 1, 1) 100 | ) # (batch, seq, embed) 101 | output_identity = mha_layer(x_identity, mask) 102 | 103 | # 由于输入是单位矩阵,输出应该保持某种结构,可以进行简单的检查 104 | assert not torch.isnan(output_identity).any(), "输出包含NaN值" 105 | 106 | print("多头注意力输出示例:") 107 | print(output) 108 | 109 | 110 | if __name__ == "__main__": 111 | test_multihead_attention() 112 | -------------------------------------------------------------------------------- /tests/test_torch_matmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import matplotlib.pyplot as plt 4 | import itertools 5 | from torch.utils.benchmark import Timer 6 | 7 | # 是否使用GPU进行测试(如果没有GPU则设为False) 8 | use_cuda = torch.cuda.is_available() 9 | device = "cuda" if use_cuda else "cpu" 10 | 11 | # 测试参数配置 12 | B_values = [1, 4, 8, 16] # B: 第1维度大小 13 | N_values = [32, 64, 128] # N: 第2维度大小 14 | D_in_values = [64, 128, 256] # D_in 15 | D_out_values = [64, 128, 256] # D_out 16 | 17 | results_matmul = {} 18 | results_linear = {} 19 | 20 | 21 | def benchmark_op(op, args): 22 | t = Timer(stmt="op(*args)", globals={"op": op, "args": args}) 23 | return t.blocked_autorange(min_run_time=0.1) 24 | 25 | 26 | # 开始测试 3D 输入情况 27 | # X: [B, N, D_in], W: [D_out, D_in], b: [D_out] 28 | # matmul: (X @ W.T) + b => [B, N, D_out] 29 | # linear: F.linear(X, W, b) => [B, N, D_out] 30 | 31 | for B, N, D_in, D_out in itertools.product( 32 | B_values, N_values, D_in_values, D_out_values 33 | ): 34 | X = torch.randn(B, N, D_in, device=device) 35 | W = torch.randn(D_out, D_in, device=device) 36 | b = torch.randn(D_out, device=device) 37 | 38 | # matmul 测试 39 | matmul_time = benchmark_op(lambda x, w, b: x @ w.T + b, (X, W, b)) 40 | # linear 测试 41 | linear_time = benchmark_op(lambda x, w, b: F.linear(x, w, b), (X, W, b)) 42 | 43 | results_matmul[(B, N, D_in, D_out)] = matmul_time.median 44 | results_linear[(B, N, D_in, D_out)] = linear_time.median 45 | 46 | # 可视化结果 47 | # 为了简化绘制,我们选定某一组 B, D_in, D_out 随 N 变化的性能对比曲线。 48 | fixed_B = 8 49 | fixed_D_in = 128 50 | fixed_D_out = 128 51 | 52 | filtered_N = [ 53 | n for n in N_values if (fixed_B, n, fixed_D_in, fixed_D_out) in results_matmul 54 | ] 55 | 56 | matmul_times = [ 57 | results_matmul[(fixed_B, n, fixed_D_in, fixed_D_out)] for n in filtered_N 58 | ] 59 | linear_times = [ 60 | results_linear[(fixed_B, n, fixed_D_in, fixed_D_out)] for n in filtered_N 61 | ] 62 | 63 | plt.figure(figsize=(8, 6)) 64 | plt.plot(filtered_N, matmul_times, marker="o", label="matmul (3D X)") 65 | plt.plot(filtered_N, linear_times, marker="s", label="F.linear (3D X)") 66 | plt.xlabel("N dimension size") 67 | plt.ylabel("Median time (s)") 68 | plt.title( 69 | f"Performance comparison at B={fixed_B}, D_in={fixed_D_in}, D_out={fixed_D_out}" 70 | ) 71 | plt.legend() 72 | plt.grid(True) 73 | plt.savefig("./result.png") 74 | --------------------------------------------------------------------------------