├── scripts ├── roofline │ ├── __init__.py │ └── proton-viewer.py ├── setups │ ├── mamba_ssu_0.conf │ ├── default.conf │ ├── prefix_tune_2d.conf │ ├── prefix_correctness.conf │ └── prefix_correctness_rocm.conf ├── callers │ ├── vllm_cuda_v1.py │ ├── triton_fp8.py │ ├── __init__.py │ ├── vllm_cuda_v2.py │ ├── base.py │ ├── xformers.py │ ├── flashinfer.py │ ├── triton_3d.py │ ├── triton_2d.py │ ├── baseline_triton.py │ ├── unified_triton.py │ ├── fused_triton.py │ ├── pytorch_native.py │ └── flash_attn.py ├── torch_utils.py ├── bench_vllm_user_range.py ├── offline_inference.py ├── bench_vllm_latency_range.py ├── bench_vllm_serve_avg.py └── profile_and_bench.py ├── requirements-lint.txt ├── ibm-triton-lib ├── ibm_triton_lib │ ├── utils │ │ ├── __init__.py │ │ └── triton_utils.py │ ├── kernels │ │ ├── legacy │ │ │ ├── fused_gqa_paged │ │ │ │ ├── __init__.py │ │ │ │ ├── utils.py │ │ │ │ └── sb_jit_func.py │ │ │ ├── __init__.py │ │ │ └── triton_chunked_prefill_paged_decode.py │ │ ├── dejavu_data │ │ │ └── dejavu_0.7 │ │ │ │ ├── triton_3.3.0 │ │ │ │ ├── cuda_12.4 │ │ │ │ │ └── gpu_NVIDIA_H100_80GB_HBM3 │ │ │ │ │ │ └── _selective_scan_update_kernel │ │ │ │ │ │ └── autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227 │ │ │ │ │ │ └── code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8 │ │ │ │ │ │ └── tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8 │ │ │ │ │ │ └── kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309 │ │ │ │ │ │ └── default │ │ │ │ │ │ └── cache.json │ │ │ │ └── rocm_torch_6.2.41134-65d174c3e │ │ │ │ │ └── gpu_AMD_Instinct_MI300X │ │ │ │ │ └── _selective_scan_update_kernel │ │ │ │ │ └── autotune_config-90178d0ab8e71db9cd16710d562763dd010643f28cd21980d5064c3ab782ecaa │ │ │ │ │ └── code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8 │ │ │ │ │ └── tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8 │ │ │ │ │ └── kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309 │ │ │ │ │ └── default │ │ │ │ │ └── cache.json │ │ │ │ └── triton_3.2.0 │ │ │ │ ├── rocm_6.3.1 │ │ │ │ └── gpu_AMD_Instinct_MI250X_MI250 │ │ │ │ │ └── attn_fwd │ │ │ │ │ └── autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227 │ │ │ │ │ └── code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96 │ │ │ │ │ └── tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9 │ │ │ │ │ └── kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8 │ │ │ │ │ └── default │ │ │ │ │ └── cache.json │ │ │ │ └── cuda_12.4 │ │ │ │ ├── gpu_NVIDIA_A100-SXM4-80GB │ │ │ │ └── attn_fwd │ │ │ │ │ └── autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227 │ │ │ │ │ └── code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96 │ │ │ │ │ └── tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9 │ │ │ │ │ └── kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8 │ │ │ │ │ └── default │ │ │ │ │ └── cache.json │ │ │ │ └── gpu_NVIDIA_H100_80GB_HBM3 │ │ │ │ └── attn_fwd │ │ │ │ └── autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227 │ │ │ │ └── code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96 │ │ │ │ └── tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9 │ │ │ │ └── kernel_configs-a70f97e8b3e7aaf9f4a4f7e850b935d2d1b3ad8cd6ad1d0843bb426e13694ae9 │ │ │ │ └── default │ │ │ │ └── cache.json │ │ └── __init__.py │ ├── __init__.py │ └── backend │ │ ├── __init__.py │ │ └── platform.py └── setup.py ├── doc ├── dev-env.png └── anatomy_of_a_triton_attention_kernel_ibm.pdf ├── .gitignore ├── .gitmodules ├── Makefile ├── .vscode └── launch.json ├── Dockerfile.hub ├── README.md ├── third_party └── vedantroy_paged_attention.py ├── Dockerfile └── LICENSE /scripts/roofline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-lint.txt: -------------------------------------------------------------------------------- 1 | black==24.10.0 2 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/dev-env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/vllm-triton-backend/HEAD/doc/dev-env.png -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/legacy/fused_gqa_paged/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_gqa_paged_splitkv import paged_attention_triton_3d 2 | -------------------------------------------------------------------------------- /doc/anatomy_of_a_triton_attention_kernel_ibm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/vllm-triton-backend/HEAD/doc/anatomy_of_a_triton_attention_kernel_ibm.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | 4 | results 5 | *.swp 6 | *.swo 7 | 8 | all-git.tar 9 | vllm-all.tar 10 | rocm-vllm-all.tar 11 | ShareGPT_V3_unfiltered_cleaned_split.json 12 | 13 | .vscode/settings.json 14 | 15 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "triton"] 2 | path = triton 3 | url = https://github.com/triton-lang/triton.git 4 | [submodule "triton-dejavu"] 5 | path = triton-dejavu 6 | url = https://github.com/IBM/triton-dejavu.git 7 | [submodule "vllm"] 8 | path = vllm 9 | url = https://github.com/vllm-project/vllm.git 10 | -------------------------------------------------------------------------------- /scripts/setups/mamba_ssu_0.conf: -------------------------------------------------------------------------------- 1 | BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] 2 | # BATCH_SIZES = [4] 3 | # order: num_query_heads, num_kv_heads 4 | NUM_HEADS = [[128, 128]] 5 | HEAD_SIZES = [64] 6 | STATE_DIM = [128] 7 | STATE_N_GROUPS = [1] 8 | HAS_INITIAL_STATE = ["True"] 9 | DTYPES = ["bfloat16"] 10 | 11 | BENCHMARK_MODES = ["CUDA_EVENTS"] 12 | 13 | IMPLEMENTATION_UT = ["BASELINE_TRITON"] # some value for now 14 | # IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_2D"] 15 | 16 | # TRITON_BACKEND_DEBUG = 1 17 | # STORE_TEST_RESULT_PATH=/results 18 | 19 | # TEST_ALLOW_INCORRECT = 1 20 | -------------------------------------------------------------------------------- /scripts/setups/default.conf: -------------------------------------------------------------------------------- 1 | DTYPES = ["float16"] 2 | SEEDS = [0] 3 | BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] 4 | # order: num_query_heads, num_kv_heads 5 | NUM_HEADS = [[32, 8]] 6 | 7 | SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] 8 | PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] 9 | PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] 10 | PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] 11 | 12 | HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 13 | # head_size * head_numbers = hidden_size 14 | 15 | BLOCK_SIZES = [16] 16 | NUM_BLOCKS = [4321] # "arbitrary values for testing..." 17 | 18 | PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] 19 | 20 | MAX_VALUES = [1.0] 21 | BENCHMARK_MODES = ["CUDA_EVENTS"] 22 | 23 | IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_3D", "UNF_TRITON_2D"] 24 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | __version__ = "0.0.1" 19 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/backend/__init__.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | import os 18 | 19 | 20 | def register(): 21 | """Register the triton attention platform.""" 22 | return "ibm_triton_lib.backend.platform.TritonPlatform" 23 | -------------------------------------------------------------------------------- /scripts/roofline/proton-viewer.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import sys 20 | import os 21 | 22 | sys.path.append(os.path.dirname(os.path.realpath(__file__))) 23 | 24 | from proton_viewer import main 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/legacy/fused_gqa_paged/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def is_fp8_dtype(t): 5 | return t == torch.float8_e4m3fn or t == torch.float8_e5m2 6 | 7 | 8 | # NOTE: this might be slow... avoid calling repeatedly 9 | def get_num_SMs(device): 10 | return torch.cuda.get_device_properties(device).multi_processor_count 11 | 12 | 13 | # Using rule based NUM_SPLITS computation 14 | # L should be the minimal kv length in a batch 15 | # BLOCK_L is the chosen block size 16 | def compute_split_l(L, BLOCK_L, P=1, device=None): 17 | NUM_SMs = 132 if device is None else get_num_SMs(device) 18 | if P >= NUM_SMs: 19 | # there's already enough parallelism 20 | # no need to further split L 21 | return 1 22 | 23 | # Find minimum num_splits that will result in enough triton programs 24 | # TODO: does num_splits need to be power of 2? 25 | num_splits = 1 26 | split_size = L 27 | while (num_splits * P < NUM_SMs) and (split_size > BLOCK_L): 28 | num_splits *= 2 29 | split_size = L // num_splits 30 | 31 | return num_splits 32 | -------------------------------------------------------------------------------- /scripts/setups/prefix_tune_2d.conf: -------------------------------------------------------------------------------- 1 | BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] 2 | # BATCH_SIZES = [4] 3 | # order: num_query_heads, num_kv_heads 4 | NUM_HEADS = [[32, 8]] 5 | 6 | SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] 7 | # SEQUENCE_LENGTHS = [64] 8 | # PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] 9 | PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5] 10 | # PREFIX_PREFILL_SHARE_OF_DECODE = [0.5] 11 | PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] 12 | # PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.5] 13 | # PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] 14 | PREFIX_PREFILL_BATCH_COMPOSITION = ["DEC_PRE"] 15 | 16 | HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 17 | # head_size * head_numbers = hidden_size 18 | 19 | BLOCK_SIZES = [16] 20 | NUM_BLOCKS = [4321] # "arbitrary values for testing..." 21 | 22 | PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] 23 | # PROMPT_PATTERNS = [[1.0]] 24 | 25 | MAX_VALUES = [1.0] 26 | BENCHMARK_MODES = ["CUDA_EVENTS"] 27 | 28 | IMPLEMENTATION_UT = ["UNF_TRITON_2D"] 29 | # IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_2D"] 30 | 31 | # TRITON_BACKEND_DEBUG = 1 32 | # STORE_TEST_RESULT_PATH=/results 33 | 34 | # TEST_ALLOW_INCORRECT = 1 35 | -------------------------------------------------------------------------------- /scripts/setups/prefix_correctness.conf: -------------------------------------------------------------------------------- 1 | # BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] 2 | BATCH_SIZES = [4] 3 | # order: num_query_heads, num_kv_heads 4 | NUM_HEADS = [[32, 8]] 5 | 6 | # SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] 7 | # SEQUENCE_LENGTHS = [64] 8 | SEQUENCE_LENGTHS = [1024] 9 | # PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] 10 | PREFIX_PREFILL_SHARE_OF_DECODE = [0.5] 11 | # PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] 12 | PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.5] 13 | PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] 14 | 15 | HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 16 | # head_size * head_numbers = hidden_size 17 | 18 | BLOCK_SIZES = [16] 19 | NUM_BLOCKS = [4321] # "arbitrary values for testing..." 20 | 21 | # PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] 22 | PROMPT_PATTERNS = [[1.0]] 23 | 24 | MAX_VALUES = [1.0] 25 | BENCHMARK_MODES = ["CUDA_EVENTS"] 26 | 27 | # IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_3D", "UNF_TRITON_2D"] 28 | # IMPLEMENTATION_UT = ["UNF_TRITON_3D", "UNF_TRITON_2D", "UNF_TRITON_AUTO"] 29 | # IMPLEMENTATION_UT = ["UNF_TRITON_2D"] 30 | IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_2D"] 31 | 32 | TRITON_BACKEND_DEBUG = 1 33 | -------------------------------------------------------------------------------- /scripts/setups/prefix_correctness_rocm.conf: -------------------------------------------------------------------------------- 1 | # BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] 2 | BATCH_SIZES = [4] 3 | # order: num_query_heads, num_kv_heads 4 | NUM_HEADS = [[32, 8]] 5 | 6 | # SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] 7 | # SEQUENCE_LENGTHS = [64] 8 | SEQUENCE_LENGTHS = [1024] 9 | # PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] 10 | PREFIX_PREFILL_SHARE_OF_DECODE = [0.5] 11 | # PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] 12 | PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.5] 13 | PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] 14 | 15 | HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 16 | # head_size * head_numbers = hidden_size 17 | 18 | BLOCK_SIZES = [128] 19 | NUM_BLOCKS = [4321] # "arbitrary values for testing..." 20 | 21 | # PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] 22 | PROMPT_PATTERNS = [[1.0]] 23 | 24 | MAX_VALUES = [1.0] 25 | BENCHMARK_MODES = ["CUDA_EVENTS"] 26 | 27 | # IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_3D", "UNF_TRITON_2D"] 28 | # IMPLEMENTATION_UT = ["UNF_TRITON_3D", "UNF_TRITON_2D", "UNF_TRITON_AUTO"] 29 | # IMPLEMENTATION_UT = ["UNF_TRITON_2D"] 30 | IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_2D"] 31 | 32 | TRITON_BACKEND_DEBUG = 1 33 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | from .triton_chunked_prefill_paged_decode import chunked_prefill_paged_decode 19 | from .triton_paged_decode_attention_2d import ( 20 | paged_attention_triton_2d as paged_attention_2d, 21 | ) 22 | from .triton_paged_decode_attention_3d import ( 23 | paged_attention_triton_3d as paged_attention_3d, 24 | ) 25 | from .fused_gqa_paged import ( 26 | paged_attention_triton_3d as paged_attention_fp8_3d, 27 | ) 28 | from .fused_chunked_prefill_paged_decode import ( 29 | fused_chunked_prefill_paged_decode as fused_chunked_prefill_paged_decode_25d, 30 | ) 31 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json: -------------------------------------------------------------------------------- 1 | { 2 | "signature": "JITFunction(ibm_triton_lib.kernels.mamba_ssm:_selective_scan_update_kernel)", 3 | "total_bench_time_s": 58.42541313171387, 4 | "evaluated_configs": 75, 5 | "keys": [ 6 | "dstate", 7 | "BLOCK_SIZE_DSTATE", 8 | "dim", 9 | "nheads_ngroups_ratio" 10 | ], 11 | "cache": { 12 | "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": "BLOCK_SIZE_M: 8, num_warps: 2, num_ctas: 1, num_stages: 6, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" 13 | }, 14 | "timings": { 15 | "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": [ 16 | 0.003274054965004325 17 | ] 18 | }, 19 | "timings_data": { 20 | "labels": [ 21 | "ms" 22 | ], 23 | "rep_t_ms": 100, 24 | "warmup_t_ms": 25, 25 | "cuda_graphs": true 26 | } 27 | } -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/_selective_scan_update_kernel/autotune_config-90178d0ab8e71db9cd16710d562763dd010643f28cd21980d5064c3ab782ecaa/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json: -------------------------------------------------------------------------------- 1 | { 2 | "signature": "JITFunction(ibm_triton_lib.kernels.mamba_ssm:_selective_scan_update_kernel)", 3 | "total_bench_time_s": 113.2074065208435, 4 | "evaluated_configs": 75, 5 | "keys": [ 6 | "dstate", 7 | "BLOCK_SIZE_DSTATE", 8 | "dim", 9 | "nheads_ngroups_ratio" 10 | ], 11 | "cache": { 12 | "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": "BLOCK_SIZE_M: 16, num_warps: 4, num_ctas: 1, num_stages: 6, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" 13 | }, 14 | "timings": { 15 | "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": [ 16 | 0.0050251600332558155 17 | ] 18 | }, 19 | "timings_data": { 20 | "labels": [ 21 | "ms" 22 | ], 23 | "rep_t_ms": 100, 24 | "warmup_t_ms": 25, 25 | "cuda_graphs": true 26 | } 27 | } -------------------------------------------------------------------------------- /scripts/callers/vllm_cuda_v1.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from vllm import _custom_ops as ops 21 | from .base import DecodeCaller 22 | 23 | 24 | class VllmCudaV1Caller(DecodeCaller): 25 | @staticmethod 26 | def make_call_func( 27 | output, 28 | query, 29 | key_cache, 30 | value_cache, 31 | num_seqs, 32 | seq_lens, 33 | max_seq_len, 34 | scale, 35 | block_tables, 36 | alibi_slopes, 37 | kv_cache_dtype, 38 | ): 39 | num_kv_heads = key_cache.shape[1] 40 | block_size = key_cache.shape[3] 41 | 42 | # Using default kv_scale 43 | k_scale = v_scale = torch.ones(1, device=query.device) 44 | 45 | call_func_under_test = lambda: ops.paged_attention_v1( 46 | output, 47 | query, 48 | key_cache, 49 | value_cache, 50 | num_kv_heads, 51 | scale, 52 | block_tables, 53 | seq_lens, 54 | block_size, 55 | max_seq_len, 56 | alibi_slopes, 57 | kv_cache_dtype, 58 | k_scale, 59 | v_scale, 60 | ) 61 | 62 | return call_func_under_test 63 | -------------------------------------------------------------------------------- /scripts/callers/triton_fp8.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from ibm_triton_lib.kernels.legacy import paged_attention_fp8_3d 21 | from .base import DecodeCaller 22 | 23 | 24 | class TritonFp8Caller(DecodeCaller): 25 | @staticmethod 26 | def make_call_func( 27 | output, 28 | query, 29 | key_cache, 30 | value_cache, 31 | num_seqs, 32 | seq_lens, 33 | max_seq_len, # unused 34 | scale, 35 | block_tables, 36 | alibi_slopes, 37 | kv_cache_dtype, # unused 38 | ): 39 | num_query_heads = query.shape[1] 40 | num_kv_heads = key_cache.shape[1] 41 | block_size = key_cache.shape[3] 42 | num_queries_per_kv = num_query_heads // num_kv_heads 43 | head_size = key_cache.shape[2] 44 | 45 | key_cache_ykt = key_cache.permute(0, 1, 3, 2).contiguous() 46 | value_cache_ykt = value_cache.permute(0, 1, 3, 2).contiguous() 47 | 48 | call_func_under_test = lambda: paged_attention_fp8_3d( 49 | output, 50 | query, 51 | key_cache_ykt, 52 | value_cache_ykt, 53 | scale, 54 | block_tables, 55 | seq_lens, 56 | alibi_slopes, 57 | block_size, 58 | num_seqs, 59 | num_query_heads, 60 | num_queries_per_kv, 61 | head_size, 62 | ) 63 | 64 | return call_func_under_test 65 | -------------------------------------------------------------------------------- /scripts/callers/__init__.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | try: 19 | from .flash_attn import ( 20 | FlashAttnDecodeCaller, 21 | FlashAttnPrefillCaller, 22 | FlashAttnPrefixPrefillCaller, 23 | ) 24 | except ModuleNotFoundError: 25 | pass 26 | 27 | try: 28 | from .xformers import XformersCaller 29 | except ModuleNotFoundError: 30 | # print("[benchmark callers] xformers not present, skipping..") 31 | pass 32 | 33 | try: 34 | from .vllm_cuda_v2 import VllmCudaV2Caller 35 | from .vllm_cuda_v1 import VllmCudaV1Caller 36 | from .baseline_triton import BaselineTritonCaller, BaselineTritonPrefixPrefillCaller 37 | except ModuleNotFoundError: 38 | pass 39 | 40 | from .triton_2d import Triton2dAttentionDecodeCaller, Triton2dChunkedPrefillCaller 41 | from .triton_3d import Triton3dAttentionDecodeCaller, Triton3dAttentionPrefillCaller 42 | from .triton_fp8 import TritonFp8Caller 43 | 44 | try: 45 | from .flashinfer import FlashInferCaller 46 | except (ModuleNotFoundError, ImportError): 47 | # print("[benchmark callers] flashinfer not present, skipping..") 48 | pass 49 | from .fused_triton import ( 50 | FusedTritonChunkedPrefixPrefill25dCaller, 51 | FusedTritonDecodeOnlyCaller, 52 | ) 53 | from .pytorch_native import PytorchNativeAttentionPrefillCaller 54 | 55 | from .unified_triton import ( 56 | UnifiedTriton2dAttentionCaller, 57 | UnifiedTriton3dAttentionCaller, 58 | UnifiedTritonAutoAttentionCaller, 59 | ) 60 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | # create fake module if triton-dejavu is not present 19 | # remove ASAP 20 | try: 21 | import triton_dejavu 22 | except ImportError: 23 | import sys 24 | 25 | class Fake_autotuner(object): 26 | 27 | def __init__(self, *args, **ignore_args): 28 | pass 29 | 30 | def __call__(self, *args, **kwds): 31 | pass 32 | 33 | def run(self, *args, **kwargs): 34 | print( 35 | "ERROR: triton-dejavu is called while not being installed. Please install triton-dejavu!" 36 | ) 37 | raise ImportError 38 | 39 | def __getitem__(self, grid): 40 | print( 41 | "ERROR: triton-dejavu is called while not being installed. Please install triton-dejavu!" 42 | ) 43 | raise ImportError 44 | 45 | class Fake_triton_dejavu(object): 46 | 47 | def autotune(*args, **kwargs): 48 | fake_decorator = lambda fn: Fake_autotuner(fn) 49 | return fake_decorator 50 | 51 | @staticmethod 52 | def ConfigSpace( 53 | kwargs_with_lists, 54 | kwarg_conditions=None, 55 | pre_hook=None, 56 | **configuration_args, 57 | ): 58 | pass 59 | 60 | sys.modules["triton_dejavu"] = Fake_triton_dejavu 61 | print( 62 | "WARNING: Created fake module to work-around missing triton-dejavu module. If you don't expect this warning, this is likely to become an error." 63 | ) 64 | 65 | from .triton_flash_attention import ( 66 | triton_wrapper_forward_prefill as prefill_flash_attention, 67 | ) 68 | 69 | from .triton_unified_attention import unified_attention 70 | 71 | from .mamba_ssm import selective_state_update 72 | -------------------------------------------------------------------------------- /scripts/callers/vllm_cuda_v2.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from vllm import _custom_ops as ops 21 | from .base import DecodeCaller 22 | 23 | PARTITION_SIZE = 512 24 | 25 | 26 | class VllmCudaV2Caller(DecodeCaller): 27 | @staticmethod 28 | def make_call_func( 29 | output, 30 | query, 31 | key_cache, 32 | value_cache, 33 | num_seqs, 34 | seq_lens, 35 | max_seq_len, 36 | scale, 37 | block_tables, 38 | alibi_slopes, 39 | kv_cache_dtype, 40 | ): 41 | block_size = key_cache.shape[3] 42 | 43 | # Using default kv_scale 44 | k_scale = v_scale = torch.ones(1, device=query.device) 45 | 46 | num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE 47 | assert PARTITION_SIZE % block_size == 0 48 | num_seqs, num_heads, head_size = output.shape 49 | tmp_output = torch.empty( 50 | size=(num_seqs, num_heads, num_partitions, head_size), 51 | dtype=output.dtype, 52 | ) 53 | exp_sums = torch.empty( 54 | size=(num_seqs, num_heads, num_partitions), 55 | dtype=torch.float32, 56 | ) 57 | max_logits = torch.empty_like(exp_sums) 58 | 59 | num_kv_heads = key_cache.shape[1] 60 | 61 | call_func_under_test = lambda: ops.paged_attention_v2( 62 | output, 63 | exp_sums, 64 | max_logits, 65 | tmp_output, 66 | query, 67 | key_cache, 68 | value_cache, 69 | num_kv_heads, 70 | scale, 71 | block_tables, 72 | seq_lens, 73 | block_size, 74 | max_seq_len, 75 | alibi_slopes, 76 | kv_cache_dtype, 77 | k_scale, 78 | v_scale, 79 | ) 80 | 81 | return call_func_under_test 82 | -------------------------------------------------------------------------------- /ibm-triton-lib/setup.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2024 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | from setuptools import setup 19 | import os 20 | import re 21 | 22 | PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) 23 | 24 | 25 | def find_version(filepath: str) -> str: 26 | """Extract version information from the given filepath. 27 | 28 | Adapted from https://github.com/vllm-project/vllm/blob/717f4bcea036a049e86802b3a05dd6f7cd17efc8/setup.py 29 | """ 30 | with open(filepath) as fp: 31 | version_match = re.search( 32 | r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M 33 | ) 34 | if version_match: 35 | return version_match.group(1) 36 | raise RuntimeError("Unable to find version string.") 37 | 38 | 39 | def package_files(directory): 40 | paths = [] 41 | for path, directories, filenames in os.walk(directory): 42 | for filename in filenames: 43 | paths.append(os.path.join("..", path, filename)) 44 | return paths 45 | 46 | 47 | dejavu_data = package_files("ibm_triton_lib/kernels/dejavu_data/") 48 | 49 | 50 | setup( 51 | name="ibm_triton_lib", 52 | version=find_version(os.path.join(PROJECT_ROOT, "ibm_triton_lib/__init__.py")), 53 | description="Triton-only backend for vLLM and Triton kernel library", 54 | # long_description=read(PROJECT_ROOT, "README.md"), 55 | # long_description_content_type="text/markdown", 56 | # author="Burkhard Ringlein, Tom Parnell, Jan van Lunteren, Chih Chieh Yang", 57 | python_requires=">=3.8", 58 | packages=[ 59 | "ibm_triton_lib", 60 | "ibm_triton_lib.utils", 61 | "ibm_triton_lib.kernels", 62 | "ibm_triton_lib.backend", 63 | "ibm_triton_lib.kernels.legacy", 64 | "ibm_triton_lib.kernels.legacy.fused_gqa_paged", 65 | ], 66 | package_data={ 67 | "ibm_triton_lib": dejavu_data, 68 | }, 69 | include_package_data=True, 70 | entry_points={ 71 | "vllm.platform_plugins": ["triton_attn = ibm_triton_lib.backend:register"] 72 | }, 73 | ) 74 | -------------------------------------------------------------------------------- /scripts/torch_utils.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | import timeit 21 | 22 | 23 | def get_gpu_label(): 24 | gpu_name = torch.cuda.get_device_name().replace(" ", "_").replace("/", "_") 25 | return gpu_name 26 | 27 | 28 | def pytorch_timer() -> float: 29 | # based on: https://github.com/pytorch/pytorch/blob/main/torch/utils/benchmark/utils/timer.py 30 | torch.cuda.synchronize() 31 | return timeit.default_timer() 32 | 33 | 34 | # TODO: move to triton-dejavu? 35 | def end2end_bench( 36 | fn, warmup=25, rep=100, quantiles=None, return_mode="mean", n_repeat_inner=1 37 | ): 38 | assert return_mode in ["min", "max", "mean", "median"] 39 | # JIT, if necessary 40 | fn() 41 | torch.cuda.synchronize() 42 | 43 | # to clear L2... 44 | cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") 45 | setup_l = lambda: cache.zero_() 46 | torch.cuda.synchronize() 47 | # setup_l = lambda: cache.zero_() 48 | # stmt_l = lambda: torch.cuda.synchronize(); fn(); torch.cuda.synchronize() 49 | 50 | timer = timeit.Timer(stmt=fn, setup=setup_l, timer=pytorch_timer) 51 | 52 | estimate_ms = (timer.timeit(5) / 5) * 1000 53 | # print(estimate_ms) 54 | n_warmup = max(1, int(warmup / estimate_ms)) 55 | n_repeat = max(1, int(rep / estimate_ms / n_repeat_inner)) 56 | # print(n_warmup) 57 | for _ in range(n_warmup): 58 | # only fn, no cache clear? 59 | # as done in triton.do_bench 60 | fn() 61 | # print(n_repeat) 62 | with torch.no_grad(): 63 | times_f = timer.repeat(repeat=n_repeat, number=n_repeat_inner) 64 | 65 | times_f_ms = [float(f * 1000.0 / n_repeat_inner) for f in times_f] 66 | times = torch.tensor(times_f_ms, dtype=torch.float) 67 | del cache 68 | if quantiles is not None: 69 | ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() 70 | if len(ret) == 1: 71 | ret = ret[0] 72 | return ret 73 | return getattr(torch, return_mode)(times).item() 74 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TAG := vllm-triton-backend-$(shell id -un) 2 | MAX_JOBS := 64 3 | 4 | SHELL := /bin/bash 5 | 6 | .PHONY: all build clean format dev rocm rocm-upstream pyupdate nightly bm-rocm spelling 7 | 8 | all: build 9 | 10 | vllm-all.tar: .git/modules/vllm/index 11 | @# cd vllm; git ls-files | xargs tar --mtime='1970-01-01' -cf ../vllm-all.tar 12 | cd vllm; git ls-files > .to-compress; tar -T .to-compress --mtime='1970-01-01 00:00:00' -W -cf ../vllm-all.tar; rm .to-compress 13 | 14 | all-git.tar: .git/HEAD 15 | cd .git; ls -A | xargs tar --mtime='1970-01-01' -cf ../all-git.tar 16 | 17 | ShareGPT_V3_unfiltered_cleaned_split.json: 18 | wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json 19 | 20 | 21 | dev: vllm-all.tar all-git.tar Dockerfile ShareGPT_V3_unfiltered_cleaned_split.json 22 | docker build --progress=plain --build-arg MAX_JOBS=$(MAX_JOBS) --build-arg VLLM_SOURCE=custom . -t ${TAG} 23 | @echo "Built docker image with tag: ${TAG}" 24 | 25 | nightly: Dockerfile.hub ShareGPT_V3_unfiltered_cleaned_split.json 26 | docker build --progress=plain --build-arg MAX_JOBS=$(MAX_JOBS) . -t ${TAG}-nightly -f Dockerfile.hub 27 | @echo "Built docker image with tag: ${TAG}-nightly" 28 | 29 | pyupdate: Dockerfile ShareGPT_V3_unfiltered_cleaned_split.json 30 | @echo "This build does only python updates, leaving vllm-all.tar all-git.tar (i.e. the vllm csrc) untouched!" 31 | docker build --progress=plain --build-arg MAX_JOBS=$(MAX_JOBS) --build-arg VLLM_SOURCE=custom . -t ${TAG} 32 | @echo "Built docker image with tag: ${TAG}" 33 | 34 | build: Dockerfile ShareGPT_V3_unfiltered_cleaned_split.json 35 | docker build --progress=plain --build-arg MAX_JOBS=$(MAX_JOBS) . -t ${TAG} 36 | @echo "Built docker image with tag: ${TAG}" 37 | 38 | rocm: Dockerfile.rocm vllm-all.tar all-git.tar ShareGPT_V3_unfiltered_cleaned_split.json 39 | docker build --progress=plain --build-arg MAX_JOBS=$(MAX_JOBS) --build-arg VLLM_SOURCE=submodule . -t ${TAG} -f Dockerfile.rocm 40 | @echo "Built docker image with tag: ${TAG}" 41 | 42 | # bare metal 43 | vllm/venv_rocm: 44 | @#cd vllm && python3 -m venv ./venv_rocm 45 | cd vllm && uv venv venv_rocm --python 3.12 46 | 47 | bm-rocm: | vllm/venv_rocm 48 | export VLLM_TARGET_DEVICE=rocm 49 | cd vllm && source ./venv_rocm/bin/activate && uv pip install -r requirements/rocm-build.txt && uv pip install -e . --no-build-isolation 50 | 51 | vllm/venv_cuda: 52 | cd vllm && uv venv venv_cuda --python 3.12 53 | 54 | bm-cuda: | vllm/venv_cuda 55 | cd vllm && source ./venv_cuda/bin/activate && VLLM_USE_PRECOMPILED=1 uv pip install --editable . 56 | 57 | 58 | clean: 59 | rm -f vllm-all.tar all-git.tar ShareGPT_V3_unfiltered_cleaned_split.json 60 | 61 | ifndef CI_ENABLED 62 | format: 63 | python -m black scripts ibm-triton-lib third_party 64 | else 65 | format: 66 | python -m black --check --verbose scripts ibm-triton-lib third_party 67 | endif 68 | 69 | spelling: 70 | codespell ./ibm-triton-lib ./triton-dejavu ./scripts 71 | 72 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/backend/platform.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | import os 19 | from functools import lru_cache, wraps 20 | from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, Union 21 | 22 | import torch 23 | from typing_extensions import ParamSpec 24 | 25 | # import custom ops, trigger op registration 26 | import vllm._C # noqa 27 | import vllm.envs as envs 28 | from vllm.logger import init_logger 29 | 30 | 31 | from vllm.platforms import Platform, PlatformEnum 32 | 33 | if not torch.version.hip: 34 | from vllm.platforms.cuda import CudaPlatform 35 | else: 36 | from vllm.platforms.rocm import RocmPlatform 37 | 38 | 39 | from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum, _Backend 40 | 41 | if TYPE_CHECKING: 42 | from vllm.config import VllmConfig 43 | else: 44 | VllmConfig = None 45 | 46 | logger = init_logger(__name__) 47 | 48 | 49 | # pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models 50 | # see https://github.com/huggingface/diffusers/issues/9704 for details 51 | torch.backends.cuda.enable_cudnn_sdp(False) 52 | 53 | 54 | if not torch.version.hip: 55 | # CudaPlatform is a constant, not a class, but it dynamically decdes between Nvml and NonNVML class 56 | # so we should inherit from this 57 | class TritonPlatform(CudaPlatform): 58 | 59 | @classmethod 60 | def get_attn_backend_cls( 61 | cls, 62 | selected_backend, 63 | head_size, 64 | dtype, 65 | kv_cache_dtype, 66 | block_size, 67 | use_v1, 68 | use_mla, 69 | ) -> str: 70 | if not envs.VLLM_USE_V1: 71 | raise RuntimeError("vllm-triton-backend plugin only supports vLLM V1") 72 | return "ibm_triton_lib.backend.triton_attn.TritonAttentionBackend" 73 | 74 | else: 75 | 76 | class TritonPlatform(RocmPlatform): 77 | 78 | @classmethod 79 | def get_attn_backend_cls( 80 | cls, 81 | selected_backend, 82 | head_size, 83 | dtype, 84 | kv_cache_dtype, 85 | block_size, 86 | use_v1, 87 | use_mla, 88 | ) -> str: 89 | if not envs.VLLM_USE_V1: 90 | raise RuntimeError("vllm-triton-backend plugin only supports vLLM V1") 91 | return "ibm_triton_lib.backend.triton_attn.TritonAttentionBackend" 92 | -------------------------------------------------------------------------------- /scripts/bench_vllm_user_range.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | import os 19 | import sys 20 | import torch 21 | from datetime import datetime 22 | 23 | 24 | def create_dir_if_not_exist_recursive(path, mode=0o777): 25 | norm_path = os.path.normpath(path) 26 | paths_l = norm_path.split(os.sep) 27 | path_walked = f"{os.sep}" 28 | for p in paths_l: 29 | if len(p) == 0: 30 | continue 31 | path_walked = os.path.join(path_walked, p) 32 | create_dir_if_not_exist(path_walked, mode) 33 | 34 | 35 | def create_dir_if_not_exist(path, mode=0o777): 36 | if not os.path.exists(path): 37 | os.mkdir(path) 38 | try: 39 | os.chmod(path, mode) 40 | except PermissionError as e: 41 | print(f"can't set permission of directory {path}: {e}") 42 | 43 | 44 | num_users_to_test = [1, 2, 4, 8, 16, 32, 64, 128] 45 | gpu_name = torch.cuda.get_device_name().replace(" ", "_").replace("/", "_") 46 | 47 | # model = "/model/llama3.1-8b/instruct/" 48 | model = sys.argv[1] 49 | model_path = f"/models/{model}/" 50 | testcase_name = sys.argv[2] 51 | 52 | # max_rounds = 128 53 | max_rounds = 64 54 | max_num_prompts = 1000 55 | 56 | timestamp_f = datetime.now().strftime("%Y-%m-%d_%H%M") 57 | 58 | # result_dir = f"/results/{model.replace('/','-')}/{gpu_name}/{testcase_name}" 59 | result_dir = ( 60 | f"/results/{model.replace('/','-')}/{gpu_name}/{testcase_name}/exp_{timestamp_f}/" 61 | ) 62 | 63 | # os.system(f"mkdir -p {result_dir}") 64 | create_dir_if_not_exist_recursive(result_dir) 65 | 66 | for max_concurrency in num_users_to_test: 67 | num_prompts = ( 68 | max_num_prompts 69 | if max_num_prompts // max_concurrency < max_rounds 70 | else int(max_rounds * max_concurrency) 71 | ) 72 | cmd = ( 73 | f"VLLM_USE_V1=1 python /workspace/benchmarks/benchmark_serving.py " 74 | f"--model {model_path} " 75 | f"--dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json " 76 | f"--save-result --result-dir {result_dir} --max-concurrency {max_concurrency} " 77 | f"--percentile-metrics ttft,tpot,itl,e2el --metric-percentiles 20,50,80,99 " 78 | f"--num-prompts {num_prompts} " 79 | ) 80 | print(cmd) 81 | rv = os.system(cmd) 82 | if rv != 0: 83 | print(f"benchmark command returned {rv}, stopping...") 84 | break 85 | 86 | print(f"results stored in: {result_dir}") 87 | os.system(f"ls -alh {result_dir}") 88 | -------------------------------------------------------------------------------- /scripts/offline_inference.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import os 20 | import time 21 | 22 | # to enable debug printing 23 | # os.environ["TRITON_BACKEND_DEBUG"] = "1" 24 | 25 | # to use triton_attn backend 26 | os.environ["VLLM_USE_V1"] = "1" 27 | os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN_VLLM_V1" 28 | # os.environ["VLLM_TRITON_ENABLE_JITCACHE"] = "1" 29 | os.environ["VLLM_TRITON_ENABLE_JITCACHE"] = "0" 30 | 31 | # enable torch profiler, can also be set on cmd line 32 | # enable_profiling = True 33 | enable_profiling = False 34 | 35 | if enable_profiling: 36 | os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_torch_profile" 37 | 38 | 39 | if __name__ == "__main__": 40 | from vllm import LLM, SamplingParams 41 | from vllm.distributed import cleanup_dist_env_and_memory 42 | 43 | llm = LLM( 44 | model="./models/hf/meta-llama/Llama-3.1-8B-Instruct/main/", 45 | # max_model_len=2048, 46 | # enforce_eager=True, 47 | enable_prefix_caching=False, 48 | ) 49 | 50 | # batch_size = 32 51 | max_tokens = 20 52 | 53 | sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) 54 | # ignore_eos=True) 55 | 56 | prompts = [ 57 | "Zurich is a beautiful city with", 58 | "San Francisco is a large city with", 59 | # "Provide a list of instructions for preparing chicken soup for a family " 60 | # "of four.", 61 | # "Skating and cross country skiing technique differ in", 62 | ] 63 | 64 | print( 65 | f"SETUP: vllm backend: {os.environ['VLLM_ATTENTION_BACKEND']} " 66 | f" JITCache: {os.environ['VLLM_TRITON_ENABLE_JITCACHE']} " 67 | ) 68 | print(f"Inference with {len(prompts)} prompts...") 69 | if enable_profiling: 70 | llm.start_profile() 71 | t0 = time.time() 72 | # outputs = llm.generate(prompts, sampling_params) 73 | outputs = [] 74 | for prompt in prompts: 75 | outputs.append(llm.generate(prompt, sampling_params)) 76 | 77 | if enable_profiling: 78 | llm.stop_profile() 79 | t1 = time.time() 80 | 81 | print(f"inference time: {t1-t0:.5f}s") 82 | 83 | for output in outputs: 84 | output = output[0] # in case of loop above 85 | prompt = output.prompt 86 | generated_text = output.outputs[0].text 87 | print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") 88 | 89 | # Add a buffer to wait for profiler in the background process 90 | # (in case MP is on) to finish writing profiling output. 91 | time.sleep(10) 92 | -------------------------------------------------------------------------------- /scripts/callers/base.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | class DecodeCaller: 20 | @staticmethod 21 | def make_call_func( 22 | output, 23 | query, 24 | key_cache, 25 | value_cache, 26 | num_seqs, 27 | seq_lens, 28 | max_seq_len, 29 | scale, 30 | block_tables, 31 | alibi_slopes, 32 | kv_cache_dtype, 33 | ): 34 | raise NotImplementedError 35 | 36 | @classmethod 37 | def select_output(cls, x, y): 38 | if cls.requires_allocated_output(): 39 | # default behaviour is in-place 40 | return x 41 | else: 42 | return y 43 | 44 | @staticmethod 45 | def requires_allocated_output() -> bool: 46 | # default behaviour is in-place -> so yes 47 | return True 48 | 49 | 50 | class PrefillCaller: 51 | @staticmethod 52 | def make_call_func( 53 | output, 54 | query, 55 | key_cache, 56 | value_cache, 57 | cu_seqlens_q, 58 | cu_seqlens_k, 59 | max_seqlen_q, 60 | max_seqlen_k, 61 | softmax_scale, 62 | causal, 63 | # kv_cache_dtype, # unused 64 | ): 65 | raise NotImplementedError 66 | 67 | @classmethod 68 | def select_output(cls, x, y): 69 | if cls.requires_allocated_output(): 70 | # default behaviour is in-place 71 | return x 72 | else: 73 | return y 74 | 75 | @staticmethod 76 | def requires_allocated_output() -> bool: 77 | # default behaviour is in-place -> so yes 78 | return True 79 | 80 | 81 | class PrefixPrefillCaller: 82 | @staticmethod 83 | def make_call_func( 84 | output, 85 | query, 86 | key_cache, 87 | value_cache, 88 | key, 89 | value, 90 | block_tables, 91 | seq_lens, 92 | ctx_lens, 93 | query_lens, 94 | start_loc, 95 | seq_start_loc, 96 | softmax_scale, 97 | # kv_cache_dtype, # unused 98 | ): 99 | raise NotImplementedError 100 | 101 | @classmethod 102 | def select_output(cls, x, y): 103 | if cls.requires_allocated_output(): 104 | # default behaviour is in-place 105 | return x 106 | else: 107 | return y 108 | 109 | @staticmethod 110 | def requires_allocated_output() -> bool: 111 | # default behaviour is in-place -> so yes 112 | return True 113 | -------------------------------------------------------------------------------- /scripts/callers/xformers.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from xformers.ops import fmha 21 | from .base import DecodeCaller 22 | 23 | 24 | class XformersCaller(DecodeCaller): 25 | @staticmethod 26 | def make_call_func( 27 | output, 28 | query, 29 | key_cache, 30 | value_cache, 31 | num_seqs, 32 | seq_lens, 33 | max_seq_len, 34 | scale, # unused 35 | block_tables, 36 | alibi_slopes, # unused 37 | kv_cache_dtype, # unused 38 | ): 39 | num_blocks = key_cache.shape[0] 40 | block_size = key_cache.shape[3] 41 | 42 | def transform_kv_cache(x): 43 | assert x.shape[0] == num_blocks 44 | assert x.shape[3] == block_size 45 | 46 | out = torch.empty( 47 | 1, x.shape[0] * x.shape[3], x.shape[1], x.shape[2], dtype=x.dtype 48 | ) 49 | 50 | for block_idx in range(x.shape[0]): 51 | for token_idx in range(x.shape[3]): 52 | out[0, block_idx * x.shape[3] + token_idx, :, :] = x[ 53 | block_idx, :, :, token_idx 54 | ] 55 | 56 | return out 57 | 58 | key_cache_xformers = transform_kv_cache(key_cache) 59 | value_cache_xformers = transform_kv_cache(value_cache) 60 | 61 | block_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask 62 | 63 | attn_bias = block_type.from_seqlens( 64 | q_seqlen=[1] * num_seqs, 65 | kv_padding=max_seq_len, 66 | kv_seqlen=seq_lens.tolist(), 67 | ) 68 | 69 | make_paged_kwargs = { 70 | "paged_type": fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, 71 | } 72 | 73 | attn_bias_paged = attn_bias.make_paged( 74 | block_tables=block_tables, page_size=block_size, **make_paged_kwargs 75 | ) 76 | op = fmha.triton_splitk.FwOp 77 | op.BLOCK_N = block_size 78 | 79 | call_func_under_test = lambda: fmha.memory_efficient_attention_forward( 80 | query.view(1, query.shape[0], query.shape[1], query.shape[2]), 81 | key_cache_xformers, 82 | value_cache_xformers, 83 | attn_bias_paged, 84 | op=op, 85 | ) 86 | 87 | return call_func_under_test 88 | 89 | @classmethod 90 | def select_output(cls, x, y): 91 | return y.view(y.shape[1], y.shape[2], y.shape[3]) 92 | 93 | @staticmethod 94 | def requires_allocated_output() -> bool: 95 | return False 96 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "pytest container", 9 | "type": "debugpy", 10 | "request": "attach", 11 | "connect": { 12 | "host": "172.73.0.73", 13 | //"host": "localhost", 14 | "port": 5679 15 | }, 16 | "pathMappings": [ 17 | { 18 | "localRoot": "${workspaceFolder}/vllm/vllm", 19 | "remoteRoot": "/usr/local/lib/python3.12/dist-packages/vllm/" 20 | }, 21 | { 22 | "localRoot": "${workspaceFolder}/triton/python/triton", 23 | "remoteRoot": "/opt/runtime/lib64/python3.12/site-packages/triton" 24 | }, 25 | { 26 | "localRoot": "${workspaceFolder}/triton/third_party/nvidia/", 27 | "remoteRoot": "/opt/runtime/lib64/python3.12/site-packages//triton/backends/nvidia/" 28 | }, 29 | { 30 | "localRoot": "${workspaceFolder}/scripts", 31 | "remoteRoot": "/scripts" 32 | }, 33 | { 34 | "localRoot": "${workspaceFolder}/ibm-triton-lib/ibm_triton_lib/", 35 | "remoteRoot": "/opt/runtime/lib64/python3.12/site-packages/ibm_triton_lib/" 36 | }, 37 | { 38 | "localRoot": "${workspaceFolder}/triton-dejavu/triton_dejavu/", 39 | "remoteRoot": "/opt/runtime/lib64/python3.12/site-packages/triton_dejavu" 40 | }, 41 | ], 42 | "justMyCode": false, 43 | }, 44 | { 45 | "name": "bare metal vllm", 46 | "type": "debugpy", 47 | "request": "attach", 48 | "connect": { 49 | //"host": "172.73.0.73", 50 | "host": "localhost", 51 | "port": 5679 52 | }, 53 | "pathMappings": [ 54 | { 55 | "localRoot": "${workspaceFolder}/vllm/vllm", 56 | "remoteRoot": "/mnt/nvme5n1p1/zrlngl/fmaas/vllm-triton-backend/vllm/venv_rocm/lib/python3.10/site-packages/vllm-0.1.dev6359+g72d9858.d20250509.rocm624-py3.10-linux-x86_64.egg/vllm/" 57 | }, 58 | { 59 | "localRoot": "${workspaceFolder}/triton/python/triton", 60 | "remoteRoot": "/mnt/nvme5n1p1/zrlngl/fmaas/vllm-triton-backend/vllm/venv_rocm/lib/python3.10/site-packages/triton/" 61 | }, 62 | //{ 63 | // "localRoot": "${workspaceFolder}/triton/third_party/nvidia/", 64 | // "remoteRoot": "/opt/runtime/lib64/python3.12/site-packages//triton/backends/nvidia/" 65 | //}, 66 | //{ 67 | // "localRoot": "${workspaceFolder}/scripts", 68 | // "remoteRoot": "/scripts" 69 | //}, 70 | //{ 71 | // "localRoot": "${workspaceFolder}/ibm-triton-lib/ibm_triton_lib/", 72 | // "remoteRoot": "/opt/runtime/lib64/python3.12/site-packages/ibm_triton_lib/" 73 | //}, 74 | //{ 75 | // "localRoot": "${workspaceFolder}/triton-dejavu/triton_dejavu/", 76 | // "remoteRoot": "/opt/runtime/lib64/python3.12/site-packages/triton_dejavu" 77 | //}, 78 | ], 79 | "justMyCode": false, 80 | } 81 | ] 82 | } -------------------------------------------------------------------------------- /scripts/callers/flashinfer.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | import flashinfer 21 | from .base import DecodeCaller 22 | 23 | 24 | class FlashInferCaller(DecodeCaller): 25 | @staticmethod 26 | def make_call_func( 27 | output, 28 | query, 29 | key_cache, 30 | value_cache, 31 | num_seqs, 32 | seq_lens, 33 | max_seq_len, # unused 34 | scale, 35 | block_tables, 36 | alibi_slopes, # unused 37 | kv_cache_dtype, # unused 38 | ): 39 | num_blocks = key_cache.shape[0] 40 | num_query_heads = query.shape[1] 41 | num_kv_heads = key_cache.shape[1] 42 | block_size = key_cache.shape[3] 43 | head_size = key_cache.shape[2] 44 | 45 | def transform_kv_cache(x): 46 | out = torch.transpose(x, 1, 3) 47 | out = torch.transpose(out, 2, 3) 48 | return out.contiguous() 49 | 50 | key_cache_flashinfer = transform_kv_cache(key_cache).unsqueeze(1) 51 | value_cache_flashinfer = transform_kv_cache(value_cache).unsqueeze(1) 52 | 53 | key_value_cache = torch.cat( 54 | (key_cache_flashinfer, value_cache_flashinfer), 1 55 | ).contiguous() 56 | 57 | kv_indptr = [0] 58 | kv_indices = [] 59 | kv_last_page_lens = [] 60 | for i in range(num_seqs): 61 | seq_len = seq_lens[i] 62 | assert seq_len > 0 63 | num_blocks = (seq_len + block_size - 1) // block_size 64 | kv_indices.extend(block_tables[i, :num_blocks]) 65 | kv_indptr.append(kv_indptr[-1] + num_blocks) 66 | kv_last_page_len = seq_len % block_size 67 | if kv_last_page_len == 0: 68 | kv_last_page_len = block_size 69 | kv_last_page_lens.append(kv_last_page_len) 70 | 71 | kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) 72 | kv_indices = torch.tensor(kv_indices, dtype=torch.int32) 73 | kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) 74 | 75 | workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8) 76 | wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( 77 | workspace_buffer, 78 | "NHD", 79 | use_tensor_cores=((num_query_heads // num_kv_heads) > 4), 80 | ) 81 | wrapper.begin_forward( 82 | kv_indptr, 83 | kv_indices, 84 | kv_last_page_lens, 85 | num_query_heads, 86 | num_kv_heads, 87 | head_size, 88 | block_size, 89 | "NONE", 90 | data_type=query.dtype, 91 | ) 92 | 93 | call_func_under_test = lambda: wrapper.forward( 94 | query, key_value_cache, logits_soft_cap=None 95 | ) 96 | 97 | return call_func_under_test 98 | 99 | @staticmethod 100 | def requires_allocated_output() -> bool: 101 | return False 102 | -------------------------------------------------------------------------------- /Dockerfile.hub: -------------------------------------------------------------------------------- 1 | ## Global Args ################################################################# 2 | ARG PYTHON_VERSION=3.12 3 | ARG MAX_JOBS=64 4 | 5 | 6 | ## Runtime ################################################################# 7 | FROM rocm/vllm-dev:nightly AS runtime 8 | 9 | ENV VIRTUAL_ENV=/usr/local/ 10 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 11 | 12 | # RUN pip install --no-cache -U pip wheel uv 13 | 14 | # # swig is required by triton-dejavu (SMAC optimizer) 15 | # # SWIG rpm not available for RHEL9 16 | # RUN microdnf install -y wget tar zlib-devel automake g++ && microdnf clean all 17 | # RUN wget https://downloads.sourceforge.net/project/swig/swig/swig-3.0.12/swig-3.0.12.tar.gz && \ 18 | # tar -xzf swig-3.0.12.tar.gz && \ 19 | # cd swig-3.0.12 && \ 20 | # bash autogen.sh && \ 21 | # wget https://downloads.sourceforge.net/project/pcre/pcre/8.45/pcre-8.45.tar.gz && \ 22 | # bash Tools/pcre-build.sh && \ 23 | # bash ./configure && \ 24 | # make && \ 25 | # make install 26 | 27 | COPY vllm/vllm /usr/local/bin/python${PYTHON_VERSION}/site-packages/vllm/ 28 | 29 | 30 | WORKDIR /workspace 31 | 32 | # # copy requirements explicitly before to avoid reinstall 33 | # COPY triton-dejavu/requirements-opt.txt dejavu-requirements-opt.txt 34 | # # RUN --mount=type=cache,target=/root/.cache/pip \ 35 | # # --mount=type=cache,target=/root/.cache/uv \ 36 | # # uv pip install -r dejavu-requirements-opt.txt \ 37 | # # && rm -f dejavu-requirements-opt.txt 38 | # RUN pip install -r dejavu-requirements-opt.txt \ 39 | # && rm -f dejavu-requirements-opt.txt 40 | # 41 | # dejavu 42 | COPY triton-dejavu triton-dejavu 43 | # RUN --mount=type=cache,target=/root/.cache/pip \ 44 | # --mount=type=cache,target=/root/.cache/uv \ 45 | # uv pip install ./triton-dejavu/ \ 46 | # && rm -rf ./triton-dejavu/ 47 | RUN pip install ./triton-dejavu/ \ 48 | && rm -rf ./triton-dejavu/ 49 | 50 | # # Install IBM kernels and vllm plugin 51 | # # must be after vllm! 52 | # COPY ibm-triton-lib ibm-triton-lib 53 | # RUN --mount=type=cache,target=/root/.cache/pip \ 54 | # --mount=type=cache,target=/root/.cache/uv \ 55 | # uv pip install ./ibm-triton-lib \ 56 | # && rm -rf ibm-triton-lib 57 | 58 | ## Benchmarking ################################################################# 59 | FROM runtime AS benchmark 60 | 61 | WORKDIR /workspace 62 | 63 | # RUN microdnf install -y git nano gcc vim \ 64 | # && microdnf clean all 65 | 66 | # RUN --mount=type=cache,target=/root/.cache/pip \ 67 | # --mount=type=cache,target=/root/.cache/uv \ 68 | # uv pip install pytest llnl-hatchet debugpy 69 | RUN pip install pytest llnl-hatchet debugpy 70 | 71 | # RUN ln -s ${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_cupti/lib/libcupti.so.12 ${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_cupti/lib/libcupti.so 72 | 73 | # RUN --mount=type=cache,target=/root/.cache/pip \ 74 | # --mount=type=cache,target=/root/.cache/uv \ 75 | # git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness && cd lm-evaluation-harness && uv pip install . 76 | RUN git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness && cd lm-evaluation-harness && pip install . 77 | 78 | ENV STORE_TEST_RESULT_PATH=/results 79 | 80 | # copy vllm benchmarks 81 | COPY vllm/benchmarks benchmarks 82 | COPY ShareGPT_V3_unfiltered_cleaned_split.json ShareGPT_V3_unfiltered_cleaned_split.json 83 | 84 | # Copy thid-party kernels and insert into path 85 | COPY third_party third_party 86 | ENV PYTHONPATH /workspace 87 | 88 | # see https://github.com/IBM/triton-dejavu?tab=readme-ov-file#environment-variables 89 | ENV TRITON_PRINT_AUTOTUNING=1 90 | ENV TRITON_DEJAVU_DEBUG=1 91 | # set as default 92 | ENV TRITON_DEJAVU_STORAGE=/workspace 93 | ENV NGL_EXP_FALLBACK=next 94 | ENV TRITON_DEJAVU_FORCE_FALLBACK=1 95 | ENV TRITON_DEJAVU_TAG='default' 96 | ENV TRITON_DEJAVU_HASH_SEARCH_PARAMS=0 97 | 98 | # open debugpy port 99 | EXPOSE 5679 100 | 101 | ENTRYPOINT ["python"] 102 | -------------------------------------------------------------------------------- /scripts/bench_vllm_latency_range.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | import os 19 | import json 20 | import sys 21 | import torch 22 | from datetime import datetime 23 | 24 | 25 | def create_dir_if_not_exist_recursive(path, mode=0o777): 26 | norm_path = os.path.normpath(path) 27 | paths_l = norm_path.split(os.sep) 28 | path_walked = f"{os.sep}" 29 | for p in paths_l: 30 | if len(p) == 0: 31 | continue 32 | path_walked = os.path.join(path_walked, p) 33 | create_dir_if_not_exist(path_walked, mode) 34 | 35 | 36 | def create_dir_if_not_exist(path, mode=0o777): 37 | if not os.path.exists(path): 38 | os.mkdir(path) 39 | try: 40 | os.chmod(path, mode) 41 | except PermissionError as e: 42 | print(f"can't set permission of directory {path}: {e}") 43 | 44 | 45 | if len(sys.argv) < 4: 46 | print(f"Usage: {sys.argv[0]} ") 47 | 48 | selected_batch_sizes = [1] # [4, 16, 32] #,128] 49 | selected_input_lengths = [500] # , 1000, 1500, 2000, 4000, 8000, 16000] 50 | selected_output_lengths = [10, 100, 200, 400, 800, 1600, 3200, 6400, 12800] 51 | 52 | gpu_name = torch.cuda.get_device_name().replace(" ", "_").replace("/", "_") 53 | 54 | # model = "/model/llama3.1-8b/instruct/" 55 | model = sys.argv[1] 56 | testcase_name = sys.argv[2] 57 | result_path = os.path.abspath(sys.argv[3]) 58 | 59 | # max_rounds = 128 60 | max_rounds = 64 61 | max_num_prompts = 1000 62 | 63 | timestamp_f = datetime.now().strftime("%Y-%m-%d_%H%M") 64 | 65 | # result_dir = f"/results/{model.replace('/','-')}/{gpu_name}/{testcase_name}" 66 | result_dir = f"{result_path}/{model.replace('/','-')}/{gpu_name}/{testcase_name}/exp_{timestamp_f}/" 67 | 68 | # os.system(f"mkdir -p {result_dir}") 69 | create_dir_if_not_exist_recursive(result_dir) 70 | 71 | bench_script = "/workspace/benchmarks/benchmark_latency.py" 72 | if not os.path.isfile(bench_script): 73 | bench_script = "vllm-triton-backend/vllm/benchmarks/benchmark_latency.py" 74 | if not os.path.isfile(bench_script): 75 | print(f"can't find benchmark script benchmark_latency.py") 76 | exit(-1) 77 | 78 | # Assisted by watsonx Code Assistant 79 | from itertools import zip_longest 80 | 81 | zipped_lists = list( 82 | zip_longest( 83 | selected_batch_sizes, 84 | selected_input_lengths, 85 | selected_output_lengths, 86 | fillvalue=None, 87 | ) 88 | ) 89 | 90 | print(zipped_lists) 91 | 92 | 93 | for bs, il, ol in zipped_lists: 94 | print( 95 | f"====== Measuring batch_size {bs}, input length {il}, output length {ol} =====" 96 | ) 97 | json_file_name = f"{result_dir}/result_bs_{bs}_il_{il}_ol_{ol}.json" 98 | cmd = ( 99 | f"VLLM_USE_V1=1 python {bench_script} " 100 | f"--model {model} " 101 | f"--input-len {il} --output-len {ol} --batch-size {bs} " 102 | f"--output-json {json_file_name}" 103 | ) 104 | print(cmd) 105 | rv = os.system(cmd) 106 | if rv != 0: 107 | print(f"benchmark command returned {rv}, stopping...") 108 | exit(rv) 109 | 110 | print(f"results stored in: {result_dir}") 111 | os.system(f"ls -alh {result_dir}") 112 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_chunked_prefill_paged_decode.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | import triton 19 | 20 | from .triton_prefix_prefill import context_attention_fwd 21 | from .triton_paged_decode_attention_2d import kernel_paged_attention_2d 22 | 23 | 24 | def next_power_of_2(x): 25 | return 1 << (x - 1).bit_length() 26 | 27 | 28 | def chunked_prefill_paged_decode( 29 | query, 30 | key, 31 | value, 32 | output, 33 | kv_cache_dtype, 34 | key_cache, 35 | value_cache, 36 | block_table, 37 | query_start_loc, 38 | seq_lens, 39 | max_query_len, 40 | k_scale, 41 | v_scale, 42 | alibi_slopes, 43 | sliding_window, 44 | scale, 45 | ): 46 | 47 | use_alibi_slopes = alibi_slopes is not None 48 | 49 | context_attention_fwd( 50 | q=query, 51 | k=key, 52 | v=value, 53 | o=output, 54 | kv_cache_dtype=kv_cache_dtype, 55 | k_cache=key_cache, 56 | v_cache=value_cache, 57 | b_loc=block_table, 58 | b_start_loc=query_start_loc, 59 | b_seq_len=seq_lens, 60 | max_input_len=max_query_len, 61 | k_scale=k_scale, 62 | v_scale=v_scale, 63 | alibi_slopes=alibi_slopes, 64 | sliding_window=sliding_window, 65 | sm_scale=scale, 66 | ) 67 | 68 | block_size = value_cache.shape[3] 69 | num_seqs = len(seq_lens) 70 | num_query_heads = query.shape[1] 71 | num_queries_per_kv = query.shape[1] // key.shape[1] 72 | head_size = query.shape[2] 73 | num_queries_per_kv_padded = max(next_power_of_2(num_queries_per_kv), 16) 74 | sliding_window_int = sliding_window if sliding_window is not None else 0 75 | 76 | kernel_paged_attention_2d[ 77 | ( 78 | num_seqs, 79 | num_query_heads, 80 | ) 81 | ]( 82 | output_ptr=output, 83 | query_ptr=query, 84 | key_cache_ptr=key_cache, 85 | value_cache_ptr=value_cache, 86 | block_tables_ptr=block_table, 87 | seq_lens_ptr=seq_lens, 88 | alibi_slopes_ptr=alibi_slopes, 89 | scale=scale, 90 | k_scale=k_scale, 91 | v_scale=v_scale, 92 | num_query_heads=num_query_heads, 93 | num_queries_per_kv=num_queries_per_kv, 94 | num_queries_per_kv_padded=num_queries_per_kv_padded, 95 | block_table_stride=block_table.stride(0), 96 | query_stride_0=query.stride(0), 97 | query_stride_1=query.stride(1), 98 | output_stride_0=output.stride(0), 99 | output_stride_1=output.stride(1), 100 | BLOCK_SIZE=block_size, 101 | HEAD_SIZE=head_size, 102 | HEAD_SIZE_PADDED=next_power_of_2(head_size), 103 | USE_ALIBI_SLOPES=use_alibi_slopes, 104 | SLIDING_WINDOW=sliding_window_int, 105 | x=key_cache.shape[4], 106 | stride_k_cache_0=key_cache.stride(0), 107 | stride_k_cache_1=key_cache.stride(1), 108 | stride_k_cache_2=key_cache.stride(2), 109 | stride_k_cache_3=key_cache.stride(3), 110 | stride_k_cache_4=key_cache.stride(4), 111 | stride_v_cache_0=value_cache.stride(0), 112 | stride_v_cache_1=value_cache.stride(1), 113 | stride_v_cache_2=value_cache.stride(2), 114 | stride_v_cache_3=value_cache.stride(3), 115 | filter_by_query_len=True, 116 | query_start_len_ptr=query_start_loc, 117 | ) 118 | -------------------------------------------------------------------------------- /scripts/callers/triton_3d.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from ibm_triton_lib.kernels.legacy import paged_attention_3d 21 | from ibm_triton_lib.kernels import prefill_flash_attention 22 | from .base import DecodeCaller, PrefillCaller 23 | 24 | 25 | class Triton3dAttentionDecodeCaller(DecodeCaller): 26 | @staticmethod 27 | def make_call_func( 28 | output, 29 | query, 30 | key_cache, 31 | value_cache, 32 | num_seqs, 33 | seq_lens, 34 | max_seq_len, 35 | scale, 36 | block_tables, 37 | alibi_slopes, 38 | kv_cache_dtype, 39 | ): 40 | num_query_heads = query.shape[1] 41 | num_kv_heads = key_cache.shape[1] 42 | block_size = key_cache.shape[3] 43 | num_queries_per_kv = num_query_heads // num_kv_heads 44 | max_num_blocks_per_seq = block_tables.shape[1] 45 | head_size = key_cache.shape[2] 46 | 47 | # Using default kv_scale 48 | k_scale = v_scale = torch.ones(1, device=query.device) 49 | 50 | call_func_under_test = lambda: paged_attention_3d( 51 | output, 52 | query, 53 | key_cache, 54 | value_cache, 55 | scale, 56 | k_scale, 57 | v_scale, 58 | kv_cache_dtype, 59 | block_tables, 60 | seq_lens, 61 | alibi_slopes, 62 | block_size, 63 | num_seqs, 64 | num_query_heads, 65 | num_queries_per_kv, 66 | head_size, 67 | ) 68 | 69 | return call_func_under_test 70 | 71 | 72 | class Triton3dAttentionPrefillCaller(PrefillCaller): 73 | @staticmethod 74 | def make_call_func( 75 | output, 76 | query, 77 | key_cache, 78 | value_cache, 79 | cu_seqlens_q, 80 | cu_seqlens_k, 81 | max_seqlen_q, 82 | max_seqlen_k, 83 | softmax_scale, 84 | causal, 85 | # kv_cache_dtype, # unused 86 | ): 87 | # with varlen 88 | # q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 89 | # k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 90 | # v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 91 | # cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 92 | # of the sequences in the batch, used to index into q. 93 | # cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 94 | # of the sequences in the batch, used to index into kv. 95 | # max_seqlen_q: int. Maximum query sequence length in the batch. 96 | # max_seqlen_k: int. Maximum key sequence length in the batch. 97 | # out: (total, nheads, headdim). 98 | 99 | def call_and_process_output(): 100 | prefill_flash_attention( 101 | q=query, 102 | k=key_cache, 103 | v=value_cache, 104 | causal=causal, 105 | sm_scale=softmax_scale, 106 | max_seqlen_q=max_seqlen_q, 107 | max_seqlen_k=max_seqlen_k, 108 | cu_seqlens_q=cu_seqlens_q, 109 | cu_seqlens_k=cu_seqlens_k, 110 | in_place_output=output, 111 | do_not_return_softmax_encodings=True, 112 | ) 113 | 114 | return call_and_process_output 115 | -------------------------------------------------------------------------------- /scripts/callers/triton_2d.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from ibm_triton_lib.kernels.legacy import ( 21 | paged_attention_2d, 22 | chunked_prefill_paged_decode, 23 | ) 24 | from .base import DecodeCaller, PrefixPrefillCaller 25 | 26 | 27 | class Triton2dAttentionDecodeCaller(DecodeCaller): 28 | @staticmethod 29 | def make_call_func( 30 | output, 31 | query, 32 | key_cache, 33 | value_cache, 34 | num_seqs, 35 | seq_lens, 36 | max_seq_len, 37 | scale, 38 | block_tables, 39 | alibi_slopes, 40 | kv_cache_dtype, 41 | ): 42 | num_query_heads = query.shape[1] 43 | num_kv_heads = key_cache.shape[1] 44 | block_size = key_cache.shape[3] 45 | num_queries_per_kv = num_query_heads // num_kv_heads 46 | max_num_blocks_per_seq = block_tables.shape[1] 47 | head_size = key_cache.shape[2] 48 | 49 | # Using default kv_scale 50 | k_scale = v_scale = torch.ones(1, device=query.device) 51 | 52 | call_func_under_test = lambda: paged_attention_2d( 53 | output, 54 | query, 55 | key_cache, 56 | value_cache, 57 | scale, 58 | k_scale, 59 | v_scale, 60 | kv_cache_dtype, 61 | block_tables, 62 | seq_lens, 63 | alibi_slopes, 64 | block_size, 65 | num_seqs, 66 | num_query_heads, 67 | num_queries_per_kv, 68 | head_size, 69 | ) 70 | 71 | return call_func_under_test 72 | 73 | 74 | class Triton2dChunkedPrefillCaller(PrefixPrefillCaller): 75 | @staticmethod 76 | def make_call_func( 77 | output, 78 | query, 79 | key_cache, 80 | value_cache, 81 | key, 82 | value, 83 | block_tables, 84 | seq_lens, 85 | ctx_lens, 86 | query_lens, 87 | start_loc, 88 | seq_start_loc, 89 | softmax_scale, 90 | # kv_cache_dtype, # unused 91 | ): 92 | """ 93 | query: shape = [num_tokens, num_heads, head_size] 94 | key: shape = [num_tokens, num_kv_heads, head_size] 95 | value: shape = [num_tokens, num_kv_heads, head_size] 96 | k_cache = [num_blocks, block_size, num_kv_heads, head_size] 97 | v_cache = [num_blocks, block_size, num_kv_heads, head_size] 98 | 99 | needs to be converted to 100 | K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] 101 | V_cache[num_blocks, num_kv_heads, head_size, block_size] 102 | 103 | Returns: 104 | shape = [num_tokens, num_heads, head_size] 105 | """ 106 | 107 | max_query_len = max(query_lens) 108 | k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) 109 | 110 | def call_and_process_output(): 111 | return chunked_prefill_paged_decode( 112 | query=query, 113 | key=key, 114 | value=value, 115 | output=output, 116 | kv_cache_dtype="fp16", # TODO 117 | key_cache=key_cache, 118 | value_cache=value_cache, 119 | block_table=block_tables, 120 | query_start_loc=start_loc, 121 | seq_lens=seq_lens, 122 | max_query_len=max_query_len, 123 | k_scale=k_scale, 124 | v_scale=v_scale, 125 | alibi_slopes=None, # TODO 126 | sliding_window=None, # TODO 127 | scale=softmax_scale, 128 | ) 129 | 130 | return call_and_process_output 131 | 132 | @staticmethod 133 | def requires_allocated_output() -> bool: 134 | return True 135 | -------------------------------------------------------------------------------- /scripts/bench_vllm_serve_avg.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | import os 19 | import json 20 | import sys 21 | import torch 22 | from datetime import datetime 23 | 24 | 25 | def create_dir_if_not_exist_recursive(path, mode=0o777): 26 | norm_path = os.path.normpath(path) 27 | paths_l = norm_path.split(os.sep) 28 | path_walked = f"{os.sep}" 29 | for p in paths_l: 30 | if len(p) == 0: 31 | continue 32 | path_walked = os.path.join(path_walked, p) 33 | create_dir_if_not_exist(path_walked, mode) 34 | 35 | 36 | def create_dir_if_not_exist(path, mode=0o777): 37 | if not os.path.exists(path): 38 | os.mkdir(path) 39 | try: 40 | os.chmod(path, mode) 41 | except PermissionError as e: 42 | print(f"can't set permission of directory {path}: {e}") 43 | 44 | 45 | if len(sys.argv) < 5: 46 | print( 47 | f"Usage: {sys.argv[0]} []" 48 | ) 49 | 50 | repitions = int(sys.argv[3]) 51 | gpu_name = torch.cuda.get_device_name().replace(" ", "_").replace("/", "_") 52 | 53 | # model = "/model/llama3.1-8b/instruct/" 54 | model = sys.argv[1] 55 | testcase_name = sys.argv[2] 56 | result_path = os.path.abspath(sys.argv[4]) 57 | port = sys.argv[5] if len(sys.argv) == 6 else "8000" 58 | 59 | # max_rounds = 128 60 | max_rounds = 64 61 | max_num_prompts = 1000 62 | 63 | timestamp_f = datetime.now().strftime("%Y-%m-%d_%H%M") 64 | 65 | # result_dir = f"/results/{model.replace('/','-')}/{gpu_name}/{testcase_name}" 66 | result_dir = f"{result_path}/{model.replace('/','-')}/{gpu_name}/{testcase_name}/exp_{timestamp_f}/" 67 | 68 | # os.system(f"mkdir -p {result_dir}") 69 | create_dir_if_not_exist_recursive(result_dir) 70 | 71 | bench_script = "/workspace/benchmarks/benchmark_serving.py" 72 | if not os.path.isfile(bench_script): 73 | bench_script = "vllm-triton-backend/vllm/benchmarks/benchmark_serving.py" 74 | if not os.path.isfile(bench_script): 75 | print(f"can't find benchmark script benchmark_serving.py") 76 | exit(-1) 77 | 78 | for i in range(repitions): 79 | print(f"====== Repition {i} =====") 80 | cmd = ( 81 | f"VLLM_USE_V1=1 python {bench_script} " 82 | f"--model {model} " 83 | f"--dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json " 84 | f"--save-result --result-dir {result_dir} " 85 | f"--percentile-metrics ttft,tpot,itl,e2el --metric-percentiles 20,50,80,99 " 86 | f"--port {port}" 87 | ) 88 | print(cmd) 89 | rv = os.system(cmd) 90 | if rv != 0: 91 | print(f"benchmark command returned {rv}, stopping...") 92 | exit(rv) 93 | 94 | print(f"results stored in: {result_dir}") 95 | # os.system(f"ls -alh {result_dir}") 96 | 97 | avg_dict = {"avg_total_token_throughput": 0, "avg_ttft": 0, "avg_itl": 0} 98 | 99 | # Assisted by watsonx Code Assistant 100 | for filename in os.listdir(result_dir): 101 | if filename.endswith(".json"): 102 | file_path = os.path.join(result_dir, filename) 103 | with open(file_path, "r") as file: 104 | try: 105 | data = json.load(file) 106 | # print(f"Loaded data from {filename}:") 107 | avg_dict["avg_total_token_throughput"] += data["total_token_throughput"] 108 | # avg_dict["avg_ttft"] += data["mean_ttft_ms"] 109 | avg_dict["avg_ttft"] += data["median_ttft_ms"] 110 | avg_dict["avg_itl"] += data["median_itl_ms"] 111 | except json.JSONDecodeError as e: 112 | print(f"Error decoding JSON from {filename}: {e}") 113 | 114 | avg_dict["avg_total_token_throughput"] /= repitions 115 | avg_dict["avg_ttft"] /= repitions 116 | avg_dict["avg_itl"] /= repitions 117 | 118 | print(f"\nSummary of {repitions} repitions:") 119 | print( 120 | f"Average of total token throughputs: {avg_dict['avg_total_token_throughput']:.2f} tokens/sec" 121 | ) 122 | print(f"Average of Median TTFTs: {avg_dict['avg_ttft']:.2f} ms") 123 | print(f"Average of Median ITLs: {avg_dict['avg_itl']:.2f} ms") 124 | -------------------------------------------------------------------------------- /scripts/callers/baseline_triton.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from third_party.vedantroy_paged_attention import paged_attention_triton_v1 21 | from ibm_triton_lib.kernels.legacy.triton_prefix_prefill import context_attention_fwd 22 | from .base import DecodeCaller, PrefixPrefillCaller 23 | 24 | 25 | class BaselineTritonCaller(DecodeCaller): 26 | @staticmethod 27 | def make_call_func( 28 | output, 29 | query, 30 | key_cache, 31 | value_cache, 32 | num_seqs, 33 | seq_lens, 34 | max_seq_len, 35 | scale, 36 | block_tables, 37 | alibi_slopes, # unused 38 | kv_cache_dtype, # unused 39 | ): 40 | num_query_heads = query.shape[1] 41 | num_kv_heads = key_cache.shape[1] 42 | block_size = key_cache.shape[3] 43 | max_num_blocks_per_seq = block_tables.shape[1] 44 | head_size = key_cache.shape[2] 45 | 46 | scratchpad_key = torch.zeros( 47 | (num_seqs, max_seq_len, num_query_heads, head_size), 48 | dtype=output.dtype, 49 | device=output.device, 50 | ) 51 | scratchpad_value = torch.zeros_like(scratchpad_key) 52 | 53 | call_func_under_test = lambda: paged_attention_triton_v1( 54 | output, 55 | query, 56 | key_cache, 57 | value_cache, 58 | scale, 59 | block_tables, 60 | seq_lens, 61 | block_size, 62 | num_seqs, 63 | seq_lens, 64 | num_query_heads, 65 | max_seq_len, 66 | max_num_blocks_per_seq, 67 | head_size, 68 | num_kv_heads, 69 | scratchpad_key, 70 | scratchpad_value, 71 | ) 72 | 73 | return call_func_under_test 74 | 75 | 76 | class BaselineTritonPrefixPrefillCaller(PrefixPrefillCaller): 77 | @staticmethod 78 | def make_call_func( 79 | output, 80 | query, 81 | key_cache, 82 | value_cache, 83 | key, 84 | value, 85 | block_tables, 86 | seq_lens, 87 | ctx_lens, 88 | query_lens, 89 | start_loc, 90 | seq_start_loc, 91 | softmax_scale, 92 | # kv_cache_dtype, # unused 93 | ): 94 | """ 95 | query: shape = [num_tokens, num_heads, head_size] 96 | key: shape = [num_tokens, num_kv_heads, head_size] 97 | value: shape = [num_tokens, num_kv_heads, head_size] 98 | k_cache = [num_blocks, block_size, num_kv_heads, head_size] 99 | v_cache = [num_blocks, block_size, num_kv_heads, head_size] 100 | 101 | needs to be converted to 102 | K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] 103 | V_cache[num_blocks, num_kv_heads, head_size, block_size] 104 | 105 | Returns: 106 | shape = [num_tokens, num_heads, head_size] 107 | """ 108 | 109 | head_size = key_cache.shape[3] 110 | block_size = key_cache.shape[1] 111 | num_kv_heads = key_cache.shape[2] 112 | num_blocks = key_cache.shape[0] 113 | 114 | max_query_len = max(query_lens) 115 | k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) 116 | 117 | def call_and_process_output(): 118 | return context_attention_fwd( 119 | q=query, 120 | k=key, 121 | v=value, 122 | o=output, 123 | kv_cache_dtype="fp16", # TODO 124 | k_cache=key_cache, 125 | v_cache=value_cache, 126 | b_loc=block_tables, 127 | b_start_loc=start_loc, 128 | b_seq_len=seq_lens, 129 | # b_ctx_len=ctx_lens, # FIXME: only in v0.7.3, not in main 130 | max_input_len=max_query_len, 131 | k_scale=k_scale, 132 | v_scale=v_scale, 133 | alibi_slopes=None, # TODO 134 | sliding_window=None, # TODO 135 | sm_scale=softmax_scale, 136 | ) 137 | 138 | return call_and_process_output 139 | 140 | @staticmethod 141 | def requires_allocated_output() -> bool: 142 | return True 143 | -------------------------------------------------------------------------------- /scripts/callers/unified_triton.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | import torch 19 | 20 | from ibm_triton_lib.kernels import unified_attention 21 | from .base import PrefixPrefillCaller 22 | 23 | 24 | class UnifiedTriton3dAttentionCaller(PrefixPrefillCaller): 25 | @staticmethod 26 | def make_call_func( 27 | output, 28 | query, 29 | key_cache, 30 | value_cache, 31 | key, 32 | value, 33 | block_tables, 34 | seq_lens, 35 | ctx_lens, 36 | query_lens, 37 | start_loc, 38 | seq_start_loc, 39 | softmax_scale, 40 | # kv_cache_dtype, # unused 41 | force_selection=3, 42 | ): 43 | """ 44 | query: shape = [num_tokens, num_heads, head_size] 45 | key: shape = [num_tokens, num_kv_heads, head_size] 46 | value: shape = [num_tokens, num_kv_heads, head_size] 47 | k_cache = [num_blocks, block_size, num_kv_heads, head_size] 48 | v_cache = [num_blocks, block_size, num_kv_heads, head_size] 49 | Returns: 50 | shape = [num_tokens, num_heads, head_size] 51 | """ 52 | 53 | max_query_len = query_lens.max() 54 | max_seqlen = seq_lens.max() 55 | 56 | avg_seqlen_q = query_lens.to(torch.float).mean() 57 | avg_seqlen_k = seq_lens.to(torch.float).mean() 58 | 59 | def call_and_process_output(): 60 | # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) 61 | return unified_attention( 62 | q=query, 63 | k=key_cache, 64 | v=value_cache, 65 | out=output, 66 | cu_seqlens_q=start_loc, 67 | max_seqlen_q=max_query_len, 68 | seqused_k=seq_lens, 69 | max_seqlen_k=max_seqlen, 70 | softmax_scale=softmax_scale, 71 | causal=True, 72 | window_size=(-1, -1), 73 | block_table=block_tables, 74 | softcap=0, 75 | q_descale=None, 76 | k_descale=None, # TODO? 77 | v_descale=None, # TODO? 78 | alibi_slopes=None, 79 | avg_seqlen_q=avg_seqlen_q, 80 | avg_seqlen_k=avg_seqlen_k, 81 | force_selection=force_selection, 82 | ) 83 | 84 | return call_and_process_output 85 | 86 | @staticmethod 87 | def requires_allocated_output() -> bool: 88 | return True 89 | 90 | 91 | class UnifiedTriton2dAttentionCaller(UnifiedTriton3dAttentionCaller): 92 | @staticmethod 93 | def make_call_func( 94 | output, 95 | query, 96 | key_cache, 97 | value_cache, 98 | key, 99 | value, 100 | block_tables, 101 | seq_lens, 102 | ctx_lens, 103 | query_lens, 104 | start_loc, 105 | seq_start_loc, 106 | softmax_scale, 107 | # kv_cache_dtype, # unused 108 | force_selection=3, 109 | ): 110 | 111 | return UnifiedTriton3dAttentionCaller.make_call_func( 112 | output, 113 | query, 114 | key_cache, 115 | value_cache, 116 | key, 117 | value, 118 | block_tables, 119 | seq_lens, 120 | ctx_lens, 121 | query_lens, 122 | start_loc, 123 | seq_start_loc, 124 | softmax_scale, 125 | force_selection=2, 126 | ) 127 | 128 | 129 | class UnifiedTritonAutoAttentionCaller(UnifiedTriton3dAttentionCaller): 130 | @staticmethod 131 | def make_call_func( 132 | output, 133 | query, 134 | key_cache, 135 | value_cache, 136 | key, 137 | value, 138 | block_tables, 139 | seq_lens, 140 | ctx_lens, 141 | query_lens, 142 | start_loc, 143 | seq_start_loc, 144 | softmax_scale, 145 | # kv_cache_dtype, # unused 146 | force_selection=3, 147 | ): 148 | 149 | return UnifiedTriton3dAttentionCaller.make_call_func( 150 | output, 151 | query, 152 | key_cache, 153 | value_cache, 154 | key, 155 | value, 156 | block_tables, 157 | seq_lens, 158 | ctx_lens, 159 | query_lens, 160 | start_loc, 161 | seq_start_loc, 162 | softmax_scale, 163 | force_selection=None, 164 | ) # none triggers vllm default behaviour 165 | -------------------------------------------------------------------------------- /scripts/callers/fused_triton.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from ibm_triton_lib.kernels.legacy import fused_chunked_prefill_paged_decode_25d 21 | from .base import PrefixPrefillCaller, DecodeCaller 22 | 23 | 24 | class FusedTritonDecodeOnlyCaller(DecodeCaller): 25 | @staticmethod 26 | def make_call_func( 27 | output, 28 | query, 29 | key_cache, 30 | value_cache, 31 | num_seqs, 32 | seq_lens, 33 | max_seq_len, 34 | scale, 35 | block_tables, 36 | alibi_slopes, 37 | kv_cache_dtype, 38 | ): 39 | """ 40 | query: shape = [num_tokens, num_heads, head_size] 41 | key: shape = [num_tokens, num_kv_heads, head_size] 42 | value: shape = [num_tokens, num_kv_heads, head_size] 43 | kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] 44 | """ 45 | 46 | query_lens = [1] * num_seqs 47 | b_query_lens = torch.tensor(query_lens, dtype=torch.int) 48 | b_start_loc = torch.cumsum( 49 | torch.tensor([0] + query_lens, dtype=torch.int), dim=0, dtype=torch.int 50 | ) 51 | 52 | max_query_len = query_lens.max() 53 | # print(query.shape) 54 | # print(key_cache.shape) 55 | # print(value_cache.shape) 56 | k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) 57 | 58 | def call_and_process_output(): 59 | return fused_chunked_prefill_paged_decode_25d( 60 | query=query, 61 | key=key_cache, # would break, just here for benchmarking 62 | value=value_cache, # would break, just here for benchmarking 63 | output=output, 64 | kv_cache_dtype=kv_cache_dtype, # TODO 65 | key_cache=key_cache, 66 | value_cache=value_cache, 67 | block_table=block_tables, 68 | query_start_loc=b_start_loc, 69 | seq_lens=seq_lens, 70 | max_query_len=max_query_len, 71 | k_scale=k_scale, 72 | v_scale=v_scale, 73 | alibi_slopes=None, # TODO 74 | sliding_window=None, # TODO 75 | sm_scale=1.0, # would break, just here for benchmarking 76 | ) 77 | 78 | return call_and_process_output 79 | 80 | @staticmethod 81 | def requires_allocated_output() -> bool: 82 | return True 83 | 84 | 85 | class FusedTritonChunkedPrefixPrefill25dCaller(PrefixPrefillCaller): 86 | @staticmethod 87 | def make_call_func( 88 | output, 89 | query, 90 | key_cache, 91 | value_cache, 92 | key, 93 | value, 94 | block_tables, 95 | seq_lens, 96 | ctx_lens, 97 | query_lens, 98 | start_loc, 99 | seq_start_loc, 100 | softmax_scale, 101 | # kv_cache_dtype, # unused 102 | ): 103 | """ 104 | query: shape = [num_tokens, num_heads, head_size] 105 | key: shape = [num_tokens, num_kv_heads, head_size] 106 | value: shape = [num_tokens, num_kv_heads, head_size] 107 | k_cache = [num_blocks, block_size, num_kv_heads, head_size] 108 | v_cache = [num_blocks, block_size, num_kv_heads, head_size] 109 | 110 | needs to be converted to 111 | K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] 112 | V_cache[num_blocks, num_kv_heads, head_size, block_size] 113 | 114 | Returns: 115 | shape = [num_tokens, num_heads, head_size] 116 | """ 117 | head_size = key_cache.shape[3] 118 | block_size = key_cache.shape[1] 119 | num_kv_heads = key_cache.shape[2] 120 | 121 | max_query_len = query_lens.max() 122 | print(start_loc) 123 | k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) 124 | 125 | def call_and_process_output(): 126 | return fused_chunked_prefill_paged_decode_25d( 127 | query=query, 128 | key=key, 129 | value=value, 130 | output=output, 131 | kv_cache_dtype="fp16", # TODO 132 | key_cache=key_cache, 133 | value_cache=value_cache, 134 | block_table=block_tables, 135 | query_start_loc=start_loc, 136 | seq_lens=seq_lens, 137 | max_query_len=max_query_len, 138 | k_scale=k_scale, 139 | v_scale=v_scale, 140 | alibi_slopes=None, # TODO 141 | sliding_window=None, # TODO 142 | sm_scale=softmax_scale, 143 | ) 144 | 145 | return call_and_process_output 146 | 147 | @staticmethod 148 | def requires_allocated_output() -> bool: 149 | return True 150 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/utils/triton_utils.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | # /******************************************************************************* 19 | # * Copyright 2024 IBM Corporation 20 | # * 21 | # * Licensed under the Apache License, Version 2.0 (the "License"); 22 | # * you may not use this file except in compliance with the License. 23 | # * You may obtain a copy of the License at 24 | # * 25 | # * http://www.apache.org/licenses/LICENSE-2.0 26 | # * 27 | # * Unless required by applicable law or agreed to in writing, software 28 | # * distributed under the License is distributed on an "AS IS" BASIS, 29 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 30 | # * See the License for the specific language governing permissions and 31 | # * limitations under the License. 32 | # *******************************************************************************/ 33 | # 34 | 35 | ########################################## 36 | # Some utilities for working with triton # 37 | ########################################## 38 | 39 | from __future__ import annotations 40 | 41 | import builtins 42 | import sys 43 | import os 44 | import time 45 | import random 46 | import string 47 | import torch 48 | import triton 49 | 50 | 51 | # from https://github.com/triton-lang/triton/blob/main/third_party/proton/tutorials/matmul.py (also apache 2.0) 52 | def unpack_grid(grid): 53 | if len(grid) == 1: 54 | return grid[0], 1, 1 55 | if len(grid) == 2: 56 | return grid[0], grid[1], 1 57 | if len(grid) == 3: 58 | return grid[0], grid[1], grid[2] 59 | 60 | 61 | cuda_version = None 62 | rocm_version = None 63 | flag_print_debug = False 64 | 65 | 66 | def _get_cuda_version(): 67 | """ 68 | Get CUDA runtime/driver version (i.e. which ptxas is used). 69 | This version is often different from the cuda version pytorch uses internally. 70 | 71 | Based on https://github.com/triton-lang/triton/blob/9d6736a501d0499348d48d192b6260338ca19da0/third_party/nvidia/backend/compiler.py#L32-L37 72 | """ 73 | global cuda_version 74 | if cuda_version is not None: 75 | return cuda_version 76 | if "_TRITON_DEJAVU_DETERMINED_CUDA_VERSION" in os.environ: 77 | cuda_version = os.environ["_TRITON_DEJAVU_DETERMINED_CUDA_VERSION"] 78 | return cuda_version 79 | try: 80 | import subprocess 81 | import re 82 | 83 | triton_backend_dir = os.path.dirname(triton.backends.__file__) 84 | ptxas_path = os.path.abspath( 85 | os.path.join(triton_backend_dir, "nvidia/bin/ptxas") 86 | ) 87 | 88 | result = subprocess.check_output( 89 | [ptxas_path, "--version"], stderr=subprocess.STDOUT 90 | ) 91 | version = re.search( 92 | r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE 93 | ) 94 | cuda_version = version.group(1) 95 | except Exception as e: 96 | if flag_print_debug: 97 | print( 98 | f"[triton-dejavu] determining cuda version failed with: {e}\n" 99 | f"using torch.version.cuda as fallback" 100 | ) 101 | cuda_version = f"torch_{torch.version.cuda}" 102 | os.environ["_TRITON_DEJAVU_DETERMINED_CUDA_VERSION"] = cuda_version 103 | return cuda_version 104 | 105 | 106 | def _get_rocm_version(): 107 | """ 108 | Get ROCM runtime/driver version (i.e. which rocm linker is used). 109 | This version is often different from the rocm version pytorch uses internally. 110 | """ 111 | global rocm_version 112 | if rocm_version is not None: 113 | return rocm_version 114 | if "_TRITON_DEJAVU_DETERMINED_ROCM_VERSION" in os.environ: 115 | rocm_version = os.environ["_TRITON_DEJAVU_DETERMINED_ROCM_VERSION"] 116 | return rocm_version 117 | try: 118 | import subprocess 119 | import re 120 | 121 | rocm_ldd_path = triton.backends.backends["amd"].compiler.path_to_rocm_lld() 122 | rocm_dir = os.path.dirname(rocm_ldd_path) 123 | amdgpu_arch_path = os.path.abspath(os.path.join(rocm_dir, "amdgpu-arch")) 124 | 125 | result = subprocess.check_output( 126 | [amdgpu_arch_path, "--version"], 127 | stderr=subprocess.STDOUT, 128 | ) 129 | version = re.search( 130 | r".*roc-(\d+\.\d+.\d+).*", result.decode("utf-8"), flags=re.MULTILINE 131 | ) 132 | rocm_version = version.group(1) 133 | except Exception as e: 134 | if flag_print_debug: 135 | print( 136 | f"[triton-dejavu] determining rocm version failed with: {e}\n" 137 | f"using torch.version.hip as fallback" 138 | ) 139 | rocm_version = f"torch_{torch.version.hip}" 140 | os.environ["_TRITON_DEJAVU_DETERMINED_ROCM_VERSION"] = rocm_version 141 | return rocm_version 142 | 143 | 144 | def get_runtime_label(): 145 | if torch.version.hip: 146 | return f"rocm_{_get_rocm_version()}" 147 | return f"cuda_{_get_cuda_version()}" 148 | -------------------------------------------------------------------------------- /scripts/callers/pytorch_native.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | from .base import PrefillCaller 21 | 22 | 23 | # based on https://github.com/pytorch/pytorch/blob/6055a4f612782ca944f2e0465f7497b7f18de4e9/torch/nn/functional.py#L5732 24 | def scaled_dot_product_attention( 25 | query, 26 | key, 27 | value, 28 | scale_factor, 29 | attn_mask=None, 30 | dropout_p=0.0, 31 | is_causal=False, 32 | enable_gqa=False, 33 | ) -> torch.Tensor: 34 | L, S = query.size(-2), key.size(-2) 35 | attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) 36 | if is_causal: 37 | assert attn_mask is None 38 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) 39 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 40 | attn_bias.to(query.dtype) 41 | 42 | if attn_mask is not None: 43 | if attn_mask.dtype == torch.bool: 44 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 45 | else: 46 | attn_bias = attn_mask + attn_bias 47 | 48 | if enable_gqa: 49 | key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) 50 | value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) 51 | 52 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 53 | attn_weight += attn_bias 54 | attn_weight = torch.softmax(attn_weight, dim=-1) # .to(query.dtype) 55 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 56 | return attn_weight @ value 57 | 58 | 59 | class PytorchNativeAttentionPrefillCaller(PrefillCaller): 60 | @staticmethod 61 | def make_call_func( 62 | output, 63 | query, 64 | key_cache, 65 | value_cache, 66 | cu_seqlens_q, 67 | cu_seqlens_k, 68 | max_seqlen_q, 69 | max_seqlen_k, 70 | softmax_scale, 71 | causal, 72 | # kv_cache_dtype, # unused 73 | ): 74 | # with varlen 75 | # q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 76 | # k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 77 | # v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 78 | # cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 79 | # of the sequences in the batch, used to index into q. 80 | # cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 81 | # of the sequences in the batch, used to index into kv. 82 | # max_seqlen_q: int. Maximum query sequence length in the batch. 83 | # max_seqlen_k: int. Maximum key sequence length in the batch. 84 | # out: (total, nheads, headdim). 85 | 86 | print(query.shape) 87 | print(key_cache.shape) 88 | 89 | num_query_heads = query.shape[1] 90 | num_kv_heads = key_cache.shape[1] 91 | head_size = key_cache.shape[2] 92 | dtype = value_cache.dtype 93 | tdevice = value_cache.device 94 | 95 | num_queries_per_kv = num_query_heads // num_kv_heads 96 | enable_gqa = False 97 | if num_queries_per_kv > 1: 98 | # Handle MQA and GQA 99 | enable_gqa = True 100 | # key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) 101 | # value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) 102 | 103 | num_seqs = len(cu_seqlens_q) - 1 104 | q_len = max_seqlen_q 105 | max_seq_len = max_seqlen_k 106 | 107 | # print(max_seqlen_q) 108 | # print(max_seqlen_k) 109 | 110 | query_torch = torch.empty( 111 | num_seqs, num_query_heads, q_len, head_size, dtype=dtype, device=tdevice 112 | ) 113 | key_torch = torch.empty( 114 | num_seqs, num_kv_heads, max_seq_len, head_size, dtype=dtype, device=tdevice 115 | ) 116 | value_torch = torch.empty( 117 | num_seqs, num_kv_heads, max_seq_len, head_size, dtype=dtype, device=tdevice 118 | ) 119 | 120 | # print(query_torch.shape) 121 | # print(key_torch.shape) 122 | 123 | for i in range(num_seqs): 124 | start_idx = cu_seqlens_q[i] 125 | end_idx = cu_seqlens_q[i + 1] 126 | seq_len = end_idx - start_idx 127 | # print(f"{start_idx} to {end_idx} ({seq_len} tokens)") 128 | 129 | query_torch[i].copy_(query[start_idx:end_idx].transpose(0, 1)) 130 | key_torch[i].copy_(key_cache[start_idx:end_idx].transpose(0, 1)) 131 | value_torch[i].copy_(value_cache[start_idx:end_idx].transpose(0, 1)) 132 | # TODO: fill with 0? 133 | # no, compare varlen with it would be unfair, IMHO 134 | 135 | def call_and_process_output(): 136 | return scaled_dot_product_attention( 137 | query=query_torch, 138 | key=key_torch, 139 | value=value_torch, 140 | is_causal=causal, 141 | enable_gqa=enable_gqa, 142 | scale_factor=softmax_scale, 143 | ) 144 | 145 | return call_and_process_output 146 | 147 | @classmethod 148 | def select_output(cls, x, y): 149 | # in: (num_seqs, num_query_heads, q_len, head_size) 150 | # out: (total, nheads, headdim) 151 | print(y.shape) 152 | num_seqs, num_query_heads, q_len, head_size = y.shape 153 | out = y.transpose(1, 2).reshape(-1, num_query_heads, head_size) 154 | print(out.shape) 155 | # return y.squeeze(1) 156 | return out 157 | 158 | @staticmethod 159 | def requires_allocated_output() -> bool: 160 | return False 161 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/legacy/fused_gqa_paged/sb_jit_func.py: -------------------------------------------------------------------------------- 1 | import triton.language as tl 2 | import triton 3 | 4 | # Some of the functions in this file are adapted from Shawn Tan's stickbreaking-attention repo 5 | # https://github.com/shawntan/stickbreaking-attention/blob/seq-aligned-folded/stickbreaking_attention/sb_varlen/softplus.py 6 | 7 | 8 | def _generate_asm(num_pack): 9 | template = """ 10 | .reg .pred p; 11 | setp.gt.f32 p, ${in_reg}, 15.; 12 | @p mov.f32 ${out_reg}, ${in_reg}; 13 | @!p ex2.approx.ftz.f32 ${out_reg}, ${in_reg}; 14 | @!p add.f32 ${out_reg}, ${out_reg}, 1.0; 15 | @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; 16 | """ 17 | out_str = "" 18 | 19 | for i in range(num_pack): 20 | inner_str = template.format(out_reg=i, in_reg=i + num_pack) 21 | out_str += "{" + inner_str + "}\n" 22 | # flatten out because torch.compile doesn't like newlines 23 | out_str = " ".join(out_str.split("\n")) 24 | return out_str 25 | 26 | 27 | def _generate_constraints(num_pack): 28 | return ( 29 | ",".join("=r" for i in range(num_pack)) 30 | + "," 31 | + ",".join("r" for i in range(num_pack)) 32 | ) 33 | 34 | 35 | NUM_REG: tl.constexpr = 1 36 | asm_str: tl.constexpr = _generate_asm(NUM_REG) 37 | constraints_str: tl.constexpr = _generate_constraints(NUM_REG) 38 | 39 | 40 | @triton.jit 41 | def softplus(x, is_compiling: tl.constexpr = False): 42 | if is_compiling: 43 | tl.static_print("Using triton softplus.") 44 | out = tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) 45 | return out 46 | else: 47 | out = tl.inline_asm_elementwise( 48 | asm=asm_str, 49 | constraints=constraints_str, 50 | pack=NUM_REG, 51 | args=[ 52 | x, 53 | ], 54 | dtype=tl.float32, 55 | is_pure=True, 56 | ) 57 | return out 58 | 59 | 60 | @triton.jit 61 | def cumsum(x, block_range=None, USE_DOT_CUMSUM: tl.constexpr = False): 62 | if USE_DOT_CUMSUM: 63 | cm = tl.where( 64 | block_range[:, None] >= block_range[None, :], 1.0, 0.0 65 | ) # lower triangular matrix 66 | return tl.dot(x, cm) 67 | else: 68 | return tl.cumsum(x, axis=1, reverse=True) 69 | 70 | 71 | @triton.jit 72 | def get_split_tblocks_range(split_idx, kv_len, BLOCK_T, num_splits): 73 | num_tblocks = (kv_len + BLOCK_T - 1) // BLOCK_T 74 | tblock_start = (split_idx * num_tblocks) // num_splits 75 | tblock_end = ((split_idx + 1) * num_tblocks) // num_splits 76 | return tblock_start, tblock_end 77 | 78 | 79 | @triton.jit 80 | def attend_one_block( 81 | q, 82 | k, 83 | v, 84 | qk_scale, 85 | m_i, 86 | d_i, 87 | acc, 88 | alibi_slopes, # [BLOCK_SS,] 89 | alibi_distances, # [BLOCK_T,] 90 | IS_LAST_BLOCK, # on_band == IS_LAST_BLOCK, dynamic 91 | tb_len_max, # Number of tokens along page size (token) dimension. 0 < t_len <= BLOCK_T and it is a dynamic value 92 | offs_t: tl.constexpr, 93 | FORCE_FP16_PV: tl.constexpr, 94 | QUANTIZE_P: tl.constexpr, 95 | MAX_FP8: tl.constexpr, 96 | IS_STICKBREAKING: tl.constexpr, 97 | USE_DOT_CUMSUM: tl.constexpr, 98 | TRANSPOSED: tl.constexpr, 99 | USE_ALIBI_SLOPES: tl.constexpr, 100 | ATTEND_CURRENT: tl.constexpr, 101 | ): 102 | kv_len_dim: tl.constexpr = 1 if not TRANSPOSED else 0 # seqlen dimension 103 | 104 | # Compute logits 105 | if not TRANSPOSED: 106 | k = k.T # [D, BLOCK_T] 107 | logits = tl.dot(q, k, out_dtype=tl.float32) # [BLOCK_SS, BLOCK_T] 108 | else: 109 | q = q.T # [D, BLOCK_SS] 110 | logits = tl.dot(k, q, out_dtype=tl.float32) # [BLOCK_T, BLOCK_SS] 111 | 112 | logits *= qk_scale # scale qk after mma 113 | 114 | if USE_ALIBI_SLOPES: 115 | alibi_biases = ( 116 | alibi_slopes[:, None] * alibi_distances[None, :] 117 | ) # [BLOCK_SS, BLOCK_T] 118 | logits += alibi_biases if not TRANSPOSED else alibi_biases.T 119 | 120 | # Handle on band block special case 121 | t_mask = offs_t < tb_len_max 122 | if IS_LAST_BLOCK: 123 | if not IS_STICKBREAKING: 124 | t_mask_logits = t_mask[None, :] if not TRANSPOSED else t_mask[:, None] 125 | logits += tl.where(t_mask_logits, 0.0, float("-inf")) 126 | else: 127 | # v = tl.where(t_mask[:, None], v, 0.0) 128 | t_mask = offs_t < (tb_len_max if ATTEND_CURRENT else (tb_len_max - 1)) 129 | 130 | if not IS_STICKBREAKING: # regular softmax 131 | # -- compute scaling constant -- 132 | m_i_new = tl.maximum( 133 | m_i, tl.max(logits, axis=kv_len_dim) 134 | ) # fp32, new max computation 135 | 136 | alpha = tl.math.exp2(m_i - m_i_new) # fp32, S4 (subtract new max from old max) 137 | p = tl.math.exp2( 138 | logits 139 | - ( 140 | m_i_new[:, None] # fp32, subtract current max # [BLOCK_SS, BLOCK_T] 141 | if not TRANSPOSED 142 | else m_i_new[None, :] 143 | ) 144 | ) 145 | 146 | # -- scale numerator --- 147 | acc *= alpha[:, None] if not TRANSPOSED else alpha[None, :] # fp32 elmentwise 148 | # --- update m_i (max) and d_i (denominator) -- 149 | m_i = m_i_new # S2 150 | d_i = d_i * alpha + tl.sum(p, axis=kv_len_dim) # S3 151 | else: # stickbreaking attention 152 | # computations in log space 153 | log_om_beta = -softplus( 154 | logits, 155 | ) # [BLOCK_SS, BLOCK_T] or [BLOCK_T, BLOCK_SS] 156 | 157 | if TRANSPOSED: 158 | log_om_beta = log_om_beta.T # [BLOCK_SS, BLOCK_T] 159 | logits = logits.T 160 | 161 | if IS_LAST_BLOCK: # on_band 162 | log_om_beta = tl.where(t_mask[None, :], log_om_beta, 0.0) 163 | 164 | log_p = logits + d_i[:, None] # [BLOCK_SS, BLOCK_T] # d_i is neg_log_acc 165 | d_i += tl.sum(log_om_beta, axis=1) # [BLOCK_SS] 166 | log_p += cumsum(log_om_beta, block_range=offs_t, USE_DOT_CUMSUM=USE_DOT_CUMSUM) 167 | 168 | # non-log space 169 | p = tl.math.exp2(log_p) # [BLOCK_SS, BLOCK_T] 170 | 171 | if IS_LAST_BLOCK: # on_band 172 | p = tl.where(t_mask[None, :], p, 0.0) # set masked elements to 0 173 | 174 | if TRANSPOSED: 175 | p = p.T # [BLOCK_T, BLOCK_SS] 176 | 177 | p_scale = 1.0 178 | if FORCE_FP16_PV: 179 | # force fp16 for the 2nd bmm 180 | v = v.to(tl.float16) 181 | else: 182 | # align p with v.dtype for the 2nd bmm 183 | if QUANTIZE_P and v.dtype == tl.float8e4nv: 184 | tl.static_assert( 185 | not IS_STICKBREAKING 186 | ) # in stickbreaking p tensor values can become too small 187 | # --- dynamic quantization of p --- 188 | p_max = tl.max(tl.abs(p), axis=kv_len_dim, keep_dims=True) 189 | p_scale = p_max / MAX_FP8 190 | p_invs_scale = 1.0 / p_scale 191 | p = p * p_invs_scale # fp32 192 | p = p.to(v.dtype) 193 | 194 | if not TRANSPOSED: 195 | acc += tl.dot(p, v, out_dtype=tl.float32) * p_scale # [BLOCK_SS, D] 196 | else: 197 | acc += tl.dot(v.T, p, out_dtype=tl.float32) * p_scale # [D, BLOCK_SS] 198 | 199 | return m_i, d_i, acc 200 | -------------------------------------------------------------------------------- /scripts/profile_and_bench.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import pytest 5 | import torch 6 | import triton 7 | import pandas as pd 8 | import numpy as np 9 | import math 10 | import random 11 | from datetime import datetime 12 | from enum import Enum 13 | import itertools 14 | 15 | 16 | from benchmark import ( 17 | test_decode_attention, 18 | create_dir_if_not_exist, 19 | create_dir_if_not_exist_recursive, 20 | write_df_and_chmod, 21 | get_runtime_label, 22 | Implementation, 23 | BenchmarkMode, 24 | impl_translate, 25 | get_gpu_label, 26 | method_translate, 27 | ) 28 | 29 | import re 30 | import string 31 | 32 | 33 | # from https://github.com/pytorch/pytorch/issues/121219#issuecomment-2722329465 34 | def clean_names_in_json(input_filename, output_filename): 35 | """ 36 | Cleans the "name" fields in a JSON file by replacing non-ASCII characters with 'x' 37 | and removing internal quotation marks. 38 | 39 | Example of problematic input: 40 | { 41 | "name": "@"�sP(0): flat_tensor" 42 | } 43 | """ 44 | with open(input_filename, "r", encoding="utf-8", errors="replace") as file: 45 | content = file.read() 46 | 47 | # Decode Unicode escape sequences 48 | content = content.encode().decode("unicode_escape") 49 | 50 | # Regex to find "name": "" 51 | def replace_non_ascii_and_quotes(match): 52 | name = match.group(1) 53 | visible_printable = "".join( 54 | c for c in string.printable if c not in "\t\n\r\x0b\x0c}{" 55 | ) 56 | cleaned_name = "".join(c if c in visible_printable else "x" for c in name) 57 | cleaned_name = cleaned_name.replace('"', "y") # Replace internal quotes 58 | return f'"name": "{cleaned_name}"' 59 | 60 | # Apply regex to clean names 61 | cleaned_content = re.sub( 62 | r'"name": "([\s\S]*?)"(?=, |\}|\s*\})', 63 | replace_non_ascii_and_quotes, 64 | content, 65 | flags=re.DOTALL, 66 | ) 67 | 68 | # Write the cleaned JSON data to a new file 69 | with open(output_filename, "w", encoding="utf-8") as outfile: 70 | outfile.write(cleaned_content) 71 | 72 | 73 | device = "cuda:0" 74 | gpu_name = get_gpu_label() 75 | 76 | do_benchmarks = True 77 | quantiles = [0.5, 0.2, 0.8] 78 | debug_flag = os.getenv("TRITON_BACKEND_DEBUG") == "1" 79 | 80 | 81 | # DTYPES = [torch.half, torch.bfloat16, torch.float] 82 | DTYPES = [torch.float16] 83 | SEEDS = [0] 84 | MAX_VALUES = [1.0] 85 | STORE_TEST_RESULT_PATH = os.environ.get("STORE_TEST_RESULT_PATH", None) 86 | # HEAD_SIZES_FLASH = [32, 64, 128] # only powers of 2! 87 | HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 88 | # head_size * head_numbers = hidden_size 89 | 90 | # order: num_query_heads, num_kv_heads 91 | # NUM_HEADS = [(32, 32), (32, 8)] 92 | NUM_HEADS = [(32, 8)] 93 | 94 | # BLOCK_SIZES = [8, 16, 32] 95 | BLOCK_SIZES = [16] 96 | # NUM_BLOCKS = [8, 16, 32] 97 | NUM_BLOCKS = [4321] # "arbitrary values for testing..." 98 | 99 | # options most likely not used...but keep for now? 100 | CAUSAL_FLASH = [True] # vLLM only needs causal=True 101 | 102 | PROMPT_PATTERNS = [] 103 | # PROMPT_PATTERNS.append([1.0]) 104 | # PROMPT_PATTERNS.append([1.0, 0.4, 0.5, 1.0, 0.2]) 105 | PROMPT_PATTERNS.append([0.1, 0.4, 0.5, 1.0, 0.2]) 106 | 107 | BATCH_SIZES = [4] 108 | SEQUENCE_LENGTHS = [128] 109 | 110 | 111 | MY_IUT = [ 112 | e for e in os.environ.get("MY_IUT", "").split(",") if len(e) > 0 113 | ] # my implementations under test (IUT) 114 | MY_METHODS = [e for e in os.environ.get("MY_METHODS", "").split(",") if len(e) > 0] 115 | 116 | if len(MY_IUT) > 0: 117 | IMPLEMENTATION_UT = [] 118 | for ci_value in MY_IUT: 119 | IMPLEMENTATION_UT.append(Implementation(impl_translate[ci_value])) 120 | if len(MY_METHODS) > 0: 121 | BENCHMARK_MODES = [] 122 | for cb_value in MY_METHODS: 123 | BENCHMARK_MODES.append(BenchmarkMode(method_translate[cb_value])) 124 | 125 | 126 | if __name__ == "__main__": 127 | if os.environ.get("TRITON_BACKEND_PDB", "0") == "1": 128 | import debugpy 129 | 130 | host_addr = os.environ.get("TRITON_BACKEND_DEBUG_ADDR", "0.0.0.0") 131 | pdb_port = int(os.environ.get("TRITON_BACKEND_DEBUG_PORT", "5679")) 132 | debugpy.listen((host_addr, pdb_port)) 133 | print(f"[debugpy] listening at {host_addr}:{pdb_port}; wait for client...\n") 134 | debugpy.wait_for_client() 135 | 136 | cuda_version = get_runtime_label() 137 | print( 138 | f"\nRunning on {gpu_name} with Triton {triton.__version__} using {cuda_version}...\n" 139 | ) 140 | 141 | print( 142 | f"Test setup:\n\tIMPLEMENATION_UT: {IMPLEMENTATION_UT}\n\tBENCHMARK_MODES: {BENCHMARK_MODES}" 143 | ) 144 | 145 | global_pds = {} 146 | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 147 | 148 | if STORE_TEST_RESULT_PATH is not None: 149 | gpu_path = os.path.join(STORE_TEST_RESULT_PATH, gpu_name) 150 | gloabl_pd_file_prefix = os.path.join(gpu_path, timestamp) 151 | create_dir_if_not_exist_recursive(gloabl_pd_file_prefix) 152 | else: 153 | print("STORE_TEST_RESULT_PATH is not set; results will not be saved") 154 | gloabl_pd_file_prefix = None 155 | 156 | global_pds = {} 157 | start_time = datetime.now() 158 | 159 | for bench_m in BENCHMARK_MODES: 160 | for impl in IMPLEMENTATION_UT: 161 | prof_filename = f"{gloabl_pd_file_prefix}/trace_{bench_m}-{impl}.json" 162 | test_decode_attention( 163 | None, 164 | None, 165 | BATCH_SIZES[0], 166 | NUM_HEADS[0], 167 | SEQUENCE_LENGTHS[0], 168 | HEAD_SIZES[0], 169 | BLOCK_SIZES[0], 170 | NUM_BLOCKS[0], 171 | PROMPT_PATTERNS[0], 172 | DTYPES[0], 173 | SEEDS[0], 174 | impl, 175 | MAX_VALUES[0], 176 | bench_m, 177 | overwrite_df=global_pds, 178 | df_file_prefix=gloabl_pd_file_prefix, 179 | torch_profiling=True, 180 | prof_filename=f"{prof_filename}-broken", 181 | ) 182 | clean_names_in_json(f"{prof_filename}-broken", prof_filename) 183 | print(f"profile stored in: {os.path.abspath(prof_filename)}") 184 | 185 | end_time = datetime.now() 186 | duration = end_time - start_time 187 | 188 | # Dump final results 189 | for test, df in global_pds.items(): 190 | if len(df) <= 20: 191 | print( 192 | f"\nPerformance results of test {test} (only tests without numerical error and with valid shapes, etc.):" 193 | ) 194 | print(df.to_string()) 195 | 196 | if STORE_TEST_RESULT_PATH is not None: 197 | for test, df in global_pds.items(): 198 | filename = os.path.abspath(f"{gloabl_pd_file_prefix}/{test}_final.csv") 199 | write_df_and_chmod(df, filename) 200 | print(f"(stored in {filename})") 201 | print(f"Torch profile traces stored in {gloabl_pd_file_prefix}/.") 202 | 203 | print( 204 | f"\nThis test used triton version: {triton.__version__}\n" 205 | f"This test was executed on: {gpu_name}\n" 206 | f"This test used: {cuda_version}\n" 207 | f"This test took: {duration}" 208 | ) 209 | -------------------------------------------------------------------------------- /scripts/callers/flash_attn.py: -------------------------------------------------------------------------------- 1 | # /******************************************************************************* 2 | # * Copyright 2025 IBM Corporation 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # * you may not use this file except in compliance with the License. 6 | # * You may obtain a copy of the License at 7 | # * 8 | # * http://www.apache.org/licenses/LICENSE-2.0 9 | # * 10 | # * Unless required by applicable law or agreed to in writing, software 11 | # * distributed under the License is distributed on an "AS IS" BASIS, 12 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # * See the License for the specific language governing permissions and 14 | # * limitations under the License. 15 | # *******************************************************************************/ 16 | # 17 | 18 | 19 | import torch 20 | 21 | if torch.version.hip: 22 | from flash_attn import flash_attn_with_kvcache, flash_attn_varlen_func 23 | else: 24 | from vllm.vllm_flash_attn import flash_attn_with_kvcache, flash_attn_varlen_func 25 | from .base import DecodeCaller, PrefillCaller, PrefixPrefillCaller 26 | 27 | 28 | class FlashAttnDecodeCaller(DecodeCaller): 29 | @staticmethod 30 | def make_call_func( 31 | output, 32 | query, 33 | key_cache, 34 | value_cache, 35 | num_seqs, # unused 36 | seq_lens, 37 | max_seq_len, # unused 38 | scale, 39 | block_tables, 40 | alibi_slopes, 41 | kv_cache_dtype, # unused 42 | ): 43 | def transform_kv_cache(x): 44 | out = torch.transpose(x, 1, 3) 45 | out = torch.transpose(out, 2, 3) 46 | return out.contiguous() 47 | 48 | key_cache_flash_attn = transform_kv_cache(key_cache) 49 | value_cache_flash_attn = transform_kv_cache(value_cache) 50 | 51 | q = query.unsqueeze(1) 52 | 53 | if torch.version.hip: 54 | call_func_under_test = lambda: flash_attn_with_kvcache( 55 | q=q, 56 | k_cache=key_cache_flash_attn, 57 | v_cache=value_cache_flash_attn, 58 | softmax_scale=scale, 59 | causal=True, 60 | cache_seqlens=seq_lens, 61 | window_size=(-1, 1), 62 | block_table=block_tables, 63 | softcap=0, 64 | alibi_slopes=alibi_slopes, 65 | ) 66 | else: 67 | call_func_under_test = lambda: flash_attn_with_kvcache( 68 | q=q, 69 | k_cache=key_cache_flash_attn, 70 | v_cache=value_cache_flash_attn, 71 | out=None, 72 | softmax_scale=scale, 73 | causal=True, 74 | cache_seqlens=seq_lens, 75 | window_size=(-1, 1), 76 | block_table=block_tables, 77 | softcap=0, 78 | alibi_slopes=alibi_slopes, 79 | ) 80 | 81 | return call_func_under_test 82 | 83 | @classmethod 84 | def select_output(cls, x, y): 85 | return y.squeeze(1) 86 | 87 | @staticmethod 88 | def requires_allocated_output() -> bool: 89 | return False 90 | 91 | 92 | class FlashAttnPrefillCaller(PrefillCaller): 93 | @staticmethod 94 | def make_call_func( 95 | output, # unused 96 | query, 97 | key_cache, 98 | value_cache, 99 | cu_seqlens_q, 100 | cu_seqlens_k, 101 | max_seqlen_q, 102 | max_seqlen_k, 103 | softmax_scale, 104 | causal, 105 | # kv_cache_dtype, # unused 106 | ): 107 | # q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 108 | # k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 109 | # v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 110 | # cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 111 | # of the sequences in the batch, used to index into q. 112 | # cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 113 | # of the sequences in the batch, used to index into kv. 114 | # max_seqlen_q: int. Maximum query sequence length in the batch. 115 | # max_seqlen_k: int. Maximum key sequence length in the batch. 116 | # out: (total, nheads, headdim). 117 | 118 | def call_and_process_output(): 119 | return flash_attn_varlen_func( 120 | q=query, 121 | k=key_cache, 122 | v=value_cache, 123 | cu_seqlens_q=cu_seqlens_q, 124 | cu_seqlens_k=cu_seqlens_k, 125 | max_seqlen_q=max_seqlen_q, 126 | max_seqlen_k=max_seqlen_k, 127 | softmax_scale=softmax_scale, 128 | causal=causal, 129 | ) 130 | 131 | return call_and_process_output 132 | 133 | @staticmethod 134 | def requires_allocated_output() -> bool: 135 | return False 136 | 137 | 138 | class FlashAttnPrefixPrefillCaller(PrefixPrefillCaller): 139 | @staticmethod 140 | def make_call_func( 141 | output, 142 | query, 143 | key_cache, 144 | value_cache, 145 | key, 146 | value, 147 | block_tables, 148 | seq_lens, 149 | ctx_lens, 150 | query_lens, 151 | start_loc, 152 | seq_start_loc, 153 | softmax_scale, 154 | # kv_cache_dtype, # unused 155 | ): 156 | """ 157 | query: shape = [num_tokens, num_heads, head_size] 158 | key: shape = [num_tokens, num_kv_heads, head_size] 159 | value: shape = [num_tokens, num_kv_heads, head_size] 160 | k_cache = [num_blocks, block_size, num_kv_heads, head_size] 161 | v_cache = [num_blocks, block_size, num_kv_heads, head_size] 162 | Returns: 163 | shape = [num_tokens, num_heads, head_size] 164 | """ 165 | 166 | max_query_len = query_lens.max() 167 | max_seqlen = seq_lens.max() 168 | 169 | if torch.version.hip: 170 | 171 | def call_and_process_output(): 172 | # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) 173 | return flash_attn_varlen_func( 174 | q=query, 175 | k=key_cache, 176 | v=value_cache, 177 | cu_seqlens_q=start_loc, 178 | cu_seqlens_k=seq_start_loc, 179 | max_seqlen_q=max_query_len, 180 | max_seqlen_k=max_seqlen, 181 | softmax_scale=softmax_scale, 182 | causal=True, 183 | block_table=block_tables, 184 | # window_size=(-1, 1), 185 | # softcap=0, 186 | # fa_version=2, # TODO 187 | ) 188 | 189 | else: 190 | 191 | def call_and_process_output(): 192 | # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) 193 | return flash_attn_varlen_func( 194 | q=query, 195 | k=key_cache, 196 | v=value_cache, 197 | out=output, 198 | cu_seqlens_q=start_loc, 199 | max_seqlen_q=max_query_len, 200 | seqused_k=seq_lens, 201 | max_seqlen_k=max_seqlen, 202 | softmax_scale=softmax_scale, 203 | causal=True, 204 | block_table=block_tables, 205 | # window_size=(-1, 1), 206 | # softcap=0, 207 | # fa_version=2, # TODO 208 | ) 209 | 210 | return call_and_process_output 211 | 212 | @staticmethod 213 | def requires_allocated_output() -> bool: 214 | if torch.version.hip: 215 | return False 216 | return True 217 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vllm-triton-backend 2 | 3 | :information_source: This repository was used to develop the now community-maintained [Triton Backend in vLLM V1 (`triton_attn`)](https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/triton_attn.py). We consider the testing and microbenchmark scripts as well as the development tools (UBI container, proton viewer) still useful (and also use it ourselves), but the latest triton attention kernels are now maintained and developed in vLLM: [`vllm/vllm/attention/ops/`](https://github.com/vllm-project/vllm/tree/main/vllm/attention/ops). The kernels contained in this repository `vllm-triton-backend/ibm-triton-lib` are only updated on an unregular basis. 4 | We may archive this repository in the near future. 5 | 6 | 7 | * * * 8 | 9 | 10 | This repo contains: 11 | 12 | - A Triton-only attention backend for vLLM, implemented as [vLLM platform plugin](https://docs.vllm.ai/en/latest/design/plugin_system.html), see [`ibm-triton-lib/ibm_triton_lib/backend`](./ibm-triton-lib/ibm_triton_lib/backend/). 13 | - New Triton kernels that implement different attention algorithms, see [`ibm-triton-lib/ibm_triton_lib/kernels`](./ibm-triton-lib/ibm_triton_lib/kernels/). 14 | - Containerized development environment (vLLM + Triton built from source). 15 | - A microbenchmarking framework for evaluating their performance. 16 | 17 | Triton kernels require autotuning to achieve best possible performance, but naïve autotuning comes with a significant overhead at runtime. Therefore, this repository depends on [triton-dejavu](https://github.com/IBM/triton-dejavu) to reduce the overhead of autotuning to zero while still adapting triton kernels for each platform and request individually. The necessary dejavu data can be found in [`ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data`](./ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/). 18 | 19 | ## How to use 20 | 21 | This repository can be used as microbenchmark framework and as vLLM plugin. In the following, we explain how to [build our container development environment](#1-build), how to [run microbenchmarks](#2-run-microbenchmarks), and how to [run triton-only attention in vllm](#3-run-vllm-triton-only-backend). 22 | 23 | ### 1) build 24 | 25 | To build the docker image: 26 | ``` 27 | git clone --recursive https://github.com/foundation-model-stack/vllm-triton-backend.git 28 | cd vllm-triton-backend 29 | make build 30 | ``` 31 | 32 | Please note that this build process installs the pre-build vllm v0.7.2. 33 | 34 | ### 2) run microbenchmarks 35 | 36 | To run the various benchmark: 37 | ```bash 38 | mkdir results 39 | chmod o+w results 40 | docker run --gpus all -it --rm \ 41 | -v $(pwd)/scripts:/scripts \ 42 | -v $(pwd)/ibm-triton-lib/ibm_triton_lib:/opt/runtime/lib64/python3.12/site-packages/ibm_triton_lib \ 43 | -v $(pwd)/results:/results \ 44 | vllm-triton-backend-$(id -un) /scripts/benchmark.py 45 | ``` 46 | The results of the benchmark are written to the results folder. 47 | One can edit the benchmark scripts and the kernel code without rebuilding the container. 48 | 49 | Since `ibm-triton-lib` is also installed as python package in the vllm-triton-backend image it can be used in python scripts with `import ibm-triton-lib`. 50 | However, if latest version of the `ibm-triton-lib` should be used, without frequently re-building the docker image, it could be mounted in the installed directory, which is currently `/opt/runtime/lib64/python3.12/site-packages/ibm-triton-lib/`, as shown above. Similar applies for the `triton_dejavu` or `vllm` module or the `scripts` folder. 51 | 52 | ### 3) run vllm triton-only backend 53 | 54 | #### Using our container 55 | 56 | To run vLLM with triton-only attention backend after [building our container](#1-build): 57 | ```bash 58 | docker run -it --rm --gpus all /path/to/models:/models vllm-triton-backend-$(id -un):latest -m vllm.entrypoints.openai.api_server --model /models/granite3.1-8b/base/ 59 | 60 | ``` 61 | 62 | #### Outside/stand-alone of our environment 63 | 64 | The Triton-only attention backend can be used within our [Docker container](#dev-environment) or outside. 65 | To install this plugin in any other environment: 66 | ``` 67 | git clone https://github.com/foundation-model-stack/vllm-triton-backend.git 68 | pip install ./vllm-triton-backend 69 | ``` 70 | 71 | If using `ibm-triton-lib` outside from our container, the following needs to be taken into account: 72 | 73 | - at least triton 3.2 is required, and therefore pytorch >= 2.6 74 | - our plugin must be installed after vllm (see [documentation](https://docs.vllm.ai/en/latest/design/plugin_system.html)) 75 | - the vllm-triton-backend depends on [triton-dejavu](https://github.com/IBM/triton-dejavu) 76 | 77 | 78 | ## Dev Environment 79 | 80 | This repo also contains a container aimed at development and using a custom vllm build. 81 | This development container can be build with: 82 | ``` 83 | make clean 84 | make dev 85 | ``` 86 | 87 | Please note that this build process is designed to avoid lengthy re-compilation of the CUDA/C++ sources of vllm (which could take up to 30min). Therefore, the current setup triggers a rebuild of vllm **only** if `git add`, `git rm`, or `git commit` affecting files inside the vllm submodule are executed (or the complete submodule is updated, c.f. [`git-updated-index`](https://git-scm.com/docs/git-update-index)). If files in the vllm submodule are "just" changed (and *not* marked for commit or committed), only the copy of the vllm python files into the site-packages happens during build of the image. This minor inconsistency during `make build` is intended, since our focus are triton kernels, not debugging vLLM CUDA. 88 | 89 | To ensure a clean build (that reflects all changes to local files), `make clean` can be executed, which forces a re-build of vLLM C sources (if not the exact build is already present in the docker cache). 90 | 91 | The development image is also based on `ubi9-minimal` and the vllm and triton builds are isolated, both from each other, and the runtime. 92 | This allows us to ensure that runtime dependencies are minimal, and allows us to clearly see the different places that CUDA gets pulled in. 93 | 94 | During build, vLLM requires a system installation of the CUDA toolkit. We install it from the system package manager. 95 | On the other hand, Triton automatically downloads its own version of CUDA and PTX during build, we do not control this. 96 | It does not require CUDA to be installed in the system or otherwise. 97 | 98 | At runtime, there are three different CUDA-related things that we need to be aware of: 99 | 1. The CUDA runtime that gets installed via pip (e.g., due to pytorch dependencies). 100 | 2. The PTX version that is bundled inside the Triton wheel. 101 | 3. The CUDA driver version that is running on the host machine (e.g., outside of docker container). 102 | 103 | All 3 of these versions can potentially be different, but need to be compatible. 104 | 105 | See figure below: 106 | 107 | ![dev environment](./doc/dev-env.png) 108 | 109 | 110 | ## Improved Proton Viewer 111 | 112 | This repo contains a custom version of tritons proton viewer: [`./scripts/roofline/proton_viewer.py`](./scripts/roofline/proton_viewer.py) 113 | 114 | The main differences are: 115 | 1. It adds a real roofline analysis by introducing the metrics `util_flops` and `util_bytes`. 116 | 2. It fixes the confusion of the metrics `flop/s` vs `flops`. 117 | - `flop/s`: `flops_per_invocations * number_of_invocations / duration_of_all_invocations` 118 | - `flops`: `flops_per_invocations * number_of_invocations` 119 | 3. It adds the support for average flops and average flop/s. 120 | 4. It makes the list of available metrics informative: 121 | 122 | ``` 123 | $ python3 /scripts/roofline/proton_viewer.py -l ./matmul.hatchet 124 | Available raw metrics: 125 | - bytes 126 | - count 127 | - flops16 128 | - time 129 | Derivable metrics: 130 | - {g,t,avg_,avg_g,avg_t}byte/s 131 | - {g,t,avg_,avg_g,avg_t}flop/s 132 | - {g,t,avg_,avg_g,avg_t}flop16/s 133 | - {g,t,avg_,avg_g,avg_t}flops 134 | - {g,t,avg_,avg_g,avg_t}flops16 135 | - avg_time/[s,ms,us,ns] 136 | - util 137 | - util_flops 138 | - util_bytes 139 | - bytes/% 140 | - count/% 141 | - flops16/% 142 | - time/% 143 | (All values without 'avg_' are cumulative.) 144 | ``` 145 | -------------------------------------------------------------------------------- /third_party/vedantroy_paged_attention.py: -------------------------------------------------------------------------------- 1 | # based on https://github.ibm.com/TPA/triton-paged-attention/blob/main/test_working.py 2 | 3 | from typing import List, Optional, Tuple, Union, NamedTuple 4 | 5 | import torch 6 | import triton 7 | import triton.language as tl 8 | 9 | from ibm_triton_lib.utils.triton_utils import unpack_grid 10 | 11 | 12 | gpu_name = torch.cuda.get_device_name() 13 | 14 | 15 | def metadata_fn( 16 | grid: tuple, 17 | metadata: NamedTuple, 18 | args: dict, 19 | ): 20 | grid_x, grid_y, grid_z = unpack_grid(grid) 21 | num_warps = metadata.num_warps 22 | num_stages = metadata.num_stages 23 | cluster_x, cluster_y, cluster_z = metadata.cluster_dims 24 | shared_memory = metadata.shared 25 | # args just contains NON-CONSTANT arguments 26 | num_seqs, num_query_heads, head_size = args["query_ptr"].shape 27 | num_blocks, num_kv_heads, _, block_size = args["key_cache_ptr"].shape 28 | _, max_num_blocks_per_seq = args["block_tables_ptr"].shape 29 | # num tokens are treated as batch 30 | dtype_size = args["query_ptr"].element_size() 31 | _, max_context_len, _, _ = args["scratchpad_key_ptr"].shape 32 | 33 | num_bytes = ( 34 | (dtype_size * num_seqs * num_query_heads * head_size) 35 | + (dtype_size * num_blocks * num_kv_heads * head_size * block_size * 2) 36 | + num_seqs * max_num_blocks_per_seq * dtype_size # dtype size? not ptr size? 37 | + ( 38 | num_seqs * max_context_len * num_query_heads * head_size * dtype_size * 2 39 | ) # scratchpad 40 | ) 41 | num_flops = num_blocks * num_kv_heads * head_size * block_size * 7 # TODO? 42 | return { 43 | "name": f"triton_vedantroy_paged_attention_1_____", 44 | "flops16": num_flops, 45 | "bytes": num_bytes, 46 | } 47 | 48 | 49 | @triton.jit(launch_metadata=metadata_fn) 50 | def paged_attention_v1( 51 | # need these b/c we can't use view/reshape 52 | scratchpad_key_ptr, # [num_seqs, max_context_len, num_heads, head_size] 53 | scratchpad_value_ptr, # [num_seqs, max_context_len, num_heads, head_size] 54 | output_ptr, # [num_seqs, num_query_heads, head_size] 55 | query_ptr, # [num_seqs, num_query_heads, head_size] 56 | key_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] 57 | value_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] 58 | block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] 59 | context_lens_ptr, # [num_seqs] 60 | scale, # float32 61 | num_seqs, # int 62 | num_heads, # int 63 | cache_block_stride, # int 64 | MAX_CONTEXT_LEN: tl.constexpr, # int 65 | BLOCK_SIZE: tl.constexpr, # int 66 | HEAD_SIZE: tl.constexpr, # int, must be power of 2 67 | MAX_NUM_BLOCKS_PER_SEQ: tl.constexpr, # int, must be power of 2 68 | ): 69 | seq_idx = tl.program_id(0) 70 | head_idx = tl.program_id(1) 71 | 72 | query_offset = seq_idx * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE 73 | query_head = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE)) 74 | block_table_offset = seq_idx * MAX_NUM_BLOCKS_PER_SEQ 75 | context_len = tl.load(context_lens_ptr + seq_idx) 76 | 77 | for tok_idx in range(0, context_len): 78 | logical_block_idx = tok_idx // BLOCK_SIZE 79 | physical_block_idx = tl.load( 80 | block_tables_ptr + block_table_offset + logical_block_idx 81 | ) 82 | 83 | start_of_block_offset = ( 84 | physical_block_idx * cache_block_stride + head_idx * HEAD_SIZE * BLOCK_SIZE 85 | ) 86 | tok_idx_within_block = tok_idx % BLOCK_SIZE 87 | tok_offsets = ( 88 | start_of_block_offset 89 | + BLOCK_SIZE * tl.arange(0, HEAD_SIZE) 90 | + tok_idx_within_block 91 | ) 92 | 93 | tok_key = tl.load(key_cache_ptr + tok_offsets) 94 | tok_value = tl.load(value_cache_ptr + tok_offsets) 95 | 96 | scratchpad_offset = ( 97 | seq_idx * (MAX_CONTEXT_LEN * num_heads * HEAD_SIZE) 98 | + tok_idx * (num_heads * HEAD_SIZE) 99 | + head_idx * HEAD_SIZE 100 | ) 101 | tl.store( 102 | scratchpad_key_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), tok_key 103 | ) 104 | tl.store( 105 | scratchpad_value_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), 106 | tok_value, 107 | ) 108 | 109 | # TODO: Not sure if this is necessary 110 | tl.debug_barrier() 111 | 112 | # scratchpad_key_ptr, # [num_seqs, max_context_len, num_heads, head_size] 113 | start_seq_offset = (MAX_CONTEXT_LEN * num_heads * HEAD_SIZE) * seq_idx 114 | start_tok_offset = ( 115 | start_seq_offset 116 | + tl.arange(0, MAX_CONTEXT_LEN) * (num_heads * HEAD_SIZE) 117 | + head_idx * HEAD_SIZE 118 | ) 119 | 120 | # [seq_len, head_size] 121 | # zero out keys that aren't part of the sequence 122 | mask = tl.arange(0, MAX_CONTEXT_LEN)[:, None] < context_len 123 | kv_offs = start_tok_offset[:, None] + tl.arange(0, HEAD_SIZE)[None, :] 124 | keys = tl.load(scratchpad_key_ptr + kv_offs, mask=mask, other=0.0) 125 | values = tl.load(scratchpad_value_ptr + kv_offs, mask=mask, other=0.0) 126 | 127 | # keys shape [seq_len x head_size], query shape = [head_size] 128 | # Can't do below b/c minimum size on all dimensions is 16 129 | # scores = tl.dot(query_head[None, :], keys.T) 130 | scores = scale * tl.sum(keys * query_head[None, :], axis=1) 131 | 132 | # This mask is necessary b/c even though we mask out the keys on load 133 | # that just results in 0s in the attention dot product, 134 | # which then get softmaxed and result in non-zero values 135 | # in the softmax output (which is wrong) 136 | # -inf guarantees that the softmax output will be 0 for masked values 137 | mask = tl.full([MAX_CONTEXT_LEN], -float("inf"), dtype=tl.float32) 138 | cond = tl.arange(0, MAX_CONTEXT_LEN) < context_len 139 | scores_masked = tl.where(cond, scores, mask) 140 | 141 | # do a numerically stable softmax on the scores 142 | scores_minus_max = scores_masked - tl.max(scores_masked, axis=0) 143 | numerator = tl.exp(scores_minus_max) 144 | denominator = tl.sum(numerator, axis=0) 145 | logits = numerator / denominator 146 | 147 | # output_offset = seq_idx * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE 148 | # tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), logits) 149 | 150 | weighted_values = tl.sum(values * logits[:, None], axis=0) 151 | 152 | output_offset = seq_idx * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE 153 | 154 | tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), weighted_values) 155 | 156 | 157 | def paged_attention_triton_v1( 158 | output, 159 | query, 160 | key_cache, 161 | value_cache, 162 | scale, 163 | block_tables, 164 | context_lens, 165 | block_size, 166 | num_seqs, 167 | seq_lens, 168 | num_query_heads, 169 | max_seq_len, 170 | max_num_blocks_per_seq, 171 | head_size, 172 | num_kv_heads, 173 | scratchpad_key, 174 | scratchpad_value, 175 | ): 176 | 177 | paged_attention_v1[(num_seqs, num_query_heads)]( 178 | scratchpad_key_ptr=scratchpad_key, 179 | scratchpad_value_ptr=scratchpad_value, 180 | output_ptr=output, 181 | query_ptr=query, 182 | key_cache_ptr=key_cache, 183 | value_cache_ptr=value_cache, 184 | block_tables_ptr=block_tables, 185 | context_lens_ptr=context_lens, 186 | scale=scale, 187 | num_seqs=num_seqs, 188 | num_heads=num_query_heads, 189 | cache_block_stride=key_cache.stride(0), 190 | MAX_CONTEXT_LEN=max_seq_len, 191 | BLOCK_SIZE=block_size, 192 | HEAD_SIZE=head_size, 193 | MAX_NUM_BLOCKS_PER_SEQ=max_num_blocks_per_seq, 194 | ) 195 | 196 | block_tables_lst = block_tables.cpu().tolist() 197 | seq_lens_lst = seq_lens.cpu().tolist() 198 | for i in range(num_seqs): 199 | q = query[i].unsqueeze(0) 200 | block_table = block_tables_lst[i] 201 | seq_len = int(seq_lens_lst[i]) 202 | 203 | keys_lst: List[torch.Tensor] = [] 204 | values_lst: List[torch.Tensor] = [] 205 | for j in range(seq_len): 206 | block_number = int(block_table[j // block_size]) 207 | block_offset = j % block_size 208 | k = key_cache[block_number, :, :, block_offset] 209 | k = k.reshape(num_kv_heads, head_size) 210 | keys_lst.append(k) 211 | 212 | v = value_cache[block_number, :, :, block_offset] 213 | values_lst.append(v) 214 | keys = torch.stack(keys_lst, dim=0) 215 | values = torch.stack(values_lst, dim=0) 216 | 217 | # why? 218 | # torch.testing.assert_close(scratchpad_key[i], keys) 219 | # torch.testing.assert_close(scratchpad_value[i], values) 220 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ## Global Args ################################################################# 2 | ARG BASE_UBI_IMAGE_TAG=9.4 3 | ARG PYTHON_VERSION=3.12 4 | ARG MAX_JOBS=64 5 | ARG PIP_VLLM_VERSION=0.8.1 6 | 7 | ARG VLLM_SOURCE=pip 8 | # or VLLM_SOURCE=custom 9 | 10 | ## Base Layer ################################################################## 11 | FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base 12 | ARG PYTHON_VERSION 13 | ENV PYTHON_VERSION=${PYTHON_VERSION} 14 | RUN microdnf -y update && microdnf install -y \ 15 | python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel \ 16 | gzip tar git\ 17 | && microdnf clean all 18 | 19 | WORKDIR /workspace 20 | 21 | ENV LANG=C.UTF-8 \ 22 | LC_ALL=C.UTF-8 23 | 24 | ## Common Builder ################################################################# 25 | FROM base AS common-builder 26 | ARG PYTHON_VERSION 27 | 28 | ENV VIRTUAL_ENV=/opt/build 29 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 30 | 31 | # create new venv to build vllm 32 | RUN python${PYTHON_VERSION} -m venv $VIRTUAL_ENV \ 33 | && pip install --no-cache -U pip wheel uv 34 | 35 | # install compiler cache to speed up compilation leveraging local or remote caching 36 | # git is required for the cutlass kernels 37 | RUN rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && rpm -ql epel-release && microdnf install -y ccache && microdnf clean all 38 | 39 | ## vLLM Builder ################################################################# 40 | FROM common-builder AS vllm-builder_custom 41 | ARG MAX_JOBS 42 | 43 | # install CUDA 44 | RUN curl -Lo /etc/yum.repos.d/cuda-rhel9.repo \ 45 | https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo 46 | 47 | RUN microdnf install -y \ 48 | cuda-nvcc-12-4 cuda-nvtx-12-4 cuda-libraries-devel-12-4 tar && \ 49 | microdnf clean all 50 | 51 | ENV CUDA_HOME="/usr/local/cuda" \ 52 | PATH="${CUDA_HOME}/bin:${PATH}" \ 53 | LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH}" 54 | 55 | # install build dependencies 56 | RUN --mount=type=cache,target=/root/.cache/pip \ 57 | --mount=type=cache,target=/root/.cache/uv \ 58 | --mount=type=bind,source=vllm/requirements/build.txt,target=requirements-build.txt \ 59 | uv pip install -r requirements-build.txt 60 | 61 | # set env variables for build 62 | ENV PATH=/usr/local/cuda/bin:$PATH 63 | ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" 64 | ENV VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real" 65 | ENV MAX_JOBS=${MAX_JOBS} 66 | ENV NVCC_THREADS=2 67 | ENV VLLM_INSTALL_PUNICA_KERNELS=1 68 | 69 | # copy git stuff 70 | WORKDIR /workspace/.git 71 | COPY all-git.tar . 72 | RUN tar -xf all-git.tar && \ 73 | rm all-git.tar 74 | 75 | # copy tarball of last commit 76 | WORKDIR /workspace/vllm 77 | 78 | COPY vllm-all.tar . 79 | RUN tar -xf vllm-all.tar && \ 80 | rm vllm-all.tar 81 | 82 | # build vllm wheel 83 | ENV CCACHE_DIR=/root/.cache/ccache 84 | RUN --mount=type=cache,target=/root/.cache/ccache \ 85 | --mount=type=bind,source=vllm/.git,target=/workspace/vllm/.git \ 86 | env CFLAGS="-march=haswell" \ 87 | CXXFLAGS="$CFLAGS $CXXFLAGS" \ 88 | CMAKE_BUILD_TYPE=Release \ 89 | python3 setup.py bdist_wheel --dist-dir=/workspace/ 90 | 91 | ## fake vLLM Builder ################################################################# 92 | FROM common-builder AS vllm-builder_pip 93 | ARG PIP_VLLM_VERSION 94 | 95 | RUN --mount=type=cache,target=/root/.cache/pip \ 96 | pip download vllm==${PIP_VLLM_VERSION} --no-deps 97 | 98 | ## merge vLLM Builder ################################################################# 99 | FROM vllm-builder_${VLLM_SOURCE} AS vllm-builder 100 | 101 | RUN ls -al /workspace/vllm-* 102 | 103 | ## Triton Builder ################################################################# 104 | FROM common-builder AS triton-builder 105 | 106 | # Triton build deps 107 | RUN --mount=type=cache,target=/root/.cache/pip \ 108 | --mount=type=cache,target=/root/.cache/uv \ 109 | uv pip install ninja cmake wheel pybind11 setuptools 110 | 111 | COPY triton triton 112 | 113 | WORKDIR /workspace/triton/python 114 | 115 | # needed to build triton 116 | RUN microdnf install -y zlib-devel gcc gcc-c++ \ 117 | && microdnf clean all 118 | 119 | # Build Triton 120 | ENV TRITON_BUILD_WITH_CCACHE=true 121 | ENV CCACHE_DIR=/root/.cache/ccache 122 | RUN --mount=type=cache,target=/root/.cache/ccache \ 123 | python3 setup.py bdist_wheel --dist-dir=/workspace/ 124 | 125 | ## Runtime ################################################################# 126 | FROM base AS runtime 127 | 128 | ENV VIRTUAL_ENV=/opt/runtime 129 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 130 | 131 | # create new venv to build vllm 132 | RUN python${PYTHON_VERSION} -m venv $VIRTUAL_ENV \ 133 | && pip install --no-cache -U pip wheel uv 134 | 135 | # swig is required by triton-dejavu (SMAC optimizer) 136 | # SWIG rpm not available for RHEL9 137 | RUN microdnf install -y wget tar zlib-devel automake g++ && microdnf clean all 138 | RUN wget https://downloads.sourceforge.net/project/swig/swig/swig-3.0.12/swig-3.0.12.tar.gz && \ 139 | tar -xzf swig-3.0.12.tar.gz && \ 140 | cd swig-3.0.12 && \ 141 | bash autogen.sh && \ 142 | wget https://downloads.sourceforge.net/project/pcre/pcre/8.45/pcre-8.45.tar.gz && \ 143 | bash Tools/pcre-build.sh && \ 144 | bash ./configure && \ 145 | make && \ 146 | make install 147 | 148 | WORKDIR /workspace 149 | 150 | # Install vllm 151 | COPY --from=vllm-builder /workspace/*.whl . 152 | RUN --mount=type=cache,target=/root/.cache/pip \ 153 | --mount=type=cache,target=/root/.cache/uv \ 154 | uv pip install vllm-*.whl 155 | 156 | # copy python stuff of vllm 157 | ARG VLLM_SOURCE 158 | RUN mkdir -p /workspace/vllm 159 | COPY vllm/vllm /workspace/vllm 160 | RUN if [ "$VLLM_SOURCE" = "custom" ] ; then cp -r /workspace/vllm/* ${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/vllm/ \ 161 | && cp -r /workspace/vllm/* ${VIRTUAL_ENV}/lib64/python${PYTHON_VERSION}/site-packages/vllm/; fi 162 | RUN rm -rf /workspace/vllm 163 | 164 | # to avaoid incompatibility with our custom triton build 165 | # see also https://github.com/vllm-project/vllm/issues/12219 166 | # RUN uv pip install -U 'torch>=2.6' 'torchvision>=0.21' 'torchaudio>=2.6' 167 | 168 | # Install Triton (will replace version that vllm/pytorch installed) 169 | COPY --from=triton-builder /workspace/*.whl . 170 | RUN --mount=type=cache,target=/root/.cache/pip \ 171 | --mount=type=cache,target=/root/.cache/uv \ 172 | uv pip install triton-*.whl 173 | 174 | # force using the python venv's cuda runtime libraries 175 | ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_nvrtc/lib:${LD_LIBRARY_PATH}" 176 | ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_runtime/lib:${LD_LIBRARY_PATH}" 177 | ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvtx/lib:${LD_LIBRARY_PATH}" 178 | ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_cupti/lib:${LD_LIBRARY_PATH}" 179 | 180 | # copy requirements explicitly before to avoid reinstall 181 | COPY triton-dejavu/requirements-opt.txt dejavu-requirements-opt.txt 182 | RUN --mount=type=cache,target=/root/.cache/pip \ 183 | --mount=type=cache,target=/root/.cache/uv \ 184 | uv pip install -r dejavu-requirements-opt.txt \ 185 | && rm -f dejavu-requirements-opt.txt 186 | 187 | # dejavu 188 | COPY triton-dejavu triton-dejavu 189 | RUN --mount=type=cache,target=/root/.cache/pip \ 190 | --mount=type=cache,target=/root/.cache/uv \ 191 | uv pip install ./triton-dejavu/ \ 192 | && rm -rf ./triton-dejavu/ 193 | 194 | # Install IBM kernels and vllm plugin 195 | # must be after vllm! 196 | COPY ibm-triton-lib ibm-triton-lib 197 | RUN --mount=type=cache,target=/root/.cache/pip \ 198 | --mount=type=cache,target=/root/.cache/uv \ 199 | uv pip install ./ibm-triton-lib \ 200 | && rm -rf ibm-triton-lib 201 | 202 | ## Benchmarking ################################################################# 203 | FROM runtime AS benchmark 204 | 205 | WORKDIR /workspace 206 | 207 | RUN microdnf install -y git nano gcc vim \ 208 | && microdnf clean all 209 | 210 | # TODO: make cuda version configurable 211 | RUN curl -Lo /etc/yum.repos.d/cuda-rhel9.repo \ 212 | https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo 213 | RUN microdnf install -y nsight-compute-2025.1.0 && microdnf clean all 214 | 215 | RUN curl -Lo /tmp/nsight-package.rpm \ 216 | https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_1/NsightSystems-linux-cli-public-2025.1.1.103-3542797.rpm 217 | 218 | RUN rpm -ivh /tmp/nsight-package.rpm && rm -f /tmp/nsight-package.rpm 219 | 220 | RUN pip install nvtx 221 | 222 | # Linking the Nsight Compute to the venv 223 | RUN ln -s /opt/nvidia/nsight-compute/2025.1.0/target/linux-desktop-glibc_2_11_3-x64/ncu $VIRTUAL_ENV/bin/ncu 224 | 225 | RUN --mount=type=cache,target=/root/.cache/pip \ 226 | --mount=type=cache,target=/root/.cache/uv \ 227 | uv pip install pytest llnl-hatchet debugpy 228 | 229 | # Install FlashInfer 230 | RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ 231 | echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment 232 | 233 | RUN --mount=type=cache,target=/root/.cache/pip \ 234 | . /etc/environment && \ 235 | python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl 236 | 237 | RUN ln -s ${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_cupti/lib/libcupti.so.12 ${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_cupti/lib/libcupti.so 238 | 239 | RUN --mount=type=cache,target=/root/.cache/pip \ 240 | --mount=type=cache,target=/root/.cache/uv \ 241 | git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness && cd lm-evaluation-harness && uv pip install . 242 | 243 | RUN git clone --depth 1 https://github.com/IBM/fmwork.git 244 | 245 | ENV STORE_TEST_RESULT_PATH=/results 246 | 247 | # copy vllm benchmarks and tests 248 | COPY vllm/benchmarks benchmarks 249 | COPY vllm/tests tests 250 | COPY ShareGPT_V3_unfiltered_cleaned_split.json ShareGPT_V3_unfiltered_cleaned_split.json 251 | 252 | # Copy thid-party kernels and insert into path 253 | COPY third_party third_party 254 | ENV PYTHONPATH /workspace 255 | 256 | # see https://github.com/IBM/triton-dejavu?tab=readme-ov-file#environment-variables 257 | ENV TRITON_PRINT_AUTOTUNING=1 258 | ENV TRITON_DEJAVU_DEBUG=1 259 | # set as default 260 | ENV TRITON_DEJAVU_STORAGE=/workspace 261 | ENV NGL_EXP_FALLBACK=next 262 | ENV TRITON_DEJAVU_FORCE_FALLBACK=1 263 | ENV TRITON_DEJAVU_TAG='default' 264 | ENV TRITON_DEJAVU_HASH_SEARCH_PARAMS=0 265 | 266 | # open debugpy port 267 | EXPOSE 5679 268 | 269 | ENTRYPOINT ["python"] 270 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/rocm_6.3.1/gpu_AMD_Instinct_MI250X_MI250/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8/default/cache.json: -------------------------------------------------------------------------------- 1 | { 2 | "signature": "JITFunction(ibm_triton_lib.kernels.triton_flash_attention:attn_fwd)", 3 | "total_bench_time_s": 86906.62447404861, 4 | "evaluated_configs": 450, 5 | "keys": [ 6 | "HQ", 7 | "HK", 8 | "IS_CAUSAL", 9 | "dropout_p", 10 | "BLOCK_DMODEL", 11 | "stride_qz", 12 | "stride_qh", 13 | "stride_qm", 14 | "stride_qk", 15 | "stride_kz", 16 | "stride_kh", 17 | "stride_kn", 18 | "stride_kk", 19 | "stride_vz", 20 | "stride_vh", 21 | "stride_vn", 22 | "stride_vk", 23 | "stride_oz", 24 | "stride_oh", 25 | "stride_om", 26 | "stride_on", 27 | "stride_bz", 28 | "stride_bh", 29 | "stride_bm", 30 | "stride_bn", 31 | "stride_az", 32 | "stride_ah", 33 | "MAX_SEQLENS_Q", 34 | "MAX_SEQLENS_K", 35 | "VARLEN", 36 | "ACTUAL_BLOCK_DMODEL" 37 | ], 38 | "cache": { 39 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '16', '16', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 16, BLOCK_N: 16, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 40 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '16', '16', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 16, BLOCK_N: 16, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 41 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 16, BLOCK_N: 16, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 42 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 16, BLOCK_N: 16, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 43 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 256, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 44 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 256, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 45 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 46 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 47 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 48 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 49 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 256, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 50 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 256, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" 51 | }, 52 | "timings": { 53 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '16', '16', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 54 | 0.004207286983728409 55 | ], 56 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '16', '16', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 57 | 0.004182395525276661 58 | ], 59 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 60 | 0.01809287816286087 61 | ], 62 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 63 | 0.017839614301919937 64 | ], 65 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 66 | 0.09088581800460815 67 | ], 68 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 69 | 0.088987797498703 70 | ], 71 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 72 | 0.23396557569503784 73 | ], 74 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 75 | 0.23347480595111847 76 | ], 77 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 78 | 0.6691922545433044 79 | ], 80 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 81 | 0.6695101261138916 82 | ], 83 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 84 | 2.025791645050049 85 | ], 86 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 87 | 2.01798415184021 88 | ] 89 | }, 90 | "timings_data": { 91 | "labels": [ 92 | "ms" 93 | ], 94 | "rep_t_ms": 100, 95 | "warmup_t_ms": 25, 96 | "cuda_graphs": true 97 | } 98 | } -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/cuda_12.4/gpu_NVIDIA_A100-SXM4-80GB/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8/default/cache.json: -------------------------------------------------------------------------------- 1 | { 2 | "signature": "JITFunction(ibm_triton_lib.kernels.triton_flash_attention:attn_fwd)", 3 | "total_bench_time_s": 211706.17069911957, 4 | "evaluated_configs": 450, 5 | "keys": [ 6 | "HQ", 7 | "HK", 8 | "IS_CAUSAL", 9 | "dropout_p", 10 | "BLOCK_DMODEL", 11 | "stride_qz", 12 | "stride_qh", 13 | "stride_qm", 14 | "stride_qk", 15 | "stride_kz", 16 | "stride_kh", 17 | "stride_kn", 18 | "stride_kk", 19 | "stride_vz", 20 | "stride_vh", 21 | "stride_vn", 22 | "stride_vk", 23 | "stride_oz", 24 | "stride_oh", 25 | "stride_om", 26 | "stride_on", 27 | "stride_bz", 28 | "stride_bh", 29 | "stride_bm", 30 | "stride_bn", 31 | "stride_az", 32 | "stride_ah", 33 | "MAX_SEQLENS_Q", 34 | "MAX_SEQLENS_K", 35 | "VARLEN", 36 | "ACTUAL_BLOCK_DMODEL" 37 | ], 38 | "cache": { 39 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 40 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 41 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 42 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 43 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 44 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 45 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 46 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 47 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 64, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 48 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 49 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 50 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 51 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 52 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 53 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" 54 | }, 55 | "timings": { 56 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 57 | 0.005401020869612694 58 | ], 59 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 60 | 0.005471085663884878 61 | ], 62 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 63 | 0.0075958045199513435 64 | ], 65 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 66 | 0.007605006452649832 67 | ], 68 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 69 | 0.011812349781394005 70 | ], 71 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 72 | 0.011950820684432983 73 | ], 74 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 75 | 0.019297460094094276 76 | ], 77 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 78 | 0.017475301399827003 79 | ], 80 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 81 | 0.038042228668928146 82 | ], 83 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 84 | 0.038091544061899185 85 | ], 86 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 87 | 0.10096532106399536 88 | ], 89 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 90 | 0.09481953084468842 91 | ], 92 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 93 | 0.2949035167694092 94 | ], 95 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 96 | 0.29237720370292664 97 | ], 98 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 99 | 0.9560787677764893 100 | ] 101 | }, 102 | "timings_data": { 103 | "labels": [ 104 | "ms" 105 | ], 106 | "rep_t_ms": 100, 107 | "warmup_t_ms": 25, 108 | "cuda_graphs": true 109 | } 110 | } -------------------------------------------------------------------------------- /ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-a70f97e8b3e7aaf9f4a4f7e850b935d2d1b3ad8cd6ad1d0843bb426e13694ae9/default/cache.json: -------------------------------------------------------------------------------- 1 | { 2 | "signature": "JITFunction(ibm_triton_lib.kernels.triton_flash_attention:attn_fwd)", 3 | "total_bench_time_s": 86841.6919836998, 4 | "evaluated_configs": 240, 5 | "keys": [ 6 | "HQ", 7 | "HK", 8 | "IS_CAUSAL", 9 | "dropout_p", 10 | "BLOCK_DMODEL", 11 | "stride_qz", 12 | "stride_qh", 13 | "stride_qm", 14 | "stride_qk", 15 | "stride_kz", 16 | "stride_kh", 17 | "stride_kn", 18 | "stride_kk", 19 | "stride_vz", 20 | "stride_vh", 21 | "stride_vn", 22 | "stride_vk", 23 | "stride_oz", 24 | "stride_oh", 25 | "stride_om", 26 | "stride_on", 27 | "stride_bz", 28 | "stride_bh", 29 | "stride_bm", 30 | "stride_bn", 31 | "stride_az", 32 | "stride_ah", 33 | "MAX_SEQLENS_Q", 34 | "MAX_SEQLENS_K", 35 | "VARLEN", 36 | "ACTUAL_BLOCK_DMODEL" 37 | ], 38 | "cache": { 39 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 40 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 41 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 42 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 43 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 44 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 45 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 46 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 47 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 48 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 49 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 50 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 51 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 52 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", 53 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" 54 | }, 55 | "timings": { 56 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 57 | 0.0036645731888711452 58 | ], 59 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 60 | 0.0036076440010219812 61 | ], 62 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 63 | 0.00487453443929553 64 | ], 65 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 66 | 0.0048555657267570496 67 | ], 68 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 69 | 0.006982282269746065 70 | ], 71 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 72 | 0.006992792245000601 73 | ], 74 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 75 | 0.010331092402338982 76 | ], 77 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 78 | 0.010227189399302006 79 | ], 80 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 81 | 0.015056964010000229 82 | ], 83 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 84 | 0.014920394867658615 85 | ], 86 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 87 | 0.04663630574941635 88 | ], 89 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 90 | 0.04339428246021271 91 | ], 92 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 93 | 0.1311214417219162 94 | ], 95 | "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 96 | 0.12436506152153015 97 | ], 98 | "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ 99 | 0.39030927419662476 100 | ] 101 | }, 102 | "timings_data": { 103 | "labels": [ 104 | "ms" 105 | ], 106 | "rep_t_ms": 100, 107 | "warmup_t_ms": 25, 108 | "cuda_graphs": true 109 | } 110 | } --------------------------------------------------------------------------------