├── .clang-format ├── .devcontainer ├── Dockerfile ├── devcontainer.json └── install │ ├── install_python.sh │ └── install_python_packages.sh ├── .github ├── pull_request_template.md └── workflows │ ├── build-doc.yml │ ├── pre-commit.yml │ ├── release-ci-docker.yml │ ├── release_wheel.yml │ ├── release_wheel_aarch64.yml │ └── release_wheel_sglang.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CMakeLists.txt ├── Jenkinsfile ├── LICENSE ├── README.md ├── aot_build_utils ├── __init__.py ├── generate.py ├── generate_aot_default_additional_params_header.py ├── generate_batch_paged_decode_inst.py ├── generate_batch_paged_prefill_inst.py ├── generate_batch_paged_prefill_sm90_inst.py ├── generate_batch_ragged_prefill_inst.py ├── generate_batch_ragged_prefill_sm90_inst.py ├── generate_dispatch_inc.py ├── generate_single_decode_inst.py ├── generate_single_prefill_inst.py ├── generate_single_prefill_sm90_inst.py ├── generate_sm90.py └── literal_map.py ├── benchmarks ├── bench_append_paged_kv_cache.py ├── bench_append_paged_mla_kv_cache.py ├── bench_batch_decode.py ├── bench_blackwell_attention.py ├── bench_deepseek_mla.py ├── bench_fused_add_rmsnorm.py ├── bench_grouped_gemm.py ├── bench_groupwise_gemm_fp8_blackwell.py ├── bench_groupwise_grouped_gemm_fp8_blackwell.py ├── bench_hopper_attention.py ├── bench_hopper_fp8_attention.py ├── bench_mixed_attention.py ├── bench_pad_ragged_tensor.py ├── bench_persistent_gemm.py ├── bench_renorm.py ├── bench_rope.py ├── bench_sampling.py └── bench_trtllm_fmha.py ├── ci ├── bash.sh └── scripts │ └── jenkins │ ├── git_skip_ci.py │ ├── git_skip_ci_globs.py │ ├── git_utils.py │ └── retry.sh ├── cmake ├── config.cmake ├── modules │ └── FindThrust.cmake └── utils │ └── Utils.cmake ├── csrc ├── activation.cu ├── aot_extension_utils.h ├── batch_decode.cu ├── batch_decode_config.inc ├── batch_decode_customize_config.jinja ├── batch_decode_jit_pybind.cu ├── batch_decode_kernel_inst.jinja ├── batch_decode_mla_config.jinja ├── batch_decode_mla_cute_sm80.cu ├── batch_decode_mla_plan.cu ├── batch_decode_mla_pybind.cu ├── batch_decode_mla_run.cu ├── batch_mla_config.jinja ├── batch_mla_plan.cu ├── batch_mla_pybind.cu ├── batch_mla_run.cu ├── batch_mla_sm90_plan.cu ├── batch_mla_sm90_pybind.cu ├── batch_mla_sm90_run.cu ├── batch_prefill.cu ├── batch_prefill_config.inc ├── batch_prefill_customize_config.jinja ├── batch_prefill_fp8_paged_sm90_kernel_inst.jinja ├── batch_prefill_fp8_ragged_sm90_kernel_inst.jinja ├── batch_prefill_fp8_sm90.cu ├── batch_prefill_jit_pybind.cu ├── batch_prefill_paged_kernel_inst.jinja ├── batch_prefill_paged_sm90_kernel_inst.jinja ├── batch_prefill_ragged_kernel_inst.jinja ├── batch_prefill_ragged_sm90_kernel_inst.jinja ├── batch_prefill_sm90.cu ├── batch_prefill_sm90_config.inc ├── batch_prefill_sm90_customize_config.jinja ├── batch_prefill_sm90_jit_pybind.cu ├── blackwell_fmha_plan.cu ├── bmm_fp8.cu ├── cascade.cu ├── custom_all_reduce.cu ├── cutlass_mla.cu ├── flashinfer_cascade_ops.cu ├── flashinfer_comm_ops.cu ├── flashinfer_gemm_ops.cu ├── flashinfer_gemm_sm90_ops.cu ├── flashinfer_mla_ops.cu ├── flashinfer_norm_ops.cu ├── flashinfer_ops.cu ├── flashinfer_ops_sm90.cu ├── flashinfer_page_ops.cu ├── flashinfer_quantization_ops.cu ├── flashinfer_rope_ops.cu ├── flashinfer_sampling_ops.cu ├── fmha_cutlass_sm100.cu ├── fmha_cutlass_sm100_pybind.cu ├── fused_moe │ └── cutlass_backend │ │ ├── cutlass_fused_moe_instantiation.cu │ │ ├── cutlass_fused_moe_kernels.cuh │ │ └── flashinfer_cutlass_fused_moe_sm100_ops.cu ├── gemm_groupwise_sm100.cu ├── gemm_sm100_pybind.cu ├── group_gemm.cu ├── group_gemm_bf16_bf16_sm90.cu ├── group_gemm_e4m3_bf16_sm90.cu ├── group_gemm_e4m3_f16_sm90.cu ├── group_gemm_e5m2_bf16_sm90.cu ├── group_gemm_e5m2_f16_sm90.cu ├── group_gemm_f16_f16_sm90.cu ├── group_gemm_groupwise_sm100.cu ├── group_gemm_sm100_pybind.cu ├── group_gemm_sm90.cu ├── logging.cc ├── norm.cu ├── nv_internal │ ├── cpp │ │ ├── common │ │ │ ├── envUtils.cpp │ │ │ ├── logger.cpp │ │ │ ├── memoryUtils.cu │ │ │ ├── stringUtils.cpp │ │ │ └── tllmException.cpp │ │ └── kernels │ │ │ └── quantization.cu │ ├── include │ │ └── tensorrt_llm │ │ │ └── common │ │ │ ├── assert.h │ │ │ ├── cudaBf16Wrapper.h │ │ │ ├── cudaFp8Utils.h │ │ │ ├── cudaUtils.h │ │ │ ├── dataType.h │ │ │ ├── logger.h │ │ │ ├── quantization.h │ │ │ ├── stringUtils.h │ │ │ └── tllmException.h │ └── tensorrt_llm │ │ ├── common │ │ ├── cublasMMWrapper.h │ │ ├── cudaBf16Fallbacks.cuh │ │ ├── cudaDriverWrapper.h │ │ ├── cudaTypeUtils.cuh │ │ ├── envUtils.h │ │ ├── memoryUtils.h │ │ ├── quantTypeUtils.cuh │ │ ├── reduceKernelUtils.cuh │ │ └── workspace.h │ │ ├── cutlass_extensions │ │ └── include │ │ │ └── cutlass_extensions │ │ │ ├── arch │ │ │ ├── copy_red_global.hpp │ │ │ ├── copy_sm90_multimem.hpp │ │ │ ├── copy_traits_sm90_multimem.hpp │ │ │ ├── grid_dependency_control.h │ │ │ └── mma.h │ │ │ ├── communication │ │ │ └── collective │ │ │ │ └── sm90_allreduce_nvls_warpspecialized.hpp │ │ │ ├── compute_occupancy.h │ │ │ ├── epilogue │ │ │ ├── collective │ │ │ │ └── epilogue_moe_finalize.hpp │ │ │ ├── fusion │ │ │ │ └── sm90_visitor_allreduce_tma_warpspecialized.hpp │ │ │ └── thread │ │ │ │ └── fused_activations.h │ │ │ ├── epilogue_helpers.h │ │ │ ├── gemm │ │ │ ├── kernel │ │ │ │ ├── default_fpA_intB_traits.h │ │ │ │ ├── fused_moe_kernel.cuh │ │ │ │ ├── fused_moe_kernel_routine.cuh │ │ │ │ ├── fused_moe_kernel_traits.cuh │ │ │ │ ├── gemm_moe_problem_visitor.h │ │ │ │ ├── gemm_universal_allreduce.hpp │ │ │ │ ├── mixed_gemm_B_layout.h │ │ │ │ ├── moe_cute_util.cuh │ │ │ │ ├── moe_cutlass_kernel.h │ │ │ │ ├── moe_problem_visitor.h │ │ │ │ ├── sm90_gemm_allreduce_tma_warpspecialized.hpp │ │ │ │ └── sm90_gemm_allreduce_tma_warpspecialized_pingpong.hpp │ │ │ ├── threadblock │ │ │ │ ├── default_dq_mma.h │ │ │ │ ├── default_dq_mma_multistage.h │ │ │ │ ├── default_dq_mma_pipelined.h │ │ │ │ ├── default_mma.h │ │ │ │ ├── default_mma_bf16.h │ │ │ │ ├── dq_mma_base.h │ │ │ │ ├── dq_mma_multistage.h │ │ │ │ ├── dq_mma_multistage_finegrained.h │ │ │ │ ├── dq_mma_multistage_percol.h │ │ │ │ ├── dq_mma_pipelined.h │ │ │ │ ├── dq_mma_pipelined_finegrained.h │ │ │ │ └── dq_mma_pipelined_percol.h │ │ │ └── warp │ │ │ │ ├── default_mma_tensor_op.h │ │ │ │ ├── mma_tensorop_compute_B_with_f16.h │ │ │ │ └── mma_tensorop_dequantizer.h │ │ │ ├── gemm_configs.h │ │ │ ├── interleaved_numeric_conversion.h │ │ │ ├── system_barrier.h │ │ │ ├── tile_interleaved_layout.h │ │ │ ├── transform │ │ │ └── threadblock │ │ │ │ └── fine_grained_scale_zero_iterator.h │ │ │ ├── util │ │ │ └── gather_tensor.hpp │ │ │ └── weight_only_quant_op.h │ │ ├── cutlass_instantiations │ │ └── gemm_grouped │ │ │ ├── cutlass_kernel_file_1.generated.cu │ │ │ ├── cutlass_kernel_file_2.generated.cu │ │ │ ├── cutlass_kernel_file_3.generated.cu │ │ │ ├── cutlass_kernel_file_4.generated.cu │ │ │ ├── cutlass_kernel_file_5.generated.cu │ │ │ ├── cutlass_kernel_file_6.generated.cu │ │ │ └── cutlass_kernel_file_7.generated.cu │ │ ├── kernels │ │ ├── cutlass_kernels │ │ │ ├── cutlass_heuristic.cpp │ │ │ ├── cutlass_heuristic.h │ │ │ ├── cutlass_type_conversion.h │ │ │ └── fp8_blockscale_gemm │ │ │ │ └── fp8_blockscale_gemm.h │ │ ├── delayStream.cu │ │ ├── delayStream.h │ │ ├── internal_cutlass_kernels │ │ │ ├── include │ │ │ │ ├── moe_gemm_kernels.h │ │ │ │ └── moe_kernels.h │ │ │ └── src │ │ │ │ └── moe_gemm │ │ │ │ ├── launchers │ │ │ │ ├── fused_moe_gemm_launcher_sm80.h │ │ │ │ ├── fused_moe_gemm_launcher_sm80.inl │ │ │ │ ├── moe_gemm_tma_ws_launcher.h │ │ │ │ ├── moe_gemm_tma_ws_launcher.inl │ │ │ │ ├── moe_gemm_tma_ws_mixed_input_launcher.h │ │ │ │ └── moe_gemm_tma_ws_mixed_input_launcher.inl │ │ │ │ ├── moe_gemm_kernels_bf16_bf16.cu │ │ │ │ ├── moe_gemm_kernels_bf16_fp8.cu │ │ │ │ ├── moe_gemm_kernels_bf16_uint4.cu │ │ │ │ ├── moe_gemm_kernels_bf16_uint8.cu │ │ │ │ ├── moe_gemm_kernels_fp16_fp16.cu │ │ │ │ ├── moe_gemm_kernels_fp16_uint4.cu │ │ │ │ ├── moe_gemm_kernels_fp16_uint8.cu │ │ │ │ ├── moe_gemm_kernels_fp32_fp32.cu │ │ │ │ ├── moe_gemm_kernels_fp4_fp4.cu │ │ │ │ ├── moe_gemm_kernels_fp8_fp8.cu │ │ │ │ ├── moe_gemm_kernels_fp8_uint4.cu │ │ │ │ ├── moe_gemm_template_dispatch.h │ │ │ │ ├── moe_gemm_template_dispatch_tma_ws.h │ │ │ │ ├── moe_gemm_template_dispatch_tma_ws_mixed_dtype.h │ │ │ │ ├── moe_gemm_tma_warp_specialized_input.cu │ │ │ │ └── moe_tma_warp_specialized_traits.h │ │ ├── lora │ │ │ ├── lora.cpp │ │ │ └── lora.h │ │ ├── preQuantScaleKernel.cu │ │ ├── preQuantScaleKernel.h │ │ ├── quantization.cuh │ │ └── quantization.h │ │ ├── runtime │ │ └── torchUtils.h │ │ └── thop │ │ ├── fp4Quantize.cpp │ │ ├── fp4Quantize.h │ │ └── thUtils.h ├── page.cu ├── pod.cu ├── pod_config.inc ├── pod_customize_config.jinja ├── pod_jit_pybind.cu ├── pod_kernel_inst.jinja ├── pytorch_conversion_utils.h ├── pytorch_extension_utils.h ├── quantization.cu ├── renorm.cu ├── rope.cu ├── runtime_utils.h ├── sampling.cu ├── single_decode.cu ├── single_decode_config.inc ├── single_decode_customize_config.jinja ├── single_decode_jit_pybind.cu ├── single_decode_kernel_inst.jinja ├── single_prefill.cu ├── single_prefill_config.inc ├── single_prefill_customize_config.jinja ├── single_prefill_fp8_sm90.cu ├── single_prefill_fp8_sm90_kernel_inst.jinja ├── single_prefill_jit_pybind.cu ├── single_prefill_kernel_inst.jinja ├── single_prefill_sm90.cu ├── single_prefill_sm90_config.inc ├── single_prefill_sm90_customize_config.jinja ├── single_prefill_sm90_jit_pybind.cu ├── single_prefill_sm90_kernel_inst.jinja ├── trtllm_allreduce.cu ├── trtllm_fmha_kernel_launcher.cu └── trtllm_fmha_runner.cu ├── custom_backend.py ├── docker ├── Dockerfile.ci_gpu ├── bash.sh └── install │ ├── install_python.sh │ └── install_python_packages.sh ├── docs ├── .gitignore ├── Makefile ├── _static │ ├── FlashInfer-black-background.png │ └── FlashInfer-white-background.png ├── api │ ├── activation.rst │ ├── cascade.rst │ ├── decode.rst │ ├── gemm.rst │ ├── logits_processor.rst │ ├── mla.rst │ ├── norm.rst │ ├── page.rst │ ├── prefill.rst │ ├── quantization.rst │ ├── rope.rst │ ├── sampling.rst │ └── sparse.rst ├── conf.py ├── index.rst ├── installation.rst ├── make.bat ├── requirements.txt └── tutorials │ ├── kv_layout.rst │ └── recursive_attention.rst ├── flashinfer ├── __init__.py ├── activation.py ├── aot.py ├── autotuner.py ├── cascade.py ├── comm.py ├── decode.py ├── fp4_quantization.py ├── fused_moe.py ├── gemm.py ├── jit │ ├── __init__.py │ ├── activation.py │ ├── attention │ │ ├── __init__.py │ │ ├── pytorch.py │ │ ├── tvm.py │ │ └── utils.py │ ├── core.py │ ├── cpp_ext.py │ ├── cubin_loader.py │ ├── env.py │ └── utils.py ├── logits_processor │ ├── __init__.py │ ├── compiler.py │ ├── fusion_rules.py │ ├── legalization.py │ ├── op.py │ ├── operators.py │ ├── pipeline.py │ ├── processors.py │ ├── types.py │ └── validators.py ├── mla.py ├── norm.py ├── page.py ├── pod.py ├── prefill.py ├── profiler │ └── __init__.py ├── py.typed ├── quantization.py ├── rope.py ├── sampling.py ├── sparse.py ├── triton │ ├── __init__.py │ ├── activation.py │ ├── cascade.py │ ├── gemm.py │ ├── kernels │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── cascade.py │ │ ├── norm.py │ │ ├── quant.py │ │ └── sm_constraint_gemm.py │ ├── norm.py │ ├── page.py │ ├── sm_constraint_gemm.py │ └── utils.py └── utils.py ├── format.sh ├── include └── flashinfer │ ├── activation.cuh │ ├── allocator.h │ ├── attention │ ├── blackwell │ │ ├── collective │ │ │ ├── fmha_common.hpp │ │ │ ├── fmha_fusion.hpp │ │ │ ├── sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp │ │ │ ├── sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp │ │ │ ├── sm100_fmha_gen_epilogue_warpspecialized.hpp │ │ │ ├── sm100_fmha_gen_mainloop_warpspecialized.hpp │ │ │ ├── sm100_fmha_load_cpasync_warpspecialized.hpp │ │ │ └── sm100_fmha_load_tma_warpspecialized.hpp │ │ ├── common │ │ │ └── pow_2.hpp │ │ ├── device │ │ │ ├── fmha.hpp │ │ │ └── sm100_mla.hpp │ │ ├── fmha_cutlass_sm100.cuh │ │ ├── kernel │ │ │ ├── fmha_options.hpp │ │ │ ├── fmha_tile_scheduler.hpp │ │ │ ├── gather_tensor.hpp │ │ │ ├── sm100_fmha_fwd_kernel_tma_warpspecialized.hpp │ │ │ ├── sm100_fmha_gen_kernel_warpspecialized.hpp │ │ │ ├── sm100_fmha_mla_reduction.hpp │ │ │ ├── sm100_fmha_mla_tma_warpspecialized.hpp │ │ │ └── sm100_mla_tile_scheduler.hpp │ │ └── plan.cuh │ ├── cascade.cuh │ ├── cutlass_mla.cuh │ ├── decode.cuh │ ├── decode_mla_cute_sm80.cuh │ ├── default_decode_params.cuh │ ├── default_prefill_params.cuh │ ├── heap.h │ ├── hopper.cuh │ ├── hopper │ │ ├── attention_updater.cuh │ │ ├── block_sparse_gather.cuh │ │ ├── default_params.cuh │ │ ├── epilogue.cuh │ │ ├── kernel_traits.cuh │ │ ├── mainloop.cuh │ │ ├── mainloop_mma.cuh │ │ ├── named_barrier.cuh │ │ ├── prefill_sm90.cuh │ │ ├── quantization │ │ │ ├── epilogue.cuh │ │ │ ├── kernel_traits.cuh │ │ │ ├── mainloop_load.cuh │ │ │ ├── mainloop_mma.cuh │ │ │ ├── mainloop_sparse_load.cuh │ │ │ └── prefill_sm90.cuh │ │ ├── sparse_mainloop.cuh │ │ ├── tile_scheduler.cuh │ │ ├── utils.cuh │ │ ├── variant_helper.cuh │ │ └── variants.cuh │ ├── mask.cuh │ ├── mla.cuh │ ├── mla_hopper.cuh │ ├── mla_params.cuh │ ├── pod.cuh │ ├── prefill.cuh │ ├── scheduler.cuh │ ├── state.cuh │ ├── variant_helper.cuh │ └── variants.cuh │ ├── attention_impl.cuh │ ├── comm │ ├── custom_all_reduce.cuh │ └── trtllm_allreduce.cuh │ ├── cp_async.cuh │ ├── cubin_loader.h │ ├── cutlass_utils.cuh │ ├── exception.h │ ├── fastdiv.cuh │ ├── fp16.h │ ├── frag_layout_swizzle.cuh │ ├── gemm │ ├── bmm_fp8.cuh │ ├── gemm_groupwise_sm100.cuh │ ├── group_gemm.cuh │ ├── group_gemm_groupwise_sm100.cuh │ ├── group_gemm_lora.cuh │ ├── group_gemm_sm90.cuh │ └── group_gemv.cuh │ ├── layout.cuh │ ├── logging.h │ ├── math.cuh │ ├── mma.cuh │ ├── norm.cuh │ ├── page.cuh │ ├── permuted_smem.cuh │ ├── pos_enc.cuh │ ├── profiler.cuh │ ├── quantization.cuh │ ├── sampling.cuh │ ├── semaphore_utils.cuh │ ├── trtllm │ ├── common.h │ └── fmha │ │ ├── cubin │ │ └── kernelMetaInfo.h │ │ ├── decoder_impl_common.h │ │ ├── decoder_params.h │ │ ├── fmhaKernels.cuh │ │ ├── fmhaRunner.cuh │ │ ├── fmhaRunnerParams.h │ │ ├── gen_kernel_launcher.cuh │ │ └── kernelParams.h │ ├── utils.cuh │ └── vec_dtypes.cuh ├── licenses ├── LICENSE.cutlass.txt └── LICENSE.flashattention3.txt ├── profiler ├── .gitignore ├── README.md └── mla.py ├── pyproject.toml ├── scripts ├── ci-flashinfer.env.example ├── ci-flashinfer.service ├── formatter.sh ├── run-ci-build-wheel.sh ├── task_cpplint.sh ├── task_jit_run_tests_part1.sh ├── task_jit_run_tests_part2.sh ├── task_jit_run_tests_part3.sh ├── task_jit_run_tests_part4.sh ├── task_lint.sh ├── task_mypy.sh ├── task_pylint.sh ├── task_show_node_info.sh ├── task_test_aot_build_import.sh └── update_whl_index.py ├── setup.py ├── src ├── bench_batch_decode.cu ├── bench_batch_decode_mla.cu ├── bench_batch_prefill.cu ├── bench_cascade.cu ├── bench_norm.cu ├── bench_sampling.cu ├── bench_single_decode.cu ├── bench_single_prefill.cu ├── cpu_reference.h ├── flashinfer_ops.cuh ├── test_batch_decode.cu ├── test_batch_prefill.cu ├── test_cascade.cu ├── test_fast_dequant.cu ├── test_fastdiv.cu ├── test_norm.cu ├── test_page.cu ├── test_sampling.cu ├── test_single_decode.cu ├── test_single_prefill.cu └── utils.h ├── tests ├── alibi_reference.py ├── conftest.py ├── jit_utils.py ├── rope_reference.py ├── test_activation.py ├── test_alibi.py ├── test_batch_decode_kernels.py ├── test_batch_prefill_kernels.py ├── test_blackwell_fmha.py ├── test_block_sparse.py ├── test_block_sparse_indices_to_vector_sparse_offsets.py ├── test_bmm_fp8.py ├── test_custom_allreduce.py ├── test_decode_fp8_calibration_scale.py ├── test_decode_prefill_lse.py ├── test_deepseek_mla.py ├── test_fp4_quantize.py ├── test_fp8_prefill.py ├── test_group_gemm.py ├── test_groupwise_scaled_gemm_fp8.py ├── test_hopper.py ├── test_hopper_fp8_attention.py ├── test_jit_example.py ├── test_jit_warmup.py ├── test_logits_cap.py ├── test_logits_processor.py ├── test_mla_decode_kernel.py ├── test_mla_page.py ├── test_non_contiguous_decode.py ├── test_non_contiguous_prefill.py ├── test_norm.py ├── test_page.py ├── test_pod_kernels.py ├── test_quantization.py ├── test_rope.py ├── test_sampling.py ├── test_shared_prefix_kernels.py ├── test_sliding_window.py ├── test_sm_constraint_gemm.py ├── test_tensor_cores_decode.py ├── test_triton_cascade.py ├── test_trtllm_allreduce.py ├── test_trtllm_cutlass_fused_moe.py └── test_trtllm_gen_decode.py ├── tvm_binding ├── batch_decode.cu ├── batch_decode_customize_config.jinja ├── batch_decode_jit_tvm_binding.cu ├── batch_mla_config.jinja ├── batch_mla_jit_tvm_binding.cu ├── batch_mla_plan.cu ├── batch_mla_run.cu ├── batch_prefill.cu ├── batch_prefill_customize_config.jinja ├── batch_prefill_jit_tvm_binding.cu ├── batch_prefill_sm90.cu ├── batch_prefill_sm90_customize_config.jinja ├── batch_prefill_sm90_jit_tvm_binding.cu ├── sampling.cu ├── sampling_jit_tvm_binding.cu └── tvm_binding_utils.h └── version.txt /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: Google 3 | DerivePointerAlignment: false 4 | ColumnLimit: 100 5 | PointerAlignment: Left 6 | # InsertNewlineAtEOF: true 7 | ... 8 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "CUDA Development Container", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "context": "." 6 | }, 7 | "runArgs": [ 8 | "--gpus=all" 9 | ], 10 | "customizations": { 11 | "vscode": { 12 | "extensions": [ 13 | "llvm-vs-code-extensions.vscode-clangd", 14 | "ms-python.python", 15 | "ms-python.black-formatter", 16 | "nvidia.nsight-vscode-edition" 17 | ] 18 | } 19 | }, 20 | "mounts": [ 21 | "type=bind,source=${localEnv:HOME}/.ssh,target=/home/devuser/.ssh,readonly" 22 | ], 23 | "remoteUser": "devuser" 24 | } 25 | -------------------------------------------------------------------------------- /.devcontainer/install/install_python.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Licensed to the Apache Software Foundation (ASF) under one 3 | # or more contributor license agreements. See the NOTICE file 4 | # distributed with this work for additional information 5 | # regarding copyright ownership. The ASF licenses this file 6 | # to you under the Apache License, Version 2.0 (the 7 | # "License"); you may not use this file except in compliance 8 | # with the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, 13 | # software distributed under the License is distributed on an 14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | # KIND, either express or implied. See the License for the 16 | # specific language governing permissions and limitations 17 | # under the License. 18 | 19 | set -e 20 | set -u 21 | set -o pipefail 22 | 23 | 24 | # Install python and pip. Don't modify this to add Python package dependencies, 25 | wget -O Miniforge3.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" 26 | bash Miniforge3.sh -b -p /home/devuser/conda 27 | 28 | /home/devuser/conda/bin/conda create -n $1 python=3.12 29 | -------------------------------------------------------------------------------- /.devcontainer/install/install_python_packages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Licensed to the Apache Software Foundation (ASF) under one 3 | # or more contributor license agreements. See the NOTICE file 4 | # distributed with this work for additional information 5 | # regarding copyright ownership. The ASF licenses this file 6 | # to you under the Apache License, Version 2.0 (the 7 | # "License"); you may not use this file except in compliance 8 | # with the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, 13 | # software distributed under the License is distributed on an 14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | # KIND, either express or implied. See the License for the 16 | # specific language governing permissions and limitations 17 | # under the License. 18 | 19 | set -e 20 | set -u 21 | 22 | pip3 install ninja pytest numpy scipy build cuda-python pytest 23 | pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 24 | pip3 install pre-commit 25 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## 📌 Description 4 | 5 | 6 | 7 | ## 🔍 Related Issues 8 | 9 | 10 | 11 | ## 🚀 Pull Request Checklist 12 | 13 | Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. 14 | 15 | ### ✅ Pre-commit Checks 16 | 17 | - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). 18 | - [ ] I have installed the hooks with `pre-commit install`. 19 | - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. 20 | 21 | > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). 22 | 23 | ## 🧪 Tests 24 | 25 | - [ ] Tests have been added or updated as needed. 26 | - [ ] All tests are passing (`unittest`, etc.). 27 | 28 | ## Reviewer Notes 29 | 30 | 31 | -------------------------------------------------------------------------------- /.github/workflows/build-doc.yml: -------------------------------------------------------------------------------- 1 | name: Build FlashInfer Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 15 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 16 | concurrency: 17 | group: "pages" 18 | cancel-in-progress: false 19 | 20 | jobs: 21 | test_linux: 22 | name: Deploy Docs 23 | runs-on: ubuntu-latest 24 | 25 | steps: 26 | - uses: actions/checkout@v2 27 | with: 28 | submodules: recursive 29 | 30 | - name: Configuring build Environment 31 | run: | 32 | sudo apt-get update 33 | python -m pip install -U pip wheel 34 | 35 | - name: Installing dependencies 36 | run: | 37 | python -m pip install -r docs/requirements.txt 38 | 39 | - name: Bulid Documentation 40 | if: github.ref == 'refs/heads/main' 41 | run: | 42 | cd docs 43 | make html 44 | 45 | - name: Upload artifact 46 | uses: actions/upload-pages-artifact@v3 47 | with: 48 | # Upload entire repository 49 | path: 'docs/_build/html' 50 | 51 | - name: Deploy to GitHub Pages 52 | id: deployment 53 | uses: actions/deploy-pages@v4 54 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | permissions: read-all 9 | 10 | jobs: 11 | pre-commit: 12 | runs-on: ubuntu-latest 13 | timeout-minutes: 30 14 | steps: 15 | - uses: actions/checkout@v4.2.2 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.11' 19 | - uses: pre-commit/action@v3.0.1 20 | -------------------------------------------------------------------------------- /.github/workflows/release-ci-docker.yml: -------------------------------------------------------------------------------- 1 | name: Release CI Docker 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: Login to Docker Hub 12 | uses: docker/login-action@v3 13 | with: 14 | username: flashinfer 15 | password: ${{ secrets.DOCKERHUB_TOKEN }} 16 | 17 | - uses: docker/build-push-action@v4 18 | with: 19 | context: docker 20 | file: docker/Dockerfile.ci_gpu 21 | push: true 22 | tags: flashinfer/flashinfer-ci:latest 23 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/nvbench"] 2 | path = 3rdparty/nvbench 3 | url = https://github.com/NVIDIA/nvbench.git 4 | [submodule "3rdparty/googletest"] 5 | path = 3rdparty/googletest 6 | url = https://github.com/google/googletest.git 7 | [submodule "3rdparty/cutlass"] 8 | path = 3rdparty/cutlass 9 | url = https://github.com/NVIDIA/cutlass.git 10 | [submodule "3rdparty/composable_kernels"] 11 | path = 3rdparty/composable_kernels 12 | url = https://github.com/ROCm/composable_kernel.git 13 | [submodule "3rdparty/spdlog"] 14 | path = 3rdparty/spdlog 15 | url = https://github.com/gabime/spdlog.git 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # To use: 2 | # 3 | # pre-commit run -a 4 | # 5 | # Or: 6 | # 7 | # pre-commit install # (runs every time you commit in git) 8 | # 9 | # To update this file: 10 | # 11 | # pre-commit autoupdate 12 | # 13 | # See https://github.com/pre-commit/pre-commit 14 | # Note the pre-commit hooks shoule only be used for formatting, but not for linting. 15 | # For linting consider using CI. 16 | repos: 17 | # Standard hooks 18 | - repo: https://github.com/pre-commit/pre-commit-hooks 19 | rev: v5.0.0 20 | hooks: 21 | - id: check-added-large-files 22 | - id: check-case-conflict 23 | - id: check-merge-conflict 24 | - id: check-symlinks 25 | - id: end-of-file-fixer 26 | - id: mixed-line-ending 27 | - id: requirements-txt-fixer 28 | - id: trailing-whitespace 29 | 30 | # Changes tabs to spaces 31 | - repo: https://github.com/Lucas-C/pre-commit-hooks 32 | rev: v1.5.5 33 | hooks: 34 | - id: remove-tabs 35 | - id: remove-crlf 36 | 37 | # Formatters 38 | - repo: https://github.com/psf/black-pre-commit-mirror 39 | rev: 24.8.0 40 | hooks: 41 | - id: black 42 | 43 | - repo: https://github.com/pycqa/isort 44 | rev: 5.13.2 45 | hooks: 46 | - id: isort 47 | args: ["--profile=black"] # <-- this one 48 | 49 | - repo: https://github.com/pre-commit/mirrors-clang-format 50 | rev: v19.1.1 51 | hooks: 52 | - id: clang-format 53 | types_or: [c++, c, cuda] 54 | exclude: | 55 | (?x)^(3rdparty/.* src/generated/.* flashinfer/jit/aot_config.py)$ 56 | 57 | - repo: https://github.com/cheshirekow/cmake-format-precommit 58 | rev: v0.6.13 59 | hooks: 60 | - id: cmake-format 61 | additional_dependencies: [pyyaml>=5.1] 62 | -------------------------------------------------------------------------------- /aot_build_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flashinfer-ai/flashinfer/4e8bb778f1522de6becbcf4b732a22d48f7d72b0/aot_build_utils/__init__.py -------------------------------------------------------------------------------- /aot_build_utils/literal_map.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by FlashInfer team. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | mask_mode_literal = { 18 | 0: "MaskMode::kNone", 19 | 1: "MaskMode::kCausal", 20 | 2: "MaskMode::kCustom", 21 | 3: "MaskMode::kMultiItemScoring", 22 | } 23 | 24 | pos_encoding_mode_literal = { 25 | 0: "PosEncodingMode::kNone", 26 | 1: "PosEncodingMode::kRoPELlama", 27 | 2: "PosEncodingMode::kALiBi", 28 | } 29 | 30 | dtype_literal = { 31 | "f16": "half", 32 | "bf16": "nv_bfloat16", 33 | "f32": "float", 34 | "e4m3": "__nv_fp8_e4m3", 35 | "e5m2": "__nv_fp8_e5m2", 36 | } 37 | 38 | idtype_literal = { 39 | "i32": "int32_t", 40 | "u32": "uint32_t", 41 | "i64": "int64_t", 42 | "u64": "uint64_t", 43 | } 44 | 45 | bool_literal = { 46 | 0: "false", 47 | 1: "true", 48 | } 49 | -------------------------------------------------------------------------------- /benchmarks/bench_hopper_fp8_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | 4 | import flashinfer 5 | 6 | 7 | def bench_single_prefill(seq_len, num_heads, causal, head_dim): 8 | num_qo_heads = num_kv_heads = num_heads 9 | q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") 10 | k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") 11 | v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") 12 | 13 | sm80_ms, sm90_ms = ( 14 | triton.testing.do_bench( 15 | lambda: flashinfer.single_prefill_with_kv_cache_return_lse( 16 | q, k, v, causal=causal, backend=backend 17 | ), 18 | warmup=100, 19 | rep=1000, 20 | ) 21 | for backend in ["fa2", "fa3"] 22 | ) 23 | 24 | q = torch.randn( 25 | seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" 26 | ).to(dtype=torch.float8_e4m3fn) 27 | k = torch.randn( 28 | seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" 29 | ).to(dtype=torch.float8_e4m3fn) 30 | v = torch.randn( 31 | seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" 32 | ).to(dtype=torch.float8_e4m3fn) 33 | 34 | fp8_sm90_ms = triton.testing.do_bench( 35 | lambda: flashinfer.single_prefill_with_kv_cache_return_lse( 36 | q, k, v, causal=causal, backend="fa3", o_dtype=torch.half 37 | ), 38 | warmup=100, 39 | rep=1000, 40 | ) 41 | 42 | def flops(ms): 43 | if causal: 44 | return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 45 | else: 46 | return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 47 | 48 | print( 49 | f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s" 50 | ) 51 | 52 | 53 | if __name__ == "__main__": 54 | for seq_len in [4096, 8192, 16384]: 55 | for num_heads in [24, 32]: 56 | for causal in [True, False]: 57 | for head_dim in [64, 128, 256]: 58 | bench_single_prefill(seq_len, num_heads, causal, head_dim) 59 | -------------------------------------------------------------------------------- /benchmarks/bench_pad_ragged_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from triton.testing import do_bench 3 | 4 | from flashinfer.triton import pad_ragged_tensor_to_multiple_of 5 | 6 | 7 | def bench_pad_ragged_tensor_to_multiple_of(batch_size, qkv_len, d, multiple_of): 8 | device = torch.device("cuda:0") 9 | torch.manual_seed(42) 10 | 11 | indptr = torch.arange(0, (batch_size + 1) * qkv_len, qkv_len, device=device) 12 | ragged_tensor = torch.randn((indptr[-1], d), device=device) 13 | 14 | ms = do_bench( 15 | lambda: pad_ragged_tensor_to_multiple_of(ragged_tensor, indptr, multiple_of) 16 | ) 17 | mem_bandwidth_gb_s = ( 18 | 2 * ragged_tensor.numel() * ragged_tensor.element_size() / ms * 1e-6 19 | ) 20 | 21 | print( 22 | f"batch_size={batch_size}, qkv_len={qkv_len}, d={d}, multiple_of={multiple_of}, ms={ms}, mem_bandwidth={mem_bandwidth_gb_s} GB/s" 23 | ) 24 | 25 | 26 | if __name__ == "__main__": 27 | for batch_size in [11, 47, 101]: 28 | for qkv_len in [500, 1017, 8011]: 29 | for d in [2048, 4096, 16384]: 30 | for multiple_of in [128]: 31 | bench_pad_ragged_tensor_to_multiple_of( 32 | batch_size, qkv_len, d, multiple_of 33 | ) 34 | -------------------------------------------------------------------------------- /ci/scripts/jenkins/retry.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | set -eux 21 | 22 | retry() { 23 | local max_retries=$1 24 | shift 25 | local n=0 26 | until [ "$n" -ge "$max_retries" ] 27 | do 28 | "$@" && break 29 | n=$((n+1)) 30 | if [ "$n" -eq "$max_retries" ]; then 31 | echo "failed to update after attempt $n / $max_retries, giving up" 32 | exit 1 33 | fi 34 | 35 | WAIT=$(python3 -c 'import random; print(random.randint(30, 200))') 36 | echo "failed to update $n / $max_retries, waiting $WAIT to try again" 37 | sleep "$WAIT" 38 | done 39 | } 40 | -------------------------------------------------------------------------------- /cmake/config.cmake: -------------------------------------------------------------------------------- 1 | # Whether to compile fp8 kernels or not. 2 | set(FLASHINFER_ENABLE_FP8_E4M3 ON) 3 | set(FLASHINFER_ENABLE_FP8_E5M2 ON) 4 | # Whether to compile bf16 kernels or not. 5 | set(FLASHINFER_ENABLE_BF16 ON) 6 | # Whether to compile prefill kernel tests/benchmarks or not. 7 | set(FLASHINFER_PREFILL ON) 8 | # Whether to compile decode kernel tests/benchmarks or not. 9 | set(FLASHINFER_DECODE ON) 10 | # Whether to compile page kernel tests/benchmarks or not. 11 | set(FLASHINFER_PAGE ON) 12 | # Whether to compile cascade kernel tests/benchmarks or not. 13 | set(FLASHINFER_CASCADE ON) 14 | # Whether to compile sampling kernel tests/benchmarks or not. 15 | set(FLASHINFER_SAMPLING ON) 16 | # Whether to compile normalization kernel tests/benchmarks or not. 17 | set(FLASHINFER_NORM ON) 18 | # Whether to compile fastdiv tests 19 | set(FLASHINFER_FASTDIV_TEST ON) 20 | # Whether to compile fastdequant tests 21 | set(FLASHINFER_FASTDEQUANT_TEST ON) 22 | # The following configurations can impact the binary size of the generated 23 | # library 24 | set(FLASHINFER_GEN_HEAD_DIMS 64 128 256 512) 25 | set(FLASHINFER_GEN_KV_LAYOUTS 0 1) 26 | set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2) 27 | set(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS "false" "true") 28 | set(FLASHINFER_GEN_MASK_MODES 0 1 2) 29 | 30 | # Set target cuda architectures for tests/benchmarks, defaults to native. 31 | # "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the 32 | # architectures of the host's GPU. it's new in CMake 3.24, if you are using an 33 | # older of CMake or you want to use a different value, you can set its value 34 | # here. Supported CUDA architectures include 80;86;89;90 35 | # NOTE(Zihao): using "native" might be slow because whenever compile a cuda file 36 | # with `-arch=native`, nvcc will spawn a `__nvcc_device_query` process to get 37 | # the architecture of the host's GPU, which could stall the compilation process. 38 | # So it's recommended to set it to a specific value if you know the architecture 39 | # of the target GPU. Example: set(FLASHINFER_CUDA_ARCHITECTURES 80) 40 | set(FLASHINFER_CUDA_ARCHITECTURES native) 41 | -------------------------------------------------------------------------------- /cmake/utils/Utils.cmake: -------------------------------------------------------------------------------- 1 | macro(__flashinfer_option variable description value) 2 | if(NOT DEFINED ${variable}) 3 | set(${variable} 4 | ${value} 5 | CACHE STRING ${description}) 6 | endif() 7 | endmacro() 8 | 9 | macro(flashinfer_list_option variable description value) 10 | __flashinfer_option(${variable} "${description}" "${value}") 11 | endmacro() 12 | 13 | set(FLASHINFER_ALL_OPTIONS) 14 | 15 | # ############################################################################## 16 | # An option that the user can select. Can accept condition to control when 17 | # option is available for user. Usage: tvm_option( "doc string" 18 | # [IF ]) The macro snippet is 19 | # copied from Apache TVM codebase. 20 | macro(flashinfer_option variable description value) 21 | set(__value ${value}) 22 | set(__condition "") 23 | set(__varname "__value") 24 | list(APPEND FLASHINFER_ALL_OPTIONS ${variable}) 25 | foreach(arg ${ARGN}) 26 | if(arg STREQUAL "IF" OR arg STREQUAL "if") 27 | set(__varname "__condition") 28 | else() 29 | list(APPEND ${__varname} ${arg}) 30 | endif() 31 | endforeach() 32 | unset(__varname) 33 | if("${__condition}" STREQUAL "") 34 | set(__condition 2 GREATER 1) 35 | endif() 36 | 37 | if(${__condition}) 38 | if("${__value}" MATCHES ";") 39 | # list values directly pass through 40 | __flashinfer_option(${variable} "${description}" "${__value}") 41 | elseif(DEFINED ${__value}) 42 | if(${__value}) 43 | __flashinfer_option(${variable} "${description}" ON) 44 | else() 45 | __flashinfer_option(${variable} "${description}" OFF) 46 | endif() 47 | else() 48 | __flashinfer_option(${variable} "${description}" "${__value}") 49 | endif() 50 | else() 51 | unset(${variable} CACHE) 52 | endif() 53 | endmacro() 54 | -------------------------------------------------------------------------------- /csrc/batch_decode_customize_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} 9 | #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} 10 | 11 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) { \ 12 | using AttentionVariant = {{ variant_name }}; \ 13 | __VA_ARGS__(); \ 14 | } 15 | 16 | using namespace flashinfer; 17 | 18 | using DTypeQ = {{ dtype_q }}; 19 | using DTypeKV = {{ dtype_kv }}; 20 | using DTypeO = {{ dtype_o }}; 21 | using IdType = {{ idtype }}; 22 | constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; 23 | constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; 24 | constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; 25 | constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; 26 | constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; 27 | 28 | struct Params { 29 | using DTypeQ = DTypeQ; 30 | using DTypeKV = DTypeKV; 31 | using DTypeO = DTypeO; 32 | using IdType = IdType; 33 | 34 | DTypeQ* q; 35 | paged_kv_t paged_kv; 36 | DTypeO* o; 37 | float* lse; 38 | 39 | {{ additional_params_decl }} 40 | 41 | uint32_t padded_batch_size; 42 | uint32_t num_qo_heads; 43 | IdType q_stride_n; 44 | IdType q_stride_h; 45 | int32_t window_left; 46 | 47 | IdType* request_indices; 48 | IdType* kv_tile_indices; 49 | IdType* o_indptr; 50 | IdType* kv_chunk_size_ptr; 51 | bool* block_valid_mask; 52 | bool partition_kv; 53 | 54 | __host__ __device__ __forceinline__ int32_t get_qo_len(int32_t batch_idx) const { return 1; } 55 | 56 | __host__ __device__ __forceinline__ int32_t get_kv_len(int32_t batch_idx) const { 57 | return paged_kv.get_length(batch_idx); 58 | } 59 | }; 60 | 61 | {{ variant_decl }} 62 | -------------------------------------------------------------------------------- /csrc/batch_decode_jit_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023-2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "batch_decode_config.inc" 17 | #include "pytorch_extension_utils.h" 18 | 19 | at::Tensor BatchDecodeWithPagedKVCachePlan( 20 | at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, 21 | at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size, 22 | int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, 23 | int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo, 24 | at::Tensor empty_q_data, at::Tensor empty_kv_data); 25 | 26 | void BatchDecodeWithPagedKVCacheRun(at::Tensor float_workspace_buffer, 27 | at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, 28 | at::Tensor q, at::Tensor paged_k_cache, 29 | at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, 30 | at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, 31 | at::Tensor o, std::optional maybe_lse, 32 | int64_t kv_layout_code, 33 | int64_t window_left ADDITIONAL_FUNC_PARAMS); 34 | 35 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 36 | // Batched decode with paged KV-Cache plan 37 | m.def("plan", BatchDecodeWithPagedKVCachePlan); 38 | // Batched decode with paged KV-Cache run 39 | m.def("run", BatchDecodeWithPagedKVCacheRun); 40 | } 41 | -------------------------------------------------------------------------------- /csrc/batch_decode_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "batch_decode_config.inc" 3 | 4 | using namespace flashinfer; 5 | 6 | namespace flashinfer { 7 | 8 | template cudaError_t 9 | BatchDecodeWithPagedKVCacheDispatched<{{ head_dim_qk }}, {{ pos_encoding_mode }}, {{ variant_name }}, Params>( 10 | Params params, {{ dtype_o }}* tmp_v, 11 | float* tmp_s, cudaStream_t stream); 12 | 13 | }; 14 | -------------------------------------------------------------------------------- /csrc/batch_decode_mla_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | using namespace flashinfer; 6 | 7 | using DTypeQ = {{ dtype_q }}; 8 | using DTypeKV = {{ dtype_kv }}; 9 | using DTypeO = {{ dtype_o }}; 10 | using IdType = {{ dtype_idx }}; 11 | 12 | constexpr bool USE_SLIDING_WINDOW = {{ use_sliding_window }}; 13 | constexpr bool USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; 14 | constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }}; 15 | constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }}; 16 | 17 | constexpr int QO_TILE_LEN = {{ qo_tile_len }}; 18 | 19 | using Params = BatchDecodeParamsMLA; 20 | using AttentionVariant = 21 | DefaultAttention; 22 | -------------------------------------------------------------------------------- /csrc/batch_decode_mla_plan.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "mla_config.inc" 6 | #include "pytorch_conversion_utils.h" 7 | #include "pytorch_extension_utils.h" 8 | 9 | using namespace flashinfer; 10 | 11 | at::Tensor BatchDecodeWithPagedKVCachePlanMLA(at::Tensor float_workspace_buffer, 12 | at::Tensor int_workspace_buffer, 13 | at::Tensor page_locked_int_workspace_buffer, 14 | at::Tensor indptr, int64_t batch_size, 15 | int64_t num_qo_heads, int64_t page_size, 16 | bool enable_cuda_graph, int64_t cuda_stream) { 17 | size_t float_workspace_size_in_bytes = 18 | float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); 19 | size_t int_workspace_size_in_bytes = 20 | int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); 21 | 22 | DecodePlanInfo plan_info; 23 | cudaStream_t stream = reinterpret_cast(cuda_stream); 24 | 25 | auto work_estimation_func = 26 | BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA; 28 | cudaError_t status = 29 | DecodePlan( 30 | static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, 31 | static_cast(int_workspace_buffer.data_ptr()), 32 | static_cast(page_locked_int_workspace_buffer.data_ptr()), 33 | int_workspace_size_in_bytes, plan_info, static_cast(indptr.data_ptr()), 34 | batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/stream, 35 | work_estimation_func); 36 | 37 | TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCachePlanMLA failed with error ", 38 | cudaGetErrorString(status)); 39 | 40 | return vec_to_tensor(plan_info.ToVector()); 41 | } 42 | -------------------------------------------------------------------------------- /csrc/batch_decode_mla_pybind.cu: -------------------------------------------------------------------------------- 1 | #include "mla_config.inc" 2 | #include "pytorch_extension_utils.h" 3 | 4 | at::Tensor BatchDecodeWithPagedKVCachePlanMLA(at::Tensor float_workspace_buffer, 5 | at::Tensor int_workspace_buffer, 6 | at::Tensor page_locked_int_workspace_buffer, 7 | at::Tensor indptr, int64_t batch_size, 8 | int64_t num_qo_heads, int64_t page_size, 9 | bool enable_cuda_graph, int64_t cuda_stream); 10 | 11 | void BatchDecodeWithPagedKVCacheRunMLA( 12 | at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, 13 | at::Tensor q_nope, at::Tensor q_pe, at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, 14 | at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, 15 | at::Tensor o, double sm_scale, int64_t window_left, double logits_soft_cap, double rope_scale, 16 | double rope_theta, std::optional maybe_lse, int64_t cuda_stream); 17 | 18 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 19 | m.def("plan", BatchDecodeWithPagedKVCachePlanMLA); 20 | m.def("run", BatchDecodeWithPagedKVCacheRunMLA); 21 | } 22 | -------------------------------------------------------------------------------- /csrc/batch_mla_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace flashinfer; 12 | 13 | #ifdef FLASHINFER_ENABLE_PROFILER 14 | #define ADDITIONAL_FUNC_PARAMS , at::Tensor profiler_buffer 15 | #define ADDITIONAL_PARAMS_SETTER \ 16 | params.profiler_buffer = static_cast(profiler_buffer.data_ptr()); 17 | #else 18 | #define ADDITIONAL_FUNC_PARAMS 19 | #define ADDITIONAL_PARAMS_SETTER 20 | #endif 21 | 22 | using DTypeQ = {{ dtype_q }}; 23 | using DTypeKV = {{ dtype_kv }}; 24 | using DTypeO = {{ dtype_o }}; 25 | using IdType = {{ dtype_idx }}; 26 | constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }}; 27 | constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }}; 28 | 29 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \ 30 | DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ 31 | using Params = MLAParams; \ 32 | __VA_ARGS__(); \ 33 | }) 34 | -------------------------------------------------------------------------------- /csrc/batch_mla_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "batch_mla_config.inc" 17 | #include "pytorch_extension_utils.h" 18 | 19 | at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, 20 | at::Tensor int_workspace_buffer, 21 | at::Tensor page_locked_int_workspace_buffer, 22 | at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len, 23 | int64_t num_heads, int64_t head_dim_o, bool causal); 24 | 25 | void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, 26 | at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, 27 | at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices, 28 | at::Tensor o, std::optional maybe_lse, 29 | int64_t mask_mode_code, int64_t num_heads, int64_t page_size, 30 | double sm_scale); 31 | 32 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 33 | m.def("plan", &BatchMLAPagedAttentionPlan); 34 | m.def("run", &BatchMLAPagedAttentionRun); 35 | } 36 | -------------------------------------------------------------------------------- /csrc/batch_mla_sm90_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "batch_mla_sm90_config.inc" 17 | #include "pytorch_extension_utils.h" 18 | 19 | at::Tensor BatchMLAPagedAttentionSM90Plan(at::Tensor float_workspace_buffer, 20 | at::Tensor int_workspace_buffer, 21 | at::Tensor page_locked_int_workspace_buffer, 22 | at::Tensor qo_indptr, at::Tensor kv_indptr, 23 | at::Tensor kv_len, int64_t num_heads, int64_t head_dim_o, 24 | bool causal); 25 | 26 | void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer, 27 | at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, 28 | at::Tensor q_nope, at::Tensor q_pe, at::Tensor ckv_cache, 29 | at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor o, 30 | std::optional maybe_lse, int64_t mask_mode_code, 31 | int64_t num_heads, int64_t page_size, 32 | double sm_scale ADDITIONAL_FUNC_PARAMS); 33 | 34 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 35 | m.def("plan", &BatchMLAPagedAttentionSM90Plan); 36 | m.def("run", &BatchMLAPagedAttentionSM90Run); 37 | } 38 | -------------------------------------------------------------------------------- /csrc/batch_prefill_fp8_paged_sm90_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "batch_prefill_sm90_config.inc" 3 | 4 | namespace flashinfer { 5 | 6 | {% for same_scheduler_for_all_heads in ["true", "false"] %} 7 | template cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched 8 | <{{ head_dim_qk }}, 9 | {{ mask_mode }}, 10 | /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, 11 | /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, 12 | {{ variant_name }}, PagedParams>(PagedParams& params, cudaStream_t stream); 13 | {% endfor %} 14 | 15 | }; // namespace flashinfer 16 | -------------------------------------------------------------------------------- /csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | // TODO: Not implemented yet 2 | -------------------------------------------------------------------------------- /csrc/batch_prefill_paged_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "batch_prefill_config.inc" 3 | 4 | namespace flashinfer { 5 | 6 | constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom; 7 | 8 | {% for cta_tile_q in [16, 64, 128] %} 9 | template cudaError_t BatchPrefillWithPagedKVCacheDispatched< 10 | /*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}}, 11 | {{ variant_name }}, PagedParams>(PagedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream); 12 | {% endfor %} 13 | 14 | }; // namespace flashinfer 15 | -------------------------------------------------------------------------------- /csrc/batch_prefill_paged_sm90_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "batch_prefill_sm90_config.inc" 3 | 4 | namespace flashinfer { 5 | 6 | {% for same_scheduler_for_all_heads in ["true", "false"] %} 7 | template cudaError_t BatchPrefillWithPagedKVCacheDispatched 8 | <{{ head_dim_qk }}, 9 | {{ head_dim_vo }}, 10 | {{ mask_mode }}, 11 | /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, 12 | /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, 13 | {{ variant_name }}, PagedParams>(PagedParams& params, cudaStream_t stream); 14 | {% endfor %} 15 | 16 | }; // namespace flashinfer 17 | -------------------------------------------------------------------------------- /csrc/batch_prefill_ragged_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "batch_prefill_config.inc" 3 | 4 | namespace flashinfer { 5 | 6 | constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom; 7 | 8 | {% for cta_tile_q in [16, 64, 128] %} 9 | template cudaError_t BatchPrefillWithRaggedKVCacheDispatched< 10 | /*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}}, 11 | {{ variant_name }}, RaggedParams>(RaggedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream); 12 | {% endfor %} 13 | 14 | }; // namespace flashinfer 15 | -------------------------------------------------------------------------------- /csrc/batch_prefill_ragged_sm90_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "batch_prefill_sm90_config.inc" 3 | 4 | namespace flashinfer { 5 | 6 | {% for same_scheduler_for_all_heads in ["true", "false"] %} 7 | template cudaError_t BatchPrefillWithRaggedKVCacheDispatched 8 | <{{ head_dim_qk }}, 9 | {{ head_dim_vo }}, 10 | {{ mask_mode }}, 11 | /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, 12 | /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, 13 | {{ variant_name }}>(RaggedParams& params, cudaStream_t stream); 14 | {% endfor %} 15 | 16 | }; // namespace flashinfer 17 | -------------------------------------------------------------------------------- /csrc/blackwell_fmha_plan.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "flashinfer/attention/blackwell/plan.cuh" 18 | #include "pytorch_extension_utils.h" 19 | 20 | void blackwell_fmha_plan(at::Tensor qo_segment_offsets, at::Tensor kv_segment_offsets, 21 | at::Tensor work_indptr, at::Tensor qo_tile_indices, 22 | at::Tensor head_indices, at::Tensor batch_indices, int64_t qo_tile_size, 23 | int64_t num_heads, int64_t num_buckets, bool causal) { 24 | const c10::cuda::OptionalCUDAGuard device_guard(qo_segment_offsets.device()); 25 | const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 26 | int batch_size = qo_segment_offsets.size(0) - 1; 27 | 28 | auto status = flashinfer::plan_kernel_wrapper( 29 | static_cast(qo_segment_offsets.data_ptr()), 30 | static_cast(kv_segment_offsets.data_ptr()), 31 | /*qo_lens=*/nullptr, 32 | /*kv_lens=*/nullptr, static_cast(work_indptr.data_ptr()), 33 | static_cast(qo_tile_indices.data_ptr()), static_cast(head_indices.data_ptr()), 34 | static_cast(batch_indices.data_ptr()), qo_tile_size, batch_size, num_heads, num_buckets, 35 | causal, /*enable_pdl=*/true, stream); 36 | TORCH_CHECK(status == cudaSuccess, "Failed to plan blackwell fmha", cudaGetErrorString(status)); 37 | } 38 | -------------------------------------------------------------------------------- /csrc/cutlass_mla.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | #include "pytorch_extension_utils.h" 19 | 20 | using namespace flashinfer; 21 | using namespace flashinfer::attention; 22 | 23 | void CutlassMLAPagedAttention(at::Tensor workspace, at::Tensor out, at::Tensor lse, 24 | at::Tensor q_nope_pe, at::Tensor ckv_kpe_cache, at::Tensor kv_lens, 25 | at::Tensor page_table) { 26 | const c10::cuda::OptionalCUDAGuard device_guard(q_nope_pe.device()); 27 | auto stream = at::cuda::getCurrentCUDAStream(); 28 | 29 | int device_index = q_nope_pe.device().index(); 30 | int batches = q_nope_pe.sizes()[0]; 31 | int page_count_per_seq = page_table.sizes()[1]; 32 | int page_count_total = ckv_kpe_cache.sizes()[0]; 33 | int page_size = ckv_kpe_cache.sizes()[1]; 34 | 35 | DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_nope_pe.scalar_type(), c_type, [&] { 36 | using cutlass_t = cutlass_dtype_t; 37 | auto status = runMla( 38 | workspace.data_ptr(), out.data_ptr(), lse.data_ptr(), q_nope_pe.data_ptr(), 39 | ckv_kpe_cache.data_ptr(), kv_lens.data_ptr(), page_table.data_ptr(), batches, 40 | page_count_per_seq, page_count_total, page_size, device_index, stream); 41 | TORCH_CHECK(status == cudaSuccess, 42 | "Failed to run CutlassMLAPagedAttention: ", cudaGetErrorString(status)); 43 | return true; 44 | }); 45 | } 46 | -------------------------------------------------------------------------------- /csrc/flashinfer_cascade_ops.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, 19 | at::Tensor v_merged, at::Tensor s_merged); 20 | 21 | void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other, 22 | std::optional mask); 23 | 24 | void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged); 25 | 26 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 27 | // Merge two self-attention states 28 | m.def("merge_state", merge_state); 29 | // Merge another self-attention state in-place. 30 | m.def("merge_state_in_place", merge_state_in_place); 31 | // "Merge multiple self-attention states" 32 | m.def("merge_states", merge_states); 33 | } 34 | -------------------------------------------------------------------------------- /csrc/flashinfer_comm_ops.cu: -------------------------------------------------------------------------------- 1 | // flashinfer: adapted from sglang + vllm code 2 | // refer to: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/common_extension.cc 3 | /* 4 | * Copyright (c) 2023 by FlashInfer team. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #include "pytorch_extension_utils.h" 19 | 20 | using fptr_t = int64_t; 21 | fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, at::Tensor& rank_data, int64_t rank, 22 | bool full_nvlink); 23 | void dispose(fptr_t _fa); 24 | int64_t meta_size(); 25 | void all_reduce(fptr_t _fa, at::Tensor& inp, at::Tensor& out, fptr_t _reg_buffer, 26 | int64_t reg_buffer_sz_bytes, int64_t num_ctas); 27 | std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); 28 | void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); 29 | void register_graph_buffers(fptr_t _fa, const std::vector>& handles, 30 | const std::vector>& offsets); 31 | 32 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 33 | m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); 34 | m.def("register_graph_buffers", ®ister_graph_buffers); 35 | m.def("dispose", &dispose); 36 | m.def("meta_size", &meta_size); 37 | m.def("register_buffer", ®ister_buffer); 38 | m.def("init_custom_ar", &init_custom_ar); 39 | m.def("all_reduce", &all_reduce); 40 | } 41 | -------------------------------------------------------------------------------- /csrc/flashinfer_gemm_ops.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, 19 | at::Tensor workspace_buffer, int64_t cublas_handle); 20 | 21 | void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr, 22 | at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld, 23 | at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major); 24 | 25 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 26 | // "Cutlass Segment GEMM" 27 | m.def("cutlass_segment_gemm", CutlassSegmentGEMM); 28 | // "BMM FP8" 29 | m.def("bmm_fp8", bmm_fp8); 30 | } 31 | -------------------------------------------------------------------------------- /csrc/flashinfer_gemm_sm90_ops.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, 19 | at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr, 20 | at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride, 21 | at::Tensor y_stride, at::Tensor empty_x_data, at::Tensor empty_y_data, 22 | bool weight_column_major); 23 | 24 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 25 | // "Cutlass Segment GEMM operator for SM90" 26 | m.def("cutlass_segment_gemm_sm90", CutlassSegmentGEMMSM90); 27 | } 28 | -------------------------------------------------------------------------------- /csrc/flashinfer_mla_ops.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void CutlassMLAPagedAttention(at::Tensor workspace, at::Tensor out, at::Tensor lse, 19 | at::Tensor q_nope_pe, at::Tensor ckv_kpe_cache, at::Tensor kv_lens, 20 | at::Tensor page_table); 21 | 22 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 23 | // "Cutlass MLA Paged Attention" 24 | m.def("cutlass_mla_paged_attention", CutlassMLAPagedAttention); 25 | } 26 | -------------------------------------------------------------------------------- /csrc/flashinfer_norm_ops.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); 19 | 20 | void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, 21 | bool enable_pdl); 22 | 23 | void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, 24 | bool enable_pdl); 25 | 26 | void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, 27 | double eps, bool enable_pdl); 28 | 29 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 30 | // Root mean square normalization 31 | m.def("rmsnorm", rmsnorm); 32 | // Fused add root mean square normalization 33 | m.def("fused_add_rmsnorm", fused_add_rmsnorm); 34 | // Gemma Root mean square normalization 35 | m.def("gemma_rmsnorm", gemma_rmsnorm); 36 | // Gemma Fused add root mean square normalization 37 | m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm); 38 | } 39 | -------------------------------------------------------------------------------- /csrc/flashinfer_page_ops.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, 19 | at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, 20 | at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, 21 | int64_t layout); 22 | 23 | void append_paged_mla_kv_cache(at::Tensor append_ckv, at::Tensor append_kpe, 24 | at::Tensor batch_indices, at::Tensor positions, at::Tensor ckv_cache, 25 | at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor kv_indptr, 26 | at::Tensor kv_last_page_len); 27 | 28 | void block_sparse_indices_to_vector_sparse_offsets( 29 | at::Tensor block_sparse_indices, at::Tensor block_sparse_indptr, 30 | at::Tensor vector_sparse_offsets, at::Tensor vector_sparse_indptr, at::Tensor kv_len_arr, 31 | int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size); 32 | 33 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 34 | // "Append paged KV-Cache operator" 35 | m.def("append_paged_kv_cache", append_paged_kv_cache); 36 | // "Append paged MLA KV-Cache operator" 37 | m.def("append_paged_mla_kv_cache", append_paged_mla_kv_cache); 38 | // "Precompute block sparse offsets" 39 | m.def("block_sparse_indices_to_vector_sparse_offsets", 40 | block_sparse_indices_to_vector_sparse_offsets); 41 | } 42 | -------------------------------------------------------------------------------- /csrc/flashinfer_quantization_ops.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y); 19 | 20 | void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, 21 | const std::string& bitorder, at::Tensor y); 22 | 23 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 24 | // GPU packbits operator 25 | m.def("packbits", packbits); 26 | // GPU segment packbits operator 27 | m.def("segment_packbits", segment_packbits); 28 | } 29 | -------------------------------------------------------------------------------- /csrc/fmha_cutlass_sm100_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023-2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void FMHACutlassSM100Run(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, 19 | at::Tensor qo_segment_offsets, at::Tensor kv_segment_offsets, 20 | at::Tensor work_indptr, at::Tensor qo_tile_indices, 21 | at::Tensor qo_head_indices, at::Tensor batch_indices, at::Tensor o, 22 | std::optional maybe_lse, int64_t mask_mode_code, 23 | double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, 24 | int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len); 25 | 26 | void blackwell_fmha_plan(at::Tensor qo_segment_offsets, at::Tensor kv_segment_offsets, 27 | at::Tensor work_indptr, at::Tensor qo_tile_indices, 28 | at::Tensor head_indices, at::Tensor batch_indices, int64_t qo_tile_size, 29 | int64_t num_heads, int64_t num_buckets, bool causal); 30 | 31 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 32 | m.def("run", FMHACutlassSM100Run); 33 | m.def("plan", blackwell_fmha_plan); 34 | } 35 | -------------------------------------------------------------------------------- /csrc/gemm_sm100_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | 18 | void CutlassGemmGroupwiseScaledSM100(at::Tensor float_workspace_buffer, at::Tensor A, at::Tensor B, 19 | at::Tensor SFA, at::Tensor SFB, at::Tensor C, 20 | int64_t scale_granularity_m, int64_t scale_granularity_n, 21 | int64_t scale_granularity_k, std::string scale_major_mode, 22 | int64_t mma_sm); 23 | 24 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 25 | m.def("gemm_fp8_nt_groupwise", CutlassGemmGroupwiseScaledSM100); 26 | } 27 | -------------------------------------------------------------------------------- /csrc/group_gemm.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | #include "pytorch_extension_utils.h" 19 | 20 | using namespace flashinfer; 21 | using namespace flashinfer::group_gemm; 22 | 23 | void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr, 24 | at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld, 25 | at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major) { 26 | unsigned int batch_size = x_ptr.size(0); 27 | 28 | const c10::cuda::OptionalCUDAGuard device_guard(workspace_buffer.device()); 29 | auto stream = at::cuda::getCurrentCUDAStream(); 30 | DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] { 31 | using cutlass_t = cutlass_dtype_t; 32 | auto status = CutlassSegmentGEMMRun( 33 | workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0), 34 | all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), 35 | x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, stream); 36 | TORCH_CHECK(status == cudaSuccess, 37 | "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); 38 | return true; 39 | }); 40 | } 41 | -------------------------------------------------------------------------------- /csrc/group_gemm_bf16_bf16_sm90.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | using namespace flashinfer; 19 | using namespace flashinfer::group_gemm; 20 | 21 | namespace flashinfer { 22 | namespace group_gemm { 23 | 24 | template cudaError_t CutlassSegmentGEMMSM90Run( 25 | void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer, 26 | size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w, 27 | void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major, 28 | cudaStream_t stream); 29 | 30 | }; // namespace group_gemm 31 | }; // namespace flashinfer 32 | -------------------------------------------------------------------------------- /csrc/group_gemm_e4m3_bf16_sm90.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | using namespace flashinfer; 19 | using namespace flashinfer::group_gemm; 20 | 21 | namespace flashinfer { 22 | namespace group_gemm { 23 | 24 | template cudaError_t CutlassSegmentGEMMSM90Run( 25 | void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer, 26 | size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w, 27 | void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major, 28 | cudaStream_t stream); 29 | 30 | }; // namespace group_gemm 31 | }; // namespace flashinfer 32 | -------------------------------------------------------------------------------- /csrc/group_gemm_e4m3_f16_sm90.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | using namespace flashinfer; 19 | using namespace flashinfer::group_gemm; 20 | 21 | namespace flashinfer { 22 | namespace group_gemm { 23 | 24 | template cudaError_t CutlassSegmentGEMMSM90Run( 25 | void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer, 26 | size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w, 27 | void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major, 28 | cudaStream_t stream); 29 | 30 | }; // namespace group_gemm 31 | }; // namespace flashinfer 32 | -------------------------------------------------------------------------------- /csrc/group_gemm_e5m2_bf16_sm90.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | using namespace flashinfer; 19 | using namespace flashinfer::group_gemm; 20 | 21 | namespace flashinfer { 22 | namespace group_gemm { 23 | 24 | template cudaError_t CutlassSegmentGEMMSM90Run( 25 | void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer, 26 | size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w, 27 | void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major, 28 | cudaStream_t stream); 29 | 30 | }; // namespace group_gemm 31 | }; // namespace flashinfer 32 | -------------------------------------------------------------------------------- /csrc/group_gemm_e5m2_f16_sm90.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | using namespace flashinfer; 19 | using namespace flashinfer::group_gemm; 20 | 21 | namespace flashinfer { 22 | namespace group_gemm { 23 | 24 | template cudaError_t CutlassSegmentGEMMSM90Run( 25 | void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer, 26 | size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w, 27 | void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major, 28 | cudaStream_t stream); 29 | 30 | }; // namespace group_gemm 31 | }; // namespace flashinfer 32 | -------------------------------------------------------------------------------- /csrc/group_gemm_f16_f16_sm90.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | using namespace flashinfer; 19 | using namespace flashinfer::group_gemm; 20 | 21 | namespace flashinfer { 22 | namespace group_gemm { 23 | 24 | template cudaError_t CutlassSegmentGEMMSM90Run( 25 | void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer, 26 | size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w, 27 | void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major, 28 | cudaStream_t stream); 29 | 30 | }; // namespace group_gemm 31 | }; // namespace flashinfer 32 | -------------------------------------------------------------------------------- /csrc/group_gemm_sm100_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | #include "pytorch_extension_utils.h" 19 | 20 | void CutlassGroupGemmGroupwiseScaledSM100(at::Tensor int_workspace_buffer, 21 | at::Tensor float_workspace_buffer, at::Tensor A, 22 | at::Tensor B, at::Tensor SFA, at::Tensor SFB, 23 | at::Tensor C, at::Tensor m_indptr, int64_t n, int64_t k, 24 | int64_t scale_granularity_m, int64_t scale_granularity_n, 25 | int64_t scale_granularity_k, std::string scale_major_mode, 26 | int64_t mma_sm); 27 | 28 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 29 | m.def("group_gemm_fp8_nt_groupwise", CutlassGroupGemmGroupwiseScaledSM100); 30 | } 31 | -------------------------------------------------------------------------------- /csrc/logging.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "flashinfer/logging.h" 17 | 18 | #include 19 | 20 | #include "Python.h" 21 | 22 | void set_log_level(int64_t log_level_code) { 23 | auto log_level = static_cast(log_level_code); 24 | flashinfer::logging::set_log_level(log_level); 25 | } 26 | 27 | void try_log_info(const std::string& msg) { FLASHINFER_LOG_INFO(msg); } 28 | 29 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 30 | m.def("set_log_level", set_log_level); 31 | m.def("try_log_info", try_log_info); 32 | } 33 | -------------------------------------------------------------------------------- /csrc/nv_internal/cpp/common/stringUtils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "tensorrt_llm/common/stringUtils.h" 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "tensorrt_llm/common/assert.h" 26 | 27 | namespace tensorrt_llm::common { 28 | 29 | void fmtstr_(char const* format, fmtstr_allocator alloc, void* target, va_list args) { 30 | va_list args0; 31 | va_copy(args0, args); 32 | 33 | size_t constexpr init_size = 2048; 34 | char fixed_buffer[init_size]; 35 | auto const size = std::vsnprintf(fixed_buffer, init_size, format, args0); 36 | TLLM_CHECK_WITH_INFO(size >= 0, std::string(std::strerror(errno))); 37 | if (size == 0) { 38 | return; 39 | } 40 | 41 | auto* memory = alloc(target, size); 42 | 43 | if (static_cast(size) < init_size) { 44 | std::memcpy(memory, fixed_buffer, size + 1); 45 | } else { 46 | auto const size2 = std::vsnprintf(memory, size + 1, format, args); 47 | TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno))); 48 | } 49 | } 50 | 51 | std::unordered_set str2set(std::string const& input, char delimiter) { 52 | std::unordered_set values; 53 | if (!input.empty()) { 54 | std::stringstream valStream(input); 55 | std::string val; 56 | while (std::getline(valStream, val, delimiter)) { 57 | if (!val.empty()) { 58 | values.insert(val); 59 | } 60 | } 61 | } 62 | return values; 63 | }; 64 | 65 | } // namespace tensorrt_llm::common 66 | -------------------------------------------------------------------------------- /csrc/nv_internal/include/tensorrt_llm/common/cudaBf16Wrapper.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #ifdef ENABLE_BF16 20 | #include 21 | #endif 22 | -------------------------------------------------------------------------------- /csrc/nv_internal/include/tensorrt_llm/common/tllmException.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #define NEW_TLLM_EXCEPTION(...) \ 25 | tensorrt_llm::common::TllmException(__FILE__, __LINE__, \ 26 | tensorrt_llm::common::fmtstr(__VA_ARGS__).c_str()) 27 | 28 | namespace tensorrt_llm::common { 29 | 30 | class TllmException : public std::runtime_error { 31 | public: 32 | static auto constexpr MAX_FRAMES = 128; 33 | 34 | explicit TllmException(char const* file, std::size_t line, char const* msg); 35 | 36 | ~TllmException() noexcept override; 37 | 38 | [[nodiscard]] std::string getTrace() const; 39 | 40 | static std::string demangle(char const* name); 41 | 42 | private: 43 | std::array mCallstack{}; 44 | int mNbFrames; 45 | }; 46 | 47 | } // namespace tensorrt_llm::common 48 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/common/quantTypeUtils.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" 24 | #include "tensorrt_llm/common/cudaFp8Utils.h" 25 | 26 | namespace tensorrt_llm { 27 | namespace common { 28 | 29 | template 30 | struct QuantTypeStaticVals; 31 | 32 | template <> 33 | struct QuantTypeStaticVals { 34 | static constexpr float MAX_VAL = 127.f; 35 | static constexpr float MIN_SCALING_FACTOR = 0.f; 36 | static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX; 37 | }; 38 | 39 | #ifdef ENABLE_FP8 40 | 41 | template <> 42 | struct QuantTypeStaticVals<__nv_fp8_e4m3> { 43 | static constexpr float MAX_VAL = 448.f; 44 | // Ref: 45 | // https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720 46 | static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f); 47 | static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f); 48 | }; 49 | 50 | #endif // ENABLE_FP8 51 | 52 | } // namespace common 53 | } // namespace tensorrt_llm 54 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /*! \file 19 | \brief Defines new layouts needed for MoE 20 | */ 21 | #pragma once 22 | 23 | #include "cutlass/cutlass.h" 24 | #include "cutlass/fast_math.h" 25 | #include "cutlass/matrix_coord.h" 26 | #include "cutlass/pitch_linear_coord.h" 27 | 28 | namespace cutlass { 29 | namespace layout { 30 | 31 | template 32 | struct ColumnMajorTileInterleave { 33 | static constexpr int kRowsPerTile = RowsPerTile; 34 | static constexpr int kColumnsInterleaved = ColumnsInterleaved; 35 | }; 36 | 37 | template 38 | struct IsColumnMajorTileInterleave { 39 | static constexpr bool value = false; 40 | }; 41 | 42 | template 43 | struct IsColumnMajorTileInterleave> { 44 | static constexpr bool value = true; 45 | }; 46 | 47 | } // namespace layout 48 | } // namespace cutlass 49 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /*! \file 19 | \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. 20 | */ 21 | 22 | #pragma once 23 | 24 | namespace cutlass { 25 | 26 | enum class WeightOnlyQuantOp { 27 | UNDEFINED, 28 | PER_COLUMN_SCALE_ONLY, 29 | FINEGRAINED_SCALE_ONLY, 30 | FINEGRAINED_SCALE_AND_ZEROS 31 | }; 32 | 33 | constexpr bool isFinegrained(WeightOnlyQuantOp op) { 34 | return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || 35 | op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; 36 | } 37 | 38 | constexpr bool hasZero(WeightOnlyQuantOp op) { 39 | return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; 40 | } 41 | 42 | } // namespace cutlass 43 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/delayStream.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | 19 | #include "pytorch_extension_utils.h" 20 | #include "tensorrt_llm/common/cudaUtils.h" 21 | #include "tensorrt_llm/kernels/delayStream.h" 22 | 23 | using namespace tensorrt_llm::common; 24 | 25 | namespace tensorrt_llm::kernels { 26 | __global__ void delayStreamKernel(long long delay_micro_secs) { 27 | for (int i = 0; i < delay_micro_secs; ++i) { 28 | // The largest delay __nanosleep can do is 1 millisecond, thus we use for loop to achieve longer 29 | // delay. 30 | __nanosleep(1000); 31 | } 32 | } 33 | 34 | void invokeDelayStreamKernel(long long delay_micro_secs, cudaStream_t stream) { 35 | delayStreamKernel<<<1, 1, 0, stream>>>(delay_micro_secs); 36 | tensorrt_llm::common::check_cuda_error(cudaGetLastError()); 37 | } 38 | } // namespace tensorrt_llm::kernels 39 | 40 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 41 | m.def("delay_kernel", [](int64_t delay_micro_secs) { 42 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 43 | tensorrt_llm::kernels::invokeDelayStreamKernel(delay_micro_secs, stream); 44 | }); 45 | } 46 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/delayStream.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include "tensorrt_llm/common/cudaUtils.h" 20 | 21 | namespace tensorrt_llm::kernels { 22 | void invokeDelayStreamKernel(long long delay_micro_secs, cudaStream_t stream); 23 | } // namespace tensorrt_llm::kernels 24 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | namespace tensorrt_llm::kernels::cutlass_kernels { 14 | template 16 | void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, 17 | ElementType_ const* biases, bool bias_is_broadcast, 18 | ElementType_* C, 19 | int64_t const* total_tokens_including_expert, 20 | int64_t num_rows, int64_t gemm_n, int64_t gemm_k, 21 | int num_experts, int multi_processor_count, 22 | cudaStream_t stream, int* kernel_occupancy); 23 | } 24 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | #include "moe_gemm_kernels.h" 18 | 19 | namespace tensorrt_llm { 20 | namespace kernels { 21 | namespace cutlass_kernels { 22 | 23 | // Keep in sync with the signature generated by generate_kernels.py 24 | template 27 | void tma_warp_specialized_generic_moe_gemm_kernelLauncher( 28 | TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, int multi_processor_count, 29 | cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); 30 | 31 | } // namespace cutlass_kernels 32 | } // namespace kernels 33 | } // namespace tensorrt_llm 34 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | 15 | #include "cutlass_extensions/gemm_configs.h" 16 | #include "cutlass_extensions/weight_only_quant_op.h" 17 | #include "moe_gemm_kernels.h" 18 | 19 | namespace tensorrt_llm { 20 | namespace kernels { 21 | namespace cutlass_kernels { 22 | 23 | template 26 | void sm90_generic_mixed_moe_gemm_kernelLauncher( 27 | GroupedGemmInput inputs, 28 | TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size); 29 | 30 | } // namespace cutlass_kernels 31 | } // namespace kernels 32 | } // namespace tensorrt_llm 33 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_bf16_bf16.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | #ifdef ENABLE_BF16 17 | template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>; 18 | #endif 19 | } // namespace tensorrt_llm 20 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_bf16_fp8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | #ifdef ENABLE_BF16 17 | template class MoeGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; 18 | #endif 19 | } // namespace tensorrt_llm 20 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_bf16_uint4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | #ifdef ENABLE_BF16 17 | template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>; 18 | #endif 19 | } // namespace tensorrt_llm 20 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_bf16_uint8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | #ifdef ENABLE_BF16 17 | template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>; 18 | #endif 19 | } // namespace tensorrt_llm 20 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_fp16_fp16.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | template class MoeGemmRunner; 17 | } 18 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_fp16_uint4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | template class MoeGemmRunner; 17 | } 18 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_fp16_uint8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | template class MoeGemmRunner; 17 | } 18 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_fp32_fp32.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | template class MoeGemmRunner; 17 | } 18 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_fp4_fp4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | #ifdef ENABLE_FP4 17 | template class MoeGemmRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half>; 18 | #ifdef ENABLE_BF16 19 | template class MoeGemmRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16>; 20 | #endif 21 | #endif 22 | } // namespace tensorrt_llm 23 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_fp8_fp8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | #ifdef ENABLE_FP8 17 | template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; 18 | #ifdef ENABLE_BF16 19 | template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; 20 | #endif 21 | // template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>; 22 | #endif 23 | } // namespace tensorrt_llm 24 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_kernels_fp8_uint4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/moe_gemm_template_dispatch.h" 14 | 15 | namespace tensorrt_llm { 16 | #ifdef ENABLE_FP8 17 | template class MoeGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half>; 18 | #ifdef ENABLE_BF16 19 | template class MoeGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16>; 20 | #endif 21 | #endif 22 | } // namespace tensorrt_llm 23 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/lora/lora.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & 3 | * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include "tensorrt_llm/kernels/lora/lora.h" 19 | 20 | #include 21 | #include 22 | 23 | namespace tensorrt_llm::kernels { 24 | 25 | int Lora_run(LoraImpl* impl, int64_t numTokens, int64_t numReqs, void const* input, 26 | int32_t const* loraRanks, void const* const* loraWeightsPtr, int weightIndex, 27 | void* const* outputs, void* workspace, cudaStream_t stream) {} 28 | 29 | } // namespace tensorrt_llm::kernels 30 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/kernels/preQuantScaleKernel.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights 3 | * reserved. SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #if defined(ENABLE_BF16) 25 | #include 26 | #endif 27 | 28 | #include 29 | #include 30 | 31 | namespace tensorrt_llm { 32 | namespace kernels { 33 | 34 | template 35 | void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, 36 | T_in const* per_channel_scale, int rows, int cols, 37 | cudaStream_t stream = 0); 38 | 39 | template 40 | void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, 41 | T_in const* per_expert_scale, 42 | int const* permuted_token_selected_experts, 43 | int64_t const* num_valid_tokens_ptr, int rows, int cols, 44 | cudaStream_t stream = 0); 45 | 46 | } // namespace kernels 47 | } // namespace tensorrt_llm 48 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/runtime/torchUtils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #include "tensorrt_llm/common/assert.h" 31 | #include "tensorrt_llm/common/cudaUtils.h" 32 | 33 | namespace tensorrt_llm::runtime { 34 | 35 | class TorchUtils { 36 | public: 37 | static nvinfer1::DataType dataType(at::ScalarType scalarType) { 38 | switch (scalarType) { 39 | case at::ScalarType::Float: 40 | return nvinfer1::DataType::kFLOAT; 41 | case at::ScalarType::Half: 42 | return nvinfer1::DataType::kHALF; 43 | case torch::kInt8: 44 | return nvinfer1::DataType::kINT8; 45 | case torch::kUInt8: 46 | return nvinfer1::DataType::kUINT8; 47 | case torch::kInt32: 48 | return nvinfer1::DataType::kINT32; 49 | case torch::kInt64: 50 | return nvinfer1::DataType::kINT64; 51 | case at::ScalarType::Bool: 52 | return nvinfer1::DataType::kBOOL; 53 | case at::ScalarType::Float8_e4m3fn: 54 | return nvinfer1::DataType::kFP8; 55 | case at::ScalarType::BFloat16: 56 | return nvinfer1::DataType::kBF16; 57 | case at::ScalarType::QUInt4x2: 58 | return nvinfer1::DataType::kINT4; 59 | default: 60 | TLLM_THROW("unsupported data type"); 61 | } 62 | } 63 | 64 | private: 65 | TorchUtils() = default; 66 | }; 67 | 68 | } // namespace tensorrt_llm::runtime 69 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | 20 | #include 21 | 22 | #include "tensorrt_llm/common/cudaUtils.h" 23 | 24 | namespace torch_ext { 25 | std::tuple fp4_quantize(at::Tensor const& self, 26 | at::Tensor const& globalScale, int64_t sfVecSize, 27 | bool sfUseUE8M0, bool isSfSwizzledLayout); 28 | } 29 | -------------------------------------------------------------------------------- /csrc/nv_internal/tensorrt_llm/thop/thUtils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x) 24 | #define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 25 | #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor") 26 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 27 | 28 | #define CHECK_CPU_INPUT(x, st) \ 29 | CHECK_CPU(x); \ 30 | CHECK_CONTIGUOUS(x); \ 31 | CHECK_TYPE(x, st) 32 | #define CHECK_OPTIONAL_INPUT(x, st) \ 33 | if (x.has_value()) { \ 34 | CHECK_INPUT(x.value(), st); \ 35 | } 36 | #define CHECK_OPTIONAL_CPU_INPUT(x, st) \ 37 | if (x.has_value()) { \ 38 | CHECK_CPU_INPUT(x.value(), st); \ 39 | } 40 | #define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl 41 | #define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl 42 | 43 | namespace torch_ext { 44 | 45 | // // TODO: switch to use torch native fp4 dtype when ready 46 | constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; // uint8_t 47 | constexpr auto SF_DTYPE = at::ScalarType::Byte; // uint8_t 48 | 49 | constexpr auto FP8_BLOCK_SCALING_SF_DTYPE = at::ScalarType::Float; 50 | 51 | } // namespace torch_ext 52 | -------------------------------------------------------------------------------- /csrc/pod_customize_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using namespace flashinfer; 14 | 15 | using DTypeQ = {{ dtype_q }}; 16 | using DTypeKV = {{ dtype_kv }}; 17 | using DTypeO = {{ dtype_o }}; 18 | using IdType = {{ idtype }}; 19 | constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; 20 | constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; 21 | constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; 22 | constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }}; 23 | constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }}; 24 | constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }}; 25 | 26 | constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }}; 27 | constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }}; 28 | constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }}; 29 | 30 | constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; 31 | constexpr bool USE_LOGITS_SOFT_CAP = false; 32 | 33 | using PrefillParams = SinglePrefillParams; 34 | using DecodeParams = BatchPrefillPagedParams; 35 | 36 | #define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ 37 | USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \ 38 | DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, { \ 39 | DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, { \ 40 | __VA_ARGS__(); \ 41 | }); \ 42 | }); 43 | -------------------------------------------------------------------------------- /csrc/pod_jit_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023-2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pod_config.inc" 17 | #include "pytorch_extension_utils.h" 18 | 19 | void pod_with_kv_cache_tensor( 20 | // Prefill params 21 | at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, 22 | std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, 23 | int64_t window_left_p, std::optional maybe_custom_mask_p, 24 | std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, 25 | double rope_rcp_scale_p, double rope_rcp_theta_p, 26 | // Decode params 27 | at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, 28 | at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, 29 | at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, 30 | at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, 31 | std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, 32 | int64_t window_left_d, std::optional maybe_custom_mask_d, 33 | std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, 34 | double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d); 35 | 36 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 37 | // Batch-request prefill attention with KV-Cache operator 38 | m.def("pod_with_kv_cache_tensor", pod_with_kv_cache_tensor); 39 | } 40 | -------------------------------------------------------------------------------- /csrc/pod_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "pytorch_conversion_utils.h" 12 | #include "pytorch_extension_utils.h" 13 | 14 | #include "pod_config.inc" 15 | 16 | using namespace flashinfer; 17 | 18 | namespace flashinfer { 19 | constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom; 20 | constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom; 21 | // Not sure about the below declaration 22 | constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; 23 | 24 | template cudaError_t PODWithKVCacheTensorDispatched< 25 | {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, 26 | {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16, 27 | {{ mask_mode_d }}, {{ variant_name_p }}, 28 | {{ variant_name_d }}, PrefillParams, DecodeParams>( 29 | PrefillParams prefill_params, {{ dtype_o }}* tmp, 30 | DecodeParams decode_params, {{ dtype_o }}* tmp_v, 31 | float *tmp_s, cudaStream_t stream); 32 | }; 33 | -------------------------------------------------------------------------------- /csrc/pytorch_conversion_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | 20 | inline at::Tensor vec_to_tensor(const std::vector& vec) { 21 | return at::tensor(vec, at::dtype(at::kLong).device(at::kCPU)); 22 | } 23 | 24 | inline std::vector tensor_to_vec(const at::Tensor& tensor) { 25 | const size_t size = tensor.numel(); 26 | const int64_t* first = tensor.const_data_ptr(); 27 | const int64_t* last = first + size; 28 | return std::vector(first, last); 29 | } 30 | -------------------------------------------------------------------------------- /csrc/runtime_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #pragma once 17 | 18 | #define FLASHINFER_DLL __attribute__((visibility("default"))) 19 | -------------------------------------------------------------------------------- /csrc/single_decode_customize_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} 8 | #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} 9 | 10 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) {\ 11 | using AttentionVariant = {{ variant_name }}; \ 12 | __VA_ARGS__(); \ 13 | } 14 | 15 | using namespace flashinfer; 16 | 17 | using DTypeQ = {{ dtype_q }}; 18 | using DTypeKV = {{ dtype_kv }}; 19 | using DTypeO = {{ dtype_o }}; 20 | using IdType = int32_t; 21 | constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; 22 | constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; 23 | constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; 24 | constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; 25 | constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; 26 | 27 | struct Params { 28 | using DTypeQ = DTypeQ; 29 | using DTypeKV = DTypeKV; 30 | using DTypeO = DTypeO; 31 | using IdType = int32_t; 32 | DTypeQ* q; 33 | DTypeKV* k; 34 | DTypeKV* v; 35 | DTypeO* o; 36 | float* lse; 37 | {{ additional_params_decl }} 38 | uint32_t kv_len; 39 | uint32_t num_qo_heads; 40 | uint32_t num_kv_heads; 41 | uint32_t q_stride_n; 42 | uint32_t q_stride_h; 43 | uint32_t kv_stride_n; 44 | uint32_t kv_stride_h; 45 | int32_t window_left; 46 | uint32_t kv_chunk_size; 47 | 48 | __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return 1; } 49 | 50 | __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { 51 | return kv_len; 52 | } 53 | }; 54 | 55 | {{ variant_decl }} 56 | -------------------------------------------------------------------------------- /csrc/single_decode_jit_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023-2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "pytorch_extension_utils.h" 18 | #include "single_decode_config.inc" 19 | 20 | void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, 21 | at::Tensor o, std::optional maybe_lse, int64_t layout, 22 | int64_t window_left ADDITIONAL_FUNC_PARAMS); 23 | 24 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 25 | // Single-request decode with KV-Cache operator 26 | m.def("run", single_decode_with_kv_cache); 27 | } 28 | -------------------------------------------------------------------------------- /csrc/single_decode_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "single_decode_config.inc" 3 | 4 | using namespace flashinfer; 5 | 6 | namespace flashinfer { 7 | 8 | template cudaError_t SingleDecodeWithKVCacheDispatched< 9 | {{ head_dim_qk }}, {{ pos_encoding_mode }}, {{ variant_name }}, Params>( 10 | Params params, {{ dtype_o }}* tmp, 11 | cudaStream_t stream); 12 | 13 | }; 14 | -------------------------------------------------------------------------------- /csrc/single_prefill_customize_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} 10 | #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} 11 | 12 | 13 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, Params, ...) \ 14 | DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ 15 | constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \ 16 | using AttentionVariant = {{ variant_name }}; \ 17 | __VA_ARGS__(); \ 18 | }) 19 | 20 | 21 | using namespace flashinfer; 22 | 23 | using DTypeQ = {{ dtype_q }}; 24 | using DTypeKV = {{ dtype_kv }}; 25 | using DTypeO = {{ dtype_o }}; 26 | using IdType = int32_t; 27 | constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; 28 | constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; 29 | constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; 30 | constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; 31 | constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; 32 | constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; 33 | 34 | struct Params { 35 | using DTypeQ = DTypeQ; 36 | using DTypeKV = DTypeKV; 37 | using DTypeO = DTypeO; 38 | using IdType = int32_t; 39 | DTypeQ* q; 40 | DTypeKV* k; 41 | DTypeKV* v; 42 | DTypeO* o; 43 | float* lse; 44 | uint_fastdiv group_size; 45 | 46 | {{ additional_params_decl }} 47 | 48 | uint32_t qo_len; 49 | uint32_t kv_len; 50 | uint32_t num_qo_heads; 51 | uint32_t num_kv_heads; 52 | uint32_t q_stride_n; 53 | uint32_t q_stride_h; 54 | uint32_t k_stride_n; 55 | uint32_t k_stride_h; 56 | uint32_t v_stride_n; 57 | uint32_t v_stride_h; 58 | uint32_t head_dim; 59 | int32_t window_left; 60 | 61 | bool partition_kv; 62 | 63 | __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { 64 | return qo_len; 65 | } 66 | 67 | __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { 68 | return kv_len; 69 | } 70 | }; 71 | 72 | {{ variant_decl }} 73 | -------------------------------------------------------------------------------- /csrc/single_prefill_fp8_sm90_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "single_prefill_sm90_config.inc" 3 | 4 | using namespace flashinfer; 5 | 6 | namespace flashinfer { 7 | 8 | template cudaError_t SingleFP8PrefillWithKVCacheDispatched 9 | <{{ head_dim_qk }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, {{ variant_name }}, Params>( 10 | Params& params, cudaStream_t stream); 11 | }; 12 | -------------------------------------------------------------------------------- /csrc/single_prefill_jit_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023-2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | #include "single_prefill_config.inc" 18 | 19 | void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, 20 | at::Tensor o, std::optional maybe_lse, 21 | int64_t mask_mode_code, int64_t layout, 22 | int64_t window_left ADDITIONAL_FUNC_PARAMS); 23 | 24 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 25 | // Single-request prefill attention with KV-Cache operator 26 | m.def("run", single_prefill_with_kv_cache); 27 | } 28 | -------------------------------------------------------------------------------- /csrc/single_prefill_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "single_prefill_config.inc" 3 | 4 | using namespace flashinfer; 5 | 6 | namespace flashinfer { 7 | 8 | constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom; 9 | 10 | template cudaError_t SinglePrefillWithKVCacheDispatched< 11 | {{ head_dim_qk }}, {{ head_dim_vo }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, {{ variant_name }}, Params>( 12 | Params params, {{ dtype_o }}* tmp, 13 | cudaStream_t stream); 14 | 15 | }; 16 | -------------------------------------------------------------------------------- /csrc/single_prefill_sm90_customize_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} 12 | #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} 13 | 14 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \ 15 | DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__(); }) 16 | 17 | using namespace flashinfer; 18 | 19 | using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; 20 | using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; 21 | using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; 22 | using IdType = cutlass_dtype_t; 23 | 24 | constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; 25 | constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; 26 | constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; 27 | constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; 28 | 29 | struct Params { 30 | using DTypeQ = DTypeQ; 31 | using DTypeKV = DTypeKV; 32 | using DTypeO = DTypeO; 33 | using IdType = IdType; 34 | 35 | // The QKV matrices. 36 | DTypeQ* q_ptr; 37 | DTypeKV* k_ptr; 38 | DTypeKV* v_ptr; 39 | DTypeO* o_ptr; 40 | float* lse_ptr; 41 | 42 | // Additional params 43 | struct AdditionalParams { 44 | {{ additional_params_decl }}; 45 | } additional_params; 46 | 47 | int64_t q_stride_n; 48 | int64_t k_stride_n; 49 | int64_t v_stride_n; 50 | int64_t o_stride_n; 51 | int64_t q_stride_h; 52 | int64_t k_stride_h; 53 | int64_t v_stride_h; 54 | int64_t o_stride_h; 55 | 56 | int qo_len; 57 | int kv_len; 58 | int head_dim; 59 | int num_qo_heads; 60 | int num_kv_heads; 61 | int group_size; 62 | int window_left; 63 | 64 | bool causal; 65 | }; 66 | 67 | {{ variant_decl }} 68 | -------------------------------------------------------------------------------- /csrc/single_prefill_sm90_jit_pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023-2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "pytorch_extension_utils.h" 17 | #include "single_prefill_sm90_config.inc" 18 | 19 | void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, 20 | at::Tensor o, std::optional maybe_lse, 21 | int64_t mask_mode_code, int64_t layout, 22 | int64_t window_left ADDITIONAL_FUNC_PARAMS); 23 | 24 | TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { 25 | // Single-request prefill attention with KV-Cache operator 26 | m.def("run", single_prefill_with_kv_cache_sm90); 27 | } 28 | -------------------------------------------------------------------------------- /csrc/single_prefill_sm90_kernel_inst.jinja: -------------------------------------------------------------------------------- 1 | #include 2 | #include "single_prefill_sm90_config.inc" 3 | 4 | using namespace flashinfer; 5 | 6 | namespace flashinfer { 7 | 8 | template cudaError_t SinglePrefillWithKVCacheDispatched 9 | <{{ head_dim_qk }}, {{ head_dim_vo }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, {{ variant_name }}, Params>( 10 | Params& params, cudaStream_t stream); 11 | 12 | }; 13 | -------------------------------------------------------------------------------- /docker/Dockerfile.ci_gpu: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.8.0-devel-ubuntu24.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | # Update package lists and install system dependencies 6 | RUN apt-get update && apt-get install -y \ 7 | curl \ 8 | git \ 9 | wget 10 | 11 | # Install python 12 | COPY install/install_python.sh /install/install_python.sh 13 | RUN bash /install/install_python.sh py312 14 | 15 | # Set home directory 16 | WORKDIR /workspace 17 | 18 | RUN echo "source activate py312" >> ~/.bashrc 19 | ENV PATH="/opt/conda/bin:$PATH" 20 | ENV PATH="/opt/conda/envs/py312/bin:$PATH" 21 | 22 | # Install torch 23 | COPY install/install_python_packages.sh /install/install_python_packages.sh 24 | RUN bash /install/install_python_packages.sh 25 | -------------------------------------------------------------------------------- /docker/install/install_python.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Licensed to the Apache Software Foundation (ASF) under one 3 | # or more contributor license agreements. See the NOTICE file 4 | # distributed with this work for additional information 5 | # regarding copyright ownership. The ASF licenses this file 6 | # to you under the Apache License, Version 2.0 (the 7 | # "License"); you may not use this file except in compliance 8 | # with the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, 13 | # software distributed under the License is distributed on an 14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | # KIND, either express or implied. See the License for the 16 | # specific language governing permissions and limitations 17 | # under the License. 18 | 19 | set -e 20 | set -u 21 | set -o pipefail 22 | 23 | 24 | # Install python and pip. Don't modify this to add Python package dependencies, 25 | wget -O Miniforge3.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" 26 | bash Miniforge3.sh -b -p /opt/conda 27 | 28 | /opt/conda/bin/conda create -n $1 python=3.12 29 | -------------------------------------------------------------------------------- /docker/install/install_python_packages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Licensed to the Apache Software Foundation (ASF) under one 3 | # or more contributor license agreements. See the NOTICE file 4 | # distributed with this work for additional information 5 | # regarding copyright ownership. The ASF licenses this file 6 | # to you under the Apache License, Version 2.0 (the 7 | # "License"); you may not use this file except in compliance 8 | # with the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, 13 | # software distributed under the License is distributed on an 14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | # KIND, either express or implied. See the License for the 16 | # specific language governing permissions and limitations 17 | # under the License. 18 | 19 | set -e 20 | set -u 21 | 22 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 23 | pip3 install ninja pytest numpy scipy build 24 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build/ 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: export FLASHINFER_BUILDING_DOCS=1 20 | %: Makefile 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /docs/_static/FlashInfer-black-background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flashinfer-ai/flashinfer/4e8bb778f1522de6becbcf4b732a22d48f7d72b0/docs/_static/FlashInfer-black-background.png -------------------------------------------------------------------------------- /docs/_static/FlashInfer-white-background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flashinfer-ai/flashinfer/4e8bb778f1522de6becbcf4b732a22d48f7d72b0/docs/_static/FlashInfer-white-background.png -------------------------------------------------------------------------------- /docs/api/activation.rst: -------------------------------------------------------------------------------- 1 | .. _apiactivation: 2 | 3 | flashinfer.activation 4 | ===================== 5 | 6 | .. currentmodule:: flashinfer.activation 7 | 8 | This module provides a set of activation operations for up/gate layers in transformer MLPs. 9 | 10 | Up/Gate output activation 11 | ------------------------- 12 | 13 | .. autosummary:: 14 | :toctree: ../generated 15 | 16 | silu_and_mul 17 | gelu_tanh_and_mul 18 | gelu_and_mul 19 | -------------------------------------------------------------------------------- /docs/api/cascade.rst: -------------------------------------------------------------------------------- 1 | .. _apicascade: 2 | 3 | flashinfer.cascade 4 | ================== 5 | 6 | .. currentmodule:: flashinfer.cascade 7 | 8 | .. _api-merge-states: 9 | 10 | Merge Attention States 11 | ---------------------- 12 | 13 | .. autosummary:: 14 | :toctree: ../generated 15 | 16 | merge_state 17 | merge_state_in_place 18 | merge_states 19 | 20 | .. _api-cascade-attention: 21 | 22 | Cascade Attention 23 | ----------------- 24 | 25 | Cascade Attention Wrapper Classes 26 | --------------------------------- 27 | 28 | .. autoclass:: MultiLevelCascadeAttentionWrapper 29 | :members: 30 | :exclude-members: begin_forward, end_forward, forward, forward_return_lse 31 | 32 | .. automethod:: __init__ 33 | 34 | .. autoclass:: BatchDecodeWithSharedPrefixPagedKVCacheWrapper 35 | :members: 36 | 37 | .. automethod:: __init__ 38 | 39 | .. autoclass:: BatchPrefillWithSharedPrefixPagedKVCacheWrapper 40 | :members: 41 | 42 | .. automethod:: __init__ 43 | -------------------------------------------------------------------------------- /docs/api/decode.rst: -------------------------------------------------------------------------------- 1 | .. _apidecode: 2 | 3 | flashinfer.decode 4 | ================= 5 | 6 | .. currentmodule:: flashinfer.decode 7 | 8 | Single Request Decoding 9 | ----------------------- 10 | 11 | .. autosummary:: 12 | :toctree: ../generated 13 | 14 | single_decode_with_kv_cache 15 | 16 | Batch Decoding 17 | -------------- 18 | 19 | .. autoclass:: BatchDecodeWithPagedKVCacheWrapper 20 | :members: 21 | :exclude-members: begin_forward, end_forward, forward, forward_return_lse 22 | 23 | .. automethod:: __init__ 24 | 25 | .. autoclass:: CUDAGraphBatchDecodeWithPagedKVCacheWrapper 26 | :members: 27 | 28 | .. automethod:: __init__ 29 | -------------------------------------------------------------------------------- /docs/api/gemm.rst: -------------------------------------------------------------------------------- 1 | .. _apigemm: 2 | 3 | flashinfer.gemm 4 | =============== 5 | 6 | .. currentmodule:: flashinfer.gemm 7 | 8 | This module provides a set of GEMM operations. 9 | 10 | FP8 Batch GEMM 11 | -------------- 12 | 13 | .. autosummary:: 14 | :toctree: ../generated 15 | 16 | gemm_fp8_nt_groupwise 17 | group_gemm_fp8_nt_groupwise 18 | bmm_fp8 19 | 20 | Grouped GEMM 21 | ------------ 22 | 23 | .. autoclass:: SegmentGEMMWrapper 24 | :members: 25 | :exclude-members: forward 26 | 27 | .. automethod:: __init__ 28 | -------------------------------------------------------------------------------- /docs/api/mla.rst: -------------------------------------------------------------------------------- 1 | .. _apimla: 2 | 3 | flashinfer.mla 4 | ============== 5 | 6 | MLA (Multi-head Latent Attention) is an attention mechanism proposed in DeepSeek series of models ( 7 | `DeepSeek-V2 `_, `DeepSeek-V3 `_, 8 | and `DeepSeek-R1 `_). 9 | 10 | .. currentmodule:: flashinfer.mla 11 | 12 | PageAttention for MLA 13 | --------------------- 14 | 15 | .. autoclass:: BatchMLAPagedAttentionWrapper 16 | :members: 17 | 18 | .. automethod:: __init__ 19 | -------------------------------------------------------------------------------- /docs/api/norm.rst: -------------------------------------------------------------------------------- 1 | .. _apinorm: 2 | 3 | flashinfer.norm 4 | =============== 5 | 6 | Kernels for normalization layers. 7 | 8 | .. currentmodule:: flashinfer.norm 9 | 10 | .. autosummary:: 11 | :toctree: ../generated 12 | 13 | rmsnorm 14 | fused_add_rmsnorm 15 | gemma_rmsnorm 16 | gemma_fused_add_rmsnorm 17 | -------------------------------------------------------------------------------- /docs/api/page.rst: -------------------------------------------------------------------------------- 1 | .. _apipage: 2 | 3 | flashinfer.page 4 | =============== 5 | 6 | Kernels to manipulate paged kv-cache. 7 | 8 | .. currentmodule:: flashinfer.page 9 | 10 | Append new K/V tensors to Paged KV-Cache 11 | ---------------------------------------- 12 | 13 | .. autosummary:: 14 | :toctree: ../generated 15 | 16 | append_paged_kv_cache 17 | append_paged_mla_kv_cache 18 | get_batch_indices_positions 19 | -------------------------------------------------------------------------------- /docs/api/prefill.rst: -------------------------------------------------------------------------------- 1 | .. _apiprefill: 2 | 3 | flashinfer.prefill 4 | ================== 5 | 6 | Attention kernels for prefill & append attention in both single request and batch serving setting. 7 | 8 | .. currentmodule:: flashinfer.prefill 9 | 10 | Single Request Prefill/Append Attention 11 | --------------------------------------- 12 | 13 | .. autosummary:: 14 | :toctree: ../generated 15 | 16 | single_prefill_with_kv_cache 17 | single_prefill_with_kv_cache_return_lse 18 | 19 | Batch Prefill/Append Attention 20 | ------------------------------ 21 | 22 | .. autoclass:: BatchPrefillWithPagedKVCacheWrapper 23 | :members: 24 | :exclude-members: begin_forward, end_forward, forward, forward_return_lse 25 | 26 | .. automethod:: __init__ 27 | 28 | .. autoclass:: BatchPrefillWithRaggedKVCacheWrapper 29 | :members: 30 | :exclude-members: begin_forward, end_forward, forward, forward_return_lse 31 | 32 | .. automethod:: __init__ 33 | -------------------------------------------------------------------------------- /docs/api/quantization.rst: -------------------------------------------------------------------------------- 1 | .. _apiquantization: 2 | 3 | flashinfer.quantization 4 | ======================= 5 | 6 | Quantization related kernels. 7 | 8 | .. currentmodule:: flashinfer.quantization 9 | 10 | .. autosummary:: 11 | :toctree: ../generated 12 | 13 | packbits 14 | segment_packbits 15 | -------------------------------------------------------------------------------- /docs/api/rope.rst: -------------------------------------------------------------------------------- 1 | .. _apirope: 2 | 3 | flashinfer.rope 4 | =============== 5 | 6 | Kernels for applying rotary embeddings. 7 | 8 | .. currentmodule:: flashinfer.rope 9 | 10 | .. autosummary:: 11 | :toctree: ../generated 12 | 13 | apply_rope_inplace 14 | apply_llama31_rope_inplace 15 | apply_rope 16 | apply_llama31_rope 17 | apply_rope_pos_ids 18 | apply_rope_pos_ids_inplace 19 | apply_llama31_rope_pos_ids 20 | apply_llama31_rope_pos_ids_inplace 21 | apply_rope_with_cos_sin_cache 22 | apply_rope_with_cos_sin_cache_inplace 23 | -------------------------------------------------------------------------------- /docs/api/sampling.rst: -------------------------------------------------------------------------------- 1 | .. _apisampling: 2 | 3 | flashinfer.sampling 4 | =================== 5 | 6 | Kernels for LLM sampling. 7 | 8 | .. currentmodule:: flashinfer.sampling 9 | 10 | .. autosummary:: 11 | :toctree: ../generated 12 | 13 | sampling_from_probs 14 | top_p_sampling_from_probs 15 | top_k_sampling_from_probs 16 | min_p_sampling_from_probs 17 | top_k_top_p_sampling_from_logits 18 | top_k_top_p_sampling_from_probs 19 | top_p_renorm_probs 20 | top_k_renorm_probs 21 | top_k_mask_logits 22 | chain_speculative_sampling 23 | -------------------------------------------------------------------------------- /docs/api/sparse.rst: -------------------------------------------------------------------------------- 1 | .. _apisparse: 2 | 3 | flashinfer.sparse 4 | ================= 5 | 6 | Kernels for block sparse flashattention. 7 | 8 | .. currentmodule:: flashinfer.sparse 9 | 10 | .. autoclass:: BlockSparseAttentionWrapper 11 | :members: 12 | :exclude-members: begin_forward, end_forward, forward, forward_return_lse 13 | 14 | .. automethod:: __init__ 15 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | 5 | # import tlcpack_sphinx_addon 6 | # Configuration file for the Sphinx documentation builder. 7 | # 8 | # For the full list of built-in configuration values, see the documentation: 9 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 10 | 11 | # -- Project information ----------------------------------------------------- 12 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 13 | 14 | root = Path(__file__).parents[1].resolve() 15 | sys.path.append(str(root)) 16 | os.environ["BUILD_DOC"] = "1" 17 | autodoc_mock_imports = [ 18 | "torch", 19 | "triton", 20 | "flashinfer._build_meta", 21 | ] 22 | 23 | project = "FlashInfer" 24 | author = "FlashInfer Contributors" 25 | copyright = f"2023-2024, {author}" 26 | 27 | package_version = (root / "version.txt").read_text().strip() 28 | version = package_version 29 | release = package_version 30 | 31 | # -- General configuration --------------------------------------------------- 32 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 33 | 34 | extensions = [ 35 | "sphinx_tabs.tabs", 36 | "sphinx.ext.autodoc", 37 | "sphinx.ext.napoleon", 38 | "sphinx.ext.autosummary", 39 | "sphinx.ext.mathjax", 40 | ] 41 | 42 | autodoc_default_flags = ["members"] 43 | autosummary_generate = True 44 | 45 | source_suffix = [".rst"] 46 | 47 | language = "en" 48 | 49 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 50 | 51 | # The name of the Pygments (syntax highlighting) style to use. 52 | pygments_style = "sphinx" 53 | 54 | # A list of ignored prefixes for module index sorting. 55 | # If true, `todo` and `todoList` produce output, else they produce nothing. 56 | todo_include_todos = False 57 | 58 | # -- Options for HTML output ---------------------------------------------- 59 | 60 | html_theme = "furo" # "sphinx_rtd_theme" 61 | 62 | templates_path = [] 63 | 64 | html_static_path = [] 65 | 66 | html_theme_options = { 67 | "logo_only": True, 68 | } 69 | 70 | html_static_path = ["_static"] 71 | html_theme_options = { 72 | "light_logo": "FlashInfer-white-background.png", 73 | "dark_logo": "FlashInfer-black-background.png", 74 | } 75 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. FlashInfer documentation master file, created by 2 | sphinx-quickstart on Sat Jan 20 12:31:26 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to FlashInfer's documentation! 7 | ====================================== 8 | 9 | `Blog `_ | `Discussion Forum `_ | `GitHub `_ 10 | 11 | FlashInfer is a library and kernel generator for Large Language Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, PageAttention and LoRA. FlashInfer focus on LLM serving and inference, and delivers state-of-the-art performance across diverse scenarios. 12 | 13 | .. toctree:: 14 | :maxdepth: 2 15 | :caption: Get Started 16 | 17 | installation 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :caption: Tutorials 22 | 23 | tutorials/recursive_attention 24 | tutorials/kv_layout 25 | 26 | .. toctree:: 27 | :maxdepth: 2 28 | :caption: PyTorch API Reference 29 | 30 | api/decode 31 | api/prefill 32 | api/cascade 33 | api/mla 34 | api/sparse 35 | api/page 36 | api/sampling 37 | api/logits_processor 38 | api/gemm 39 | api/norm 40 | api/rope 41 | api/activation 42 | api/quantization 43 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | furo == 2024.8.6 2 | sphinx == 8.1.3 3 | sphinx-reredirects == 0.1.5 4 | sphinx-tabs == 3.4.5 5 | sphinx-toolbox == 3.8.1 6 | -------------------------------------------------------------------------------- /flashinfer/jit/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by FlashInfer team. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import pathlib 18 | 19 | import torch 20 | 21 | 22 | def write_if_different(path: pathlib.Path, content: str) -> None: 23 | if path.exists(): 24 | with open(path, "r") as f: 25 | if f.read() == content: 26 | return 27 | else: 28 | path.parent.mkdir(parents=True, exist_ok=True) 29 | with open(path, "w") as f: 30 | f.write(content) 31 | 32 | 33 | dtype_map = { 34 | torch.float16: "half", 35 | torch.bfloat16: "nv_bfloat16", 36 | torch.float8_e4m3fn: "__nv_fp8_e4m3", 37 | torch.float8_e5m2: "__nv_fp8_e5m2", 38 | torch.int8: "int8_t", 39 | torch.uint8: "uint8_t", 40 | torch.int32: "int32_t", 41 | torch.uint32: "uint32_t", 42 | torch.int64: "int64_t", 43 | torch.uint64: "uint64_t", 44 | } 45 | 46 | filename_safe_dtype_map = { 47 | torch.float16: "f16", 48 | torch.bfloat16: "bf16", 49 | torch.float8_e4m3fn: "e4m3", 50 | torch.float8_e5m2: "e5m2", 51 | torch.int8: "i8", 52 | torch.uint8: "u8", 53 | torch.int32: "i32", 54 | torch.uint32: "u32", 55 | torch.int64: "i64", 56 | torch.uint64: "u64", 57 | } 58 | 59 | pos_encoding_mode_literal = { 60 | 0: "PosEncodingMode::kNone", 61 | 1: "PosEncodingMode::kRoPELlama", 62 | 2: "PosEncodingMode::kALiBi", 63 | } 64 | 65 | mask_mode_literal = { 66 | 0: "MaskMode::kNone", 67 | 1: "MaskMode::kCausal", 68 | 2: "MaskMode::kCustom", 69 | 3: "MaskMode::kMultiItemScoring", 70 | } 71 | -------------------------------------------------------------------------------- /flashinfer/logits_processor/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 by FlashInfer team. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from .compiler import CompileError as CompileError 18 | from .compiler import Compiler as Compiler 19 | from .compiler import compile_pipeline as compile_pipeline 20 | from .fusion_rules import FusionRule as FusionRule 21 | from .legalization import LegalizationError as LegalizationError 22 | from .legalization import legalize_processors as legalize_processors 23 | from .op import Op as Op 24 | from .op import ParameterizedOp as ParameterizedOp 25 | from .pipeline import LogitsPipe as LogitsPipe 26 | from .processors import LogitsProcessor as LogitsProcessor 27 | from .processors import MinP as MinP 28 | from .processors import Sample as Sample 29 | from .processors import Softmax as Softmax 30 | from .processors import Temperature as Temperature 31 | from .processors import TopK as TopK 32 | from .processors import TopP as TopP 33 | from .types import TaggedTensor as TaggedTensor 34 | from .types import TensorType as TensorType 35 | -------------------------------------------------------------------------------- /flashinfer/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flashinfer-ai/flashinfer/4e8bb778f1522de6becbcf4b732a22d48f7d72b0/flashinfer/py.typed -------------------------------------------------------------------------------- /flashinfer/triton/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cascade # noqa: F401 2 | from . import sm_constraint_gemm # noqa: F401 3 | -------------------------------------------------------------------------------- /flashinfer/triton/activation.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from typing import Optional 3 | 4 | import torch 5 | import triton # type: ignore[import] 6 | 7 | from flashinfer.triton.kernels.activation import silu_and_mul_kernel 8 | 9 | 10 | def silu_and_mul( 11 | x: torch.Tensor, 12 | x_scale: Optional[torch.Tensor] = None, 13 | o_scale: Optional[torch.Tensor] = None, 14 | dtype: Optional[torch.dtype] = None, 15 | ) -> torch.Tensor: 16 | """Sigmoid Linear Unit and Multiplication 17 | 18 | Computes `silu(x[:,:d]) * x[:, d:]`, where `d = x.shape[-1] // 2. 19 | 20 | If the scale of `x` is `x_scale`, the scale applied to the output 21 | is the square of that, as the sigmoid function ranges in (0, 1). 22 | 23 | Args: 24 | x: The input tensor, of shape `(b, 2 * d)`. 25 | x_scale: An optional scale which was applied to `x`. 26 | o_scale: The scale to apply to the output. 27 | dtype: The desired output dtype. 28 | 29 | Returns: 30 | The output activation, of shape `(b, d)`. 31 | """ 32 | 33 | b, n = x.shape 34 | 35 | assert n % 2 == 0 36 | d = n // 2 37 | 38 | o_dtype = dtype or x.dtype 39 | o = torch.empty((b, d), dtype=o_dtype, device=x.device) 40 | 41 | def grid(meta: Mapping[str, int]) -> tuple[int, int]: 42 | return (b, triton.cdiv(d, meta["BLOCK_SIZE"])) 43 | 44 | silu_and_mul_kernel[grid]( 45 | o_ptr=o, 46 | o_stride=o.stride(0), 47 | o_scale_ptr=o_scale, 48 | x_ptr=x, 49 | x_stride=x.stride(0), 50 | x_scale_ptr=x_scale, 51 | d=d, 52 | BLOCK_SIZE=1024, 53 | HAS_X_SCALE=x_scale is not None, 54 | HAS_O_SCALE=o_scale is not None, 55 | ) 56 | 57 | return o 58 | -------------------------------------------------------------------------------- /flashinfer/triton/kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flashinfer-ai/flashinfer/4e8bb778f1522de6becbcf4b732a22d48f7d72b0/flashinfer/triton/kernels/__init__.py -------------------------------------------------------------------------------- /flashinfer/triton/kernels/activation.py: -------------------------------------------------------------------------------- 1 | import triton # type: ignore[import] 2 | import triton.language as tl # type: ignore[import] 3 | 4 | from flashinfer.triton.kernels.quant import scale_and_clamp 5 | 6 | 7 | @triton.jit 8 | def silu_and_mul_kernel( 9 | o_ptr, 10 | o_stride, 11 | o_scale_ptr, 12 | x_ptr, 13 | x_stride, 14 | x_scale_ptr, 15 | d, 16 | BLOCK_SIZE: tl.constexpr, 17 | HAS_X_SCALE: tl.constexpr, 18 | HAS_O_SCALE: tl.constexpr, 19 | ) -> None: 20 | """Sigmoid Linear Unit and Multiplication Kernel 21 | 22 | Args: 23 | o_ptr: Pointer to the 2D output tensor. 24 | o_stride: Output tensor stride. 25 | o_scale_ptr: The optional, known scale of the output activations. 26 | x_ptr: Pointer to the 2D input tensor. 27 | x_stride: Input tensor stride. 28 | x_scale_ptr: The optional, known scale of the input tensor. 29 | d: The number of elements along the second dimension. 30 | BLOCK_SIZE: Tunable block size to process in each kernel. 31 | 32 | Operating on a 2D grid, computes the following: 33 | 34 | ``` 35 | out[i, j] = sigmoid(x[i, j]) * x[i, j] * x[i, j + d] 36 | ``` 37 | 38 | If scales are provided, the input and output tensors are scaled. 39 | """ 40 | 41 | i = tl.program_id(axis=0).to(tl.int64) 42 | j = tl.program_id(axis=1) 43 | 44 | o_row_ptr = o_ptr + o_stride * i 45 | x_row_ptr = x_ptr + x_stride * i 46 | 47 | offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 48 | mask = offsets < d 49 | 50 | a = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32) 51 | b = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32) 52 | 53 | if HAS_X_SCALE: 54 | x_scale = tl.load(x_scale_ptr) 55 | a *= x_scale 56 | b *= x_scale 57 | 58 | result = tl.sigmoid(a) * a * b 59 | 60 | if HAS_O_SCALE: 61 | o_scale = tl.load(o_scale_ptr) 62 | result = scale_and_clamp(result, o_scale, o_ptr.dtype.element_ty) 63 | 64 | tl.store(o_row_ptr + offsets, result, mask=mask) 65 | -------------------------------------------------------------------------------- /flashinfer/triton/kernels/quant.py: -------------------------------------------------------------------------------- 1 | import triton # type: ignore[import] 2 | import triton.language as tl # type: ignore[import] 3 | 4 | 5 | @triton.jit 6 | def scale_and_clamp(x, scale, dtype): 7 | """Scales a value and clamps it to the range of the target dtype. 8 | 9 | This function hard-wires the upper/lower bounds in order to be 10 | compatible with both `torch.compile` and `triton.jit`. 11 | """ 12 | if dtype == tl.float8e4nv: 13 | clamp_min = -448.0 14 | clamp_max = 448.0 15 | elif dtype == tl.float8e5: 16 | clamp_min = -57344.0 17 | clamp_max = 57344.0 18 | elif dtype == tl.float16: 19 | clamp_min = -65504.0 20 | clamp_max = 65504.0 21 | elif dtype == tl.bfloat16: 22 | clamp_min = -3.3895313892515355e38 23 | clamp_max = 3.3895313892515355e38 24 | else: 25 | tl.static_assert(False, f"Unsupported dtype: {dtype}") 26 | 27 | return tl.clamp(x.to(tl.float32) * scale, clamp_min, clamp_max).to(dtype) 28 | -------------------------------------------------------------------------------- /flashinfer/triton/page.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 by FlashInfer team. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import triton 18 | import triton.language as tl 19 | 20 | 21 | @triton.jit 22 | def get_batch_indices_positions_kernel( 23 | append_indptr, 24 | seq_lens_ptr, 25 | batch_indices_ptr, 26 | positions_ptr, 27 | num_stages: tl.constexpr, 28 | ): 29 | batch_idx = tl.program_id(0) 30 | 31 | batch_start = tl.load(append_indptr + batch_idx) 32 | batch_end = tl.load(append_indptr + batch_idx + 1) 33 | seq_len = tl.load(seq_lens_ptr + batch_idx) 34 | 35 | for i in tl.range(batch_start, batch_end, 128, num_stages=num_stages): 36 | offsets = tl.arange(0, 128) + i 37 | mask = offsets < batch_end 38 | tl.store(batch_indices_ptr + offsets, batch_idx, mask) 39 | tl.store(positions_ptr + offsets, offsets + seq_len - batch_end, mask) 40 | -------------------------------------------------------------------------------- /flashinfer/triton/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def check_input(x: torch.Tensor): 7 | assert x.is_cuda, f"{str(x)} must be a CUDA Tensor" 8 | assert x.is_contiguous(), f"{str(x)} must be contiguous" 9 | 10 | 11 | def check_dim(d, x: torch.Tensor): 12 | assert x.dim() == d, f"{str(x)} must be a {d}D tensor" 13 | 14 | 15 | def check_shape(a: torch.Tensor, b: torch.Tensor): 16 | assert a.dim() == b.dim(), "tensors should have same dim" 17 | for i in range(a.dim()): 18 | assert a.size(i) == b.size( 19 | i 20 | ), f"tensors shape mismatch, {a.size()} and {b.size()}" 21 | 22 | 23 | def check_device(tensors: List[torch.Tensor]): 24 | device = tensors[0].device 25 | for t in tensors: 26 | assert ( 27 | t.device == device 28 | ), f"All tensors should be on the same device, but got {device} and {t.device}" 29 | -------------------------------------------------------------------------------- /include/flashinfer/allocator.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_ALLOCATOR_H_ 17 | #define FLASHINFER_ALLOCATOR_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "exception.h" 23 | 24 | namespace flashinfer { 25 | 26 | // create a function that returns T* from base pointer and offset 27 | template 28 | T* GetPtrFromBaseOffset(void* base_ptr, int64_t offset) { 29 | return reinterpret_cast(reinterpret_cast(base_ptr) + offset); 30 | } 31 | 32 | struct AlignedAllocator { 33 | void* base_ptr; 34 | void* cur_ptr; 35 | size_t remaining_space; 36 | AlignedAllocator(void* buf, size_t space) : base_ptr(buf), cur_ptr(buf), remaining_space(space) {} 37 | template 38 | T* aligned_alloc(size_t size, size_t alignment, std::string name) { 39 | if (std::align(alignment, size, cur_ptr, remaining_space)) { 40 | T* result = reinterpret_cast(cur_ptr); 41 | cur_ptr = (char*)cur_ptr + size; 42 | remaining_space -= size; 43 | return result; 44 | } else { 45 | std::ostringstream oss; 46 | oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment " 47 | << alignment << " in AlignedAllocator"; 48 | FLASHINFER_ERROR(oss.str()); 49 | } 50 | return nullptr; 51 | } 52 | 53 | size_t aligned_alloc_offset(size_t size, size_t alignment, std::string name) { 54 | return (char*)aligned_alloc(size, alignment, name) - (char*)base_ptr; 55 | } 56 | 57 | size_t num_allocated_bytes() { return (char*)cur_ptr - (char*)base_ptr; } 58 | }; 59 | 60 | } // namespace flashinfer 61 | 62 | #endif // FLASHINFER_ALLOCATOR_H_ 63 | -------------------------------------------------------------------------------- /include/flashinfer/attention/heap.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_ATTENTION_HEAP_H 17 | #define FLASHINFER_ATTENTION_HEAP_H 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | namespace flashinfer { 25 | 26 | /*! 27 | * \brief Heap data structure for (index, value) pairs 28 | * \note minimal element on top 29 | */ 30 | class MinHeap { 31 | public: 32 | // first: index, second: cost 33 | using Element = std::pair; 34 | 35 | MinHeap(int capacity) : heap_(capacity) { 36 | for (int i = 0; i < capacity; ++i) { 37 | heap_[i] = std::make_pair(i, 0.f); 38 | } 39 | } 40 | 41 | void insert(const Element& element) { 42 | heap_.push_back(element); 43 | std::push_heap(heap_.begin(), heap_.end(), compare); 44 | } 45 | 46 | Element pop() { 47 | std::pop_heap(heap_.begin(), heap_.end(), compare); 48 | Element minElement = heap_.back(); 49 | heap_.pop_back(); 50 | return minElement; 51 | } 52 | 53 | std::vector getHeap() const { return heap_; } 54 | 55 | private: 56 | // Custom comparator for the min-heap: compare based on 'val' in the pair 57 | static bool compare(const Element& a, const Element& b) { 58 | return a.second > b.second; // create a min-heap based on val 59 | } 60 | 61 | std::vector heap_; 62 | }; 63 | 64 | } // namespace flashinfer 65 | 66 | #endif // FLASHINFER_ATTENTION_HEAP_H 67 | -------------------------------------------------------------------------------- /include/flashinfer/attention/mask.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_ATTENTION_MASK_CUH_ 17 | #define FLASHINFER_ATTENTION_MASK_CUH_ 18 | 19 | namespace flashinfer { 20 | 21 | enum class MaskMode { 22 | kNone = 0U, // No mask 23 | kCausal = 1U, // Causal mask 24 | kCustom = 2U, // Custom mask 25 | kMultiItemScoring = 3U, 26 | }; 27 | 28 | } // namespace flashinfer 29 | 30 | #endif // FLASHINFER_ATTENTION_MASK_CUH_ 31 | -------------------------------------------------------------------------------- /include/flashinfer/attention/mla_params.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_MLA_PARAMS_CUH_ 17 | #define FLASHINFER_MLA_PARAMS_CUH_ 18 | #include 19 | 20 | #include "../fastdiv.cuh" 21 | #include "../profiler.cuh" 22 | 23 | namespace flashinfer { 24 | 25 | template 26 | struct MLAParams { 27 | using DTypeQ = DTypeQ_; 28 | using DTypeKV = DTypeKV_; 29 | using DTypeO = DTypeO_; 30 | using IdType = IdType_; 31 | 32 | DTypeQ* q_nope; 33 | DTypeQ* q_pe; 34 | DTypeKV* ckv; 35 | DTypeKV* kpe; 36 | DTypeO* partial_o; 37 | float* partial_lse; 38 | DTypeO* final_o; 39 | float* final_lse; 40 | 41 | IdType* q_indptr; 42 | IdType* kv_indptr; 43 | IdType* partial_indptr; 44 | IdType* merge_packed_offset_start; 45 | IdType* merge_packed_offset_end; 46 | IdType* merge_partial_packed_offset_start; 47 | IdType* merge_partial_packed_offset_end; 48 | IdType* merge_partial_stride; 49 | IdType* kv_indices; 50 | IdType* q_len; 51 | IdType* kv_len; 52 | IdType* q_start; 53 | IdType* kv_start; 54 | IdType* kv_end; 55 | IdType* work_indptr; 56 | 57 | PROFILER_PARAMS_DECL 58 | 59 | uint_fastdiv block_size; 60 | uint_fastdiv num_heads; 61 | 62 | uint32_t q_nope_stride_n; 63 | uint32_t q_nope_stride_h; 64 | uint32_t q_pe_stride_n; 65 | uint32_t q_pe_stride_h; 66 | uint32_t ckv_stride_page; 67 | uint32_t ckv_stride_n; 68 | uint32_t kpe_stride_page; 69 | uint32_t kpe_stride_n; 70 | uint32_t o_stride_n; 71 | uint32_t o_stride_h; 72 | 73 | float sm_scale; 74 | }; 75 | 76 | }; // namespace flashinfer 77 | 78 | #endif // FLASHINFER_MLA_PARAMS_CUH_ 79 | -------------------------------------------------------------------------------- /include/flashinfer/attention_impl.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_ATTENTION_IMPL_CUH_ 17 | #define FLASHINFER_ATTENTION_IMPL_CUH_ 18 | 19 | #include "attention/cascade.cuh" 20 | #include "attention/decode.cuh" 21 | #include "attention/default_decode_params.cuh" 22 | #include "attention/default_prefill_params.cuh" 23 | #include "attention/prefill.cuh" 24 | #include "attention/variants.cuh" 25 | 26 | #endif // FLASHINFER_ATTENTION_IMPL_CUH_ 27 | -------------------------------------------------------------------------------- /include/flashinfer/exception.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_EXCEPTION_H_ 17 | #define FLASHINFER_EXCEPTION_H_ 18 | 19 | #include 20 | #include 21 | 22 | #define FLASHINFER_ERROR(message) throw flashinfer::Error(__FUNCTION__, __FILE__, __LINE__, message) 23 | 24 | template 25 | void write_to_stream(std::ostringstream& oss, T&& val) { 26 | oss << std::forward(val); 27 | } 28 | 29 | template 30 | void write_to_stream(std::ostringstream& oss, T&& val, Args&&... args) { 31 | oss << std::forward(val) << " "; 32 | write_to_stream(oss, std::forward(args)...); 33 | } 34 | 35 | #define FLASHINFER_CHECK(condition, ...) \ 36 | if (!(condition)) { \ 37 | std::ostringstream oss; \ 38 | write_to_stream(oss, __VA_ARGS__); \ 39 | std::cerr << oss.str() << std::endl; \ 40 | FLASHINFER_ERROR(oss.str()); \ 41 | } 42 | 43 | namespace flashinfer { 44 | class Error : public std::exception { 45 | private: 46 | std::string message_; 47 | 48 | public: 49 | Error(const std::string& func, const std::string& file, int line, const std::string& message) { 50 | std::ostringstream oss; 51 | oss << "Error in function '" << func << "' " 52 | << "at " << file << ":" << line << ": " << message; 53 | message_ = oss.str(); 54 | } 55 | 56 | virtual const char* what() const noexcept override { return message_.c_str(); } 57 | }; 58 | 59 | } // namespace flashinfer 60 | 61 | #endif // FLASHINFER_EXCEPTION_H_ 62 | -------------------------------------------------------------------------------- /include/flashinfer/frag_layout_swizzle.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ 17 | #define FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ 18 | 19 | #include 20 | 21 | #include 22 | 23 | __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) { 24 | uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1); 25 | x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x5410 : 0x3276); 26 | tmp = __shfl_xor_sync(0xffffffff, x, 0x2); 27 | x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x5410 : 0x3276); 28 | return x; 29 | } 30 | 31 | __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) { 32 | uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); 33 | x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175); 34 | tmp = __shfl_xor_sync(0xffffffff, x, 0x8); 35 | x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276); 36 | tmp = __shfl_xor_sync(0xffffffff, x, 0x10); 37 | x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276); 38 | return x; 39 | } 40 | 41 | #endif // FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ 42 | -------------------------------------------------------------------------------- /include/flashinfer/gemm/group_gemm_lora.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_GROUP_GEMM_LORA_CUH_ 17 | #define FLASHINFER_GROUP_GEMM_LORA_CUH_ 18 | 19 | namespace flashinfer { 20 | 21 | namespace group_gemm { 22 | 23 | // TODO(Zihao): port punica's sgmv kernel 24 | 25 | } // namespace group_gemm 26 | 27 | } // namespace flashinfer 28 | 29 | #endif // FLASHINFER_GROUP_GEMM_LORA_CUH_ 30 | -------------------------------------------------------------------------------- /include/flashinfer/gemm/group_gemv.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_GROUP_GEMV_CUH_ 17 | #define FLASHINFER_GROUP_GEMV_CUH_ 18 | 19 | namespace flashinfer { 20 | 21 | namespace group_gemm { 22 | 23 | // TODO(Zihao): port punica's bgmv kernel 24 | 25 | } // namespace group_gemm 26 | 27 | } // namespace flashinfer 28 | 29 | #endif // FLASHINFER_GROUP_GEMV_CUH_ 30 | -------------------------------------------------------------------------------- /include/flashinfer/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_LOGGING_H_ 17 | #define FLASHINFER_LOGGING_H_ 18 | 19 | #include 20 | #include 21 | 22 | #define FLASHINFER_LOG_TRACE(...) spdlog::trace(__VA_ARGS__) 23 | #define FLASHINFER_LOG_DEBUG(...) spdlog::debug(__VA_ARGS__) 24 | #define FLASHINFER_LOG_INFO(...) spdlog::info(__VA_ARGS__) 25 | #define FLASHINFER_LOG_WARN(...) spdlog::warn(__VA_ARGS__) 26 | #define FLASHINFER_LOG_ERROR(...) spdlog::error(__VA_ARGS__) 27 | #define FLASHINFER_LOG_CRITICAL(...) spdlog::critical(__VA_ARGS__) 28 | 29 | namespace flashinfer { 30 | 31 | namespace logging { 32 | 33 | inline void set_log_level(spdlog::level::level_enum lvl) { 34 | auto fmt = "[%Y-%m-%d %H:%M:%S.%f] [%n] [%^%l%$] %v"; 35 | auto console_sink = std::make_shared(); 36 | console_sink->set_pattern(fmt); 37 | console_sink->set_level(lvl); 38 | spdlog::set_default_logger(std::make_shared("flashinfer", console_sink)); 39 | } 40 | 41 | } // namespace logging 42 | 43 | } // namespace flashinfer 44 | 45 | #endif // FLASHINFER_LOGGING_H_ 46 | -------------------------------------------------------------------------------- /include/flashinfer/semaphore_utils.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef FLASHINFER_SEMAPHORE_UTILS_CUH 17 | #define FLASHINFER_SEMAPHORE_UTILS_CUH 18 | 19 | #include 20 | 21 | #include "utils.cuh" 22 | 23 | namespace flashinfer { 24 | 25 | template 26 | __global__ void zero_gmem_semaphore(T* semaphore, int size) { 27 | for (int i = threadIdx.x; i < size; i += blockDim.x) { 28 | semaphore[i] = 0; 29 | } 30 | } 31 | 32 | template 33 | cudaError_t zero_gmem_semaphore_launcher(T* semaphore, int size, bool enable_pdl, 34 | cudaStream_t stream) { 35 | cudaLaunchConfig_t config = {0}; 36 | config.gridDim = 1; 37 | config.blockDim = 128; 38 | config.dynamicSmemBytes = 0; 39 | config.stream = stream; 40 | cudaLaunchAttribute attrs[1]; 41 | attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; 42 | attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; 43 | config.numAttrs = 1; 44 | config.attrs = attrs; 45 | 46 | FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, zero_gmem_semaphore, semaphore, size)); 47 | 48 | return cudaSuccess; 49 | } 50 | 51 | } // namespace flashinfer 52 | 53 | #endif // FLASHINFER_SEMAPHORE_UTILS_CUH 54 | -------------------------------------------------------------------------------- /include/flashinfer/trtllm/fmha/decoder_params.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #pragma once 17 | #include "../common.h" 18 | using XQADataType = Data_type; 19 | 20 | struct XQAParams { 21 | XQADataType data_type = DATA_TYPE_FP16; 22 | XQADataType kv_cache_data_type = DATA_TYPE_FP16; 23 | void* output = nullptr; 24 | void const* qHeads = nullptr; 25 | // float const* kv_scale_quant_orig = nullptr; 26 | float kv_scale_quant_orig = 1.f; 27 | uint32_t* semaphores = nullptr; 28 | void* workspaces = nullptr; 29 | uint32_t batch_size = 0; 30 | int32_t beam_width = 0; 31 | 32 | int32_t num_q_heads = 0; 33 | int32_t num_kv_heads = 0; 34 | int32_t head_size = 0; 35 | int timestep = 0; 36 | 37 | // Paged KV cache parameters. 38 | int generation_input_length; 39 | bool paged_kv_cache = true; // always true 40 | int tokens_per_block; 41 | int max_blocks_per_sequence; 42 | bool multi_block_mode; 43 | bool multi_query_tokens = false; 44 | }; 45 | -------------------------------------------------------------------------------- /include/flashinfer/trtllm/fmha/fmhaRunner.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | #include "fmhaKernels.cuh" 22 | #include "fmhaRunnerParams.h" 23 | 24 | class TllmGenFmhaRunner { 25 | public: 26 | // Constructor. 27 | explicit TllmGenFmhaRunner(Data_type dtypeQ, Data_type dtypeKv, Data_type dtypeOut); 28 | 29 | TllmGenFmhaRunner() = default; 30 | 31 | // Check if fmha is supported. 32 | bool isSupported(TllmGenFmhaRunnerParams const& runnerParams) const; 33 | 34 | // Check if fmha is supported with additional info. 35 | std::pair isSupportedWithInfo( 36 | TllmGenFmhaRunnerParams const& runnerParams) const; 37 | 38 | // Run the fmha kernel. 39 | void run(TllmGenFmhaRunnerParams const&); 40 | 41 | private: 42 | // The input/output datatype. 43 | Data_type mDtypeQ, mDtypeKv, mDtypeOut; 44 | // The SM version. 45 | int mSM; 46 | // The class that stores all the kernels. 47 | TllmGenFmhaKernel const* mKernel; 48 | }; 49 | -------------------------------------------------------------------------------- /include/flashinfer/trtllm/fmha/gen_kernel_launcher.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #pragma once 17 | #include 18 | #include 19 | 20 | #include "decoder_impl_common.h" 21 | #include "decoder_params.h" 22 | #include "pytorch_extension_utils.h" 23 | 24 | void trtllm_paged_attention(at::Tensor& out, at::Tensor& query, at::Tensor& key_value_cache, 25 | at::Tensor& workspace_buffer, int64_t num_heads, int64_t num_kv_heads, 26 | double scale, at::Tensor& block_tables, at::Tensor& seq_lens, 27 | int64_t block_size, int64_t max_seq_len, 28 | const std::string kv_cache_dtype, double k_scale, double v_scale); 29 | -------------------------------------------------------------------------------- /licenses/LICENSE.cutlass.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | SPDX-License-Identifier: BSD-3-Clause 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /licenses/LICENSE.flashattention3.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /profiler/.gitignore: -------------------------------------------------------------------------------- 1 | *.perfetto-trace 2 | -------------------------------------------------------------------------------- /profiler/README.md: -------------------------------------------------------------------------------- 1 | # FlashInfer Profiler (Experimental) 2 | 3 | FlashInfer Profiler is a tool for intra-kernel profiling for diagnosing kernel performance. 4 | 5 | ## Prerequisites 6 | 7 | The Chrome tracing do not support overlapping events inside a single thread (the wgmma instructions are asynchronous, and the execution of several wgmma instructions might overlap). We use our fork of [tg4perfetto](https://github.com/ihavnoid/tg4perfetto), modified to use the latest protobuf, to generate perfetto traces. 8 | 9 | ```bash 10 | # pip install from github 11 | pip install protobuf 12 | pip install git+https://github.com/flashinfer-ai/tg4perfetto.git 13 | ``` 14 | 15 | ## Examples 16 | 17 | ### MLA 18 | 19 | Run the following command to profile the MLA kernel for different configurations. 20 | 21 | ```bash 22 | python mla.py --batch-size 64 --seq-len 1024 --num-heads 128 --profiler-buffer-size 1048576 23 | ``` 24 | 25 | The generated traces will be saved in the current directory. 26 | 27 | ```bash 28 | ls *.perfetto-trace 29 | ``` 30 | 31 | User can use [ui.perfetto.dev](https://ui.perfetto.dev/) to visualize the traces. 32 | 33 | Below is a screenshot of the trace generated by the above command. 34 | 35 | ![MLA Trace](https://raw.githubusercontent.com/flashinfer-ai/web-data/main/examples/flashinfer-profiler-mla.png) 36 | 37 | ## Limitations 38 | 39 | - The instrumentation is intrusive (we insert `__threadfence_block()` in the kernel to avoid instruction reordering) and will slow down the kernel execution. 40 | 41 | ## Acknowledgements 42 | 43 | This work is in-part inspired by [Mosaic GPU DSL](https://github.com/jax-ml/jax/tree/main/jax/experimental/mosaic)'s warp-level profiling, as well as [Proton Intra-kernel profiling](https://github.com/triton-lang/triton/pull/4861) in Triton. 44 | 45 | We thank [tg4perfetto](https://github.com/ihavnoid/tg4perfetto) for providing examples of generating perfetto traces from python. 46 | -------------------------------------------------------------------------------- /scripts/ci-flashinfer.env.example: -------------------------------------------------------------------------------- 1 | RUNNER_SCOPE=repo 2 | REPO_URL=https://github.com/flashinfer-ai/flashinfer 3 | #LABELS=gpu,sm80 4 | ACCESS_TOKEN=foo-access-token 5 | RUNNER_WORKDIR=/tmp/ci-flashinfer 6 | CI_RUNNER_CACHE_DIR=/data/ci-flashinfer-cache 7 | DISABLE_AUTO_UPDATE=1 8 | EPHEMERAL=1 9 | -------------------------------------------------------------------------------- /scripts/ci-flashinfer.service: -------------------------------------------------------------------------------- 1 | # https://github.com/myoung34/docker-github-actions-runner/wiki/Usage 2 | # Install with: 3 | # install -m 644 ci-flashinfer.service $HOME/.config/systemd/user/ 4 | # systemctl --user daemon-reload 5 | # Run with: 6 | # systemctl --user start ci-flashinfer 7 | # Stop with: 8 | # systemctl --user stop ci-flashinfer 9 | # See live logs with: 10 | # journalctl -f -u ci-flashinfer.service --no-hostname --no-tail 11 | [Unit] 12 | Description=Ephemeral GitHub Actions Runner Container for flashinfer-ai/flashinfer 13 | [Service] 14 | TimeoutStartSec=0 15 | Restart=always 16 | ExecStartPre=-/usr/bin/docker stop %N 17 | ExecStartPre=-/usr/bin/docker rm %N 18 | ExecStartPre=-/usr/bin/docker pull myoung34/github-runner:latest 19 | ExecStart=/usr/bin/docker run --rm \ 20 | --env-file %h/.config/ci-flashinfer.env \ 21 | -e RUNNER_NAME=%H \ 22 | -e CI_UID=%U \ 23 | -e CI_GID=%G \ 24 | -v /var/run/docker.sock:/var/run/docker.sock \ 25 | -v /tmp/ci-flashinfer:/tmp/ci-flashinfer \ 26 | --name %N \ 27 | myoung34/github-runner:latest 28 | -------------------------------------------------------------------------------- /scripts/formatter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Formatting CUDA files" 3 | find include/ -regex '.*\.\(h\|cuh\|cu\|cc\)' | xargs clang-format -i 4 | find src/ -regex '.*\.\(h\|cuh\|cu\|cc\)' -not -path './src/generated' | xargs clang-format -i 5 | find csrc/ -regex '.*\.\(h\|cuh\|cu\|cc\)' -not -path './csrc/generated' | xargs clang-format -i 6 | echo "Formatting Python files" 7 | find flashinfer/ -regex '.*\.\(py\)' | xargs black 8 | -------------------------------------------------------------------------------- /scripts/run-ci-build-wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # adapted from https://github.com/punica-ai/punica/blob/591b59899f0a20760821785d06b331c8a2e5cb86/ci/run-ci-build-wheel.bash 3 | set -e 4 | 5 | assert_env() { 6 | local var_name="$1" 7 | if [ -z "${!var_name}" ]; then 8 | echo "Error: Environment variable '$var_name' is not set." 9 | exit 1 10 | fi 11 | } 12 | 13 | assert_env FLASHINFER_CI_CACHE 14 | assert_env FLASHINFER_CI_CUDA_VERSION 15 | assert_env FLASHINFER_CI_PYTHON_VERSION 16 | assert_env FLASHINFER_CI_TORCH_VERSION 17 | assert_env TORCH_CUDA_ARCH_LIST 18 | PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" 19 | export CONDA_pkgs_dirs="${FLASHINFER_CI_CACHE}/conda-pkgs" 20 | export XDG_CACHE_HOME="${FLASHINFER_CI_CACHE}/xdg-cache" 21 | mkdir -p "$CONDA_pkgs_dirs" "$XDG_CACHE_HOME" 22 | export HOME=/tmp/home 23 | mkdir -p $HOME 24 | export PATH="$HOME/.local/bin:$PATH" 25 | CUDA_MAJOR="${FLASHINFER_CI_CUDA_VERSION%.*}" 26 | CUDA_MINOR="${FLASHINFER_CI_CUDA_VERSION#*.}" 27 | TORCH_MAJOR="${FLASHINFER_CI_TORCH_VERSION%.*}" 28 | TORCH_MINOR="${FLASHINFER_CI_TORCH_VERSION#*.}" 29 | PYVER="${FLASHINFER_CI_PYTHON_VERSION//./}" 30 | export PATH="/opt/python/cp${PYVER}-cp${PYVER}/bin:$PATH" 31 | 32 | FLASHINFER_LOCAL_VERSION="cu${CUDA_MAJOR}${CUDA_MINOR}torch${FLASHINFER_CI_TORCH_VERSION}" 33 | if [ -n "${FLASHINFER_GIT_SHA}" ]; then 34 | FLASHINFER_LOCAL_VERSION="${FLASHINFER_GIT_SHA}.${FLASHINFER_LOCAL_VERSION}" 35 | fi 36 | 37 | echo "::group::Install PyTorch" 38 | pip install torch==${FLASHINFER_CI_TORCH_VERSION}.* --index-url "https://download.pytorch.org/whl/cu${CUDA_MAJOR}${CUDA_MINOR}" 39 | echo "::endgroup::" 40 | 41 | echo "::group::Install build system" 42 | pip install ninja numpy 43 | pip install --upgrade setuptools wheel build 44 | echo "::endgroup::" 45 | 46 | 47 | echo "::group::Build wheel for FlashInfer" 48 | cd "$PROJECT_ROOT" 49 | 50 | python -m build --no-isolation --sdist 51 | 52 | python -m flashinfer.aot 53 | FLASHINFER_LOCAL_VERSION=$FLASHINFER_LOCAL_VERSION \ 54 | python -m build --no-isolation --wheel 55 | 56 | ls -la dist/ 57 | echo "::endgroup::" 58 | -------------------------------------------------------------------------------- /scripts/task_cpplint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2024 by FlashInfer team. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cpplint include/flashinfer/* include/flashinfer/attention/* include/flashinfer/distributed/* include/flashinfer/group_gemm/* 17 | -------------------------------------------------------------------------------- /scripts/task_jit_run_tests_part1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eo pipefail 4 | set -x 5 | : ${MAX_JOBS:=$(nproc)} 6 | : ${CUDA_VISIBLE_DEVICES:=0} 7 | 8 | pip install -e . -v 9 | 10 | # pytest -s tests/test_group_gemm.py 11 | pytest -s tests/test_logits_cap.py 12 | pytest -s tests/test_sliding_window.py 13 | pytest -s tests/test_tensor_cores_decode.py 14 | pytest -s tests/test_batch_decode_kernels.py 15 | #pytest -s tests/test_alibi.py 16 | -------------------------------------------------------------------------------- /scripts/task_jit_run_tests_part2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eo pipefail 4 | set -x 5 | : ${MAX_JOBS:=$(nproc)} 6 | : ${CUDA_VISIBLE_DEVICES:=0} 7 | 8 | pip install -e . -v 9 | 10 | pytest -s tests/test_block_sparse.py 11 | pytest -s tests/test_jit_example.py 12 | pytest -s tests/test_jit_warmup.py 13 | pytest -s tests/test_norm.py 14 | pytest -s tests/test_rope.py 15 | pytest -s tests/test_mla_page.py 16 | pytest -s tests/test_quantization.py 17 | -------------------------------------------------------------------------------- /scripts/task_jit_run_tests_part3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eo pipefail 4 | set -x 5 | : ${MAX_JOBS:=$(nproc)} 6 | : ${CUDA_VISIBLE_DEVICES:=0} 7 | 8 | pip install -e . -v 9 | 10 | pytest -s tests/test_sampling.py 11 | -------------------------------------------------------------------------------- /scripts/task_jit_run_tests_part4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eo pipefail 4 | set -x 5 | : ${MAX_JOBS:=$(nproc)} 6 | : ${CUDA_VISIBLE_DEVICES:=0} 7 | 8 | pip install -e . -v 9 | 10 | export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # avoid memory fragmentation 11 | pytest -s tests/test_deepseek_mla.py 12 | pytest -s tests/test_group_gemm.py 13 | -------------------------------------------------------------------------------- /scripts/task_lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2024 by FlashInfer team. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | ./task_cpplint.sh 17 | -------------------------------------------------------------------------------- /scripts/task_mypy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2024 by FlashInfer team. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | mypy --check-untyped-defs flashinfer/ 17 | -------------------------------------------------------------------------------- /scripts/task_pylint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2024 by FlashInfer team. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | pylint flashinfer/ 17 | -------------------------------------------------------------------------------- /scripts/task_show_node_info.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Licensed to the Apache Software Foundation (ASF) under one 3 | # or more contributor license agreements. See the NOTICE file 4 | # distributed with this work for additional information 5 | # regarding copyright ownership. The ASF licenses this file 6 | # to you under the Apache License, Version 2.0 (the 7 | # "License"); you may not use this file except in compliance 8 | # with the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, 13 | # software distributed under the License is distributed on an 14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | # KIND, either express or implied. See the License for the 16 | # specific language governing permissions and limitations 17 | # under the License. 18 | 19 | set -euxo pipefail 20 | 21 | echo "===== JENKINS INFO =====" 22 | echo "NODE_NAME=$NODE_NAME" 23 | echo "EXECUTOR_NUMBER=$EXECUTOR_NUMBER" 24 | echo "WORKSPACE=$WORKSPACE" 25 | echo "BUILD_NUMBER=$BUILD_NUMBER" 26 | echo "WORKSPACE=$WORKSPACE" 27 | 28 | echo "===== EC2 INFO =====" 29 | function ec2_metadata() { 30 | # See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html 31 | curl -w '\n' -fsSL "http://169.254.169.254/latest/meta-data/$1" || echo failed 32 | } 33 | 34 | ec2_metadata ami-id 35 | ec2_metadata instance-id 36 | ec2_metadata instance-type 37 | ec2_metadata hostname 38 | ec2_metadata public-hostname 39 | 40 | echo "===== RUNNER INFO =====" 41 | df --human-readable 42 | lscpu 43 | free 44 | nvidia-smi 2>/dev/null || echo "cuda not found" 45 | -------------------------------------------------------------------------------- /scripts/task_test_aot_build_import.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eo pipefail 4 | set -x 5 | : ${MAX_JOBS:=$(nproc)} 6 | : ${CUDA_VISIBLE_DEVICES:=""} 7 | export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0+PTX" 8 | 9 | python -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)" 10 | python -m flashinfer.aot 11 | python -m build --no-isolation --wheel 12 | pip install dist/*.whl 13 | 14 | # test import 15 | mkdir -p tmp 16 | cd tmp 17 | python -c "from flashinfer.page import gen_page_module; p = gen_page_module().aot_path; print(p); assert p.exists();" 18 | -------------------------------------------------------------------------------- /scripts/update_whl_index.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import pathlib 3 | import re 4 | 5 | for path in sorted(pathlib.Path("dist").glob("*.whl")): 6 | with open(path, "rb") as f: 7 | sha256 = hashlib.sha256(f.read()).hexdigest() 8 | ver, cu, torch = re.findall( 9 | r"flashinfer_python-([0-9.]+(?:\.post[0-9]+)?)\+cu(\d+)torch([0-9.]+)-", 10 | path.name, 11 | )[0] 12 | index_dir = pathlib.Path(f"flashinfer-whl/cu{cu}/torch{torch}/flashinfer-python") 13 | index_dir.mkdir(exist_ok=True) 14 | base_url = "https://github.com/flashinfer-ai/flashinfer/releases/download" 15 | full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}" 16 | with (index_dir / "index.html").open("a") as f: 17 | f.write(f'{path.name}
\n') 18 | -------------------------------------------------------------------------------- /src/bench_norm.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | #include "utils.h" 22 | 23 | using namespace flashinfer; 24 | 25 | template 26 | void bench_rms_norm(nvbench::state& state) { 27 | size_t batch_size = state.get_int64("batch_size"); 28 | size_t hidden_dim = state.get_int64("hidden_dim"); 29 | 30 | thrust::device_vector x(batch_size * hidden_dim); 31 | thrust::device_vector w(hidden_dim); 32 | thrust::device_vector y(batch_size * hidden_dim); 33 | 34 | state.add_global_memory_reads(batch_size * hidden_dim + hidden_dim, "Read"); 35 | state.add_global_memory_writes(batch_size * hidden_dim, "Write"); 36 | 37 | state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { 38 | timer.start(); 39 | cudaError_t status = 40 | norm::RMSNorm(thrust::raw_pointer_cast(x.data()), thrust::raw_pointer_cast(w.data()), 41 | thrust::raw_pointer_cast(y.data()), batch_size, hidden_dim, 1e-5); 42 | timer.stop(); 43 | if (status != cudaSuccess) { 44 | state.skip("RMSNorm kernel launch failed"); 45 | } 46 | }); 47 | } 48 | 49 | auto bench_rms_norm_f16 = bench_rms_norm; 50 | NVBENCH_BENCH(bench_rms_norm_f16) 51 | .set_name("bench_rms_norm_f16") 52 | .add_int64_axis("batch_size", {32, 128, 512, 2048}) 53 | .add_int64_axis("hidden_dim", {3072, 4096, 32768}); 54 | -------------------------------------------------------------------------------- /tests/test_bmm_fp8.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from flashinfer import bmm_fp8 6 | 7 | 8 | def to_float8(x, dtype=torch.float8_e4m3fn): 9 | finfo = torch.finfo(dtype) 10 | min_val, max_val = x.aminmax() 11 | amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) 12 | scale = finfo.max / amax 13 | x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) 14 | return x_scl_sat.to(dtype), scale.float().reciprocal() 15 | 16 | 17 | @pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) 18 | @pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) 19 | @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) 20 | def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): 21 | if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: 22 | pytest.skip("Invalid combination: both input and mat2 are e5m2") 23 | 24 | input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) 25 | input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) 26 | 27 | # mat2 row major -> column major 28 | mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose( 29 | -2, -1 30 | ) 31 | mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) 32 | 33 | res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) 34 | bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res) 35 | 36 | reference = torch.bmm(input, mat2) 37 | cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) 38 | assert cos_sim > 0.99 39 | 40 | 41 | if __name__ == "__main__": 42 | pytest.main([__file__]) 43 | -------------------------------------------------------------------------------- /tests/test_page.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import flashinfer 5 | 6 | 7 | @pytest.mark.parametrize("contiguous", [True, False]) 8 | def test_append_paged_kv_cache(contiguous): 9 | nnz_kv = 100 10 | num_kv_heads = 32 11 | head_dim = 128 12 | 13 | if contiguous: 14 | k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) 15 | v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) 16 | else: 17 | kv_append = torch.randn(nnz_kv, 2, num_kv_heads, head_dim).half().to(0) 18 | k_append = kv_append[:, 0] 19 | v_append = kv_append[:, 1] 20 | # 45 + 8 + 25 + 22 = nnz_kv 21 | kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0") 22 | kv_append_indptr = torch.cat( 23 | [torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)] 24 | ).int() 25 | 26 | max_num_pages = 1000 27 | page_size = 16 28 | paged_kv_cache = ( 29 | torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0) 30 | ) 31 | num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0") 32 | kv_page_indptr = torch.cat( 33 | [torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)] 34 | ).int() 35 | # use first 8 pages in the paged-kv 36 | kv_page_indices = torch.arange(8, dtype=torch.int32, device="cuda:0") 37 | # 45 = (3 - 1) * 16 + 13 38 | # 8 = (1 - 1) * 16 + 8 39 | # 25 = (2 - 1) * 16 + 9 40 | # 22 = (2 - 1) * 16 + 6 41 | kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0") 42 | batch_indices, positions = flashinfer.get_batch_indices_positions( 43 | kv_append_indptr, 44 | flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size), 45 | nnz_kv, 46 | ) 47 | 48 | flashinfer.append_paged_kv_cache( 49 | k_append, 50 | v_append, 51 | batch_indices, 52 | positions, 53 | paged_kv_cache, 54 | kv_page_indices, 55 | kv_page_indptr, 56 | kv_last_page_len, 57 | ) 58 | -------------------------------------------------------------------------------- /tvm_binding/batch_decode_customize_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} 10 | #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} 11 | 12 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \ 13 | DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { \ 14 | using AttentionVariant = {{ variant_name }}; \ 15 | __VA_ARGS__(); \ 16 | }) 17 | 18 | using namespace flashinfer; 19 | 20 | using DTypeQ = {{ dtype_q }}; 21 | using DTypeKV = {{ dtype_kv }}; 22 | using DTypeO = {{ dtype_o }}; 23 | using IdType = {{ idtype }}; 24 | constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; 25 | constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; 26 | constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; 27 | constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; 28 | 29 | struct Params { 30 | using DTypeQ = DTypeQ; 31 | using DTypeKV = DTypeKV; 32 | using DTypeO = DTypeO; 33 | using IdType = IdType; 34 | 35 | DTypeQ* q; 36 | paged_kv_t paged_kv; 37 | DTypeO* o; 38 | float* lse; 39 | 40 | IdType* decode_maybe_q_rope_offset; 41 | 42 | {{ additional_params_decl }} 43 | 44 | uint32_t padded_batch_size; 45 | uint32_t num_qo_heads; 46 | IdType q_stride_n; 47 | IdType q_stride_h; 48 | int32_t window_left; 49 | 50 | IdType* request_indices; 51 | IdType* kv_tile_indices; 52 | IdType* o_indptr; 53 | IdType* kv_chunk_size_ptr; 54 | bool* block_valid_mask; 55 | bool partition_kv; 56 | 57 | __host__ __device__ __forceinline__ int32_t get_qo_len(int32_t batch_idx) const { return 1; } 58 | 59 | __host__ __device__ __forceinline__ int32_t get_kv_len(int32_t batch_idx) const { 60 | return paged_kv.get_length(batch_idx); 61 | } 62 | }; 63 | 64 | {{ variant_decl }} 65 | -------------------------------------------------------------------------------- /tvm_binding/batch_decode_jit_tvm_binding.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023-2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "batch_decode_config.inc" 17 | #include "tvm_binding_utils.h" 18 | 19 | IntTuple BatchDecodeWithPagedKVCachePlan( 20 | DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, 21 | DLTensor* page_locked_int_workspace_buffer, DLTensor* indptr, int64_t batch_size, 22 | int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, 23 | int64_t pos_encoding_mode_code, int64_t window_left, int64_t head_dim_qk, int64_t head_dim_vo, 24 | DataType q_scalar_type, DataType kv_scalar_type, TVMStreamHandle cuda_stream); 25 | 26 | void BatchDecodeWithPagedKVCacheRun( 27 | DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, IntTuple plan_info_vec, 28 | DLTensor* q, DLTensor* paged_kv_cache, DLTensor* paged_kv_indptr, DLTensor* paged_kv_indices, 29 | DLTensor* paged_kv_last_page_len, DLTensor* q_rope_offset, DLTensor* paged_kv_rope_pos_offset, 30 | DLTensor* o, DLTensor* lse, int64_t pos_encoding_mode_code, int64_t kv_layout_code, 31 | int64_t window_left ADDITIONAL_FUNC_PARAMS, TVMStreamHandle cuda_stream); 32 | 33 | TVM_FFI_DLL_EXPORT_TYPED_FUNC(batch_decode_with_paged_kv_cache_plan, 34 | BatchDecodeWithPagedKVCachePlan); 35 | TVM_FFI_DLL_EXPORT_TYPED_FUNC(batch_decode_with_paged_kv_cache_run, BatchDecodeWithPagedKVCacheRun); 36 | -------------------------------------------------------------------------------- /tvm_binding/batch_mla_config.jinja: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace flashinfer; 12 | 13 | #define ADDITIONAL_FUNC_PARAMS 14 | #define ADDITIONAL_PARAMS_SETTER 15 | 16 | using DTypeQ = {{ dtype_q }}; 17 | using DTypeKV = {{ dtype_kv }}; 18 | using DTypeO = {{ dtype_o }}; 19 | using IdType = {{ dtype_idx }}; 20 | constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }}; 21 | constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }}; 22 | 23 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \ 24 | DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ 25 | using Params = MLAParams; \ 26 | __VA_ARGS__(); \ 27 | }) 28 | -------------------------------------------------------------------------------- /tvm_binding/batch_mla_jit_tvm_binding.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "batch_mla_config.inc" 17 | #include "tvm_binding_utils.h" 18 | 19 | IntTuple BatchMLAPagedAttentionPlan(DLTensor* float_workspace_buffer, 20 | DLTensor* int_workspace_buffer, 21 | DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, 22 | DLTensor* kv_indptr, IntTuple kv_len_arr, int64_t num_heads, 23 | int64_t head_dim_o, bool causal, TVMStreamHandle cuda_stream); 24 | 25 | void BatchMLAPagedAttentionRun(DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, 26 | IntTuple plan_info_vec, DLTensor* q, DLTensor* kv_cache, 27 | DLTensor* kv_indices, DLTensor* o, DLTensor* lse, 28 | int64_t mask_mode_code, int64_t num_heads, int64_t page_size, 29 | double sm_scale, TVMStreamHandle cuda_stream); 30 | 31 | TVM_FFI_DLL_EXPORT_TYPED_FUNC(batch_mla_paged_attention_plan, BatchMLAPagedAttentionPlan); 32 | TVM_FFI_DLL_EXPORT_TYPED_FUNC(batch_mla_paged_attention_run, BatchMLAPagedAttentionRun); 33 | -------------------------------------------------------------------------------- /tvm_binding/sampling.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "tvm_binding_utils.h" 9 | 10 | using namespace flashinfer; 11 | 12 | // TODO: change the philox seeds and offsets to DLTensor once the underlying API for sampling 13 | // changes to support multiple seeds 14 | void SamplingFromProbs(DLTensor* probs, DLTensor* output, DLTensor* maybe_indices, 15 | bool deterministic, uint64_t philox_seed, uint64_t philox_offset, 16 | int64_t cuda_stream) { 17 | CHECK(probs->ndim == 2) << "Probs should have 2 dimensions"; 18 | unsigned int batch_size = output->shape[0]; 19 | unsigned int vocab_size = probs->shape[1]; 20 | 21 | cudaStream_t stream = reinterpret_cast(cuda_stream); 22 | float* probs_cast = static_cast(probs->data) + probs->byte_offset; 23 | int* output_cast = static_cast(output->data) + output->byte_offset; 24 | int* maybe_indices_cast = 25 | maybe_indices ? static_cast(maybe_indices->data) + maybe_indices->byte_offset : nullptr; 26 | 27 | cudaError_t status = 28 | sampling::SamplingFromProb(probs_cast, output_cast, maybe_indices_cast, batch_size, 29 | vocab_size, deterministic, philox_seed, philox_offset, stream); 30 | CHECK(status == cudaSuccess) << "SamplingFromProbs failed with error " 31 | << cudaGetErrorString(status); 32 | } 33 | -------------------------------------------------------------------------------- /tvm_binding/sampling_jit_tvm_binding.cu: -------------------------------------------------------------------------------- 1 | #include "tvm_binding_utils.h" 2 | 3 | void SamplingFromProbs(DLTensor* probs, DLTensor* output, DLTensor* maybe_indices, 4 | bool deterministic, uint64_t philox_seed, uint64_t philox_offset, 5 | int64_t cuda_stream); 6 | 7 | TVM_FFI_DLL_EXPORT_TYPED_FUNC(sampling_from_probs, SamplingFromProbs); 8 | -------------------------------------------------------------------------------- /tvm_binding/tvm_binding_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 by FlashInfer team. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #pragma once 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | using IdType = int32_t; 27 | using tvm::ffi::Array; 28 | using tvm::runtime::DataType; 29 | using tvm::runtime::IntTuple; 30 | using tvm::runtime::NDArray; 31 | 32 | #define DISPATCH_BOOL(expr, const_expr, ...) \ 33 | [&]() -> bool { \ 34 | if (expr) { \ 35 | constexpr bool const_expr = true; \ 36 | return __VA_ARGS__(); \ 37 | } else { \ 38 | constexpr bool const_expr = false; \ 39 | return __VA_ARGS__(); \ 40 | } \ 41 | }() 42 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.2.5 2 | --------------------------------------------------------------------------------