├── benchmarks ├── __init__.py ├── fastertransformer │ ├── build_ext.py │ ├── ft_llama.h │ ├── ft_pybind11.cc │ ├── CMakeLists.txt │ ├── __init__.py │ └── .clang-format ├── runft.py ├── nvbench │ ├── CMakeLists.txt │ ├── sgmv.cu │ └── sgmv_flashinfer.cu ├── bench_textgen_lora_all.py ├── bench_backbone_vs_lora.py ├── bench_sgmv_cutlass.py ├── bench_sgmv.py ├── benchmark_utils.py ├── bench_batch_decode.py ├── bench_model_prefill_decode.py ├── bench_layer_lora_decode.py ├── bench_model_lora_decode.py └── bench_lora_op_impls.py ├── version.txt ├── src └── punica │ ├── models │ ├── __init__.py │ └── llama.py │ ├── utils │ ├── __init__.py │ ├── cat_tensor.py │ ├── lora.py │ ├── convert_lora_weight.py │ └── kvcache.py │ └── __init__.py ├── .clang-format ├── .release-please-manifest.json ├── assets ├── sgmv.png ├── textgen.png ├── backbone-vs-sgmv.png └── punica-tui-demo.mp4 ├── .clangd ├── csrc ├── bgmv │ ├── bgmv_all.cu │ ├── bgmv_config.h │ └── bgmv_impl.cuh ├── rms_norm │ ├── rms_norm.h │ └── rms_norm_cutlass.cu ├── sgmv │ ├── sgmv.h │ ├── sgmv_cutlass.cu │ └── sgmv_cutlass.cuh ├── sgmv_flashinfer │ ├── sgmv_config.h │ └── sgmv_all.cu └── flashinfer_adapter │ ├── flashinfer_config.h │ ├── flashinfer_decl.h │ └── flashinfer_all.cu ├── ci ├── ci-punica.env.example ├── test-run-ci.bash ├── ci-punica.service ├── run-ci-build-wheel.bash └── run-ci-gpu-tests.bash ├── .gitignore ├── release-please-config.json ├── .gitmodules ├── MANIFEST.in ├── examples ├── finetune │ ├── run-llmtuner.py │ ├── data │ │ └── dataset_info.json │ ├── dataset_info.json │ ├── finetune.sh │ ├── create-finetune-data.py │ └── README.md ├── textgen.py └── textgen_lora.py ├── CMakeLists.txt ├── .github └── workflows │ ├── release-please.yml │ ├── gpu-test.yml │ └── release_wheel.yml ├── tests ├── test_rms_norm.py ├── test_bgmv.py ├── test_kvcache.py ├── test_sgmv_cutlass.py ├── test_sgmv.py └── test_flashinfer.py ├── licenses └── LICENSE.cutlass.txt ├── pyproject.toml ├── CHANGELOG.md ├── README.md └── setup.py /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 1.1.0 2 | -------------------------------------------------------------------------------- /src/punica/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | -------------------------------------------------------------------------------- /.release-please-manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | ".": "1.1.0" 3 | } 4 | -------------------------------------------------------------------------------- /assets/sgmv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/punica-ai/punica/HEAD/assets/sgmv.png -------------------------------------------------------------------------------- /assets/textgen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/punica-ai/punica/HEAD/assets/textgen.png -------------------------------------------------------------------------------- /assets/backbone-vs-sgmv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/punica-ai/punica/HEAD/assets/backbone-vs-sgmv.png -------------------------------------------------------------------------------- /assets/punica-tui-demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/punica-ai/punica/HEAD/assets/punica-tui-demo.mp4 -------------------------------------------------------------------------------- /.clangd: -------------------------------------------------------------------------------- 1 | CompileFlags: 2 | Remove: 3 | - -forward-unknown-to-host-compiler 4 | - -arch=native 5 | Diagnostics: 6 | Suppress: 7 | - variadic_device_fn 8 | -------------------------------------------------------------------------------- /csrc/bgmv/bgmv_all.cu: -------------------------------------------------------------------------------- 1 | #include "bgmv_config.h" 2 | #include "bgmv_impl.cuh" 3 | 4 | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half) 5 | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16) 6 | -------------------------------------------------------------------------------- /benchmarks/fastertransformer/build_ext.py: -------------------------------------------------------------------------------- 1 | from benchmarks.fastertransformer import _build_ext 2 | 3 | if __name__ == "__main__": 4 | module = _build_ext(do_cmake=True) 5 | print("Built ext:", module.__spec__.origin) 6 | -------------------------------------------------------------------------------- /csrc/rms_norm/rms_norm.h: -------------------------------------------------------------------------------- 1 | template 2 | bool rms_norm(T *__restrict__ output, const T *__restrict__ input, 3 | const T *__restrict__ weight, int rows, int columns, 4 | float epsilon); 5 | -------------------------------------------------------------------------------- /ci/ci-punica.env.example: -------------------------------------------------------------------------------- 1 | RUNNER_SCOPE=repo 2 | REPO_URL=https://github.com/punica-ai/punica 3 | LABELS=gpu,sm80 4 | ACCESS_TOKEN=foo-access-token 5 | RUNNER_WORKDIR=/tmp/ci-punica 6 | CI_RUNNER_CACHE_DIR=/data/ci-punica-cache 7 | DISABLE_AUTO_UPDATE=1 8 | EPHEMERAL=1 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | csrc/flashinfer_adapter/generated/ 2 | src/punica/_build_meta.py 3 | data/ 4 | build/ 5 | tmp/ 6 | .cache/ 7 | .hypothesis/ 8 | __pycache__/ 9 | *.egg-info/ 10 | *.py[cod] 11 | *.so 12 | dist/ 13 | .pytest_cache/ 14 | LLaMA-Factory/ 15 | wandb/ 16 | model/ 17 | *.env 18 | .coverage 19 | .vscode/ 20 | -------------------------------------------------------------------------------- /csrc/sgmv/sgmv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | 6 | template 7 | bool sgmv(DType *y, DType *x, DType **w, int32_t *s, void *tmp_d, 8 | int num_problems, int d_in, int d_out, int layer_idx, 9 | cudaStream_t stream); 10 | 11 | size_t sgmv_tmp_size(int num_problems); 12 | -------------------------------------------------------------------------------- /release-please-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json", 3 | "bootstrap-sha": "28bf99aa4ad758e25216236dc5b5381de9189914", 4 | "packages": { 5 | ".": { 6 | "changelog-path": "CHANGELOG.md", 7 | "release-type": "simple" 8 | } 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /src/punica/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from punica.utils.cat_tensor import BatchLenInfo 2 | from punica.utils.kvcache import BatchedKvCache, KvCache, KvPool 3 | from punica.utils.lora import BatchedLoraWeight, LoraWeight 4 | 5 | __all__ = [ 6 | "BatchLenInfo", 7 | "KvPool", 8 | "KvCache", 9 | "BatchedKvCache", 10 | "LoraWeight", 11 | "BatchedLoraWeight", 12 | ] 13 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/nvbench"] 2 | path = third_party/nvbench 3 | url = https://github.com/NVIDIA/nvbench.git 4 | [submodule "third_party/cutlass"] 5 | path = third_party/cutlass 6 | url = https://github.com/NVIDIA/cutlass.git 7 | [submodule "third_party/flashinfer"] 8 | path = third_party/flashinfer 9 | url = https://github.com/flashinfer-ai/flashinfer.git 10 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # sdist & wheel 2 | include version.txt 3 | recursive-include licenses * 4 | recursive-include csrc * 5 | recursive-include third_party/cutlass/include * 6 | recursive-include third_party/flashinfer/include * 7 | 8 | # wheel-only 9 | exclude src/punica/_build_meta.py 10 | 11 | # Unneeded files 12 | prune benchmarks 13 | prune */__pycache__ 14 | global-exclude *.so 15 | -------------------------------------------------------------------------------- /ci/test-run-ci.bash: -------------------------------------------------------------------------------- 1 | docker run --runtime=nvidia --gpus all --rm -t \ 2 | -v $HOME/ci-test/cache:/ci-cache \ 3 | -v $HOME/ci-test/punica-checkout:/app \ 4 | -e PUNICA_CI_TORCH_VERSION=2.1.0 \ 5 | -e PUNICA_CI_CUDA_MAJOR=12 \ 6 | -e PUNICA_CI_CUDA_MINOR=1 \ 7 | --user $(id -u):$(id -g) \ 8 | nvidia/cuda:12.1.0-devel-ubuntu22.04 \ 9 | bash /app/ci/run-ci-gpu-tests.bash 10 | -------------------------------------------------------------------------------- /examples/finetune/run-llmtuner.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | 3 | import sys 4 | 5 | sys.path.append("LLaMA-Factory/src") 6 | from llmtuner import run_exp 7 | from llmtuner.extras.template import register_template 8 | 9 | register_template( 10 | name="empty", 11 | system="", 12 | prefix=["{{system}}"], 13 | prompt=["{{query}}"], 14 | sep=["\n"], 15 | ) 16 | 17 | if __name__ == "__main__": 18 | run_exp() 19 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.25) 2 | project(punica_ops_bench CUDA CXX) 3 | set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES OFF) 4 | set(CMAKE_CUDA_ARCHITECTURES native) 5 | 6 | find_package(CUDAToolkit REQUIRED) 7 | list(APPEND CMAKE_PREFIX_PATH "${CUDAToolkit_LIBRARY_DIR}/cmake/thrust") 8 | find_package(Thrust REQUIRED CONFIG) 9 | thrust_create_target(Thrust) 10 | 11 | 12 | add_subdirectory(third_party/nvbench EXCLUDE_FROM_ALL) 13 | 14 | add_subdirectory(benchmarks/nvbench) 15 | -------------------------------------------------------------------------------- /csrc/sgmv_flashinfer/sgmv_config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | template 5 | bool sgmv_shrink(T* y, T* x, T** w, int32_t* s, void* tmp, 6 | uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, 7 | cudaStream_t stream); 8 | 9 | // clang-format off 10 | 11 | #define FOR_SGMV_NARROW(f, T) \ 12 | f(T, 16) \ 13 | f(T, 32) \ 14 | f(T, 64) \ 15 | f(T, 96) \ 16 | f(T, 128) 17 | 18 | // clang-format on 19 | -------------------------------------------------------------------------------- /benchmarks/fastertransformer/ft_llama.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | class FtLlama { 6 | enum class DataType 7 | { 8 | FP32, 9 | FP16, 10 | BF16 11 | } dtype_; 12 | void* impl_; 13 | 14 | public: 15 | FtLlama( 16 | size_t num_heads, size_t head_dim, size_t inter_size, size_t num_layers, const char* data_type, int device_id); 17 | ~FtLlama(); 18 | void 19 | forward(const std::vector>& input_ids, size_t request_output_len, std::function callback); 20 | }; 21 | -------------------------------------------------------------------------------- /examples/finetune/data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "gsm8k": { 3 | "file_name": "gsm8k-train.jsonl", 4 | "columns": { 5 | "prompt": "prompt", 6 | "response": "response" 7 | } 8 | }, 9 | "viggo": { 10 | "file_name": "viggo-train.jsonl", 11 | "columns": { 12 | "prompt": "prompt", 13 | "response": "response" 14 | } 15 | }, 16 | "sqlctx": { 17 | "file_name": "sqlctx-train.jsonl", 18 | "columns": { 19 | "prompt": "prompt", 20 | "response": "response" 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /csrc/sgmv/sgmv_cutlass.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "sgmv_cutlass.cuh" 5 | 6 | template bool sgmv(nv_half *y, nv_half *x, nv_half **w, int32_t *s, 7 | void *tmp_d, int num_problems, int d_in, int d_out, 8 | int layer_idx, cudaStream_t stream); 9 | 10 | template bool sgmv(nv_bfloat16 *y, nv_bfloat16 *x, nv_bfloat16 **w, 11 | int32_t *s, void *tmp_d, int num_problems, 12 | int d_in, int d_out, int layer_idx, 13 | cudaStream_t stream); 14 | -------------------------------------------------------------------------------- /examples/finetune/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "gsm8k": { 3 | "hf_hub_url": "gsm8k", 4 | "subset": "main", 5 | "columns": { 6 | "prompt": "question", 7 | "response": "answer" 8 | } 9 | }, 10 | "viggo": { 11 | "hf_hub_url": "GEM/viggo", 12 | "columns": { 13 | "prompt": "meaning_representation", 14 | "response": "target" 15 | } 16 | }, 17 | "sqlctx": { 18 | "hf_hub_url": "b-mc2/sql-create-context", 19 | "columns": { 20 | "prompt": "context", 21 | "query": "question", 22 | "response": "answer" 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /.github/workflows/release-please.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | 6 | permissions: 7 | contents: write 8 | pull-requests: write 9 | 10 | name: release-please 11 | 12 | jobs: 13 | release-please: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: google-github-actions/release-please-action@v4 17 | id: release 18 | outputs: 19 | release_created: ${{ steps.release.outputs.release_created }} 20 | tag_name: ${{ steps.release.outputs.tag_name }} 21 | 22 | wheel: 23 | needs: release-please 24 | if: ${{ needs.release-please.outputs.release_created }} 25 | uses: ./.github/workflows/release_wheel.yml 26 | with: 27 | tag_name: ${{ needs.release-please.outputs.tag_name }} 28 | secrets: inherit 29 | -------------------------------------------------------------------------------- /src/punica/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ops as ops 2 | from ._build_meta import __version__ as __version__ 3 | from .models.llama import ( 4 | LlamaForCausalLM as LlamaForCausalLM, 5 | LlamaModel as LlamaModel, 6 | ) 7 | from .models.llama_lora import ( 8 | BatchedLlamaLoraWeight as BatchedLlamaLoraWeight, 9 | LlamaForCausalLMWithLora as LlamaForCausalLMWithLora, 10 | LlamaLoraWeight as LlamaLoraWeight, 11 | LlamaModelWithLora as LlamaModelWithLora, 12 | ) 13 | from .utils.cat_tensor import ( 14 | BatchLenInfo as BatchLenInfo, 15 | ) 16 | from .utils.kvcache import ( 17 | BatchedKvCache as BatchedKvCache, 18 | KvCache as KvCache, 19 | KvPool as KvPool, 20 | ) 21 | from .utils.lora import ( 22 | BatchedLoraWeight as BatchedLoraWeight, 23 | LoraWeight as LoraWeight, 24 | ) 25 | -------------------------------------------------------------------------------- /benchmarks/fastertransformer/ft_pybind11.cc: -------------------------------------------------------------------------------- 1 | #include "ft_llama.h" 2 | #include 3 | #include 4 | #include 5 | 6 | namespace py = pybind11; 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 9 | { 10 | py::class_(m, "FtLlama") 11 | .def(py::init(), 12 | py::arg("num_heads"), 13 | py::arg("head_dim"), 14 | py::arg("inter_size"), 15 | py::arg("num_layers"), 16 | py::arg("dtype"), 17 | py::arg("device_id") = 0) 18 | .def("forward", 19 | &FtLlama::forward, 20 | py::arg("input_ids"), 21 | py::arg("request_output_len"), 22 | py::arg("callback") = nullptr); 23 | } 24 | -------------------------------------------------------------------------------- /benchmarks/runft.py: -------------------------------------------------------------------------------- 1 | from .fastertransformer import build_ext as _build_ft 2 | 3 | ctr = 0 4 | 5 | 6 | def cb(): 7 | global ctr 8 | ctr += 1 9 | print("hi", ctr) 10 | 11 | 12 | def main(): 13 | ft = _build_ft() 14 | model = ft.FtLlama( 15 | num_heads=32, 16 | head_dim=128, 17 | inter_size=11008, 18 | num_layers=32, 19 | dtype="float16", 20 | ) 21 | input_ids = [ 22 | [0, 37, 92, 26, 66, 36, 55, 70, 73, 15, 36, 51, 34, 52, 29], 23 | [0, 37, 92, 70, 73, 15, 66, 93, 34, 52, 99], 24 | [0, 92, 16, 66, 16, 45, 70, 93, 11, 36, 53, 30, 52, 29], 25 | [0, 37, 92, 26, 66, 36, 55, 70, 23, 23], 26 | [0, 37, 92, 26, 66, 36, 55, 70, 29, 15, 34, 52, 23], 27 | [0, 73], 28 | [0, 37, 92, 15, 66, 93, 34, 52, 23], 29 | [0, 70, 73, 15, 66, 93, 30, 92, 29], 30 | ] 31 | model.forward(input_ids, 20, cb) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /benchmarks/fastertransformer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12) 2 | project(ft LANGUAGES CXX) 3 | include(FetchContent) 4 | 5 | # FasterTransformer 6 | FetchContent_Declare( 7 | fastertransformer 8 | GIT_REPOSITORY https://github.com/void-main/FasterTransformer.git 9 | GIT_TAG e770ddf2bc66217034b6e9e3b0c3256ebf1c1b40 10 | ) 11 | FetchContent_MakeAvailable(fastertransformer) 12 | cmake_policy(SET CMP0079 NEW) 13 | target_link_libraries(cuda_driver_wrapper PUBLIC -lcublas -lcudart -ldl) 14 | find_package(CUDA REQUIRED) 15 | 16 | # ft 17 | add_library(ft SHARED ft_llama.cc) 18 | target_include_directories(ft PUBLIC 19 | ${CUDA_TOOLKIT_ROOT_DIR}/include 20 | ${FasterTransformer_SOURCE_DIR} 21 | ${FasterTransformer_SOURCE_DIR}/3rdparty/cutlass/include 22 | ) 23 | target_link_directories(ft PUBLIC 24 | ${CUDA_TOOLKIT_ROOT_DIR}/lib64 25 | ) 26 | target_link_libraries(ft PUBLIC 27 | -lcublas -lcublasLt -lcudart 28 | Llama 29 | ) 30 | -------------------------------------------------------------------------------- /benchmarks/nvbench/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(bench_srcs 2 | sgmv.cu 3 | sgmv_flashinfer.cu 4 | ) 5 | 6 | foreach(bench_src IN LISTS bench_srcs) 7 | get_filename_component(bench_name "${bench_src}" NAME_WLE) 8 | string(PREPEND bench_name "nvbench_") 9 | add_executable(${bench_name} "${bench_src}") 10 | target_include_directories(${bench_name} PUBLIC ${PROJECT_SOURCE_DIR}/csrc) 11 | target_include_directories(${bench_name} PUBLIC ${PROJECT_SOURCE_DIR}/third_party/flashinfer/include) 12 | target_include_directories(${bench_name} PUBLIC ${PROJECT_SOURCE_DIR}/third_party/nvbench) 13 | target_include_directories(${bench_name} PUBLIC ${PROJECT_SOURCE_DIR}/third_party/cutlass/include) 14 | target_include_directories(${bench_name} PUBLIC ${PROJECT_SOURCE_DIR}/third_party/cutlass/tools/util/include) 15 | target_link_libraries(${bench_name} PUBLIC nvbench::main Thrust) 16 | set_target_properties(${bench_name} PROPERTIES CXX_STANDARD 17) 17 | set_target_properties(${bench_name} PROPERTIES CUDA_STANDARD 17) 18 | endforeach() 19 | -------------------------------------------------------------------------------- /examples/finetune/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd "$(dirname "$0")" 4 | 5 | if [ $# -ne 1 ] 6 | then 7 | echo "usage: $0 " 8 | echo "Example datasets: gsm8k, sqlctx, viggo" 9 | exit 1 10 | fi 11 | DATASET=$1 12 | 13 | LORA_RANK=16 14 | OUTPUT_DIR="../../model/$DATASET-r$LORA_RANK" 15 | 16 | python3 run-llmtuner.py \ 17 | --stage sft \ 18 | --model_name_or_path meta-llama/Llama-2-7b-hf \ 19 | --flash_attn \ 20 | --do_train \ 21 | --template empty \ 22 | --dataset $DATASET \ 23 | --dataset_dir data/ \ 24 | --finetuning_type lora \ 25 | --lora_rank $LORA_RANK \ 26 | --lora_alpha $LORA_RANK \ 27 | --lora_target q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj \ 28 | --output_dir "$OUTPUT_DIR" \ 29 | --overwrite_cache \ 30 | --per_device_train_batch_size 32 \ 31 | --gradient_accumulation_steps 1 \ 32 | --lr_scheduler_type cosine \ 33 | --logging_steps 10 \ 34 | --save_steps 200 \ 35 | --learning_rate 5e-5 \ 36 | --num_train_epochs 4 \ 37 | --plot_loss \ 38 | --fp16 39 | -------------------------------------------------------------------------------- /tests/test_rms_norm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import punica.ops 5 | 6 | 7 | def assert_close(a, b): 8 | rtol, atol = { 9 | torch.float16: (5e-3, 5e-3), 10 | torch.bfloat16: (3e-2, 2e-2), 11 | }[a.dtype] 12 | torch.testing.assert_close(a, b, rtol=rtol, atol=atol) 13 | 14 | 15 | def _rms_norm_ref_impl( 16 | x: torch.Tensor, 17 | w: torch.Tensor, 18 | eps: float = 1e-6, 19 | ): 20 | dtype = x.dtype 21 | x = x.to(torch.float32) 22 | variance = x.pow(2).mean(-1, keepdim=True) 23 | x = x * torch.rsqrt(variance + eps) 24 | return (w * x).to(dtype) 25 | 26 | 27 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 28 | @torch.inference_mode() 29 | def test_rms_norm_correctness(dtype_str): 30 | torch.manual_seed(0xABCDABCD987) 31 | h = 4096 32 | bs = 17 33 | dtype = getattr(torch, dtype_str) 34 | device = torch.device("cuda:0") 35 | 36 | w = torch.randn(h, dtype=dtype, device=device) 37 | x = torch.randn(bs, h, dtype=dtype, device=device) 38 | 39 | y_ref = _rms_norm_ref_impl(x, w) 40 | y_our = punica.ops.rms_norm(x, w) 41 | torch.testing.assert_close(y_ref, y_our) 42 | -------------------------------------------------------------------------------- /csrc/bgmv/bgmv_config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | template 4 | void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X, 5 | const T* __restrict__ W, const int64_t* __restrict__ indicies, 6 | int64_t batch_size, int64_t num_layers, int64_t layer_idx, 7 | float scale); 8 | 9 | // clang-format off 10 | 11 | #define FOR_BGMV_WIDE(f, T, narrow) \ 12 | f(T, narrow, 768) \ 13 | f(T, narrow, 1024) \ 14 | f(T, narrow, 2048) \ 15 | f(T, narrow, 2560) \ 16 | f(T, narrow, 3072) \ 17 | f(T, narrow, 4096) \ 18 | f(T, narrow, 5120) \ 19 | f(T, narrow, 7168) \ 20 | f(T, narrow, 8192) \ 21 | f(T, narrow, 9216) \ 22 | f(T, narrow, 10240) \ 23 | f(T, narrow, 11008) \ 24 | f(T, narrow, 12288) \ 25 | f(T, narrow, 13824) \ 26 | f(T, narrow, 16384) \ 27 | f(T, narrow, 20480) \ 28 | f(T, narrow, 28672) \ 29 | f(T, narrow, 36864) \ 30 | f(T, narrow, 49152) \ 31 | 32 | #define FOR_BGMV_WIDE_NARROW(f, T) \ 33 | FOR_BGMV_WIDE(f, T, 8) \ 34 | FOR_BGMV_WIDE(f, T, 16) \ 35 | FOR_BGMV_WIDE(f, T, 32) \ 36 | FOR_BGMV_WIDE(f, T, 64) 37 | 38 | // clang-format on 39 | -------------------------------------------------------------------------------- /ci/ci-punica.service: -------------------------------------------------------------------------------- 1 | # https://github.com/myoung34/docker-github-actions-runner/wiki/Usage 2 | # Install with: 3 | # install -m 644 ci-punica.service $HOME/.config/systemd/user/ 4 | # systemctl --user daemon-reload 5 | # Run with: 6 | # systemctl --user start ci-punica 7 | # Stop with: 8 | # systemctl --user stop ci-punica 9 | # See live logs with: 10 | # journalctl -f -u ci-punica.service --no-hostname --no-tail 11 | [Unit] 12 | Description=Ephemeral GitHub Actions Runner Container for punica-ai/punica 13 | [Service] 14 | TimeoutStartSec=0 15 | Restart=always 16 | ExecStartPre=-/usr/bin/docker stop %N 17 | ExecStartPre=-/usr/bin/docker rm %N 18 | ExecStartPre=-/usr/bin/docker pull myoung34/github-runner:latest 19 | ExecStart=/usr/bin/docker run --rm \ 20 | --env-file %h/.config/ci-punica.env \ 21 | -e RUNNER_NAME=%H \ 22 | -e CI_UID=%U \ 23 | -e CI_GID=%G \ 24 | -v /var/run/docker.sock:/var/run/docker.sock \ 25 | -v /tmp/ci-punica:/tmp/ci-punica \ 26 | --name %N \ 27 | myoung34/github-runner:latest 28 | -------------------------------------------------------------------------------- /src/punica/utils/cat_tensor.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class BatchLenInfo: 8 | def __init__( 9 | self, 10 | prefills: Sequence[int], 11 | decode: int, 12 | indptr_device: torch.device, 13 | indptr_dtype: torch.dtype = torch.int32, 14 | ): 15 | tmp = [0] 16 | tmp.extend(prefills) 17 | self._prefills = tmp[1:] 18 | self._decode = decode 19 | if len(prefills) > 0: 20 | cumsum = np.cumsum(tmp) 21 | self._indptr = torch.tensor( 22 | cumsum, dtype=indptr_dtype, device=indptr_device 23 | ) 24 | self._doff = cumsum[-1] 25 | else: 26 | self._indptr = None 27 | self._doff = 0 28 | 29 | @property 30 | def prefills(self) -> list[int]: 31 | """Length of each prefill request.""" 32 | return self._prefills 33 | 34 | @property 35 | def decode(self) -> int: 36 | """Number of decode requests.""" 37 | return self._decode 38 | 39 | @property 40 | def doff(self) -> int: 41 | """Index of the first decode request. Equivalently, total length of prefills.""" 42 | return self._doff 43 | 44 | @property 45 | def indptr(self) -> torch.Tensor | None: 46 | """`indptr[i] := sum(prefills[:i])`. None if no prefill.""" 47 | return self._indptr 48 | -------------------------------------------------------------------------------- /benchmarks/fastertransformer/__init__.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import subprocess 3 | 4 | import torch.utils.cpp_extension 5 | 6 | __all__ = ["build_ext"] 7 | 8 | 9 | def _build_ext(do_cmake: bool): 10 | project_root = pathlib.Path(__file__).parents[2].resolve() 11 | build_path = project_root / "build/fastertransformer" 12 | build_path.mkdir(parents=True, exist_ok=True) 13 | ext_name = "_punica_ft" 14 | 15 | if do_cmake or not (build_path / "CMakeCache.txt").exists(): 16 | subprocess.check_call( 17 | [ 18 | "cmake", 19 | "-DCMAKE_BUILD_TYPE=Release", 20 | f"{project_root}/benchmarks/fastertransformer", 21 | ], 22 | cwd=build_path, 23 | ) 24 | 25 | if do_cmake or not (build_path / "libft.so").exists(): 26 | nprocs = subprocess.check_output(["nproc"], text=True).strip() 27 | subprocess.check_call( 28 | ["cmake", "--build", ".", "--target", "ft", "--parallel", nprocs], 29 | cwd=build_path, 30 | ) 31 | 32 | module = torch.utils.cpp_extension.load( 33 | name=ext_name, 34 | sources=[pathlib.Path(__file__).parent / "ft_pybind11.cc"], 35 | extra_ldflags=[ 36 | f"{build_path}/libft.so", 37 | f"-Wl,-rpath={build_path}", 38 | ], 39 | build_directory=str(build_path), 40 | ) 41 | return module 42 | 43 | 44 | def build_ext(): 45 | return _build_ext(do_cmake=False) 46 | -------------------------------------------------------------------------------- /benchmarks/bench_textgen_lora_all.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import pathlib 3 | import subprocess 4 | from datetime import datetime 5 | 6 | import pytz 7 | 8 | 9 | def main(): 10 | this_file = pathlib.Path(__file__) 11 | project_root = this_file.parents[1] 12 | now = datetime.now(pytz.timezone("US/Pacific")) 13 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.jsonl" 14 | out_path = project_root / "data" / out_filename 15 | out_path.parent.mkdir(parents=True, exist_ok=True) 16 | 17 | batch_size = 32 18 | num_batches = 32 19 | maxlen = 2048 20 | model_ = ["7b", "13b"] 21 | pop_ = ["bgmv", "bmm", "uniform", "zipf:1.5"] 22 | system_ = ["punica", "hf", "ds", "ft_backbone", "vllm_backbone"] 23 | all_ = list(itertools.product(model_, pop_, system_)) 24 | 25 | for model, pop, system in all_: 26 | if system != "punica" and pop == "bgmv": 27 | continue 28 | args = { 29 | "--system": system, 30 | "--model": model, 31 | "--lora-popularity": pop, 32 | "--batch-size": str(batch_size), 33 | "--num-batches": str(num_batches), 34 | "--maxlen": str(maxlen), 35 | "--save-to": str(out_path), 36 | } 37 | cmd = ["python", "-m", "benchmarks.bench_textgen_lora"] 38 | cmd.extend([f"{k}={v}" for k, v in args.items()]) 39 | print(" ".join(cmd)) 40 | subprocess.run(cmd, cwd=project_root) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /ci/run-ci-build-wheel.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | assert_env() { 5 | local var_name="$1" 6 | if [ -z "${!var_name}" ]; then 7 | echo "Error: Environment variable '$var_name' is not set." 8 | exit 1 9 | fi 10 | } 11 | 12 | assert_env PUNICA_CI_PYTHON_VERSION 13 | assert_env PUNICA_CI_TORCH_VERSION 14 | assert_env PUNICA_CI_CUDA_VERSION 15 | assert_env PUNICA_BUILD_VERSION 16 | assert_env TORCH_CUDA_ARCH_LIST 17 | PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" 18 | export CONDA_pkgs_dirs=/ci-cache/conda-pkgs 19 | export XDG_CACHE_HOME=/ci-cache/xdg-cache 20 | mkdir -p "$CONDA_pkgs_dirs" "$XDG_CACHE_HOME" 21 | export HOME=/tmp/home 22 | mkdir -p $HOME 23 | export PATH="$HOME/.local/bin:$PATH" 24 | CUDA_MAJOR="${PUNICA_CI_CUDA_VERSION%.*}" 25 | CUDA_MINOR="${PUNICA_CI_CUDA_VERSION#*.}" 26 | PYVER="${PUNICA_CI_PYTHON_VERSION//./}" 27 | export PATH="/opt/python/cp${PYVER}-cp${PYVER}/bin:$PATH" 28 | 29 | 30 | echo "::group::Install PyTorch" 31 | pip install torch==$PUNICA_CI_TORCH_VERSION --index-url "https://download.pytorch.org/whl/cu${CUDA_MAJOR}${CUDA_MINOR}" 32 | echo "::endgroup::" 33 | 34 | echo "::group::Install build system" 35 | pip install ninja numpy 36 | pip install --upgrade setuptools wheel build 37 | echo "::endgroup::" 38 | 39 | 40 | echo "::group::Build wheel for Punica" 41 | cd "$PROJECT_ROOT" 42 | PUNICA_BUILD_VERSION="${PUNICA_BUILD_VERSION}+cu${CUDA_MAJOR}${CUDA_MINOR}" python -m build --no-isolation 43 | rm -f dist/*.tar.gz 44 | python -m build --no-isolation --sdist 45 | echo "::endgroup::" 46 | -------------------------------------------------------------------------------- /ci/run-ci-gpu-tests.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | assert_env() { 5 | local var_name="$1" 6 | if [ -z "${!var_name}" ]; then 7 | echo "Error: Environment variable '$var_name' is not set." 8 | exit 1 9 | fi 10 | } 11 | 12 | assert_env PUNICA_CI_TORCH_VERSION 13 | assert_env PUNICA_CI_CUDA_MAJOR 14 | assert_env PUNICA_CI_CUDA_MINOR 15 | PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" 16 | export CONDA_pkgs_dirs=/ci-cache/conda-pkgs 17 | export XDG_CACHE_HOME=/ci-cache/xdg-cache 18 | mkdir -p "$CONDA_pkgs_dirs" "$XDG_CACHE_HOME" 19 | export HOME=/tmp/home 20 | mkdir -p $HOME 21 | nvidia-smi 22 | 23 | 24 | echo "::group::Install Mamba and Python" 25 | if [ ! -f "/ci-cache/Miniforge3.sh" ]; then 26 | wget -O "/ci-cache/Miniforge3.sh" "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" 27 | fi 28 | bash "/ci-cache/Miniforge3.sh" -b -p "$HOME/conda" 29 | source "$HOME/conda/etc/profile.d/conda.sh" 30 | source "$HOME/conda/etc/profile.d/mamba.sh" 31 | mamba create -y -n punica-ci python=3.10.13 git 32 | mamba activate punica-ci 33 | echo "::endgroup::" 34 | 35 | 36 | echo "::group::Install PyTorch" 37 | pip install torch==$PUNICA_CI_TORCH_VERSION --index-url "https://download.pytorch.org/whl/cu${PUNICA_CI_CUDA_MAJOR}${PUNICA_CI_CUDA_MINOR}" 38 | echo "::endgroup::" 39 | 40 | 41 | echo "::group::Install Punica" 42 | cd "$PROJECT_ROOT" 43 | pip install ninja numpy 44 | pip install -v --no-build-isolation -e .[dev] 45 | echo "::endgroup::" 46 | 47 | 48 | echo "::group::Punica pytest" 49 | pytest --cov=src --cov-report=xml -v 50 | echo "::endgroup::" 51 | -------------------------------------------------------------------------------- /licenses/LICENSE.cutlass.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | SPDX-License-Identifier: BSD-3-Clause 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "torch", "numpy", "ninja", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "punica" 7 | description = "Punica: System for serving Large Language Models." 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | keywords = ["pytorch", "llm", "lora", "transformer"] 11 | dynamic = ["version"] 12 | dependencies = [ 13 | "torch", 14 | "transformers", 15 | "numpy", 16 | ] 17 | 18 | [project.optional-dependencies] 19 | dev = [ 20 | "pytest", 21 | "pytest-cov", 22 | "ruff", 23 | "pytz", 24 | "tqdm", 25 | "scipy", 26 | "peft", 27 | "accelerate", 28 | "textual", 29 | ] 30 | 31 | # Include different sets of files for sdist and wheel 32 | # see: https://stackoverflow.com/a/54953494/1332817 33 | [tool.setuptools.package-data] # wheel-only files 34 | punica = ["src/punica/_build_meta.py"] 35 | [tool.setuptools.exclude-package-data] # exclude from wheel 36 | punica = ["csrc", "third_party"] 37 | 38 | [tool.ruff] 39 | exclude = ["third_party", "src/punica/_build_meta.py"] 40 | 41 | [tool.ruff.lint.isort] 42 | known-first-party = ["punica"] 43 | combine-as-imports = true 44 | 45 | [tool.ruff.lint] 46 | select = [ 47 | "E", # pycodestyle 48 | "F", # Pyflakes 49 | "UP", # pyupgrade 50 | "SIM", # flake8-simplify 51 | "I", # isort 52 | ] 53 | ignore = [ 54 | "E501", # Line too long 55 | "E741", # Ambiguous variable name 56 | ] 57 | 58 | [tool.ruff.per-file-ignores] 59 | 60 | [tool.pytest.ini_options] 61 | testpaths = ["tests"] 62 | 63 | [tool.pyright] 64 | include = ["examples", "src", "tests"] 65 | exclude = ["examples/finetune"] 66 | -------------------------------------------------------------------------------- /.github/workflows/gpu-test.yml: -------------------------------------------------------------------------------- 1 | name: Tests on GPU 2 | on: 3 | push: 4 | paths: 5 | - "*.py" 6 | - "csrc/**" 7 | - "src/**" 8 | - "tests/**" 9 | branches-ignore: 10 | - ci-wheel 11 | pull_request: 12 | paths: 13 | - "*.py" 14 | - "csrc/**" 15 | - "src/**" 16 | - "tests/**" 17 | jobs: 18 | gpu-test: 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | include: 23 | - { arch: sm80, torch: "2.1.0", cuda_major: 12, cuda_minor: 1 } 24 | - { arch: sm80, torch: "2.1.0", cuda_major: 11, cuda_minor: 8 } 25 | - { arch: sm86, torch: "2.1.0", cuda_major: 12, cuda_minor: 1 } 26 | - { arch: sm86, torch: "2.1.0", cuda_major: 11, cuda_minor: 8 } 27 | runs-on: [self-hosted, "${{ matrix.arch }}"] 28 | steps: 29 | - name: Checkout 30 | uses: actions/checkout@v4 31 | with: 32 | submodules: true 33 | 34 | - name: Run tests and collect coverage 35 | run: | 36 | chown -R $CI_UID:$CI_GID "$GITHUB_WORKSPACE" 37 | docker run --runtime=nvidia --gpus all --rm -t \ 38 | -v "$CI_RUNNER_CACHE_DIR":/ci-cache \ 39 | -v "$GITHUB_WORKSPACE":/app \ 40 | -e PUNICA_CI_TORCH_VERSION=${{ matrix.torch }} \ 41 | -e PUNICA_CI_CUDA_MAJOR=${{ matrix.cuda_major }} \ 42 | -e PUNICA_CI_CUDA_MINOR=${{ matrix.cuda_minor }} \ 43 | --user $CI_UID:$CI_GID \ 44 | nvidia/cuda:${{ matrix.cuda_major }}.${{ matrix.cuda_minor }}.0-devel-ubuntu22.04 \ 45 | bash /app/ci/run-ci-gpu-tests.bash 46 | 47 | - name: Upload coverage to Codecov 48 | uses: codecov/codecov-action@v3 49 | env: 50 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 51 | -------------------------------------------------------------------------------- /csrc/flashinfer_adapter/flashinfer_config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | template 5 | bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs, 6 | int32_t* kv_indptr, int32_t* last_page_offset, 7 | void* tmpbuf, int head_dim, int num_layers, 8 | int layer_idx, int group_size, 9 | int num_kv_heads, int page_size, 10 | int batch_size); 11 | 12 | template 13 | bool FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr, 14 | int32_t* last_page_offset, void* tmpbuf, 15 | int head_dim, int num_layers, int layer_idx, 16 | int group_size, int num_kv_heads, 17 | int page_size, int batch_size); 18 | 19 | template 20 | void FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr, 21 | int32_t* last_page_offset, T* key, T* value, 22 | int32_t* seqlen_indptr, int num_layers, 23 | int layer_idx, int num_kv_heads, int page_size, 24 | int batch_size); 25 | 26 | template 27 | void FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr, 28 | int32_t* last_page_offset, T* key, T* value, 29 | int num_layers, int layer_idx, int num_kv_heads, 30 | int page_size, int batch_size); 31 | 32 | // clang-format off 33 | 34 | #define FOR_FlashInferBatchDecode_D(f, ...) \ 35 | f(64, __VA_ARGS__) \ 36 | f(128, __VA_ARGS__) 37 | 38 | // clang-format on 39 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [1.1.0](https://github.com/punica-ai/punica/compare/v1.0.3...v1.1.0) (2023-12-30) 4 | 5 | 6 | ### Features 7 | 8 | * **cuda-graph:** use pytorch current stream for sgmv ([b5f5e1a](https://github.com/punica-ai/punica/commit/b5f5e1a3aa46702c843e8815acf8f654a242e109)) 9 | 10 | 11 | ### Bug Fixes 12 | 13 | * **cuda-graph:** specify stream for precompute_sgmv_args in sgmv_cutlass ([07a40b9](https://github.com/punica-ai/punica/commit/07a40b9d30e98d88963e8a7e140120a25ac0d518)) 14 | 15 | ## [1.0.3](https://github.com/punica-ai/punica/compare/v1.0.2...v1.0.3) (2023-12-28) 16 | 17 | 18 | ### Bug Fixes 19 | 20 | * **release-please:** provide tag_name ([4cd735b](https://github.com/punica-ai/punica/commit/4cd735b747dbc113e26d4b9ede5f2bfaa06db5f0)) 21 | 22 | ## [1.0.2](https://github.com/punica-ai/punica/compare/v1.0.1...v1.0.2) (2023-12-28) 23 | 24 | 25 | ### Bug Fixes 26 | 27 | * **release-please:** trigger release_wheel when release is created ([3b980a1](https://github.com/punica-ai/punica/commit/3b980a1baa4ca24f75da39b9b7d0cfcd8b859470)) 28 | 29 | ## [1.0.1](https://github.com/punica-ai/punica/compare/v1.0.0...v1.0.1) (2023-12-28) 30 | 31 | 32 | ### Bug Fixes 33 | 34 | * **release-please:** allow triggering release_wheel gh action ([54f3d38](https://github.com/punica-ai/punica/commit/54f3d38e57d71503683eb275b989da63ba02e1bd)) 35 | 36 | ## [1.0.0](https://github.com/punica-ai/punica/compare/v0.3.1...v1.0.0) (2023-12-28) 37 | 38 | 39 | ### ⚠ BREAKING CHANGES 40 | 41 | * automate package release 42 | 43 | ### Features 44 | 45 | * automate package release ([31e2663](https://github.com/punica-ai/punica/commit/31e2663220537f4b10d65ada1a29936f04fbf953)) 46 | 47 | 48 | ### Bug Fixes 49 | 50 | * **pkg:** Include different sets of files for sdist and wheel ([0ef4f7b](https://github.com/punica-ai/punica/commit/0ef4f7b7a148048908f6895aef885ee188789cde)) 51 | -------------------------------------------------------------------------------- /tests/test_bgmv.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import punica.ops 5 | 6 | 7 | def assert_close(a, b): 8 | rtol, atol = { 9 | torch.float16: (5e-3, 5e-3), 10 | torch.bfloat16: (3e-2, 2e-2), 11 | }[a.dtype] 12 | torch.testing.assert_close(a, b, rtol=rtol, atol=atol) 13 | 14 | 15 | def _lora_ref_impl( 16 | y: torch.Tensor, 17 | x: torch.Tensor, 18 | wa_T_all: torch.Tensor, 19 | wb_T_all: torch.Tensor, 20 | indicies: torch.LongTensor, 21 | layer_idx: int, 22 | scale: float, 23 | ): 24 | bs = x.shape[0] 25 | s = torch.tensor(scale, dtype=torch.float32, device=x.device) 26 | for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): 27 | xi = x[i].unsqueeze(0).to(torch.float32) 28 | wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) 29 | wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) 30 | 31 | tmp = (xi @ wa).to(x.dtype).to(torch.float32) 32 | y[i] += (tmp @ wb).squeeze(0) * s 33 | 34 | 35 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 36 | @torch.inference_mode() 37 | def test_lora_correctness(dtype_str): 38 | torch.manual_seed(0xABCDABCD987) 39 | num_loras = 4 40 | num_layers = 5 41 | h1 = 4096 42 | h2 = 11008 43 | r = 8 44 | bs = 32 45 | scale = 0.123 46 | dtype = getattr(torch, dtype_str) 47 | device = torch.device("cuda:0") 48 | 49 | wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype, device=device) 50 | wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype, device=device) 51 | indices = torch.randint(num_loras, (bs,), dtype=torch.long, device=device) 52 | 53 | for layer_idx in range(num_layers): 54 | x = torch.randn(bs, h1, dtype=dtype, device=device) 55 | y = torch.randn(bs, h2, dtype=dtype, device=device) 56 | 57 | y_ref = y.clone() 58 | _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) 59 | y_our = y.clone() 60 | punica.ops.add_lora_bgmv( 61 | y_our, x, wa_T_all, wb_T_all, indices, layer_idx, scale 62 | ) 63 | 64 | assert_close(y_ref, y_our) 65 | -------------------------------------------------------------------------------- /benchmarks/bench_backbone_vs_lora.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import pathlib 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | import pytz 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .benchmark_utils import bench 12 | 13 | 14 | @torch.inference_mode() 15 | def bench_backbone_vs_lora(f): 16 | torch.manual_seed(0xABCDABCD987) 17 | dtype = torch.float16 18 | device = torch.device("cuda:0") 19 | h1 = 4096 20 | h2 = 11008 21 | r = 16 22 | bs_list = np.arange(1, 65) 23 | 24 | res = dict( 25 | backbone_avg=[], 26 | backbone_std=[], 27 | single_lora_avg=[], 28 | single_lora_std=[], 29 | multi_lora_avg=[], 30 | multi_lora_std=[], 31 | ) 32 | for bs in tqdm(bs_list): 33 | w = torch.randn(h1, h2, dtype=dtype, device=device) 34 | wa = torch.randn(h1, r, dtype=dtype, device=device) 35 | wb = torch.randn(r, h2, dtype=dtype, device=device) 36 | x = torch.randn(bs, 1, h1, dtype=dtype, device=device) 37 | 38 | def muti_lora(): 39 | for i in range(bs): 40 | x[i] @ wa @ wb 41 | 42 | l_backbone = bench(lambda: x @ w, warmup=200, repeat=500) 43 | l_single_lora = bench(lambda: x @ wa @ wb, warmup=200, repeat=500) 44 | l_multi_lora = bench(muti_lora, warmup=200, repeat=500) 45 | 46 | res["backbone_avg"].append(l_backbone.avg()) 47 | res["backbone_std"].append(l_backbone.std()) 48 | res["single_lora_avg"].append(l_single_lora.avg()) 49 | res["single_lora_std"].append(l_single_lora.std()) 50 | res["multi_lora_avg"].append(l_multi_lora.avg()) 51 | res["multi_lora_std"].append(l_multi_lora.std()) 52 | 53 | json.dump(res, f) 54 | 55 | 56 | def main(): 57 | this_file = pathlib.Path(__file__) 58 | project_root = this_file.parents[1] 59 | now = datetime.now(pytz.timezone("US/Pacific")) 60 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.json.gz" 61 | out_path = project_root / "data" / out_filename 62 | 63 | print(out_path) 64 | with gzip.open(out_path, "wt") as f: 65 | bench_backbone_vs_lora(f) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /src/punica/utils/lora.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | 5 | 6 | class LoraWeight: 7 | def __init__( 8 | self, 9 | num_layers: int, 10 | in_features: int, 11 | out_features: int, 12 | lora_rank: int, 13 | dtype: torch.dtype, 14 | device: torch.device, 15 | ): 16 | # SGMV-Shrink custom CUDA kernel uses column-major. 17 | self.wa = torch.zeros( 18 | (num_layers, lora_rank, in_features), dtype=dtype, device=device 19 | ) 20 | # SGMV-Expand cutlass kernel uses row-major. 21 | self.wb = torch.zeros( 22 | (num_layers, lora_rank, out_features), dtype=dtype, device=device 23 | ) 24 | 25 | def copy_from_tensor(self, a: torch.Tensor, b: torch.Tensor): 26 | """ 27 | Copy from column-major weight tensors. 28 | 29 | Args: 30 | a: Shape: `[num_layers, lora_rank, in_features]`. 31 | b: Shape: `[num_layers, out_features, lora_rank]`. 32 | """ 33 | self.wa.copy_(a.to(self.wa.device).to(self.wa.dtype)) 34 | self.wb.copy_(b.to(self.wb.device).to(self.wb.dtype).transpose(1, 2)) 35 | 36 | @property 37 | def device(self) -> torch.device: 38 | return self.wa.device 39 | 40 | @property 41 | def dtype(self) -> torch.dtype: 42 | return self.wa.dtype 43 | 44 | @property 45 | def num_layers(self) -> int: 46 | return self.wa.size(0) 47 | 48 | @property 49 | def in_features(self) -> int: 50 | return self.wa.size(2) 51 | 52 | @property 53 | def out_features(self) -> int: 54 | return self.wb.size(2) 55 | 56 | @property 57 | def lora_rank(self) -> int: 58 | return self.wa.size(1) 59 | 60 | 61 | class BatchedLoraWeight: 62 | def __init__(self, weights: Sequence[LoraWeight]): 63 | assert len(weights) > 0 64 | device = weights[0].device 65 | self.wa_ptr = torch.tensor( 66 | [w.wa.data_ptr() for w in weights], dtype=torch.int64, device=device 67 | ) 68 | self.wb_ptr = torch.tensor( 69 | [w.wb.data_ptr() for w in weights], dtype=torch.int64, device=device 70 | ) 71 | -------------------------------------------------------------------------------- /benchmarks/fastertransformer/.clang-format: -------------------------------------------------------------------------------- 1 | # https://github.com/NVIDIA/FasterTransformer/blob/main/.clang-format 2 | 3 | Language: Cpp 4 | AccessModifierOffset: -4 5 | AlignAfterOpenBracket: Align 6 | AllowShortEnumsOnASingleLine: false 7 | AlignConsecutiveAssignments: true 8 | AlignConsecutiveDeclarations: true 9 | AlignEscapedNewlines: Right 10 | AlignOperands: true 11 | AlignTrailingComments: true 12 | AllowAllParametersOfDeclarationOnNextLine: true 13 | AllowAllArgumentsOnNextLine: true 14 | AllowShortBlocksOnASingleLine: Empty 15 | AllowShortCaseLabelsOnASingleLine: false 16 | AllowShortFunctionsOnASingleLine: Empty 17 | AllowShortIfStatementsOnASingleLine: Never 18 | AllowShortLoopsOnASingleLine: false 19 | AlwaysBreakAfterReturnType: None 20 | AlwaysBreakBeforeMultilineStrings: false 21 | AlwaysBreakTemplateDeclarations: true 22 | BinPackArguments: false 23 | BinPackParameters: false 24 | BreakBeforeBinaryOperators: NonAssignment 25 | BreakBeforeBraces: Stroustrup 26 | BreakBeforeTernaryOperators: false 27 | BreakConstructorInitializers: AfterColon 28 | BreakInheritanceList: AfterColon 29 | BreakStringLiterals: false 30 | ColumnLimit: 120 31 | CompactNamespaces: false 32 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 33 | ConstructorInitializerIndentWidth: 4 34 | ContinuationIndentWidth: 4 35 | Cpp11BracedListStyle: true 36 | DerivePointerAlignment: false 37 | FixNamespaceComments: true 38 | IndentCaseLabels: true 39 | IndentPPDirectives: None 40 | IndentWidth: 4 41 | IndentWrappedFunctionNames: false 42 | KeepEmptyLinesAtTheStartOfBlocks: true 43 | MaxEmptyLinesToKeep: 1 44 | NamespaceIndentation: None 45 | PointerAlignment: Left 46 | ReflowComments: true 47 | SortIncludes: true 48 | SortUsingDeclarations: false 49 | SpaceAfterCStyleCast: false 50 | SpaceAfterTemplateKeyword: false 51 | SpaceBeforeAssignmentOperators: true 52 | SpaceBeforeCtorInitializerColon: false 53 | SpaceBeforeInheritanceColon: false 54 | SpaceBeforeParens: ControlStatements 55 | SpaceInEmptyParentheses: false 56 | SpacesBeforeTrailingComments: 2 57 | SpacesInAngles: false 58 | SpacesInCStyleCastParentheses: false 59 | SpacesInContainerLiterals: false 60 | SpacesInParentheses: false 61 | SpacesInSquareBrackets: false 62 | Standard: Cpp11 63 | TabWidth: 4 64 | UseTab: Never 65 | -------------------------------------------------------------------------------- /src/punica/utils/convert_lora_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | 4 | import torch 5 | 6 | 7 | def convert_lora_weight(peft_weight_path): 8 | weights = torch.load( 9 | peft_weight_path, map_location=torch.device("cpu"), weights_only=True 10 | ) 11 | projs = set() 12 | num_layers = 0 13 | rank = 0 14 | tmp = {} 15 | for key, value in weights.items(): 16 | layer, proj, ab = re.findall( 17 | r"\.(\d+)\..*\.(\w+)_proj\.lora_(A|B)\.weight$", key 18 | )[0] 19 | ab = ab.upper() 20 | layer = int(layer) 21 | projs.add(proj) 22 | # PyTorch Linear layer is column-major 23 | if ab == "A": 24 | assert value.size(0) < value.size(1) 25 | r = value.size(0) 26 | elif ab == "B": 27 | assert value.size(0) > value.size(1) 28 | r = value.size(1) 29 | else: 30 | raise KeyError(f"Unknown weight key: {key}") 31 | if rank != 0: 32 | assert r == rank 33 | else: 34 | rank = r 35 | num_layers = max(num_layers, layer + 1) 36 | tmp[(layer, proj, ab)] = value 37 | 38 | out = {} 39 | for proj in projs: 40 | for ab in "AB": 41 | tensors = [] 42 | for layer in range(num_layers): 43 | tensors.append(tmp[(layer, proj, ab)]) 44 | out[f"{proj}.{ab}"] = torch.stack(tensors) 45 | 46 | return out 47 | 48 | 49 | def _convert_lora_weight_main(): 50 | parser = argparse.ArgumentParser( 51 | description="Convert LoRA weight to Punica's format." 52 | ) 53 | parser.add_argument( 54 | "input", help="Path to the LoRA weight file, as trained by PEFT." 55 | ) 56 | parser.add_argument( 57 | "output", help="Path to the output LoRA weight in Punica's format." 58 | ) 59 | args = parser.parse_args() 60 | 61 | weights = convert_lora_weight(args.input) 62 | print("Input:", args.input) 63 | for key, value in weights.items(): 64 | print("Key:", key, " shape:", list(value.shape), " dtype:", value.dtype) 65 | torch.save(weights, args.output) 66 | print("Saved to:", args.output) 67 | 68 | 69 | if __name__ == "__main__": 70 | _convert_lora_weight_main() 71 | -------------------------------------------------------------------------------- /csrc/flashinfer_adapter/flashinfer_decl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "flashinfer/page.cuh" 3 | #include "flashinfer/rope.cuh" 4 | 5 | namespace flashinfer { 6 | template 10 | cudaError_t BatchPrefillWithPagedKVCacheDispatched( 11 | DTypeIn* q, paged_kv_t paged_kv, 12 | IdType* qo_indptr, DTypeOut* o, float* tmp, uint32_t num_qo_heads, 13 | float rope_scale, float rope_theta, cudaStream_t stream); 14 | } 15 | 16 | #define INST_BatchPrefill(T, PAGE_SIZE, GROUP_SIZE, HEAD_DIM) \ 17 | namespace flashinfer { \ 18 | template cudaError_t BatchPrefillWithPagedKVCacheDispatched< \ 19 | PAGE_SIZE, GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, \ 20 | RotaryMode::kLlama, /* ALLOW_FP16_QK_REDUCTION= */ false, \ 21 | /* CAUSAL= */ true, T, T, int32_t>( \ 22 | T * q, paged_kv_t paged_kv, \ 23 | int32_t* qo_indptr, T* o, float* tmp, uint32_t num_qo_heads, \ 24 | float rope_scale, float rope_theta, cudaStream_t stream); \ 25 | } 26 | 27 | namespace flashinfer { 28 | template 31 | cudaError_t BatchDecodeWithPagedKVCacheDispatched( 32 | DTypeIn* q, paged_kv_t paged_kv, DTypeOut* o, 33 | float* tmp, float rope_scale, float rope_theta, cudaStream_t stream); 34 | } 35 | #define INST_BatchDecode(T, PAGE_SIZE, GROUP_SIZE, HEAD_DIM) \ 36 | namespace flashinfer { \ 37 | template cudaError_t BatchDecodeWithPagedKVCacheDispatched< \ 38 | PAGE_SIZE, GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, \ 39 | RotaryMode::kLlama, T, T, int32_t>( \ 40 | T * q, paged_kv_t paged_kv, T* o, \ 41 | float* tmp, float rope_scale, float rope_theta, cudaStream_t stream); \ 42 | } 43 | -------------------------------------------------------------------------------- /examples/finetune/create-finetune-data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pathlib 4 | 5 | import datasets 6 | 7 | 8 | def create_dataset( 9 | preset_name, 10 | dataset_name, 11 | dataset_config_name, 12 | prompt_template, 13 | response_template, 14 | output_dir, 15 | ): 16 | raw_datasets = datasets.load_dataset(dataset_name, dataset_config_name) 17 | for split, dataset in raw_datasets.items(): 18 | filename = f"{preset_name}-{split}.jsonl" 19 | outpath = pathlib.Path(output_dir) / filename 20 | print(outpath) 21 | with open(outpath, "w") as f: 22 | for example in dataset: 23 | prompt = prompt_template.format(**example) 24 | response = response_template.format(**example) 25 | f.write(json.dumps({"prompt": prompt, "response": response})) 26 | f.write("\n") 27 | 28 | 29 | presets = {} 30 | 31 | presets["gsm8k"] = dict( 32 | dataset_name="gsm8k", 33 | dataset_config_name="main", 34 | prompt="<>\nAnswer the following Grade School Math problem.\n<>\n[INST] {question} [/INST]\n", 35 | response="{answer}", 36 | ) 37 | 38 | presets["sqlctx"] = dict( 39 | dataset_name="b-mc2/sql-create-context", 40 | dataset_config_name="main", 41 | prompt="<>\nGenerate a correct SQL query from the following database schema.\n{context}\n<>\n[INST] {question} [/INST]\n", 42 | response="{answer}", 43 | ) 44 | 45 | presets["viggo"] = dict( 46 | dataset_name="GEM/viggo", 47 | dataset_config_name="main", 48 | prompt="<>\nGenerate a description based on the following representation.\n<>\n[INST] {meaning_representation} [/INST]\n", 49 | response="{target}", 50 | ) 51 | 52 | 53 | def main(): 54 | data_dir = pathlib.Path(__file__).parent / "data" 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--preset", choices=presets.keys(), required=True) 57 | parser.add_argument("--output_dir", default=str(data_dir)) 58 | args = parser.parse_args() 59 | p = presets[args.preset] 60 | create_dataset( 61 | preset_name=args.preset, 62 | dataset_name=p["dataset_name"], 63 | dataset_config_name=p["dataset_config_name"], 64 | prompt_template=p["prompt"], 65 | response_template=p["response"], 66 | output_dir=args.output_dir, 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /benchmarks/bench_sgmv_cutlass.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import itertools 3 | import json 4 | import pathlib 5 | from datetime import datetime 6 | 7 | import pytz 8 | import torch 9 | from tqdm import tqdm 10 | 11 | import punica.ops 12 | 13 | from .benchmark_utils import bench, gc_torch, get_lora_lens 14 | 15 | 16 | @torch.inference_mode() 17 | def bench_sgmv(f): 18 | bs_ = list(range(1, 65)) 19 | pop_ = ["bmm", "bgmv", "uniform", "zipf:1.5"] 20 | h1_ = [8, 16, 32, 64] 21 | h2 = 4096 22 | num_layers = 1 23 | dtype = torch.float16 24 | device = torch.device("cuda:0") 25 | 26 | all_ = list(itertools.product(h1_, pop_, bs_)) 27 | for h1, pop, bs in (pbar := tqdm(all_)): 28 | problem_sizes = get_lora_lens(bs, pop) 29 | 30 | setup = dict( 31 | h1=h1, 32 | h2=h2, 33 | popularity=pop, 34 | num_problems=len(problem_sizes), 35 | batch_size=bs, 36 | ) 37 | pbar.set_postfix(setup) 38 | 39 | torch.manual_seed(0xABCDABCD987) 40 | gc_torch() 41 | 42 | w = [ 43 | torch.randn((num_layers, h1, h2), dtype=dtype, device=device) 44 | for _ in range(len(problem_sizes)) 45 | ] 46 | w_ptr = torch.tensor([t.data_ptr() for t in w], dtype=torch.int64, device=device) 47 | s = torch.cumsum( 48 | torch.tensor([0] + problem_sizes, device=device), dim=0, dtype=torch.int32 49 | ) 50 | x = torch.randn((s[-1], h1), dtype=dtype, device=device) 51 | y = torch.randn((s[-1], h2), dtype=dtype, device=device) 52 | 53 | latency = bench( 54 | lambda: punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0), 55 | warmup=200, 56 | repeat=1000, 57 | ) 58 | 59 | result = { 60 | "setup": setup, 61 | "latency": { 62 | "avg": latency.avg(), 63 | "std": latency.std(), 64 | }, 65 | } 66 | f.write(json.dumps(result) + "\n") 67 | f.flush() 68 | 69 | 70 | def main(): 71 | this_file = pathlib.Path(__file__) 72 | project_root = this_file.parents[1] 73 | now = datetime.now(pytz.timezone("US/Pacific")) 74 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.jsonl.gz" 75 | out_path = project_root / "data" / out_filename 76 | 77 | print(out_path) 78 | with gzip.open(out_path, "wt") as f: 79 | bench_sgmv(f) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /benchmarks/bench_sgmv.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import itertools 3 | import json 4 | import pathlib 5 | from datetime import datetime 6 | 7 | import pytz 8 | import torch 9 | from tqdm import tqdm 10 | 11 | import punica.ops 12 | 13 | from .benchmark_utils import bench, gc_torch, get_lora_lens 14 | 15 | 16 | @torch.inference_mode() 17 | def bench_sgmv(f): 18 | bs_ = list(range(1, 65)) 19 | pop_ = ["bmm", "bgmv", "uniform", "zipf:1.5", "Nx8"] 20 | h1 = 4096 21 | h2 = 16 22 | num_layers = 1 23 | dtype = torch.float16 24 | device = torch.device("cuda:0") 25 | 26 | all_ = list(itertools.product(pop_, bs_)) 27 | for pop, bs in (pbar := tqdm(all_)): 28 | if pop == "Nx8": 29 | if bs % 8 != 0: 30 | continue 31 | problem_sizes = [(bs // 8)] * 8 32 | else: 33 | problem_sizes = get_lora_lens(bs, pop) 34 | 35 | setup = dict( 36 | h1=h1, 37 | h2=h2, 38 | popularity=pop, 39 | num_problems=len(problem_sizes), 40 | batch_size=bs, 41 | ) 42 | pbar.set_postfix(setup) 43 | 44 | torch.manual_seed(0xABCDABCD987) 45 | gc_torch() 46 | 47 | w = [ 48 | torch.randn((num_layers, h2, h1), dtype=dtype, device=device) 49 | for _ in range(len(problem_sizes)) 50 | ] 51 | w_ptr = torch.tensor( 52 | [t.data_ptr() for t in w], dtype=torch.int64, device=device 53 | ) 54 | s = torch.cumsum( 55 | torch.tensor([0] + problem_sizes, device=device), dim=0, dtype=torch.int32 56 | ) 57 | x = torch.randn((s[-1], h1), dtype=dtype, device=device) 58 | y = torch.randn((s[-1], h2), dtype=dtype, device=device) 59 | 60 | latency = bench( 61 | lambda: punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0), 62 | warmup=200, 63 | repeat=1000, 64 | ) 65 | 66 | result = { 67 | "setup": setup, 68 | "latency": { 69 | "avg": latency.avg(), 70 | "std": latency.std(), 71 | }, 72 | } 73 | f.write(json.dumps(result) + "\n") 74 | f.flush() 75 | 76 | 77 | def main(): 78 | this_file = pathlib.Path(__file__) 79 | project_root = this_file.parents[1] 80 | now = datetime.now(pytz.timezone("US/Pacific")) 81 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.jsonl.gz" 82 | out_path = project_root / "data" / out_filename 83 | 84 | print(out_path) 85 | with gzip.open(out_path, "wt") as f: 86 | bench_sgmv(f) 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /csrc/sgmv_flashinfer/sgmv_all.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | #include "sgmv_config.h" 8 | #include "sgmv_flashinfer.cuh" 9 | 10 | template 11 | bool sgmv_shrink(T* y, T* x, T** w, int32_t* s, void* tmp, 12 | uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, 13 | cudaStream_t stream) { 14 | static_assert(d_out % 16 == 0); 15 | 16 | constexpr uint32_t num_warps = 4; 17 | constexpr uint32_t num_stages = 2; 18 | constexpr uint32_t num_k_frags_per_stage = 8; 19 | constexpr uint32_t num_blocks_n = d_out / 16; 20 | uint32_t smem = num_stages * sizeof(T) * num_k_frags_per_stage * 16 * 16 * 21 | (num_warps + num_blocks_n); 22 | auto cooperative_kernel = 23 | flashinfer::sgmv::sgmv_shrink; 24 | auto kernel = flashinfer::sgmv::sgmv_shrink; 25 | 26 | int dev_id = 0; 27 | int num_blocks_per_sm = 0; 28 | int num_sm = 0; 29 | bool use_cooperative = true; 30 | cudaGetDevice(&dev_id); 31 | cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id); 32 | cudaOccupancyMaxActiveBlocksPerMultiprocessor( 33 | &num_blocks_per_sm, cooperative_kernel, num_warps * 32, smem); 34 | 35 | const uint32_t max_grid_size = num_sm * num_blocks_per_sm; 36 | 37 | uint32_t chunk_size = 256; 38 | uint32_t num_chunks = (d_in + chunk_size - 1) / chunk_size; 39 | if (num_chunks * num_problems > max_grid_size) { 40 | use_cooperative = false; 41 | chunk_size = d_in; 42 | num_chunks = 1; 43 | } 44 | 45 | dim3 nthrs(32, num_warps); 46 | dim3 nblks(num_chunks, num_problems); 47 | 48 | void* args[] = {(void*)&y, (void*)&x, (void*)&w, 49 | (void*)&s, (void*)&tmp, (void*)&num_problems, 50 | (void*)&d_in, (void*)&layer_idx, (void*)&chunk_size}; 51 | 52 | cudaError_t status; 53 | if (use_cooperative) { 54 | if (smem > 46 * 1024) { 55 | cudaFuncSetAttribute(cooperative_kernel, 56 | cudaFuncAttributeMaxDynamicSharedMemorySize, smem); 57 | } 58 | status = cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks, 59 | nthrs, args, smem, stream); 60 | } else { 61 | if (smem > 46 * 1024) { 62 | cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 63 | smem); 64 | } 65 | status = cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem, stream); 66 | } 67 | return status == cudaSuccess; 68 | } 69 | 70 | #define INST(T, d_out) \ 71 | template bool sgmv_shrink( \ 72 | T * y, T * x, T * *w, int32_t * s, void* tmp, uint32_t num_problems, \ 73 | uint32_t d_in, uint32_t layer_idx, cudaStream_t stream); 74 | 75 | FOR_SGMV_NARROW(INST, nv_half); 76 | FOR_SGMV_NARROW(INST, nv_bfloat16); 77 | -------------------------------------------------------------------------------- /benchmarks/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import dataclasses 3 | import gc 4 | import itertools 5 | import time 6 | from collections.abc import Callable 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class Benchmark(abc.ABC): 13 | def setup(self): 14 | pass 15 | 16 | def before_run(self): 17 | pass 18 | 19 | @abc.abstractmethod 20 | def run(self): 21 | pass 22 | 23 | def after_run(self): 24 | pass 25 | 26 | def teardown(self): 27 | pass 28 | 29 | 30 | class wrap_benchmark(Benchmark): 31 | def __init__(self, fn_run: Callable[[], None]): 32 | self.fn_run = fn_run 33 | 34 | def run(self): 35 | self.fn_run() 36 | 37 | 38 | @dataclasses.dataclass 39 | class BenchResult: 40 | warmup: int 41 | repeat: int 42 | latency: np.ndarray 43 | 44 | def avg(self) -> np.ndarray: 45 | return np.mean(self.latency) 46 | 47 | def std(self) -> np.ndarray: 48 | return np.std(self.latency) 49 | 50 | def avg_std(self) -> np.ndarray: 51 | return self.avg(), self.std() 52 | 53 | 54 | def bench( 55 | f: Benchmark | Callable[[], None], 56 | warmup: int = 100, 57 | repeat: int = 500, 58 | ) -> BenchResult: 59 | b = f if isinstance(f, Benchmark) else wrap_benchmark(f) 60 | 61 | cache = torch.empty(256 * 2**20, dtype=torch.int8, device="cuda:0") 62 | b.setup() 63 | 64 | latency = np.zeros(repeat, dtype=np.float64) 65 | for i in range(-warmup, repeat): 66 | b.before_run() 67 | cache.zero_() 68 | 69 | torch.cuda.synchronize() 70 | t0 = time.perf_counter() 71 | b.run() 72 | torch.cuda.synchronize() 73 | t1 = time.perf_counter() 74 | 75 | b.after_run() 76 | 77 | if i >= 0: 78 | latency[i] = t1 - t0 79 | 80 | b.teardown() 81 | del cache 82 | return BenchResult(warmup, repeat, latency) 83 | 84 | 85 | def gc_torch(): 86 | gc.collect() 87 | torch.cuda.empty_cache() 88 | 89 | 90 | def batched(iterable, n): 91 | "Batch data into tuples of length n. The last batch may be shorter." 92 | # batched('ABCDEFG', 3) --> ABC DEF G 93 | if n < 1: 94 | raise ValueError("n must be at least one") 95 | it = iter(iterable) 96 | while batch := list(itertools.islice(it, n)): 97 | yield batch 98 | 99 | 100 | def get_lora_lens(bs: int, popularity: str) -> list[int]: 101 | if popularity == "bmm": 102 | return [bs] 103 | if popularity == "bgmv": 104 | return [1] * bs 105 | if popularity == "uniform": 106 | n = int(np.ceil(np.sqrt(bs))) 107 | lens = np.array([bs // n] * n) 108 | while True: 109 | diff = bs - lens.sum() 110 | if diff == 0: 111 | break 112 | lens[: abs(diff)] += np.sign(diff) 113 | return lens.tolist() 114 | if popularity.startswith("zipf:"): 115 | alpha = float(popularity.split(":")[1]) 116 | assert alpha > 1 117 | lens = [] 118 | a = 1 119 | while sum(lens) + int(np.floor(a)) < bs: 120 | lens.append(int(np.floor(a))) 121 | a *= alpha 122 | lens.append(bs - sum(lens)) 123 | return sorted(lens, reverse=True) 124 | raise KeyError(popularity) 125 | -------------------------------------------------------------------------------- /benchmarks/bench_batch_decode.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import itertools 3 | import json 4 | import pathlib 5 | from datetime import datetime 6 | 7 | import pytz 8 | import torch 9 | from tqdm import tqdm 10 | 11 | import punica.ops 12 | from punica import BatchedKvCache, KvCache, KvPool 13 | 14 | from .benchmark_utils import bench, gc_torch 15 | 16 | 17 | class batch_decode_Resources: 18 | def __init__( 19 | self, 20 | num_heads: int, 21 | head_dim: int, 22 | page_len: int, 23 | seqlens: list[int], 24 | dtype: str, 25 | device: torch.device, 26 | ): 27 | dtype = getattr(torch, dtype) 28 | self.kvpool = KvPool( 29 | num_layers=1, 30 | num_heads=num_heads, 31 | head_dim=head_dim, 32 | page_len=page_len, 33 | dtype=dtype, 34 | device=device, 35 | ) 36 | self.q = torch.randn( 37 | (len(seqlens), num_heads, head_dim), dtype=dtype, device=device 38 | ) 39 | kv_list: list[KvCache] = [] 40 | for seqlen in seqlens: 41 | kv_list.append(KvCache(self.kvpool, seqlen)) 42 | self.kv_list = kv_list 43 | self.kv = BatchedKvCache(kv_list) 44 | 45 | def release(self): 46 | for kvcache in self.kv_list: 47 | kvcache.release() 48 | 49 | 50 | @torch.inference_mode() 51 | def bench_batch_decode(f): 52 | num_heads_ = [32, 40] 53 | batch_size_ = [ 54 | 1, 55 | 2, 56 | 3, 57 | 4, 58 | 5, 59 | 6, 60 | 7, 61 | 8, 62 | 10, 63 | 12, 64 | 14, 65 | 16, 66 | 20, 67 | 24, 68 | 28, 69 | 32, 70 | 40, 71 | 48, 72 | 56, 73 | 64, 74 | ] 75 | seqlen_ = list(reversed(range(2048, 0, -64))) 76 | dtype = "float16" 77 | device = torch.device("cuda:0") 78 | page_len = 16 79 | head_dim = 128 80 | 81 | all_ = list(itertools.product(num_heads_, seqlen_, batch_size_)) 82 | for num_heads, seqlen, batch_size in (pbar := tqdm(all_)): 83 | setup = dict( 84 | num_heads=num_heads, 85 | head_dim=head_dim, 86 | page_len=page_len, 87 | seqlen=seqlen, 88 | batch_size=batch_size, 89 | ) 90 | pbar.set_postfix(setup) 91 | torch.manual_seed(0xABCDABCD987) 92 | gc_torch() 93 | res = batch_decode_Resources( 94 | num_heads=num_heads, 95 | head_dim=head_dim, 96 | page_len=page_len, 97 | seqlens=[seqlen] * batch_size, 98 | dtype=dtype, 99 | device=device, 100 | ) 101 | latency = bench(lambda: punica.ops.batch_decode(res.q, res.kv, layer_idx=0)) 102 | res.release() 103 | 104 | result = { 105 | "setup": setup, 106 | "latency": {"avg": latency.avg(), "std": latency.std()}, 107 | } 108 | f.write(json.dumps(result) + "\n") 109 | f.flush() 110 | 111 | 112 | def main(): 113 | this_file = pathlib.Path(__file__) 114 | project_root = this_file.parents[1] 115 | now = datetime.now(pytz.timezone("US/Pacific")) 116 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.jsonl.gz" 117 | out_path = project_root / "data" / out_filename 118 | 119 | print(out_path) 120 | with gzip.open(out_path, "wt") as f: 121 | bench_batch_decode(f) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /examples/finetune/README.md: -------------------------------------------------------------------------------- 1 | # Example: Finetune & Convert weight to Punica format 2 | 3 | In this example, we will first use [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) to finetune [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf) on three datasets: [gsm8k](https://huggingface.co/datasets/gsm8k), [sqlctx](https://huggingface.co/datasets/b-mc2/sql-create-context), [viggo](https://huggingface.co/datasets/GEM/viggo). Then, convert the PEFT weight to Punica format. Finally, we run each model using Punia. 4 | 5 | ## Download finetuned weight 6 | 7 | If you want to skip the finetuning process, you can download the finetuned weights: 8 | 9 | ```bash 10 | # CWD: project root 11 | mkdir -p model 12 | git lfs install 13 | git clone https://huggingface.co/abcdabcd987/gsm8k-llama2-7b-lora-16 model/gsm8k-r16 14 | git clone https://huggingface.co/abcdabcd987/sqlctx-llama2-7b-lora-16 model/sqlctx-r16 15 | git clone https://huggingface.co/abcdabcd987/viggo-llama2-7b-lora-16 model/viggo-r16 16 | ``` 17 | 18 | ## Finetune on local GPU 19 | 20 | If you prefer to finetune by yourself: 21 | 22 | ```bash 23 | git clone https://github.com/hiyouga/LLaMA-Factory.git --branch v0.2.2 examples/finetune/LLaMA-Factory 24 | 25 | python examples/finetune/create-finetune-data.py --preset gsm8k 26 | python examples/finetune/create-finetune-data.py --preset sqlctx 27 | python examples/finetune/create-finetune-data.py --preset viggo 28 | 29 | bash examples/finetune/finetune.sh gsm8k 30 | bash examples/finetune/finetune.sh sqlctx 31 | bash examples/finetune/finetune.sh viggo 32 | ``` 33 | 34 | ## Convert weight to Punica format 35 | 36 | ```bash 37 | python -m punica.utils.convert_lora_weight model/gsm8k-r16/adapter_model.bin model/gsm8k-r16.punica.pt 38 | python -m punica.utils.convert_lora_weight model/sqlctx-r16/adapter_model.bin model/sqlctx-r16.punica.pt 39 | python -m punica.utils.convert_lora_weight model/viggo-r16/adapter_model.bin model/viggo-r16.punica.pt 40 | ``` 41 | 42 | ## Test run 43 | 44 | ```bash 45 | gsm8k_prompt=$'<>\nAnswer the following Grade School Math problem.\n<>\n[INST] A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take? [/INST]\n' 46 | sqlctx_prompt=$'<>\nGenerate a correct SQL query from the following database schema.\nCREATE TABLE student_course_registrations (student_id VARCHAR, registration_date VARCHAR); CREATE TABLE students (student_details VARCHAR, student_id VARCHAR)\n<>\n[INST] What is detail of the student who most recently registered course? [/INST]\n' 47 | viggo_prompt=$'<>\nGenerate a description based on the following representation.\n<>\n[INST] verify_attribute(name[Metal Gear Solid 3: Snake Eater], release_year[2004], rating[excellent], genres[action-adventure, shooter, tactical]) [/INST]\n' 48 | 49 | python examples/textgen_lora.py --lora-weight model/gsm8k-r16.punica.pt --prompt "$gsm8k_prompt" 50 | python examples/textgen_lora.py --lora-weight model/sqlctx-r16.punica.pt --prompt "$sqlctx_prompt" 51 | python examples/textgen_lora.py --lora-weight model/viggo-r16.punica.pt --prompt "$viggo_prompt" 52 | ``` 53 | 54 | Reference outputs: 55 | 56 | ``` 57 | It takes 2/2=<<2/2=1>>1 bolt of white fiber 58 | So the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric 59 | #### 3 60 | 61 | SELECT T2.student_details FROM student_course_registrations AS T1 JOIN students AS T2 ON T1.student_id = T2.student_id ORDER BY T1.registration_date DESC LIMIT 1 62 | 63 | You mentioned that you greatly enjoyed Metal Gear Solid 3: Snake Eater. Would you say you're a big fan of action-adventure games from 2004 involving shooting and tactical gameplay? 64 | ``` 65 | -------------------------------------------------------------------------------- /.github/workflows/release_wheel.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | workflow_call: 4 | inputs: 5 | tag_name: 6 | required: true 7 | type: string 8 | secrets: 9 | WHL_TOKEN: 10 | required: true 11 | PYPI_TEST_TOKEN: 12 | required: true 13 | 14 | env: 15 | TORCH_CUDA_ARCH_LIST: "8.0 8.6 8.9+PTX" # Need fix for 9.0 16 | PUNICA_CI_TORCH_VERSION: "2.1.0" 17 | 18 | jobs: 19 | build: 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | python: ["3.10", "3.11"] 24 | cuda: ["11.8", "12.1"] 25 | runs-on: [self-hosted] 26 | steps: 27 | - uses: actions/checkout@v4 28 | with: 29 | submodules: true 30 | 31 | - name: Build wheel 32 | run: | 33 | chown -R $CI_UID:$CI_GID "$GITHUB_WORKSPACE" 34 | version="$(cat version.txt)" 35 | docker run --rm -t \ 36 | -v "$CI_RUNNER_CACHE_DIR":/ci-cache \ 37 | -v "$GITHUB_WORKSPACE":/app \ 38 | -e PUNICA_CI_PYTHON_VERSION=${{ matrix.python }} \ 39 | -e PUNICA_CI_CUDA_VERSION=${{ matrix.cuda }} \ 40 | -e PUNICA_CI_TORCH_VERSION=$PUNICA_CI_TORCH_VERSION \ 41 | -e PUNICA_BUILD_VERSION=$version \ 42 | -e TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" \ 43 | --user $CI_UID:$CI_GID \ 44 | pytorch/manylinux-builder:cuda${{ matrix.cuda }} \ 45 | bash /app/ci/run-ci-build-wheel.bash 46 | 47 | - run: du -h dist/* 48 | 49 | - uses: actions/upload-artifact@v4 50 | with: 51 | name: wheel-cuda${{ matrix.cuda }}-python${{ matrix.python }} 52 | path: dist/* 53 | 54 | release: 55 | needs: build 56 | runs-on: ubuntu-latest 57 | steps: 58 | - uses: actions/download-artifact@v4 59 | with: 60 | path: dist/ 61 | merge-multiple: true 62 | pattern: wheel-* 63 | 64 | - run: ls -lah dist/ 65 | 66 | - uses: softprops/action-gh-release@v1 67 | with: 68 | tag_name: ${{ inputs.tag_name }} 69 | files: | 70 | dist/punica-*.whl 71 | dist/punica-*.tar.gz 72 | 73 | - name: Clone wheel index 74 | run: git clone https://oauth2:${WHL_TOKEN}@github.com/punica-ai/whl.git punica-whl 75 | env: 76 | WHL_TOKEN: ${{ secrets.WHL_TOKEN }} 77 | 78 | - name: Update wheel index 79 | shell: python 80 | run: | 81 | import pathlib 82 | import hashlib 83 | import re 84 | for path in sorted(pathlib.Path("dist").glob("*.whl")): 85 | with open(path, "rb") as f: 86 | sha256 = hashlib.sha256(f.read()).hexdigest() 87 | ver, cu = re.findall(r"punica-([0-9.]+)\+cu(\d+)-", path.name)[0] 88 | with open(f"punica-whl/cu{cu}/punica/index.html", "a") as f: 89 | f.write(f'{path.name}
\n') 90 | 91 | - name: Push wheel index 92 | run: | 93 | cd punica-whl 94 | git config --local user.name "github-actions[bot]" 95 | git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" 96 | git add -A 97 | git commit -m "update whl" 98 | git push 99 | 100 | - name: Upload sdist to pypi 101 | run: | 102 | pip install twine 103 | python -m twine upload --repository testpypi --username=__token__ dist/*.tar.gz 104 | env: 105 | TWINE_PASSWORD: ${{ secrets.PYPI_TEST_TOKEN }} 106 | -------------------------------------------------------------------------------- /benchmarks/bench_model_prefill_decode.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F821 2 | import gzip 3 | import itertools 4 | import json 5 | import pathlib 6 | from datetime import datetime 7 | 8 | import pytz 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import LlamaConfig 12 | 13 | from punica import BatchedKvCache, BatchLenInfo, KvCache, KvPool 14 | 15 | from .benchmark_utils import bench, gc_torch 16 | 17 | 18 | class model_Resources: 19 | def __init__( 20 | self, 21 | config: LlamaConfig, 22 | page_len: int, 23 | seqlen: int, 24 | prefills: int, 25 | decodes: int, 26 | dtype: torch.dtype, 27 | device: torch.device, 28 | ): 29 | num_heads = config.num_attention_heads 30 | head_dim = config.hidden_size // config.num_attention_heads 31 | self.kvpool = KvPool( 32 | num_layers=config.num_hidden_layers, 33 | num_heads=num_heads, 34 | head_dim=head_dim, 35 | page_len=page_len, 36 | dtype=dtype, 37 | device=device, 38 | ) 39 | self.input_ids = torch.randint( 40 | 0, 32000, (seqlen * prefills + decodes,), dtype=torch.int64, device=device 41 | ) 42 | self.prefill_kv = ( 43 | BatchedKvCache([KvCache(self.kvpool, seqlen) for _ in range(prefills)]) 44 | if prefills 45 | else None 46 | ) 47 | self.decode_kv = ( 48 | BatchedKvCache([KvCache(self.kvpool, seqlen) for _ in range(decodes)]) 49 | if decodes 50 | else None 51 | ) 52 | self.blen = BatchLenInfo([seqlen] * prefills, decodes, device) 53 | 54 | 55 | @torch.inference_mode() 56 | def bench_model_prefill_decode(f): 57 | num_heads_ = [32] 58 | num_layers_ = [32] 59 | intermediate_size_ = [11008] 60 | prefill_decode_ = [(0, 1), (1, 0)] 61 | batch_size_ = list(range(1, 33)) 62 | seqlen_ = [128, 512, 1024, 1536, 2048] 63 | dtype = torch.float16 64 | device = torch.device("cuda:0") 65 | page_len = 16 66 | head_dim = 128 67 | 68 | all_ = list( 69 | itertools.product( 70 | zip(num_heads_, num_layers_, intermediate_size_), 71 | prefill_decode_, 72 | seqlen_, 73 | batch_size_, 74 | ) 75 | ) 76 | last_num_heads = 0 77 | model = None 78 | for (num_heads, num_layers, intermediate_size), ( 79 | prefill, 80 | decode, 81 | ), seqlen, batch_size in (pbar := tqdm(all_)): 82 | if last_num_heads != num_heads: 83 | config = LlamaConfig( 84 | hidden_size=num_heads * head_dim, 85 | num_attention_heads=num_heads, 86 | intermediate_size=intermediate_size, 87 | num_hidden_layers=num_layers, 88 | ) 89 | del model 90 | gc_torch() 91 | default_dtype = torch.get_default_dtype() 92 | torch.set_default_dtype(dtype) 93 | with device: 94 | model = LlamaForCausalLM(config).to(device) 95 | torch.set_default_dtype(default_dtype) 96 | 97 | torch.manual_seed(0xABCDABCD987) 98 | gc_torch() 99 | res = model_Resources( 100 | config=config, 101 | page_len=page_len, 102 | seqlen=seqlen, 103 | prefills=batch_size * prefill, 104 | decodes=batch_size * decode, 105 | dtype=dtype, 106 | device=device, 107 | ) 108 | setup = dict( 109 | num_heads=num_heads, 110 | head_dim=head_dim, 111 | num_layers=num_layers, 112 | intermediate_size=intermediate_size, 113 | page_len=page_len, 114 | seqlen=seqlen, 115 | prefills=batch_size * prefill, 116 | decodes=batch_size * decode, 117 | batch_size=batch_size, 118 | ) 119 | pbar.set_postfix(setup) 120 | 121 | latency = bench( 122 | lambda: model(res.input_ids, res.blen, res.prefill_kv, res.decode_kv), 123 | warmup=1, 124 | repeat=5, 125 | ) 126 | del res 127 | 128 | result = { 129 | "setup": setup, 130 | "latency": {"avg": latency.avg(), "std": latency.std()}, 131 | } 132 | f.write(json.dumps(result) + "\n") 133 | f.flush() 134 | 135 | 136 | def main(): 137 | this_file = pathlib.Path(__file__) 138 | project_root = this_file.parents[1] 139 | now = datetime.now(pytz.timezone("US/Pacific")) 140 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.jsonl.gz" 141 | out_path = project_root / "data" / out_filename 142 | 143 | print(out_path) 144 | with gzip.open(out_path, "wt") as f: 145 | bench_model_prefill_decode(f) 146 | 147 | 148 | if __name__ == "__main__": 149 | main() 150 | -------------------------------------------------------------------------------- /tests/test_kvcache.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from punica.utils.kvcache import BatchedKvCache, GrowableTensor, KvCache, KvPool 6 | 7 | num_layers = 24 8 | num_heads = 32 9 | head_dim = 128 10 | page_len = 16 11 | dtype = torch.float16 12 | device = torch.device("cuda:0") 13 | 14 | 15 | def test_KvPool(): 16 | pool = KvPool(num_layers, num_heads, head_dim, page_len, dtype, device) 17 | shape = (num_layers, 2, num_heads, page_len, head_dim) 18 | assert pool.num_layers == num_layers 19 | assert pool.num_heads == num_heads 20 | assert pool.head_dim == head_dim 21 | assert pool.page_len == page_len 22 | assert pool.dtype == dtype 23 | assert pool.device == device 24 | assert pool.page_meta.shape == shape 25 | assert pool.page_nbytes == np.prod(shape) * 2 26 | 27 | pages = [] 28 | assert pool.num_pages == 0 29 | for _ in range(10): 30 | page = pool.alloc_page() 31 | assert page.shape == (num_layers, 2, num_heads, page_len, head_dim) 32 | assert page.dtype == dtype 33 | assert page.device == device 34 | pages.append(page) 35 | assert pool.num_pages == len(pages) 36 | 37 | while pages: 38 | page = pages.pop() 39 | pool.free_page(page) 40 | assert pool.num_pages == len(pages) 41 | 42 | 43 | def test_GrowableTensor_empty(): 44 | t = GrowableTensor([], dtype, device) 45 | assert t.view().shape == (0,) 46 | assert t.view().dtype == dtype 47 | assert t.view().device == device 48 | 49 | 50 | def test_GrowableTensor_grow_from_empty(): 51 | t = GrowableTensor([], dtype, device) 52 | for i in range(10): 53 | t.append(i) 54 | assert t.view().shape == (i + 1,) 55 | assert t.view().dtype == dtype 56 | assert t.view().device == device 57 | assert t.view().tolist() == list(range(i + 1)) 58 | 59 | 60 | def test_GrowableTensor_grow_from_data(): 61 | data = [1, 2, 3, 4, 5] 62 | t = GrowableTensor(data, dtype, device) 63 | for i in range(10): 64 | t.append(i) 65 | assert t.view().shape == (len(data) + i + 1,) 66 | assert t.view().dtype == dtype 67 | assert t.view().device == device 68 | assert t.view().tolist() == data + list(range(i + 1)) 69 | 70 | 71 | def test_GrowableTensor_clear(): 72 | data = [1, 2, 3, 4, 5] 73 | t = GrowableTensor(data, dtype, device) 74 | t.clear() 75 | assert t.view().shape == (0,) 76 | assert t.view().dtype == dtype 77 | assert t.view().device == device 78 | assert t.view().tolist() == [] 79 | for i in range(10): 80 | t.append(i) 81 | assert t.view().shape == (i + 1,) 82 | assert t.view().dtype == dtype 83 | assert t.view().device == device 84 | assert t.view().tolist() == list(range(i + 1)) 85 | 86 | 87 | @pytest.fixture 88 | def pool(): 89 | return KvPool(num_layers, num_heads, head_dim, page_len, dtype, device) 90 | 91 | 92 | @pytest.mark.parametrize("init_len", [0, 1, 15, 16, 17, 31, 32, 33]) 93 | def test_KvCache(pool: KvPool, init_len: int): 94 | kvcache = KvCache(pool, init_len) 95 | assert kvcache.pool is pool 96 | assert kvcache.seqlen == init_len 97 | assert kvcache.num_pages == (init_len + page_len - 1) // page_len 98 | assert kvcache.ptrs.shape == (kvcache.num_pages,) 99 | assert kvcache.ptrs.dtype == torch.int64 100 | assert kvcache.ptrs.device == device 101 | assert pool.num_pages == kvcache.num_pages 102 | 103 | for i in range(1, 65): 104 | kvcache.acquire_one() 105 | assert kvcache.seqlen == init_len + i 106 | assert kvcache.num_pages == (kvcache.seqlen + page_len - 1) // page_len 107 | assert kvcache.ptrs.shape == (kvcache.num_pages,) 108 | assert kvcache.ptrs.dtype == torch.int64 109 | assert kvcache.ptrs.device == device 110 | assert pool.num_pages == kvcache.num_pages 111 | 112 | kvcache.release() 113 | assert kvcache.seqlen == 0 114 | assert kvcache.num_pages == 0 115 | assert kvcache.ptrs.shape == (0,) 116 | assert kvcache.ptrs.dtype == torch.int64 117 | assert kvcache.ptrs.device == device 118 | assert pool.num_pages == 0 119 | 120 | 121 | def test_BatchedKvCache(pool: KvPool): 122 | seqlens = [15, 16, 17, 31, 32, 33] 123 | num_pages = [1, 1, 2, 2, 2, 3] 124 | kv_list = [KvCache(pool, seqlen) for seqlen in seqlens] 125 | assert num_pages == [kv.num_pages for kv in kv_list] 126 | assert pool.num_pages == sum(num_pages) 127 | batched_kv = BatchedKvCache(kv_list) 128 | assert batched_kv.pool is pool 129 | assert batched_kv.ptrs.shape == (sum(num_pages),) 130 | assert batched_kv.ptrs.dtype == torch.int64 131 | assert batched_kv.ptrs.device == device 132 | assert batched_kv.indptr.tolist() == [0, 1, 2, 4, 6, 8, 11] 133 | assert batched_kv.last_page_offset.tolist() == [15, 16, 1, 15, 16, 1] 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Punica: Serving multiple LoRA finetuned LLM as one 2 | 3 | [(paper)](https://arxiv.org/abs/2310.18547) 4 | 5 | ## Demo 6 | 7 | [punica-tui-demo.webm](https://github.com/punica-ai/punica/assets/2470081/532c6114-9322-4d53-ae88-1d0f44dc960f) 8 | 9 | ```bash 10 | python examples/tui-multi-lora.py 11 | ``` 12 | 13 | ## Overview 14 | 15 | [Low rank adapation](https://arxiv.org/abs/2106.09685) (LoRA) is a parameter efficient way to add new knowledge to a pretrained LLM. Although the pretrained LLM takes 100s of GB storage, a LoRA finetuned model only adds 1% storage and memory overhead. Punica enables running multiple LoRA finetuned models at the cost of running one. 16 | 17 | How? 18 | 19 | Assuming `W` of shape `[H1, H2]` is the weight of the pretrained model, LoRA adds two small matrices `A` of shape `[H1, r]` and `B` of `[r, H2]`. Running a input `x` on the finetuned model would be `y := x @ (W + A@B)`, which is the same as `y := x@W + x@A@B`. 20 | 21 | When there are `n` LoRA models, there will be `A1`, `B1`, `A2`, `B2`, ..., `An`, `Bn`. Given a input batch `X := (x1,x2,...,xn)` that maps to each LoRA model, the output is `Y := X@W + (x1@A1@B1, x2@A2@B2, ..., xn@An@Bn)`. The left-hand-side computes the input batch on the pretrained model. It is quite efficient. The latency is almost the same as when there's only one input, thanks to the strong [batching effect](https://le.qun.ch/en/blog/2023/05/13/transformer-batching/). 22 | 23 | We figured out an efficient way to compute the right-hand-side (the LoRA addon). We encapsulate this operation in a CUDA kernel, called Segmented Gather Matrix-Vector multiplication (SGMV), as illustrated below. 24 | 25 |

SGMV

26 | 27 | In the following microbenchmark figure, we can observe the strong batching effect of the pretrained model. Naive implementation of LoRA is slow, as depicted in the orange line. LoRA implemented via SGMV is efficient and preserves the strong batching effect. 28 | 29 |

SGMV is fast and maintains strong batching effect

30 | 31 | The following figure shows the text generation throughput comparison between Punica and other systems, including [HuggingFace Transformers](https://github.com/huggingface/transformers/), [DeepSpeed](https://github.com/microsoft/DeepSpeed), [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm). The benchmark considers different settings of LoRA model popularity. *Distinct* means that each request is for a different LoRA model. *Identical* means that all requests are for the same LoRA model. *Uniform* and *Skewed* are in between. **Punica achieves 12x throughput compared to state-of-the-art systems.** 32 | 33 |

Punica achieves 12x throughput compared to state-of-the-art systems

34 | 35 | Read our paper to understand more: [Punica: Multi-Tenant LoRA Serving](https://arxiv.org/abs/2310.18547). 36 | 37 | 38 | ## Installation 39 | 40 | You can install Punica from binary package or build from source. 41 | 42 | ### Install from binary package 43 | 44 | * Pros: No need to compile. Fast to install. 45 | * Cons: Might not match your CUDA version, CUDA architecture, PyTorch version, or Python version. 46 | * Current precompiled versions: 47 | * CUDA: 11.8, 12.1 48 | * Python: 3.10, 3.11 49 | * TORCH_CUDA_ARCH_LIST: `8.0 8.6 8.9+PTX` 50 | 51 | ```bash 52 | pip install ninja torch 53 | pip install punica -i https://punica-ai.github.io/whl/cu121/ --extra-index-url https://pypi.org/simple 54 | # Note: Change cu121 to your CUDA version. 55 | ``` 56 | 57 | ### Build from source 58 | 59 | ```bash 60 | # Please install torch before punica 61 | pip install ninja numpy torch 62 | 63 | # Clone punica 64 | git clone https://github.com/punica-ai/punica.git 65 | cd punica 66 | git submodule sync 67 | git submodule update --init 68 | 69 | # If you encouter problem while compilation, set TORCH_CUDA_ARCH_LIST to your CUDA architecture. 70 | # export TORCH_CUDA_ARCH_LIST="8.0" 71 | 72 | # Build and install punica 73 | pip install -v --no-build-isolation . 74 | ``` 75 | 76 | ## Examples 77 | 78 | ### Serving multiple LoRA models 79 | 80 | See the demo above. 81 | 82 | ### Finetune & convert to Punica format & serve with Punica 83 | 84 | See [`examples/finetune/`](examples/finetune/) 85 | 86 | ### Benchmark text generation 87 | 88 | ```bash 89 | python -m benchmarks.bench_textgen_lora --system punica --batch-size 32 90 | ``` 91 | 92 | 93 | ## Citation 94 | 95 | ```bibtex 96 | @misc{punica, 97 | title={Punica: Multi-Tenant LoRA Serving}, 98 | author={Lequn Chen and Zihao Ye and Yongji Wu and Danyang Zhuo and Luis Ceze and Arvind Krishnamurthy}, 99 | year={2023}, 100 | eprint={2310.18547}, 101 | archivePrefix={arXiv}, 102 | primaryClass={cs.DC} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /benchmarks/bench_layer_lora_decode.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import itertools 3 | import json 4 | import pathlib 5 | from datetime import datetime 6 | 7 | import pytz 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import LlamaConfig 11 | 12 | from punica import ( 13 | BatchedKvCache, 14 | BatchedLlamaLoraWeight, 15 | BatchLenInfo, 16 | KvCache, 17 | KvPool, 18 | LlamaLoraWeight, 19 | ) 20 | from punica.models.llama_lora import LlamaDecoderLayerWithLora 21 | 22 | from .benchmark_utils import bench, gc_torch, get_lora_lens 23 | 24 | 25 | class layer_lora_decode_Resources: 26 | def __init__( 27 | self, 28 | config: LlamaConfig, 29 | page_len: int, 30 | lora_rank: int, 31 | lora_popularity: int, 32 | seqlens: list[int], 33 | dtype: torch.dtype, 34 | device: torch.device, 35 | ): 36 | num_heads = config.num_attention_heads 37 | head_dim = config.hidden_size // config.num_attention_heads 38 | self.kvpool = KvPool( 39 | num_layers=1, 40 | num_heads=num_heads, 41 | head_dim=head_dim, 42 | page_len=page_len, 43 | dtype=dtype, 44 | device=device, 45 | ) 46 | bs = len(seqlens) 47 | self.hidden_states = torch.randn( 48 | (bs, num_heads * head_dim), dtype=dtype, device=device 49 | ) 50 | kv_list: list[KvCache] = [] 51 | for seqlen in seqlens: 52 | kv_list.append(KvCache(self.kvpool, seqlen)) 53 | self.kv_list = kv_list 54 | self.kv = BatchedKvCache(kv_list) 55 | self.blen = BatchLenInfo([], bs, device) 56 | 57 | lora_lens = get_lora_lens(bs, lora_popularity) 58 | self.num_lora_models = len(lora_lens) 59 | weights = [ 60 | LlamaLoraWeight(config, lora_rank, dtype, device) 61 | for _ in range(self.num_lora_models) 62 | ] 63 | self.lora = BatchedLlamaLoraWeight(weights, lora_lens) 64 | 65 | def release(self): 66 | for kvcache in self.kv_list: 67 | kvcache.release() 68 | 69 | 70 | @torch.inference_mode() 71 | def bench_layer_lora_decode(f): 72 | num_heads_ = [32, 40] 73 | intermediate_size_ = [11008, 13824] 74 | pop_ = ["bgmv", "bmm", "uniform", "zipf:1.5"] 75 | batch_size_ = list(range(1, 64)) 76 | seqlen_ = list(reversed(range(2048, 0, -64))) 77 | dtype = torch.float16 78 | device = torch.device("cuda:0") 79 | page_len = 16 80 | lora_rank = 16 81 | head_dim = 128 82 | 83 | all_ = list( 84 | itertools.product( 85 | zip(num_heads_, intermediate_size_), pop_, seqlen_, batch_size_ 86 | ) 87 | ) 88 | last_num_heads = 0 89 | for (num_heads, intermediate_size), pop, seqlen, batch_size in (pbar := tqdm(all_)): 90 | if last_num_heads != num_heads: 91 | config = LlamaConfig( 92 | hidden_size=num_heads * head_dim, 93 | num_attention_heads=num_heads, 94 | intermediate_size=intermediate_size, 95 | num_hidden_layers=1, 96 | ) 97 | default_dtype = torch.get_default_dtype() 98 | torch.set_default_dtype(dtype) 99 | with device: 100 | model = LlamaDecoderLayerWithLora(config, layer_idx=0).to(device) 101 | torch.set_default_dtype(default_dtype) 102 | 103 | torch.manual_seed(0xABCDABCD987) 104 | gc_torch() 105 | res = layer_lora_decode_Resources( 106 | config=config, 107 | page_len=page_len, 108 | lora_rank=lora_rank, 109 | lora_popularity=pop, 110 | seqlens=[seqlen] * batch_size, 111 | dtype=dtype, 112 | device=device, 113 | ) 114 | setup = dict( 115 | num_heads=num_heads, 116 | head_dim=head_dim, 117 | intermediate_size=intermediate_size, 118 | page_len=page_len, 119 | lora_rank=lora_rank, 120 | lora_popularity=pop, 121 | num_lora_models=res.num_lora_models, 122 | seqlen=seqlen, 123 | batch_size=batch_size, 124 | ) 125 | pbar.set_postfix(setup) 126 | 127 | latency = bench( 128 | lambda: model(res.hidden_states, res.blen, None, res.kv, res.lora), 129 | warmup=10, 130 | repeat=50, 131 | ) 132 | res.release() 133 | 134 | result = { 135 | "setup": setup, 136 | "latency": {"avg": latency.avg(), "std": latency.std()}, 137 | } 138 | f.write(json.dumps(result) + "\n") 139 | f.flush() 140 | 141 | 142 | def main(): 143 | this_file = pathlib.Path(__file__) 144 | project_root = this_file.parents[1] 145 | now = datetime.now(pytz.timezone("US/Pacific")) 146 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.jsonl.gz" 147 | out_path = project_root / "data" / out_filename 148 | 149 | print(out_path) 150 | with gzip.open(out_path, "wt") as f: 151 | bench_layer_lora_decode(f) 152 | 153 | 154 | if __name__ == "__main__": 155 | main() 156 | -------------------------------------------------------------------------------- /benchmarks/bench_model_lora_decode.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F821 2 | import gzip 3 | import itertools 4 | import json 5 | import pathlib 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import pytz 10 | import torch 11 | from tqdm import tqdm 12 | from transformers import LlamaConfig 13 | 14 | from punica import ( 15 | BatchedKvCache, 16 | BatchedLlamaLoraWeight, 17 | BatchLenInfo, 18 | KvCache, 19 | KvPool, 20 | LlamaForCausalLMWithLora, 21 | LlamaLoraWeight, 22 | ) 23 | 24 | from .benchmark_utils import bench, gc_torch 25 | 26 | 27 | class model_lora_decode_Resources: 28 | def __init__( 29 | self, 30 | config: LlamaConfig, 31 | page_len: int, 32 | lora_rank: int, 33 | seqlens: list[int], 34 | dtype: torch.dtype, 35 | device: torch.device, 36 | ): 37 | num_heads = config.num_attention_heads 38 | head_dim = config.hidden_size // config.num_attention_heads 39 | self.kvpool = KvPool( 40 | num_layers=config.num_hidden_layers, 41 | num_heads=num_heads, 42 | head_dim=head_dim, 43 | page_len=page_len, 44 | dtype=dtype, 45 | device=device, 46 | ) 47 | bs = len(seqlens) 48 | self.input_ids = torch.randint( 49 | 0, 32000, (bs,), dtype=torch.int64, device=device 50 | ) 51 | kv_list: list[KvCache] = [] 52 | for seqlen in seqlens: 53 | kv_list.append(KvCache(self.kvpool, seqlen)) 54 | self.kv_list = kv_list 55 | self.kv = BatchedKvCache(kv_list) 56 | self.blen = BatchLenInfo([], bs, device) 57 | 58 | num_lora_models = int(np.ceil(np.sqrt(bs))) 59 | self.num_lora_models = num_lora_models 60 | weights = [ 61 | LlamaLoraWeight(config, lora_rank, dtype, device) 62 | for _ in range(num_lora_models) 63 | ] 64 | lens = [int(np.floor(np.sqrt(bs)))] * (num_lora_models - 1) 65 | lens.append(bs - sum(lens)) 66 | self.lora = BatchedLlamaLoraWeight(weights, lens) 67 | 68 | def release(self): 69 | for kvcache in self.kv_list: 70 | kvcache.release() 71 | 72 | 73 | @torch.inference_mode() 74 | def bench_model_lora_decode(f): 75 | num_heads_ = [32, 40] 76 | num_layers_ = [32, 40] 77 | intermediate_size_ = [11008, 13824] 78 | batch_size_ = list(range(1, 37)) 79 | seqlen_ = list(reversed(range(2048, 0, -64))) 80 | dtype = torch.float16 81 | device = torch.device("cuda:0") 82 | page_len = 16 83 | lora_rank = 16 84 | head_dim = 128 85 | 86 | all_ = list( 87 | itertools.product( 88 | zip(num_heads_, num_layers_, intermediate_size_), seqlen_, batch_size_ 89 | ) 90 | ) 91 | last_num_heads = 0 92 | model = None 93 | for (num_heads, num_layers, intermediate_size), seqlen, batch_size in ( 94 | pbar := tqdm(all_) 95 | ): 96 | if last_num_heads != num_heads: 97 | config = LlamaConfig( 98 | hidden_size=num_heads * head_dim, 99 | num_attention_heads=num_heads, 100 | intermediate_size=intermediate_size, 101 | num_hidden_layers=num_layers, 102 | ) 103 | del model 104 | gc_torch() 105 | default_dtype = torch.get_default_dtype() 106 | torch.set_default_dtype(dtype) 107 | with device: 108 | model = LlamaForCausalLMWithLora(config).to(device) 109 | torch.set_default_dtype(default_dtype) 110 | 111 | torch.manual_seed(0xABCDABCD987) 112 | gc_torch() 113 | res = model_lora_decode_Resources( 114 | config=config, 115 | page_len=page_len, 116 | lora_rank=lora_rank, 117 | seqlens=[seqlen] * batch_size, 118 | dtype=dtype, 119 | device=device, 120 | ) 121 | setup = dict( 122 | num_heads=num_heads, 123 | head_dim=head_dim, 124 | num_layers=num_layers, 125 | intermediate_size=intermediate_size, 126 | page_len=page_len, 127 | lora_rank=lora_rank, 128 | num_lora_models=res.num_lora_models, 129 | seqlen=seqlen, 130 | batch_size=batch_size, 131 | ) 132 | pbar.set_postfix(setup) 133 | 134 | latency = bench( 135 | lambda: model(res.input_ids, res.blen, None, res.kv, res.lora), 136 | warmup=1, 137 | repeat=5, 138 | ) 139 | res.release() 140 | del res 141 | 142 | result = { 143 | "setup": setup, 144 | "latency": {"avg": latency.avg(), "std": latency.std()}, 145 | } 146 | f.write(json.dumps(result) + "\n") 147 | f.flush() 148 | 149 | 150 | def main(): 151 | this_file = pathlib.Path(__file__) 152 | project_root = this_file.parents[1] 153 | now = datetime.now(pytz.timezone("US/Pacific")) 154 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.jsonl.gz" 155 | out_path = project_root / "data" / out_filename 156 | 157 | print(out_path) 158 | with gzip.open(out_path, "wt") as f: 159 | bench_model_lora_decode(f) 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /benchmarks/bench_lora_op_impls.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import itertools 3 | import json 4 | import pathlib 5 | from datetime import datetime 6 | 7 | import pytz 8 | import torch 9 | from tqdm import tqdm 10 | 11 | import punica.ops 12 | 13 | from .benchmark_utils import bench, get_lora_lens 14 | 15 | 16 | def lora_loop( 17 | y: torch.Tensor, 18 | x: torch.Tensor, 19 | wa: torch.Tensor, 20 | wb: torch.Tensor, 21 | s: torch.IntTensor, 22 | layer_idx: int, 23 | ): 24 | for i in range(len(wa)): 25 | xi = x[s[i] : s[i + 1]] 26 | wai = wa[i][layer_idx, :, :] 27 | wbi = wb[i][layer_idx, :, :] 28 | y[s[i] : s[i + 1]] += xi @ wai @ wbi 29 | 30 | 31 | def gather( 32 | wa_all: list[torch.Tensor], 33 | wb_all: list[torch.Tensor], 34 | problem_sizes: list[int], 35 | layer_idx: int, 36 | ) -> tuple[torch.Tensor, torch.Tensor]: 37 | wa, wb = [], [] 38 | for i, l in enumerate(problem_sizes): 39 | wa.extend([wa_all[i][layer_idx, :, :]] * l) 40 | wb.extend([wb_all[i][layer_idx, :, :]] * l) 41 | wa = torch.stack(wa) 42 | wb = torch.stack(wb) 43 | return wa, wb 44 | 45 | 46 | def bmm( 47 | y: torch.Tensor, 48 | x: torch.Tensor, 49 | wa: torch.Tensor, 50 | wb: torch.Tensor, 51 | ): 52 | y += (x.unsqueeze(1) @ wa @ wb).squeeze(1) 53 | 54 | 55 | def lora_gbmm( 56 | y: torch.Tensor, 57 | x: torch.Tensor, 58 | wa_all: list[torch.Tensor], 59 | wb_all: list[torch.Tensor], 60 | problem_sizes: list[int], 61 | layer_idx: int, 62 | ): 63 | wa, wb = gather(wa_all, wb_all, problem_sizes, layer_idx) 64 | bmm(y, x, wa, wb) 65 | 66 | 67 | @torch.inference_mode() 68 | def bench_lora_op_impls(f): 69 | bs_ = list(range(1, 65)) 70 | pop_ = ["bmm", "bgmv", "uniform", "zipf:1.5", "Nx8"] 71 | h1 = 4096 72 | h2 = 11008 73 | r = 16 74 | num_layers = 1 75 | dtype = torch.float16 76 | device = torch.device("cuda:0") 77 | 78 | all_ = list(itertools.product(pop_, bs_)) 79 | for pop, bs in (pbar := tqdm(all_)): 80 | if pop == "Nx8": 81 | if bs % 8 != 0: 82 | continue 83 | problem_sizes = [(bs // 8)] * 8 84 | else: 85 | problem_sizes = get_lora_lens(bs, pop) 86 | 87 | setup = dict( 88 | h1=h1, 89 | h2=h2, 90 | r=r, 91 | popularity=pop, 92 | num_problems=len(problem_sizes), 93 | batch_size=bs, 94 | ) 95 | pbar.set_postfix(setup) 96 | 97 | torch.manual_seed(0xABCDABCD987) 98 | wa = [ 99 | torch.rand((num_layers, h1, r), dtype=dtype, device=device) 100 | for _ in range(len(problem_sizes)) 101 | ] 102 | wb = [ 103 | torch.rand((num_layers, r, h2), dtype=dtype, device=device) 104 | for _ in range(len(problem_sizes)) 105 | ] 106 | wa_t = [t.transpose(-1, -2) for t in wa] 107 | wb_t = [t.transpose(-1, -2) for t in wb] 108 | wa_ptr = torch.tensor( 109 | [t.data_ptr() for t in wa_t], dtype=torch.int64, device=device 110 | ) 111 | wb_ptr = torch.tensor( 112 | [t.data_ptr() for t in wb_t], dtype=torch.int64, device=device 113 | ) 114 | s = torch.cumsum( 115 | torch.tensor([0] + problem_sizes, device=device), dim=0, dtype=torch.int32 116 | ) 117 | x = torch.rand((s[-1], h1), dtype=dtype, device=device) 118 | y = torch.rand((s[-1], h2), dtype=dtype, device=device) 119 | gwa, gwb = gather(wa, wb, problem_sizes, layer_idx=0) 120 | 121 | l_loop = bench( 122 | lambda: lora_loop(y, x, wa, wb, s, layer_idx=0), warmup=200, repeat=1000 123 | ) 124 | l_gbmm = bench( 125 | lambda: lora_gbmm(y, x, wa, wb, problem_sizes, layer_idx=0), 126 | warmup=200, 127 | repeat=1000, 128 | ) 129 | l_gather = bench( 130 | lambda: gather(wa, wb, problem_sizes, layer_idx=0), warmup=200, repeat=1000 131 | ) 132 | l_bmm = bench(lambda: bmm(y, x, gwa, gwb), warmup=200, repeat=1000) 133 | l_sgmv = bench( 134 | lambda: punica.ops.add_lora_sgmv( 135 | y, x, wa_ptr, wb_ptr, s, layer_idx=0, lora_rank=r 136 | ), 137 | warmup=200, 138 | repeat=1000, 139 | ) 140 | 141 | result = { 142 | "setup": setup, 143 | "loop": dict(avg=l_loop.avg(), std=l_loop.std()), 144 | "gbmm": dict(avg=l_gbmm.avg(), std=l_gbmm.std()), 145 | "gather": dict(avg=l_gather.avg(), std=l_gather.std()), 146 | "bmm": dict(avg=l_bmm.avg(), std=l_bmm.std()), 147 | "sgmv": dict(avg=l_sgmv.avg(), std=l_sgmv.std()), 148 | } 149 | f.write(json.dumps(result) + "\n") 150 | f.flush() 151 | 152 | 153 | def main(): 154 | this_file = pathlib.Path(__file__) 155 | project_root = this_file.parents[1] 156 | now = datetime.now(pytz.timezone("US/Pacific")) 157 | out_filename = f"{now:%Y%m%d-%H%M%S}-{this_file.stem}.jsonl.gz" 158 | out_path = project_root / "data" / out_filename 159 | 160 | print(out_path) 161 | with gzip.open(out_path, "wt") as f: 162 | bench_lora_op_impls(f) 163 | 164 | 165 | if __name__ == "__main__": 166 | main() 167 | -------------------------------------------------------------------------------- /src/punica/utils/kvcache.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable, Sequence 2 | from typing import Any 3 | 4 | import torch 5 | 6 | 7 | class KvPool: 8 | def __init__( 9 | self, 10 | num_layers: int, 11 | num_heads: int, 12 | head_dim: int, 13 | page_len: int, 14 | dtype: torch.dtype, 15 | device: torch.device, 16 | ): 17 | self.num_layers = num_layers 18 | self.num_heads = num_heads 19 | self.head_dim = head_dim 20 | self.page_len = page_len 21 | self.dtype = dtype 22 | self.device = device 23 | self._page_shape = (num_layers, 2, num_heads, page_len, head_dim) 24 | self._page_meta = torch.empty(self._page_shape, dtype=dtype, device="meta") 25 | self._allocated = set() 26 | 27 | @property 28 | def page_meta(self) -> torch.Tensor: 29 | return self._page_meta 30 | 31 | @property 32 | def page_nbytes(self) -> int: 33 | return self._page_meta.nbytes 34 | 35 | @property 36 | def num_pages(self) -> int: 37 | return len(self._allocated) 38 | 39 | def allocated_pages(self) -> Iterable[torch.Tensor]: 40 | return iter(self._allocated) 41 | 42 | def alloc_page(self) -> torch.Tensor: 43 | page = torch.empty(self._page_shape, dtype=self.dtype, device=self.device) 44 | self._allocated.add(page) 45 | return page 46 | 47 | def free_page(self, page: torch.Tensor) -> None: 48 | self._allocated.remove(page) 49 | 50 | 51 | class GrowableTensor: 52 | def __init__( 53 | self, 54 | data: Sequence[Any], 55 | dtype: torch.dtype | None = None, 56 | device: torch.device | None = None, 57 | ): 58 | self._buf = torch.tensor(data, dtype=dtype, device=device) 59 | self._len = len(data) 60 | 61 | def view(self) -> torch.Tensor: 62 | return self._buf[: self._len] 63 | 64 | def append(self, data: Any) -> None: 65 | self._maybe_grow(self._len + 1) 66 | self._buf[self._len] = data 67 | self._len += 1 68 | 69 | def clear(self) -> None: 70 | self._len = 0 71 | 72 | @staticmethod 73 | def _next_power_of_two(x: int) -> int: 74 | return 1 << (x - 1).bit_length() 75 | 76 | def _maybe_grow(self, capacity: int) -> None: 77 | if self._buf.numel() >= capacity: 78 | return 79 | new_capacity = self._next_power_of_two(capacity) 80 | new_buf = torch.empty( 81 | new_capacity, dtype=self._buf.dtype, device=self._buf.device 82 | ) 83 | new_buf[: self._len] = self._buf 84 | self._buf = new_buf 85 | 86 | 87 | class KvCache: 88 | """Key-value cache for one sequence.""" 89 | 90 | def __init__(self, pool: KvPool, init_len: int): 91 | if init_len < 0: 92 | raise ValueError("init_len must be non-negative") 93 | 94 | self._pool = pool 95 | if init_len > 0: 96 | npages = (init_len + pool.page_len - 1) // pool.page_len 97 | self._pages = [pool.alloc_page() for _ in range(npages)] 98 | self._seqlen = init_len 99 | else: 100 | self._pages = [] 101 | self._seqlen = 0 102 | self._ptrs = GrowableTensor( 103 | [t.data_ptr() for t in self._pages], dtype=torch.int64, device=pool.device 104 | ) 105 | 106 | @property 107 | def pool(self) -> KvPool: 108 | return self._pool 109 | 110 | @property 111 | def seqlen(self) -> int: 112 | return self._seqlen 113 | 114 | @property 115 | def ptrs(self) -> torch.Tensor: 116 | return self._ptrs.view() 117 | 118 | @property 119 | def pages(self) -> Sequence[torch.Tensor]: 120 | return self._pages 121 | 122 | @property 123 | def num_pages(self) -> int: 124 | return len(self._pages) 125 | 126 | def acquire_one(self): 127 | """Reserve space for one more token""" 128 | last_page_offset = (self._seqlen - 1) % self._pool.page_len + 1 129 | if last_page_offset == self._pool.page_len: 130 | self._pages.append(self._pool.alloc_page()) 131 | self._ptrs.append(self._pages[-1].data_ptr()) 132 | self._seqlen += 1 133 | 134 | def release(self): 135 | """Release all pages""" 136 | self._seqlen = 0 137 | for page in self._pages: 138 | self._pool.free_page(page) 139 | self._pages.clear() 140 | self._ptrs.clear() 141 | 142 | 143 | class BatchedKvCache: 144 | """Key-value cache for a batch of sequences.""" 145 | 146 | def __init__(self, kv: Sequence[KvCache]): 147 | assert len(kv) > 0 148 | pool = kv[0].pool 149 | device = pool.device 150 | ptrs = [] 151 | indptr = [0] 152 | last_page_offset = [] 153 | for c in kv: 154 | assert c.pool is pool 155 | assert c.num_pages > 0 156 | ptrs.append(c.ptrs) 157 | indptr.append(indptr[-1] + c.num_pages) 158 | last_page_offset.append((c.seqlen - 1) % pool.page_len + 1) 159 | 160 | self.pool = pool 161 | self.ptrs = torch.cat(ptrs) 162 | self.indptr = torch.tensor(indptr, dtype=torch.int32, device=device) 163 | self.last_page_offset = torch.tensor( 164 | last_page_offset, dtype=torch.int32, device=device 165 | ) 166 | -------------------------------------------------------------------------------- /tests/test_sgmv_cutlass.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import punica.ops 5 | 6 | 7 | def assert_close(a, b): 8 | rtol, atol = { 9 | torch.float16: (5e-3, 5e-3), 10 | torch.bfloat16: (3e-2, 2e-2), 11 | }[a.dtype] 12 | torch.testing.assert_close(a, b, rtol=rtol, atol=atol) 13 | 14 | 15 | def sgmv_ref_impl( 16 | y: torch.Tensor, 17 | x: torch.Tensor, 18 | w: list[torch.Tensor], 19 | s: torch.IntTensor, 20 | layer_idx: int, 21 | ): 22 | for i in range(len(w)): 23 | xi = x[s[i] : s[i + 1]].to(torch.float32) 24 | wi = w[i][layer_idx, :, :].to(torch.float32) 25 | yi = y[s[i] : s[i + 1]].to(torch.float32) 26 | y[s[i] : s[i + 1]] = (yi + xi @ wi).to(y.dtype) 27 | 28 | 29 | def lora_ref_impl( 30 | y: torch.Tensor, 31 | x: torch.Tensor, 32 | wa: torch.Tensor, 33 | wb: torch.Tensor, 34 | s: torch.IntTensor, 35 | layer_idx: int, 36 | ): 37 | for i in range(len(wa)): 38 | xi = x[s[i] : s[i + 1]].to(torch.float32) 39 | wai = wa[i][layer_idx, :, :].to(torch.float32) 40 | wbi = wb[i][layer_idx, :, :].to(torch.float32) 41 | yi = y[s[i] : s[i + 1]].to(torch.float32) 42 | tmp = (xi @ wai).to(y.dtype).to(torch.float32) 43 | y[s[i] : s[i + 1]] = (yi + tmp @ wbi).to(y.dtype) 44 | 45 | 46 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 47 | @pytest.mark.parametrize("h", [4096, 11008]) 48 | @pytest.mark.parametrize("r", [16, 32, 64, 96, 128]) 49 | @pytest.mark.parametrize( 50 | "direction", 51 | [pytest.param("shrink", marks=pytest.mark.xfail(reason="#11")), "expand"], 52 | ) 53 | @pytest.mark.parametrize("batch_setup", ["1x7", "7x1", "3x3"]) 54 | @torch.inference_mode() 55 | def test_sgmv_correctness(dtype_str, h, r, direction, batch_setup): 56 | torch.manual_seed(0xABCDABCD987) 57 | num_problems, problem_size = map(int, batch_setup.split("x")) 58 | num_layers = 5 59 | dtype = getattr(torch, dtype_str) 60 | device = torch.device("cuda:0") 61 | if direction == "shrink": 62 | h1, h2 = h, r 63 | else: 64 | h1, h2 = r, h 65 | 66 | w = [ 67 | torch.randn((num_layers, h1, h2), dtype=dtype, device=device) 68 | for _ in range(num_problems) 69 | ] 70 | w_ptr = torch.tensor([t.data_ptr() for t in w], dtype=torch.int64, device=device) 71 | s = torch.cumsum( 72 | torch.tensor([0] + [problem_size] * num_problems, device=device), 73 | dim=0, 74 | dtype=torch.int32, 75 | ) 76 | x = torch.randn((s[-1], h1), dtype=dtype, device=device) 77 | y = torch.randn((s[-1], h2), dtype=dtype, device=device) 78 | for layer_idx in range(num_layers): 79 | y_ref = y.clone() 80 | sgmv_ref_impl(y_ref, x, w, s, layer_idx) 81 | y_our = y.clone() 82 | punica.ops.sgmv_cutlass(y_our, x, w_ptr, s, layer_idx) 83 | assert_close(y_ref, y_our) 84 | 85 | 86 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 87 | @pytest.mark.parametrize("batch_setup", ["1x7", "7x1", "3x3"]) 88 | @torch.inference_mode() 89 | def test_lora_correctness(dtype_str, batch_setup): 90 | torch.manual_seed(0xABCDABCD987) 91 | num_layers = 5 92 | h1 = 4096 93 | h2 = 11008 94 | r = 16 95 | num_problems, problem_size = map(int, batch_setup.split("x")) 96 | dtype = getattr(torch, dtype_str) 97 | device = torch.device("cuda:0") 98 | 99 | wa = [ 100 | torch.rand((num_layers, h1, r), dtype=dtype, device=device) 101 | for _ in range(num_problems) 102 | ] 103 | wb = [ 104 | torch.rand((num_layers, r, h2), dtype=dtype, device=device) 105 | for _ in range(num_problems) 106 | ] 107 | wa_ptr = torch.tensor([t.data_ptr() for t in wa], dtype=torch.int64, device=device) 108 | wb_ptr = torch.tensor([t.data_ptr() for t in wb], dtype=torch.int64, device=device) 109 | s = torch.cumsum( 110 | torch.tensor([0] + [problem_size] * num_problems, device=device), 111 | dim=0, 112 | dtype=torch.int32, 113 | ) 114 | x = torch.rand((s[-1], h1), dtype=dtype, device=device) 115 | y = torch.rand((s[-1], h2), dtype=dtype, device=device) 116 | 117 | for layer_idx in range(num_layers): 118 | y_ref = y.clone() 119 | lora_ref_impl(y_ref, x, wa, wb, s, layer_idx) 120 | y_our = y.clone() 121 | punica.ops.add_lora_sgmv_cutlass(y_our, x, wa_ptr, wb_ptr, s, layer_idx, r) 122 | assert_close(y_ref, y_our) 123 | 124 | 125 | @pytest.mark.parametrize( 126 | "direction", 127 | [pytest.param("shrink", marks=pytest.mark.xfail(reason="#11")), "expand"], 128 | ) 129 | @torch.inference_mode() 130 | def test_sgmv_cuda_graph(direction): 131 | torch.manual_seed(0xABCDABCD987) 132 | num_problems, problem_size = 13, 5 133 | num_layers = 5 134 | dtype = torch.float16 135 | device = torch.device("cuda:0") 136 | h, r = 11008, 16 137 | if direction == "shrink": 138 | h1, h2 = h, r 139 | else: 140 | h1, h2 = r, h 141 | 142 | w = [ 143 | torch.randn((num_layers, h1, h2), dtype=dtype, device=device) 144 | for _ in range(num_problems) 145 | ] 146 | w_ptr = torch.tensor([t.data_ptr() for t in w], dtype=torch.int64, device=device) 147 | s = torch.cumsum( 148 | torch.tensor([0] + [problem_size] * num_problems, device=device), 149 | dim=0, 150 | dtype=torch.int32, 151 | ) 152 | x = torch.randn((s[-1], h1), dtype=dtype, device=device) 153 | y = torch.randn((s[-1], h2), dtype=dtype, device=device) 154 | for layer_idx in range(num_layers): 155 | y_our = y.clone() 156 | punica.ops.sgmv_cutlass(y_our, x, w_ptr, s, layer_idx) 157 | 158 | y_graph = torch.empty_like(y) 159 | graph = torch.cuda.CUDAGraph() 160 | with torch.cuda.graph(graph): 161 | punica.ops.sgmv_cutlass(y_graph, x, w_ptr, s, layer_idx) 162 | 163 | for _ in range(2): 164 | y_graph.copy_(y.clone()) 165 | graph.replay() 166 | assert (y_graph == y_our).all() 167 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import datetime 3 | import itertools 4 | import os 5 | import pathlib 6 | import platform 7 | import re 8 | import subprocess 9 | 10 | import setuptools 11 | import torch 12 | import torch.utils.cpp_extension as torch_cpp_ext 13 | 14 | root = pathlib.Path(__name__).parent 15 | 16 | 17 | def glob(pattern): 18 | return [str(p) for p in root.glob(pattern)] 19 | 20 | 21 | def remove_unwanted_pytorch_nvcc_flags(): 22 | REMOVE_NVCC_FLAGS = [ 23 | "-D__CUDA_NO_HALF_OPERATORS__", 24 | "-D__CUDA_NO_HALF_CONVERSIONS__", 25 | "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", 26 | "-D__CUDA_NO_HALF2_OPERATORS__", 27 | ] 28 | for flag in REMOVE_NVCC_FLAGS: 29 | with contextlib.suppress(ValueError): 30 | torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) 31 | 32 | 33 | def generate_flashinfer_cu() -> list[str]: 34 | page_sizes = os.environ.get("PUNICA_PAGE_SIZES", "16").split(",") 35 | group_sizes = os.environ.get("PUNICA_GROUP_SIZES", "1,2,4,8").split(",") 36 | head_dims = os.environ.get("PUNICA_HEAD_DIMS", "128").split(",") 37 | page_sizes = [int(x) for x in page_sizes] 38 | group_sizes = [int(x) for x in group_sizes] 39 | head_dims = [int(x) for x in head_dims] 40 | dtypes = {"fp16": "nv_half", "bf16": "nv_bfloat16"} 41 | funcs = ["prefill", "decode"] 42 | prefix = "csrc/flashinfer_adapter/generated" 43 | (root / prefix).mkdir(parents=True, exist_ok=True) 44 | files = [] 45 | 46 | # dispatch.inc 47 | path = root / prefix / "dispatch.inc" 48 | if not path.exists(): 49 | with open(root / prefix / "dispatch.inc", "w") as f: 50 | f.write("#define _DISPATCH_CASES_page_size(...) \\\n") 51 | for x in page_sizes: 52 | f.write(f" _DISPATCH_CASE({x}, PAGE_SIZE, __VA_ARGS__) \\\n") 53 | f.write("// EOL\n") 54 | 55 | f.write("#define _DISPATCH_CASES_group_size(...) \\\n") 56 | for x in group_sizes: 57 | f.write(f" _DISPATCH_CASE({x}, GROUP_SIZE, __VA_ARGS__) \\\n") 58 | f.write("// EOL\n") 59 | 60 | f.write("#define _DISPATCH_CASES_head_dim(...) \\\n") 61 | for x in head_dims: 62 | f.write(f" _DISPATCH_CASE({x}, HEAD_DIM, __VA_ARGS__) \\\n") 63 | f.write("// EOL\n") 64 | 65 | f.write("\n") 66 | 67 | # impl 68 | for func, page_size, group_size, head_dim, dtype in itertools.product( 69 | funcs, page_sizes, group_sizes, head_dims, dtypes 70 | ): 71 | fname = f"batch_{func}_p{page_size}_g{group_size}_h{head_dim}_{dtype}.cu" 72 | files.append(prefix + "/" + fname) 73 | if (root / prefix / fname).exists(): 74 | continue 75 | with open(root / prefix / fname, "w") as f: 76 | f.write('#include "../flashinfer_decl.h"\n\n') 77 | f.write(f'#include "flashinfer/{func}.cuh"\n\n') 78 | f.write( 79 | f"INST_Batch{func.capitalize()}({dtypes[dtype]}, {page_size}, {group_size}, {head_dim})\n" 80 | ) 81 | 82 | return files 83 | 84 | 85 | def get_local_version_suffix() -> str: 86 | if not (root / ".git").is_dir(): 87 | return "" 88 | now = datetime.datetime.now() 89 | git_hash = subprocess.check_output( 90 | ["git", "rev-parse", "--short", "HEAD"], cwd=root, text=True 91 | ).strip() 92 | commit_number = subprocess.check_output( 93 | ["git", "rev-list", "HEAD", "--count"], cwd=root, text=True 94 | ).strip() 95 | dirty = ".dirty" if subprocess.run(["git", "diff", "--quiet"]).returncode else "" 96 | return f"+c{commit_number}.d{now:%Y%m%d}.{git_hash}{dirty}" 97 | 98 | 99 | def get_version() -> str: 100 | version = os.getenv("PUNICA_BUILD_VERSION") 101 | if version is None: 102 | with open(root / "version.txt") as f: 103 | version = f.read().strip() 104 | version += get_local_version_suffix() 105 | return version 106 | 107 | 108 | def get_cuda_version() -> tuple[int, int]: 109 | if torch_cpp_ext.CUDA_HOME is None: 110 | nvcc = "nvcc" 111 | else: 112 | nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc") 113 | txt = subprocess.check_output([nvcc, "--version"], text=True) 114 | major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0]) 115 | return major, minor 116 | 117 | 118 | def generate_build_meta() -> None: 119 | d = {} 120 | version = get_version() 121 | d["cuda_major"], d["cuda_minor"] = get_cuda_version() 122 | d["torch"] = torch.__version__ 123 | d["python"] = platform.python_version() 124 | d["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 125 | with open(root / "src/punica/_build_meta.py", "w") as f: 126 | f.write(f"__version__ = {version!r}\n") 127 | f.write(f"build_meta = {d!r}") 128 | 129 | 130 | if __name__ == "__main__": 131 | remove_unwanted_pytorch_nvcc_flags() 132 | generate_build_meta() 133 | 134 | ext_modules = [] 135 | ext_modules.append( 136 | torch_cpp_ext.CUDAExtension( 137 | name="punica.ops._kernels", 138 | sources=[ 139 | "csrc/punica_ops.cc", 140 | "csrc/bgmv/bgmv_all.cu", 141 | "csrc/flashinfer_adapter/flashinfer_all.cu", 142 | "csrc/rms_norm/rms_norm_cutlass.cu", 143 | "csrc/sgmv/sgmv_cutlass.cu", 144 | "csrc/sgmv_flashinfer/sgmv_all.cu", 145 | ] 146 | + generate_flashinfer_cu(), 147 | include_dirs=[ 148 | str(root.resolve() / "third_party/cutlass/include"), 149 | str(root.resolve() / "third_party/flashinfer/include"), 150 | ], 151 | extra_compile_args={ 152 | "cxx": ["-O3"], 153 | "nvcc": ["-O3"], 154 | }, 155 | ) 156 | ) 157 | 158 | setuptools.setup( 159 | version=get_version(), 160 | ext_modules=ext_modules, 161 | cmdclass={"build_ext": torch_cpp_ext.BuildExtension}, 162 | ) 163 | -------------------------------------------------------------------------------- /examples/textgen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import transformers 5 | 6 | from punica import ( 7 | BatchedKvCache, 8 | BatchLenInfo, 9 | KvCache, 10 | KvPool, 11 | LlamaForCausalLM, 12 | ) 13 | 14 | 15 | class TextGeneration: 16 | def __init__( 17 | self, 18 | input_ids: list[int], 19 | *, 20 | temperature: float, 21 | repetition_penalty: float, 22 | top_p: float, 23 | top_k: int, 24 | max_new_tokens: int, 25 | stop_token_id: int, 26 | ): 27 | self.temperature = temperature 28 | self.repetition_penalty = repetition_penalty 29 | self.top_p = top_p 30 | self.top_k = top_k 31 | self.max_new_tokens = max_new_tokens 32 | self.stop_token_id = stop_token_id 33 | 34 | # Logits processing adapted from: https://github.com/lm-sys/FastChat/blob/bb7ca37c2bfad629ba4751dec188bdcdc2cf0c81/fastchat/serve/inference.py 35 | self.logits_processor = transformers.LogitsProcessorList() 36 | if temperature > 0 and temperature != 1.0: 37 | self.logits_processor.append( 38 | transformers.TemperatureLogitsWarper(temperature) 39 | ) 40 | if repetition_penalty > 1.0: 41 | self.logits_processor.append( 42 | transformers.RepetitionPenaltyLogitsProcessor(repetition_penalty) 43 | ) 44 | if 0 < top_p < 1.0: 45 | self.logits_processor.append(transformers.TopPLogitsWarper(top_p)) 46 | if top_k > 0: 47 | self.logits_processor.append(transformers.TopKLogitsWarper(top_k)) 48 | 49 | self.output_ids = [int(x) for x in input_ids] 50 | self.prompt_len = len(self.output_ids) 51 | 52 | def get_next_token_id(self, logits: torch.Tensor) -> int: 53 | if self.logits_processor: 54 | if self.repetition_penalty > 1.0: 55 | t = torch.as_tensor([self.output_ids], device=logits.device) 56 | else: 57 | t = None 58 | last_token_logits = self.logits_processor(t, logits[-1].unsqueeze(0))[0] 59 | else: 60 | last_token_logits = logits[-1, :] 61 | 62 | if self.temperature <= 0 or self.top_p <= 0: 63 | _, indices = torch.topk(last_token_logits, 2) 64 | else: 65 | probs = torch.softmax(last_token_logits, dim=-1) 66 | indices = torch.multinomial(probs, num_samples=2) 67 | token = int(indices.tolist()[0]) 68 | return token 69 | 70 | def append_token(self, token_id: int): 71 | self.output_ids.append(token_id) 72 | 73 | def is_stop(self) -> int: 74 | if len(self.output_ids) - self.prompt_len >= self.max_new_tokens: 75 | return True 76 | if self.output_ids[-1] == self.stop_token_id: 77 | return True 78 | return False 79 | 80 | 81 | @torch.inference_mode() 82 | def main(): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument("--model", type=str, default="meta-llama/Llama-2-7b-chat-hf") 85 | parser.add_argument("--prompt", default="Give me a 3 day travel plan for Seattle.") 86 | parser.add_argument( 87 | "--template", 88 | default="[INST] <> You are a helpful, respectful and honest assistant. <>\n{prompt} [/INST]\n", 89 | ) 90 | parser.add_argument("--temperature", type=float, default=0.7) 91 | parser.add_argument("--repetition-penalty", type=float, default=1.1) 92 | parser.add_argument("--top-p", type=float, default=0.9) 93 | parser.add_argument("--top-k", type=int, default=-1) 94 | parser.add_argument("--max-new-tokens", type=int, default=2000) 95 | 96 | args = parser.parse_args() 97 | dtype = torch.float16 98 | device = torch.device("cuda:0") 99 | 100 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=True) 101 | model_config = transformers.LlamaConfig.from_pretrained(args.model) 102 | model = LlamaForCausalLM.from_pretrained( 103 | args.model, low_cpu_mem_usage=True, torch_dtype=dtype 104 | ).to(device) 105 | kvpool = KvPool( 106 | num_layers=model_config.num_hidden_layers, 107 | num_heads=model_config.num_attention_heads, 108 | head_dim=model_config.hidden_size // model_config.num_attention_heads, 109 | page_len=16, 110 | dtype=dtype, 111 | device=device, 112 | ) 113 | 114 | prompt = args.template.format(prompt=args.prompt) 115 | input_ids = tokenizer.encode(prompt) 116 | kvcache = KvCache(kvpool, len(input_ids)) 117 | textgen = TextGeneration( 118 | input_ids, 119 | temperature=args.temperature, 120 | repetition_penalty=args.repetition_penalty, 121 | top_p=args.top_p, 122 | top_k=args.top_k, 123 | max_new_tokens=args.max_new_tokens, 124 | stop_token_id=tokenizer.eos_token_id, 125 | ) 126 | 127 | # Print prompt 128 | text = tokenizer.decode( 129 | input_ids, 130 | skip_special_tokens=True, 131 | spaces_between_special_tokens=False, 132 | clean_up_tokenization_spaces=True, 133 | ) 134 | print(text, end="", flush=True) 135 | last_print_len = len(text) 136 | 137 | # Prefill 138 | logits, _ = model( 139 | input_ids=torch.tensor(input_ids, dtype=torch.long, device=device), 140 | blen=BatchLenInfo([len(input_ids)], 0, device), 141 | prefill_kv=BatchedKvCache([kvcache]), 142 | decode_kv=None, 143 | ) 144 | next_token_id = textgen.get_next_token_id(logits) 145 | textgen.append_token(next_token_id) 146 | 147 | # Decode 148 | while not textgen.is_stop(): 149 | kvcache.acquire_one() 150 | logits, _ = model( 151 | input_ids=torch.tensor([next_token_id], dtype=torch.long, device=device), 152 | blen=BatchLenInfo([], 1, device), 153 | prefill_kv=None, 154 | decode_kv=BatchedKvCache([kvcache]), 155 | ) 156 | next_token_id = textgen.get_next_token_id(logits) 157 | textgen.append_token(next_token_id) 158 | 159 | text = tokenizer.decode( 160 | textgen.output_ids, 161 | skip_special_tokens=True, 162 | spaces_between_special_tokens=False, 163 | clean_up_tokenization_spaces=True, 164 | ) 165 | print(text[last_print_len:], end="", flush=True) 166 | last_print_len = len(text) 167 | 168 | kvcache.release() 169 | print() 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /examples/textgen_lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import transformers 5 | 6 | from punica import ( 7 | BatchedKvCache, 8 | BatchedLlamaLoraWeight, 9 | BatchLenInfo, 10 | KvCache, 11 | KvPool, 12 | LlamaForCausalLMWithLora, 13 | LlamaLoraWeight, 14 | ) 15 | 16 | 17 | class TextGeneration: 18 | def __init__( 19 | self, 20 | input_ids: list[int], 21 | *, 22 | temperature: float, 23 | repetition_penalty: float, 24 | top_p: float, 25 | top_k: int, 26 | max_new_tokens: int, 27 | stop_token_id: int, 28 | ): 29 | self.temperature = temperature 30 | self.repetition_penalty = repetition_penalty 31 | self.top_p = top_p 32 | self.top_k = top_k 33 | self.max_new_tokens = max_new_tokens 34 | self.stop_token_id = stop_token_id 35 | 36 | # Logits processing adapted from: https://github.com/lm-sys/FastChat/blob/bb7ca37c2bfad629ba4751dec188bdcdc2cf0c81/fastchat/serve/inference.py 37 | self.logits_processor = transformers.LogitsProcessorList() 38 | if temperature > 0 and temperature != 1.0: 39 | self.logits_processor.append( 40 | transformers.TemperatureLogitsWarper(temperature) 41 | ) 42 | if repetition_penalty > 1.0: 43 | self.logits_processor.append( 44 | transformers.RepetitionPenaltyLogitsProcessor(repetition_penalty) 45 | ) 46 | if 0 < top_p < 1.0: 47 | self.logits_processor.append(transformers.TopPLogitsWarper(top_p)) 48 | if top_k > 0: 49 | self.logits_processor.append(transformers.TopKLogitsWarper(top_k)) 50 | 51 | self.output_ids = [int(x) for x in input_ids] 52 | self.prompt_len = len(self.output_ids) 53 | 54 | def get_next_token_id(self, logits: torch.Tensor) -> int: 55 | if self.logits_processor: 56 | if self.repetition_penalty > 1.0: 57 | t = torch.as_tensor([self.output_ids], device=logits.device) 58 | else: 59 | t = None 60 | last_token_logits = self.logits_processor(t, logits[-1].unsqueeze(0))[0] 61 | else: 62 | last_token_logits = logits[-1, :] 63 | 64 | if self.temperature <= 0 or self.top_p <= 0: 65 | _, indices = torch.topk(last_token_logits, 2) 66 | else: 67 | probs = torch.softmax(last_token_logits, dim=-1) 68 | indices = torch.multinomial(probs, num_samples=2) 69 | token = int(indices.tolist()[0]) 70 | return token 71 | 72 | def append_token(self, token_id: int): 73 | self.output_ids.append(token_id) 74 | 75 | def is_stop(self) -> int: 76 | if len(self.output_ids) - self.prompt_len >= self.max_new_tokens: 77 | return True 78 | if self.output_ids[-1] == self.stop_token_id: 79 | return True 80 | return False 81 | 82 | 83 | @torch.inference_mode() 84 | def main(): 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--base-model", type=str, default="meta-llama/Llama-2-7b-hf") 87 | parser.add_argument("--lora-weight", required=True) 88 | parser.add_argument("--prompt", required=True) 89 | parser.add_argument("--temperature", type=float, default=0.9) 90 | parser.add_argument("--repetition-penalty", type=float, default=1.1) 91 | parser.add_argument("--top-p", type=float, default=0.9) 92 | parser.add_argument("--top-k", type=int, default=-1) 93 | parser.add_argument("--max-new-tokens", type=int, default=500) 94 | 95 | args = parser.parse_args() 96 | dtype = torch.float16 97 | device = torch.device("cuda:0") 98 | 99 | tokenizer = transformers.AutoTokenizer.from_pretrained( 100 | args.base_model, use_fast=True 101 | ) 102 | model_config = transformers.LlamaConfig.from_pretrained(args.base_model) 103 | model = LlamaForCausalLMWithLora.from_pretrained( 104 | args.base_model, low_cpu_mem_usage=True, torch_dtype=dtype 105 | ).to(device) 106 | kvpool = KvPool( 107 | num_layers=model_config.num_hidden_layers, 108 | num_heads=model_config.num_attention_heads, 109 | head_dim=model_config.hidden_size // model_config.num_attention_heads, 110 | page_len=16, 111 | dtype=dtype, 112 | device=device, 113 | ) 114 | tmp = torch.load(args.lora_weight, map_location=device, weights_only=True) 115 | lora_rank = tmp["q.A"].size(1) 116 | lora_weight = LlamaLoraWeight(model_config, lora_rank, dtype, device) 117 | lora_weight.copy_from_tensors(tmp) 118 | del tmp 119 | 120 | input_ids = tokenizer.encode(args.prompt) 121 | kvcache = KvCache(kvpool, len(input_ids)) 122 | textgen = TextGeneration( 123 | input_ids, 124 | temperature=args.temperature, 125 | repetition_penalty=args.repetition_penalty, 126 | top_p=args.top_p, 127 | top_k=args.top_k, 128 | max_new_tokens=args.max_new_tokens, 129 | stop_token_id=tokenizer.eos_token_id, 130 | ) 131 | 132 | # Print prompt 133 | text = tokenizer.decode( 134 | input_ids, 135 | skip_special_tokens=True, 136 | spaces_between_special_tokens=False, 137 | clean_up_tokenization_spaces=True, 138 | ) 139 | print(text, end="", flush=True) 140 | last_print_len = len(text) 141 | 142 | # Prefill 143 | logits, _ = model( 144 | input_ids=torch.tensor(input_ids, dtype=torch.long, device=device), 145 | blen=BatchLenInfo([len(input_ids)], 0, device), 146 | prefill_kv=BatchedKvCache([kvcache]), 147 | decode_kv=None, 148 | lora=BatchedLlamaLoraWeight([lora_weight], [len(input_ids)]), 149 | ) 150 | next_token_id = textgen.get_next_token_id(logits) 151 | textgen.append_token(next_token_id) 152 | 153 | # Decode 154 | while not textgen.is_stop(): 155 | kvcache.acquire_one() 156 | logits, _ = model( 157 | input_ids=torch.tensor([next_token_id], dtype=torch.long, device=device), 158 | blen=BatchLenInfo([], 1, device), 159 | prefill_kv=None, 160 | decode_kv=BatchedKvCache([kvcache]), 161 | lora=BatchedLlamaLoraWeight([lora_weight], [1]), 162 | ) 163 | next_token_id = textgen.get_next_token_id(logits) 164 | textgen.append_token(next_token_id) 165 | 166 | text = tokenizer.decode( 167 | textgen.output_ids, 168 | skip_special_tokens=True, 169 | spaces_between_special_tokens=False, 170 | clean_up_tokenization_spaces=True, 171 | ) 172 | print(text[last_print_len:], end="", flush=True) 173 | last_print_len = len(text) 174 | 175 | kvcache.release() 176 | print() 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /tests/test_sgmv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | import punica.ops 6 | 7 | 8 | def assert_close(a, b): 9 | rtol, atol = { 10 | torch.float16: (5e-3, 5e-3), 11 | torch.bfloat16: (3e-2, 2e-2), 12 | }[a.dtype] 13 | torch.testing.assert_close(a, b, rtol=rtol, atol=atol) 14 | 15 | 16 | def sgmv_ref_impl( 17 | y: torch.Tensor, 18 | x: torch.Tensor, 19 | w: list[torch.Tensor], 20 | s: torch.Tensor, 21 | layer_idx: int, 22 | ): 23 | for i in range(len(w)): 24 | xi = x[s[i] : s[i + 1]].to(torch.float32) 25 | wi = w[i][layer_idx, :, :].T.to(torch.float32) 26 | yi = y[s[i] : s[i + 1]].to(torch.float32) 27 | y[s[i] : s[i + 1]] = (yi + xi @ wi).to(y.dtype) 28 | 29 | 30 | def get_lora_lens(bs: int, popularity: str) -> list[int]: 31 | if popularity == "identical": 32 | return [bs] 33 | if popularity == "distinct": 34 | return [1] * bs 35 | if popularity == "uniform": 36 | n = int(np.ceil(np.sqrt(bs))) 37 | lens = np.array([bs // n] * n) 38 | while True: 39 | diff = bs - lens.sum() 40 | if diff == 0: 41 | break 42 | lens[: abs(diff)] += np.sign(diff) 43 | return lens.tolist() 44 | if popularity.startswith("zipf:"): 45 | alpha = float(popularity.split(":")[1]) 46 | assert alpha > 1 47 | lens = [] 48 | a = 1 49 | while sum(lens) + int(np.floor(a)) < bs: 50 | lens.append(int(np.floor(a))) 51 | a *= alpha 52 | lens.append(bs - sum(lens)) 53 | return sorted(lens, reverse=True) 54 | if popularity.startswith("skewed"): 55 | if bs < 3: 56 | return [bs] 57 | # Create a highly imbalanced distribution by setting the first segment 58 | # length to 1 and the remainder to the second segment. 59 | return [1, bs - 1] 60 | raise KeyError(popularity) 61 | 62 | 63 | def lora_ref_impl( 64 | y: torch.Tensor, 65 | x: torch.Tensor, 66 | wa: torch.Tensor, 67 | wb: torch.Tensor, 68 | s: torch.IntTensor, 69 | layer_idx: int, 70 | ): 71 | for i in range(len(wa)): 72 | xi = x[s[i] : s[i + 1]].to(torch.float32) 73 | wai = wa[i][layer_idx, :, :].to(torch.float32) 74 | wbi = wb[i][layer_idx, :, :].to(torch.float32) 75 | yi = y[s[i] : s[i + 1]].to(torch.float32) 76 | tmp = (xi @ wai).to(y.dtype).to(torch.float32) 77 | y[s[i] : s[i + 1]] = (yi + tmp @ wbi).to(y.dtype) 78 | 79 | 80 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 81 | @pytest.mark.parametrize("h", [4096, 11008]) 82 | @pytest.mark.parametrize("r", [16, 32, 64, 96, 128]) 83 | @pytest.mark.parametrize( 84 | "direction", 85 | [ 86 | "shrink", 87 | pytest.param("expand", marks=pytest.mark.xfail(reason="TODO: sgmv expand")), 88 | ], 89 | ) 90 | @pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical", "skewed"]) 91 | @pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 7, 10, 16, 32, 64, 133]) 92 | @torch.inference_mode() 93 | def test_sgmv_correctness(dtype_str, h, r, direction, popularity, batch_size): 94 | torch.manual_seed(0xABCDABCD987) 95 | seqlens = get_lora_lens(batch_size, popularity) 96 | num_layers = 5 97 | dtype = getattr(torch, dtype_str) 98 | device = torch.device("cuda:0") 99 | if direction == "shrink": 100 | h1, h2 = h, r 101 | else: 102 | h1, h2 = r, h 103 | 104 | w = [ 105 | torch.randn((num_layers, h2, h1), dtype=dtype, device=device) 106 | for _ in range(len(seqlens)) 107 | ] 108 | w_ptr = torch.tensor([t.data_ptr() for t in w], dtype=torch.int64, device=device) 109 | s = torch.cumsum( 110 | torch.tensor([0] + seqlens, device=device), 111 | dim=0, 112 | dtype=torch.int32, 113 | ) 114 | x = torch.randn((int(s[-1]), h1), dtype=dtype, device=device) 115 | y = torch.randn((int(s[-1]), h2), dtype=dtype, device=device) 116 | for layer_idx in range(num_layers): 117 | y_ref = y.clone() 118 | sgmv_ref_impl(y_ref, x, w, s, layer_idx) 119 | y_our = y.clone() 120 | punica.ops.sgmv(y_our, x, w_ptr, s, layer_idx) 121 | assert_close(y_ref, y_our) 122 | 123 | 124 | @pytest.mark.xfail(reason="TODO: sgmv expand") 125 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 126 | @pytest.mark.parametrize("batch_setup", ["1x7", "7x1", "3x3", "32x1", "1x32"]) 127 | @torch.inference_mode() 128 | def test_lora_correctness(dtype_str, batch_setup): 129 | torch.manual_seed(0xABCDABCD987) 130 | num_layers = 5 131 | h1 = 4096 132 | h2 = 11008 133 | r = 16 134 | num_problems, problem_size = map(int, batch_setup.split("x")) 135 | dtype = getattr(torch, dtype_str) 136 | device = torch.device("cuda:0") 137 | 138 | wa = [ 139 | torch.rand((num_layers, h1, r), dtype=dtype, device=device) 140 | for _ in range(num_problems) 141 | ] 142 | wb = [ 143 | torch.rand((num_layers, r, h2), dtype=dtype, device=device) 144 | for _ in range(num_problems) 145 | ] 146 | wa_ptr = torch.tensor([t.data_ptr() for t in wa], dtype=torch.int64, device=device) 147 | wb_ptr = torch.tensor([t.data_ptr() for t in wb], dtype=torch.int64, device=device) 148 | s = torch.cumsum( 149 | torch.tensor([0] + [problem_size] * num_problems, device=device), 150 | dim=0, 151 | dtype=torch.int32, 152 | ) 153 | x = torch.rand((s[-1], h1), dtype=dtype, device=device) 154 | y = torch.rand((s[-1], h2), dtype=dtype, device=device) 155 | 156 | for layer_idx in range(num_layers): 157 | y_ref = y.clone() 158 | lora_ref_impl(y_ref, x, wa, wb, s, layer_idx) 159 | y_our = y.clone() 160 | punica.ops.add_lora_bgmv(y_our, x, wa_ptr, wb_ptr, s, layer_idx, r) 161 | assert_close(y_ref, y_our) 162 | 163 | 164 | @pytest.mark.parametrize( 165 | "direction", 166 | [ 167 | "shrink", 168 | pytest.param("expand", marks=pytest.mark.xfail(reason="TODO: sgmv expand")), 169 | ], 170 | ) 171 | @torch.inference_mode() 172 | def test_sgmv_cuda_graph(direction): 173 | torch.manual_seed(0xABCDABCD987) 174 | batch_size = 133 175 | popularity = "zipf:1.5" 176 | seqlens = get_lora_lens(batch_size, popularity) 177 | num_layers = 5 178 | dtype = torch.float16 179 | device = torch.device("cuda:0") 180 | h, r = 11008, 16 181 | if direction == "shrink": 182 | h1, h2 = h, r 183 | else: 184 | h1, h2 = r, h 185 | 186 | w = [ 187 | torch.randn((num_layers, h2, h1), dtype=dtype, device=device) 188 | for _ in range(len(seqlens)) 189 | ] 190 | w_ptr = torch.tensor([t.data_ptr() for t in w], dtype=torch.int64, device=device) 191 | s = torch.cumsum( 192 | torch.tensor([0] + seqlens, device=device), 193 | dim=0, 194 | dtype=torch.int32, 195 | ) 196 | x = torch.randn((int(s[-1]), h1), dtype=dtype, device=device) 197 | y = torch.randn((int(s[-1]), h2), dtype=dtype, device=device) 198 | for layer_idx in range(num_layers): 199 | y_our = y.clone() 200 | punica.ops.sgmv(y_our, x, w_ptr, s, layer_idx) 201 | 202 | y_graph = torch.empty_like(y) 203 | graph = torch.cuda.CUDAGraph() 204 | with torch.cuda.graph(graph): 205 | punica.ops.sgmv(y_graph, x, w_ptr, s, layer_idx) 206 | 207 | for _ in range(2): 208 | y_graph.copy_(y.clone()) 209 | graph.replay() 210 | assert (y_graph == y_our).all() 211 | -------------------------------------------------------------------------------- /csrc/sgmv/sgmv_cutlass.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "cutlass/cutlass.h" 10 | #include "cutlass/gemm/device/gemm_grouped.h" 11 | #include "cutlass/gemm/kernel/default_gemm_grouped.h" 12 | #include "cutlass/layout/matrix.h" 13 | #include "cutlass/numeric_types.h" 14 | 15 | template 16 | struct cutlass_dtype { 17 | using type = T; 18 | }; 19 | 20 | template <> 21 | struct cutlass_dtype { 22 | using type = cutlass::half_t; 23 | }; 24 | 25 | template <> 26 | struct cutlass_dtype { 27 | using type = cutlass::bfloat16_t; 28 | }; 29 | 30 | template 31 | __global__ void precompute_sgmv_args(cutlass::gemm::GemmCoord *all_problems, 32 | T **ptr_y, T **ptr_x, T **ptr_w, 33 | int64_t *ld_y, int64_t *ld_x, 34 | int64_t *ld_w, T *y, T *x, T **w, 35 | int32_t *s, int d_in, int d_out, 36 | int layer_idx) { 37 | int i = blockIdx.x; 38 | int m = s[i + 1] - s[i], k = d_in, n = d_out; 39 | all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); 40 | ptr_w[i] = w[i] + layer_idx * d_in * d_out; 41 | ptr_x[i] = x + s[i] * d_in; 42 | ptr_y[i] = y + s[i] * d_out; 43 | ld_x[i] = k; 44 | ld_w[i] = n; 45 | ld_y[i] = n; 46 | } 47 | 48 | size_t sgmv_tmp_size(int num_problems) { 49 | constexpr auto sz = sizeof(void *) * 3 + sizeof(int64_t) * 3 + 50 | sizeof(cutlass::gemm::GemmCoord); 51 | return sz * num_problems; 52 | } 53 | 54 | template 55 | inline T *alloc_from_buf(void **buf, int n) { 56 | auto *p = (T *)*buf; 57 | *buf = (void *)(p + n); 58 | return p; 59 | } 60 | 61 | template 62 | bool sgmv(DType *y, DType *x, DType **w, int32_t *s, void *tmp_d, 63 | int num_problems, int d_in, int d_out, int layer_idx, 64 | cudaStream_t stream) { 65 | using cutlass_t = typename cutlass_dtype::type; 66 | 67 | auto ptr_Y = alloc_from_buf(&tmp_d, num_problems); 68 | auto ptr_X = alloc_from_buf(&tmp_d, num_problems); 69 | auto ptr_W = alloc_from_buf(&tmp_d, num_problems); 70 | auto ld_Y = alloc_from_buf(&tmp_d, num_problems); 71 | auto ld_X = alloc_from_buf(&tmp_d, num_problems); 72 | auto ld_W = alloc_from_buf(&tmp_d, num_problems); 73 | auto all_problems = 74 | alloc_from_buf(&tmp_d, num_problems); 75 | 76 | precompute_sgmv_args<<>>( 77 | all_problems, ptr_Y, ptr_X, ptr_W, ld_Y, ld_X, ld_W, (cutlass_t *)y, 78 | (cutlass_t *)x, (cutlass_t **)w, s, d_in, d_out, layer_idx); 79 | 80 | using cutlass::epilogue::thread::LinearCombination; 81 | using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; 82 | if (d_in < d_out) { 83 | // Expand 84 | using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< 85 | cutlass_t, // Element A 86 | cutlass::layout::RowMajor, // Layout A 87 | cutlass::ComplexTransform::kNone, // 88 | 8, // Granularity A 89 | cutlass_t, // Element B 90 | cutlass::layout::RowMajor, // Layout B 91 | cutlass::ComplexTransform::kNone, // 92 | 8, // Granularity B 93 | cutlass_t, // Element C&D 94 | cutlass::layout::RowMajor, // Layout C&D 95 | float, // Element Accumulator 96 | cutlass::arch::OpClassTensorOp, // Operator Class Tag 97 | cutlass::arch::Sm80, // Architecture 98 | cutlass::gemm::GemmShape<32, 128, 16>, // Thread Block Shape 99 | cutlass::gemm::GemmShape<32, 64, 16>, // Warp Shape 100 | cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape 101 | LinearCombination, // Epilogue 102 | GemmIdentityThreadblockSwizzle<1>, // Swizzling Operator 103 | 2 // Stages 104 | >::GemmKernel; 105 | 106 | using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; 107 | typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); 108 | 109 | using GemmGrouped = cutlass::gemm::device::GemmGrouped; 110 | typename GemmGrouped::Arguments args(all_problems, num_problems, 512, 111 | epilogue_op, ptr_X, ptr_W, ptr_Y, 112 | ptr_Y, ld_X, ld_W, ld_Y, ld_Y); 113 | 114 | GemmGrouped gemm; 115 | auto status = gemm.initialize(args, nullptr, stream); 116 | if (status != cutlass::Status::kSuccess) { 117 | fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n", 118 | cutlassGetStatusString(status)); 119 | return false; 120 | } 121 | status = gemm.run(stream); 122 | if (status != cutlass::Status::kSuccess) { 123 | fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n", 124 | cutlassGetStatusString(status)); 125 | return false; 126 | } 127 | } else { 128 | // Shrink 129 | using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< 130 | cutlass_t, // Element A 131 | cutlass::layout::RowMajor, // Layout A 132 | cutlass::ComplexTransform::kNone, // 133 | 8, // Granularity A 134 | cutlass_t, // Element B 135 | cutlass::layout::RowMajor, // Layout B 136 | cutlass::ComplexTransform::kNone, // 137 | 8, // Granularity B 138 | cutlass_t, // Element C&D 139 | cutlass::layout::RowMajor, // Layout C&D 140 | float, // Element Accumulator 141 | cutlass::arch::OpClassTensorOp, // Operator Class Tag 142 | cutlass::arch::Sm80, // Architecture 143 | cutlass::gemm::GemmShape<16, 64, 64>, // Thread Block Shape 144 | cutlass::gemm::GemmShape<16, 16, 64>, // Warp Shape 145 | cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape 146 | LinearCombination, // Epilogue 147 | GemmIdentityThreadblockSwizzle<2>, // Swizzling Operator 148 | 2 // Stages 149 | >::GemmKernel; 150 | 151 | using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; 152 | typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); 153 | 154 | using GemmGrouped = cutlass::gemm::device::GemmGrouped; 155 | typename GemmGrouped::Arguments args(all_problems, num_problems, 512, 156 | epilogue_op, ptr_X, ptr_W, ptr_Y, 157 | ptr_Y, ld_X, ld_W, ld_Y, ld_Y); 158 | 159 | GemmGrouped gemm; 160 | auto status = gemm.initialize(args, nullptr, stream); 161 | if (status != cutlass::Status::kSuccess) { 162 | fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n", 163 | cutlassGetStatusString(status)); 164 | return false; 165 | } 166 | status = gemm.run(stream); 167 | if (status != cutlass::Status::kSuccess) { 168 | fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n", 169 | cutlassGetStatusString(status)); 170 | return false; 171 | } 172 | } 173 | return true; 174 | } 175 | -------------------------------------------------------------------------------- /csrc/rms_norm/rms_norm_cutlass.cu: -------------------------------------------------------------------------------- 1 | // Adapted from cutlass 2 | // https://github.com/NVIDIA/cutlass/blob/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616/tools/util/include/cutlass/util/device_rmsnorm.h 3 | /****************************************************************************** 4 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 5 | * reserved. SPDX-License-Identifier: BSD-3-Clause 6 | * 7 | * Redistribution and use in source and binary forms, with or without 8 | * modification, are permitted provided that the following conditions are met: 9 | * 10 | * 1. Redistributions of source code must retain the above copyright notice, 11 | * this list of conditions and the following disclaimer. 12 | * 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * 3. Neither the name of the copyright holder nor the names of its 18 | * contributors may be used to endorse or promote products derived from 19 | * this software without specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | * POSSIBILITY OF SUCH DAMAGE. 32 | * 33 | ******************************************************************************/ 34 | 35 | #include 36 | #include 37 | #include 38 | 39 | #include 40 | #include 41 | 42 | template 43 | __inline__ __device__ T warpReduceSum(T *val) { 44 | #pragma unroll 45 | for (int i = 0; i < NUM; i++) { 46 | #pragma unroll 47 | for (int mask = 16; mask > 0; mask >>= 1) 48 | val[i] += __shfl_xor_sync(0xffffffff, val[i], mask, 32); 49 | } 50 | return (T)(0.0f); 51 | } 52 | 53 | template 54 | __inline__ __device__ T blockReduceSum(T *val) { 55 | __shared__ T shared[NUM][33]; 56 | int lane = threadIdx.x & 0x1f; 57 | int wid = threadIdx.x >> 5; 58 | 59 | warpReduceSum(val); 60 | 61 | if (lane == 0) { 62 | #pragma unroll 63 | for (int i = 0; i < NUM; i++) { 64 | shared[i][wid] = val[i]; 65 | } 66 | } 67 | 68 | __syncthreads(); 69 | 70 | bool is_mask = threadIdx.x < (blockDim.x / 32.f); 71 | #pragma unroll 72 | for (int i = 0; i < NUM; i++) { 73 | val[i] = is_mask ? shared[i][lane] : (T)(0.0f); 74 | } 75 | warpReduceSum(val); 76 | return (T)0.0f; 77 | } 78 | 79 | template 80 | __global__ void rmsnorm_twoPassAlgo_e8(float4 *__restrict__ output, 81 | const float4 *__restrict__ input, 82 | const float4 *__restrict__ weight, int m, 83 | int n, float epsilon) { 84 | const int m_idx = blockIdx.x; 85 | const int tid = threadIdx.x; 86 | const int bdimx = blockDim.x; 87 | __shared__ float s_mean; 88 | float local_sums[1] = {0.0f}; 89 | const int n_8 = n / 8; 90 | int offset = m_idx * n_8; 91 | input += offset; 92 | output += offset; 93 | 94 | for (int index = tid; index < n_8; index += bdimx) { 95 | const float4 local_val = input[index]; 96 | const half2 *h1 = (half2 *)&local_val.x; 97 | const half2 *h2 = (half2 *)&local_val.y; 98 | const half2 *h3 = (half2 *)&local_val.z; 99 | const half2 *h4 = (half2 *)&local_val.w; 100 | local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + 101 | static_cast(h1->y) * static_cast(h1->y) + 102 | static_cast(h2->x) * static_cast(h2->x) + 103 | static_cast(h2->y) * static_cast(h2->y) + 104 | static_cast(h3->x) * static_cast(h3->x) + 105 | static_cast(h3->y) * static_cast(h3->y) + 106 | static_cast(h4->x) * static_cast(h4->x) + 107 | static_cast(h4->y) * static_cast(h4->y); 108 | } 109 | 110 | blockReduceSum(local_sums); 111 | if (threadIdx.x == 0) { 112 | s_mean = rsqrtf(local_sums[0] / n + epsilon); 113 | } 114 | __syncthreads(); 115 | 116 | for (int index = tid; index < n_8; index += bdimx) { 117 | const float4 local_val = input[index]; 118 | const float4 weight_val = weight[index]; 119 | 120 | const half2 *l1 = (half2 *)&local_val.x; 121 | const half2 *l2 = (half2 *)&local_val.y; 122 | const half2 *l3 = (half2 *)&local_val.z; 123 | const half2 *l4 = (half2 *)&local_val.w; 124 | 125 | const half2 *g1 = (half2 *)&weight_val.x; 126 | const half2 *g2 = (half2 *)&weight_val.y; 127 | const half2 *g3 = (half2 *)&weight_val.z; 128 | const half2 *g4 = (half2 *)&weight_val.w; 129 | 130 | float4 tmp; 131 | half2 *h1 = (half2 *)&tmp.x; 132 | half2 *h2 = (half2 *)&tmp.y; 133 | half2 *h3 = (half2 *)&tmp.z; 134 | half2 *h4 = (half2 *)&tmp.w; 135 | 136 | h1->x = static_cast(static_cast(l1->x) * s_mean * 137 | static_cast(g1->x)); 138 | h1->y = static_cast(static_cast(l1->y) * s_mean * 139 | static_cast(g1->y)); 140 | h2->x = static_cast(static_cast(l2->x) * s_mean * 141 | static_cast(g2->x)); 142 | h2->y = static_cast(static_cast(l2->y) * s_mean * 143 | static_cast(g2->y)); 144 | h3->x = static_cast(static_cast(l3->x) * s_mean * 145 | static_cast(g3->x)); 146 | h3->y = static_cast(static_cast(l3->y) * s_mean * 147 | static_cast(g3->y)); 148 | h4->x = static_cast(static_cast(l4->x) * s_mean * 149 | static_cast(g4->x)); 150 | h4->y = static_cast(static_cast(l4->y) * s_mean * 151 | static_cast(g4->y)); 152 | 153 | output[index] = tmp; 154 | } 155 | } 156 | 157 | template 158 | bool rms_norm(T *__restrict__ output, const T *__restrict__ input, 159 | const T *__restrict__ weight, int rows, int columns, 160 | float epsilon) { 161 | if (columns % 8 != 0) { 162 | return false; 163 | } 164 | 165 | dim3 grid(rows); 166 | dim3 block(std::min(1024, (columns / 8 + 31) / 32 * 32)); 167 | 168 | if (std::is_same::value) { 169 | rmsnorm_twoPassAlgo_e8 170 | <<>>((float4 *)output, (float4 *)input, (float4 *)weight, 171 | rows, columns, epsilon); 172 | return true; 173 | } else if (std::is_same::value) { 174 | rmsnorm_twoPassAlgo_e8 175 | <<>>((float4 *)output, (float4 *)input, (float4 *)weight, 176 | rows, columns, epsilon); 177 | return true; 178 | } 179 | return false; 180 | } 181 | 182 | template bool rms_norm(nv_half *__restrict__ output, 183 | const nv_half *__restrict__ input, 184 | const nv_half *__restrict__ weight, int rows, 185 | int columns, float epsilon); 186 | template bool rms_norm(nv_bfloat16 *__restrict__ output, 187 | const nv_bfloat16 *__restrict__ input, 188 | const nv_bfloat16 *__restrict__ weight, int rows, 189 | int columns, float epsilon); 190 | -------------------------------------------------------------------------------- /benchmarks/nvbench/sgmv.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cutlass/util/device_memory.h" 9 | #include "nvbench/nvbench.cuh" 10 | #include "sgmv/sgmv_cutlass.cuh" 11 | 12 | template 13 | void sgmv_cpu_reference(T *y, T *x, T **w, int *s, int num_problems, int d_in, 14 | int d_out, int layer_idx) { 15 | for (int p = 0; p < num_problems; p++) { 16 | for (int i = s[p]; i < s[p + 1]; i++) { 17 | for (int j = 0; j < d_out; j++) { 18 | float accum = y[i * d_out + j]; 19 | for (int k = 0; k < d_in; k++) { 20 | accum += float(x[i * d_in + k]) * 21 | float(w[p][layer_idx * d_in * d_out + k * d_out + j]); 22 | } 23 | y[i * d_out + j] = accum; 24 | } 25 | } 26 | } 27 | } 28 | 29 | template 30 | bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) { 31 | float a_f32 = static_cast(a); 32 | float b_f32 = static_cast(b); 33 | return fabs(a_f32 - b_f32) <= (atol + rtol * fabs(b_f32)); 34 | } 35 | 36 | void bench_sgmv(nvbench::state &state) { 37 | using cutlass_t = typename cutlass_dtype::type; 38 | auto problem_size_s = state.get_string("problem_size"); 39 | int num_problems = state.get_int64("num_problems"); 40 | int d_in = state.get_int64("d_in"); 41 | int d_out = state.get_int64("d_out"); 42 | int num_loras = 100; 43 | int num_layers = 3; 44 | cudaStream_t stream = nullptr; 45 | std::mt19937 gen(0xabcdabcd987); 46 | std::normal_distribution dis; 47 | 48 | int problem_size; 49 | if (problem_size_s == "num_problems") { 50 | problem_size = num_problems; 51 | } else { 52 | try { 53 | problem_size = std::stoi(problem_size_s); 54 | } catch (...) { 55 | state.skip("problem_size is not valid"); 56 | return; 57 | } 58 | } 59 | 60 | std::vector s(num_problems + 1); 61 | std::vector> w_all(num_loras); 62 | s[0] = 0; 63 | for (size_t b = 1; b < num_problems + 1; b++) { 64 | s[b] = s[b - 1] + problem_size; 65 | } 66 | int batch_size = s.back(); 67 | state.add_summary("batch_size").set_int64("value", batch_size); 68 | 69 | std::vector x(batch_size * d_in); 70 | std::vector y_init(batch_size * d_out); 71 | 72 | // random init x, w, y with normal distribution 73 | for (size_t i = 0; i < batch_size * d_in; i++) { 74 | x[i] = dis(gen); 75 | } 76 | for (size_t i = 0; i < num_loras; i++) { 77 | w_all[i].resize(num_layers * d_in * d_out); 78 | for (size_t j = 0; j < w_all[i].size(); j++) { 79 | w_all[i][j] = dis(gen); 80 | } 81 | } 82 | for (size_t i = 0; i < batch_size * d_out; i++) { 83 | y_init[i] = dis(gen); 84 | } 85 | 86 | // copy std vector x, w, y to thrust device vector 87 | thrust::device_vector x_d(x.begin(), x.end()); 88 | std::vector> w_all_d; 89 | for (size_t i = 0; i < num_loras; i++) { 90 | w_all_d.emplace_back(w_all[i].begin(), w_all[i].end()); 91 | } 92 | thrust::device_vector s_d(s.begin(), s.end()); 93 | 94 | // build w ptr 95 | std::vector w; 96 | for (int i = 0; i < num_problems; ++i) { 97 | w.push_back(w_all[i].data()); 98 | } 99 | std::vector w_gpu_ptr; 100 | for (int i = 0; i < num_problems; ++i) { 101 | w_gpu_ptr.push_back(thrust::raw_pointer_cast(w_all_d[i].data())); 102 | } 103 | thrust::device_vector w_d(w_gpu_ptr.begin(), w_gpu_ptr.end()); 104 | cutlass::DeviceAllocation tmp_d(sgmv_tmp_size(num_problems)); 105 | 106 | for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { 107 | // call cpu_reference function 108 | std::vector y_cpu_ref(y_init.begin(), y_init.end()); 109 | sgmv_cpu_reference(y_cpu_ref.data(), x.data(), w.data(), s.data(), 110 | num_problems, d_in, d_out, layer_idx); 111 | 112 | // call sgmv function 113 | thrust::device_vector y_d(y_init.begin(), y_init.end()); 114 | sgmv(thrust::raw_pointer_cast(y_d.data()), 115 | thrust::raw_pointer_cast(x_d.data()), 116 | thrust::raw_pointer_cast(w_d.data()), 117 | thrust::raw_pointer_cast(s_d.data()), tmp_d.get(), num_problems, d_in, 118 | d_out, layer_idx, stream); 119 | 120 | // copy thrust device_vector y_d to std vector y_h 121 | thrust::host_vector y_h = y_d; 122 | 123 | // compare y_h and y_cpu_ref 124 | for (size_t i = 0; i < batch_size * d_out; i++) { 125 | if (!isclose(float(y_h[i]), float(y_cpu_ref[i]), 1e-3, 1e-3)) { 126 | state.skip("y_h and y_cpu_ref are not close"); 127 | printf("layer_idx=%i, i=%zu, ref=%f, our=%f, diff=%f\n", layer_idx, i, 128 | float(y_cpu_ref[i]), float(y_h[i]), 129 | float(y_h[i]) - float(y_cpu_ref[i])); 130 | return; 131 | } 132 | } 133 | } 134 | 135 | // nvbench sgmv kernel 136 | state.add_global_memory_reads( 137 | batch_size * d_in * sizeof(half) // x 138 | + num_problems * d_in * d_out * sizeof(half) // w 139 | + (num_problems + 1) * sizeof(int32_t) // s 140 | ); 141 | state.add_global_memory_writes(batch_size * d_out * sizeof(half)); 142 | 143 | thrust::device_vector y_d(y_init.begin(), y_init.end()); 144 | state.exec(nvbench::exec_tag::sync, [&](nvbench::launch &) { 145 | int layer_idx = 0; 146 | sgmv(thrust::raw_pointer_cast(y_d.data()), 147 | thrust::raw_pointer_cast(x_d.data()), 148 | (half **)thrust::raw_pointer_cast(w_d.data()), 149 | thrust::raw_pointer_cast(s_d.data()), tmp_d.get(), num_problems, d_in, 150 | d_out, layer_idx, stream); 151 | }); 152 | } 153 | 154 | int wide_dim = 4096; 155 | int narrow_dim = 16; 156 | 157 | std::vector num_problems = {1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 158 | 14, 16, 20, 24, 28, 32, 40, 48, 56, 64}; 159 | std::vector problem_size = { 160 | "1", "2", "3", "4", "5", "6", "7", "8", "10", "12", 161 | "14", "16", "20", "24", "28", "32", "40", "48", "56", "64"}; 162 | 163 | NVBENCH_BENCH(bench_sgmv) 164 | .set_name("sgmv_expand_NxN") 165 | .add_int64_axis("d_in", {narrow_dim}) 166 | .add_int64_axis("d_out", {wide_dim}) 167 | .add_string_axis("problem_size", {"num_problems"}) 168 | .add_int64_axis("num_problems", num_problems); 169 | 170 | NVBENCH_BENCH(bench_sgmv) 171 | .set_name("sgmv_expand_bgmv") 172 | .add_int64_axis("d_in", {narrow_dim}) 173 | .add_int64_axis("d_out", {wide_dim}) 174 | .add_string_axis("problem_size", {"1"}) 175 | .add_int64_axis("num_problems", num_problems); 176 | 177 | NVBENCH_BENCH(bench_sgmv) 178 | .set_name("sgmv_expand_bmm") 179 | .add_int64_axis("d_in", {narrow_dim}) 180 | .add_int64_axis("d_out", {wide_dim}) 181 | .add_string_axis("problem_size", problem_size) 182 | .add_int64_axis("num_problems", {1}); 183 | 184 | NVBENCH_BENCH(bench_sgmv) 185 | .set_name("sgmv_expand_fixed-num-problems") 186 | .add_int64_axis("d_in", {narrow_dim}) 187 | .add_int64_axis("d_out", {wide_dim}) 188 | .add_string_axis("problem_size", problem_size) 189 | .add_int64_axis("num_problems", {8}); 190 | 191 | NVBENCH_BENCH(bench_sgmv) 192 | .set_name("sgmv_shrink_NxN") 193 | .add_int64_axis("d_in", {wide_dim}) 194 | .add_int64_axis("d_out", {narrow_dim}) 195 | .add_string_axis("problem_size", {"num_problems"}) 196 | .add_int64_axis("num_problems", num_problems); 197 | 198 | NVBENCH_BENCH(bench_sgmv) 199 | .set_name("sgmv_shrink_bgmv") 200 | .add_int64_axis("d_in", {wide_dim}) 201 | .add_int64_axis("d_out", {narrow_dim}) 202 | .add_string_axis("problem_size", {"1"}) 203 | .add_int64_axis("num_problems", num_problems); 204 | 205 | NVBENCH_BENCH(bench_sgmv) 206 | .set_name("sgmv_shrink_bmm") 207 | .add_int64_axis("d_in", {wide_dim}) 208 | .add_int64_axis("d_out", {narrow_dim}) 209 | .add_string_axis("problem_size", problem_size) 210 | .add_int64_axis("num_problems", {1}); 211 | 212 | NVBENCH_BENCH(bench_sgmv) 213 | .set_name("sgmv_shrink_fixed-num-problems") 214 | .add_int64_axis("d_in", {wide_dim}) 215 | .add_int64_axis("d_out", {narrow_dim}) 216 | .add_string_axis("problem_size", problem_size) 217 | .add_int64_axis("num_problems", {8}); 218 | -------------------------------------------------------------------------------- /csrc/bgmv/bgmv_impl.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "flashinfer/vec_dtypes.cuh" 10 | 11 | namespace cg = cooperative_groups; 12 | 13 | // nthrs = (32, 4) 14 | template 15 | __global__ void bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X, 16 | const T* __restrict__ W, 17 | const int64_t* __restrict__ indicies, 18 | int64_t num_layers, int64_t layer_idx, 19 | float scale) { 20 | auto block = cg::this_thread_block(); 21 | size_t j = blockIdx.x; 22 | size_t batch_idx = blockIdx.y; 23 | constexpr size_t vec_size = 16 / sizeof(T); 24 | constexpr size_t tx = 32; 25 | constexpr size_t ty = 4; 26 | constexpr size_t num_pipeline_stages = 2; 27 | constexpr size_t tile_size = tx * ty * vec_size; 28 | __shared__ T W_shared[num_pipeline_stages * tile_size]; 29 | __shared__ T X_shared[num_pipeline_stages * tile_size]; 30 | __shared__ float y_warpwise[ty]; 31 | 32 | int64_t idx = indicies[batch_idx] * num_layers + layer_idx; 33 | 34 | size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; 35 | size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; 36 | auto pipe = cuda::make_pipeline(); 37 | 38 | // pipeline load W/X and compute WX; 39 | pipe.producer_acquire(); 40 | cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, 41 | W + (idx * feat_out + j) * feat_in + 42 | (threadIdx.y * tx + threadIdx.x) * vec_size, 43 | cuda::aligned_size_t<16>(16), pipe); 44 | cuda::memcpy_async( 45 | X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, 46 | X + (batch_idx * feat_in) + (threadIdx.y * tx + threadIdx.x) * vec_size, 47 | cuda::aligned_size_t<16>(16), pipe); 48 | pipe.producer_commit(); 49 | size_t copy_idx, compute_idx; 50 | float y = 0.f; 51 | flashinfer::vec_t x_vec, w_vec; 52 | size_t tile_idx; 53 | 54 | #pragma unroll 55 | for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; 56 | ++tile_idx) { 57 | copy_idx = tile_idx % num_pipeline_stages; 58 | // pipeline stage: async copy W fragment 59 | pipe.producer_acquire(); 60 | if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { 61 | cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + 62 | (threadIdx.y * tx + threadIdx.x) * vec_size, 63 | W + (idx * feat_out + j) * feat_in + 64 | tile_idx * tile_size + 65 | (threadIdx.y * tx + threadIdx.x) * vec_size, 66 | cuda::aligned_size_t<16>(16), pipe); 67 | cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + 68 | (threadIdx.y * tx + threadIdx.x) * vec_size, 69 | X + (batch_idx * feat_in) + tile_idx * tile_size + 70 | (threadIdx.y * tx + threadIdx.x) * vec_size, 71 | cuda::aligned_size_t<16>(16), pipe); 72 | } 73 | pipe.producer_commit(); 74 | 75 | compute_idx = (tile_idx - 1) % num_pipeline_stages; 76 | // pipeline stage: compute WX 77 | pipe.consumer_wait(); 78 | block.sync(); 79 | x_vec.load(X_shared + X_shared_offset[compute_idx] + 80 | (threadIdx.y * tx + threadIdx.x) * vec_size); 81 | w_vec.load(W_shared + W_shared_offset[compute_idx] + 82 | (threadIdx.y * tx + threadIdx.x) * vec_size); 83 | float sum = 0.f; 84 | #pragma unroll 85 | for (size_t i = 0; i < vec_size; ++i) { 86 | sum += float(w_vec[i]) * float(x_vec[i]) * scale; 87 | } 88 | #pragma unroll 89 | for (size_t offset = tx / 2; offset > 0; offset /= 2) { 90 | sum += __shfl_down_sync(0xffffffff, sum, offset); 91 | } 92 | y_warpwise[threadIdx.y] = sum; 93 | block.sync(); 94 | #pragma unroll 95 | for (size_t i = 0; i < ty; ++i) { 96 | y += y_warpwise[i]; 97 | } 98 | 99 | block.sync(); 100 | pipe.consumer_release(); 101 | } 102 | 103 | compute_idx = (tile_idx - 1) % num_pipeline_stages; 104 | // final pipeline stage 105 | pipe.consumer_wait(); 106 | block.sync(); 107 | x_vec.load(X_shared + X_shared_offset[compute_idx] + 108 | (threadIdx.y * tx + threadIdx.x) * vec_size); 109 | w_vec.load(W_shared + W_shared_offset[compute_idx] + 110 | (threadIdx.y * tx + threadIdx.x) * vec_size); 111 | float sum = 0.f; 112 | #pragma unroll 113 | for (size_t i = 0; i < vec_size; ++i) { 114 | sum += float(w_vec[i]) * float(x_vec[i]) * scale; 115 | } 116 | #pragma unroll 117 | for (size_t offset = tx / 2; offset > 0; offset /= 2) { 118 | sum += __shfl_down_sync(0xffffffff, sum, offset); 119 | } 120 | y_warpwise[threadIdx.y] = 121 | ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) 122 | ? sum 123 | : 0.f; 124 | block.sync(); 125 | #pragma unroll 126 | for (size_t i = 0; i < ty; ++i) { 127 | y += y_warpwise[i]; 128 | } 129 | 130 | block.sync(); 131 | pipe.consumer_release(); 132 | 133 | // write Y; 134 | if (block.thread_rank() == 0) { 135 | Y[batch_idx * feat_out + j] += y; 136 | } 137 | } 138 | 139 | // nthrs = (2, 16, 4) 140 | template 141 | __global__ void bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X, 142 | const T* __restrict__ W, 143 | const int64_t* __restrict__ indicies, 144 | int64_t num_layers, int64_t layer_idx, 145 | float scale) { 146 | auto block = cg::this_thread_block(); 147 | constexpr size_t vec_size = 16 / sizeof(T); 148 | constexpr size_t tx = feat_in / vec_size; 149 | static_assert(feat_in % vec_size == 0); 150 | constexpr size_t ty = 32 / tx; 151 | static_assert(32 % tx == 0); 152 | constexpr size_t tz = 4; 153 | size_t tile_idx = blockIdx.x; 154 | size_t batch_idx = blockIdx.y; 155 | int64_t idx = indicies[batch_idx] * num_layers + layer_idx; 156 | 157 | // load X; 158 | flashinfer::vec_t x_vec; 159 | x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); 160 | 161 | // load W; 162 | flashinfer::vec_t w_vec; 163 | w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + 164 | block.thread_rank() * vec_size); 165 | 166 | float sum = 0.f; 167 | #pragma unroll 168 | for (size_t i = 0; i < vec_size; ++i) { 169 | sum += float(w_vec[i]) * float(x_vec[i]) * scale; 170 | } 171 | 172 | cg::thread_block_tile g = cg::tiled_partition(block); 173 | #pragma unroll 174 | for (size_t offset = tx / 2; offset > 0; offset /= 2) { 175 | sum += g.shfl_down(sum, offset); 176 | } 177 | sum = g.shfl(sum, 0); 178 | 179 | if (threadIdx.x == 0) { 180 | Y[batch_idx * feat_out + tile_idx * (tz * ty) + threadIdx.z * ty + 181 | threadIdx.y] += sum; 182 | } 183 | } 184 | 185 | template 186 | void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X, 187 | const T* __restrict__ W, const int64_t* __restrict__ indicies, 188 | int64_t batch_size, int64_t num_layers, int64_t layer_idx, 189 | float scale) { 190 | size_t vec_size = 16 / sizeof(T); 191 | if constexpr (feat_in < feat_out) { 192 | size_t tx = feat_in / vec_size; 193 | size_t ty = 32 / tx; 194 | size_t tz = 4; 195 | dim3 nblks(feat_out / (ty * tz), batch_size); 196 | dim3 nthrs(tx, ty, tz); 197 | 198 | bgmv_expand_kernel 199 | <<>>(Y, X, W, indicies, num_layers, layer_idx, scale); 200 | } else { 201 | assert(feat_in % (vec_size * 32) == 0); 202 | dim3 nblks(feat_out, batch_size); 203 | dim3 nthrs(32, 4); 204 | bgmv_shrink_kernel 205 | <<>>(Y, X, W, indicies, num_layers, layer_idx, scale); 206 | } 207 | } 208 | 209 | #define INST_BGMV(feat_in, feat_out, T) \ 210 | template void bgmv_kernel( \ 211 | T* __restrict__ Y, const T* __restrict__ X, const T* __restrict__ W, \ 212 | const int64_t* __restrict__ indicies, int64_t batch_size, \ 213 | int64_t num_layers, int64_t layer_idx, float scale); 214 | 215 | #define INST_BGMV_TWOSIDE(T, narrow, wide) \ 216 | INST_BGMV(narrow, wide, T) \ 217 | INST_BGMV(wide, narrow, T) 218 | -------------------------------------------------------------------------------- /csrc/flashinfer_adapter/flashinfer_all.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "flashinfer/page.cuh" 6 | #include "flashinfer_config.h" 7 | #include "flashinfer_decl.h" 8 | #include "generated/dispatch.inc" 9 | 10 | using flashinfer::paged_kv_t; 11 | using flashinfer::PageStorage; 12 | using flashinfer::RotaryMode; 13 | 14 | #define _DISPATCH_SWITCH(cond, ...) \ 15 | [&]() -> bool { \ 16 | switch (cond) { \ 17 | __VA_ARGS__ \ 18 | default: \ 19 | return false; \ 20 | } \ 21 | }() 22 | 23 | #define _DISPATCH_CASE(case_expr, var, ...) \ 24 | case case_expr: { \ 25 | constexpr auto var = case_expr; \ 26 | return __VA_ARGS__(); \ 27 | } 28 | 29 | #define DISPATCH_group_size(expr, ...) \ 30 | _DISPATCH_SWITCH(expr, _DISPATCH_CASES_group_size(__VA_ARGS__)) 31 | 32 | #define DISPATCH_page_size(expr, ...) \ 33 | _DISPATCH_SWITCH(expr, _DISPATCH_CASES_page_size(__VA_ARGS__)) 34 | 35 | #define DISPATCH_head_dim(expr, ...) \ 36 | _DISPATCH_SWITCH(expr, _DISPATCH_CASES_head_dim(__VA_ARGS__)) 37 | 38 | namespace { 39 | template 40 | inline T* alloc_from_buf(void** buf, int n) { 41 | auto* p = (T*)*buf; 42 | *buf = (void*)(p + n); 43 | return p; 44 | } 45 | } // namespace 46 | 47 | template 48 | bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs, 49 | int32_t* kv_indptr, int32_t* last_page_offset, 50 | void* tmpbuf, int head_dim, int num_layers, 51 | int layer_idx, int group_size, 52 | int num_kv_heads, int page_size, 53 | int batch_size) { 54 | return DISPATCH_page_size(page_size, [&] { 55 | return DISPATCH_group_size(group_size, [&] { 56 | return DISPATCH_head_dim(head_dim, [&] { 57 | auto kv_aux = alloc_from_buf(&tmpbuf, 4 * (batch_size + 1)); 58 | paged_kv_t paged_kv( 59 | num_layers, layer_idx, num_kv_heads, page_size, head_dim, 60 | batch_size, kv_ptrs, kv_indptr, last_page_offset, kv_aux); 61 | int num_qo_heads = num_kv_heads * group_size; 62 | constexpr bool allow_fp16_qk_reduction = false; 63 | constexpr bool causal = true; 64 | constexpr auto rotary = RotaryMode::kLlama; 65 | float rope_scale = 1.f; 66 | float rope_theta = 1e4; 67 | cudaStream_t stream = nullptr; 68 | auto status = flashinfer::BatchPrefillWithPagedKVCacheDispatched< 69 | PAGE_SIZE, GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, rotary, 70 | allow_fp16_qk_reduction, causal>(q, paged_kv, qo_indptr, o, 71 | (float*)tmpbuf, num_qo_heads, 72 | rope_scale, rope_theta, stream); 73 | if (status != cudaSuccess) { 74 | fprintf(stderr, "batch_prefill failed: %s\n", 75 | cudaGetErrorString(status)); 76 | } 77 | return true; 78 | }); 79 | }); 80 | }); 81 | } 82 | 83 | template 84 | bool FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr, 85 | int32_t* last_page_offset, void* tmpbuf, 86 | int head_dim, int num_layers, int layer_idx, 87 | int group_size, int num_kv_heads, 88 | int page_size, int batch_size) { 89 | return DISPATCH_page_size(page_size, [&] { 90 | return DISPATCH_group_size(group_size, [&] { 91 | return DISPATCH_head_dim(head_dim, [&] { 92 | auto kv_aux = alloc_from_buf(&tmpbuf, 4 * (batch_size + 1)); 93 | paged_kv_t paged_kv( 94 | num_layers, layer_idx, num_kv_heads, page_size, head_dim, 95 | batch_size, kv_ptrs, kv_indptr, last_page_offset, kv_aux); 96 | constexpr auto rotary = RotaryMode::kLlama; 97 | float rope_scale = 1.f; 98 | float rope_theta = 1e4; 99 | cudaStream_t stream = nullptr; 100 | auto status = flashinfer::BatchDecodeWithPagedKVCacheDispatched< 101 | PAGE_SIZE, GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, rotary>( 102 | q, paged_kv, o, nullptr, rope_scale, rope_theta, stream); 103 | if (status != cudaSuccess) { 104 | fprintf(stderr, "batch_decode failed: %s\n", 105 | cudaGetErrorString(status)); 106 | } 107 | return true; 108 | }); 109 | }); 110 | }); 111 | } 112 | 113 | template 114 | void FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr, 115 | int32_t* last_page_offset, T* key, T* value, 116 | int32_t* seqlen_indptr, int num_layers, 117 | int layer_idx, int num_kv_heads, int page_size, 118 | int batch_size) { 119 | paged_kv_t paged_kv( 120 | num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size, 121 | kv_ptrs, kv_indptr, last_page_offset); 122 | 123 | constexpr size_t vec_size = 124 | std::max(16 / sizeof(T), static_cast(head_dim / 32)); 125 | constexpr size_t bdx = head_dim / vec_size; 126 | constexpr size_t bdy = 1; 127 | dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy); 128 | dim3 nthrs(bdx, bdy); 129 | flashinfer::AppendPagedKVCachePrefillKernel 131 | <<>>(paged_kv, key, value, seqlen_indptr); 132 | } 133 | 134 | template 135 | void FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr, 136 | int32_t* last_page_offset, T* key, T* value, 137 | int num_layers, int layer_idx, int num_kv_heads, 138 | int page_size, int batch_size) { 139 | paged_kv_t paged_kv( 140 | num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size, 141 | kv_ptrs, kv_indptr, last_page_offset); 142 | 143 | constexpr size_t vec_size = 144 | std::max(16 / sizeof(T), static_cast(head_dim / 32)); 145 | constexpr size_t bdx = head_dim / vec_size; 146 | constexpr size_t bdy = 1; 147 | dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy); 148 | dim3 nthrs(bdx, bdy); 149 | flashinfer::AppendPagedKVCacheDecodeKernel 151 | <<>>(paged_kv, key, value); 152 | } 153 | 154 | #define INST_FlashInferBatchPrefillKernel(T) \ 155 | template bool FlashInferBatchPrefillKernel( \ 156 | T * o, T * q, int32_t * qo_indptr, T * *kv_ptrs, int32_t * kv_indptr, \ 157 | int32_t * last_page_offset, void* tmpbuf, int head_dim, int num_layers, \ 158 | int layer_idx, int group_size, int num_kv_heads, int page_size, \ 159 | int batch_size); 160 | INST_FlashInferBatchPrefillKernel(nv_half); 161 | INST_FlashInferBatchPrefillKernel(nv_bfloat16); 162 | 163 | #define INST_FlashInferBatchDecodeKernel(T) \ 164 | template bool FlashInferBatchDecodeKernel( \ 165 | T * o, T * q, T * *kv_ptrs, int32_t * kv_indptr, \ 166 | int32_t * last_page_offset, void* tmpbuf, int head_dim, int num_layers, \ 167 | int layer_idx, int group_size, int num_kv_heads, int page_size, \ 168 | int batch_size); 169 | INST_FlashInferBatchDecodeKernel(nv_half); 170 | INST_FlashInferBatchDecodeKernel(nv_bfloat16); 171 | 172 | #define INST_FlashInferInitKvKernel(head_dim, T) \ 173 | template void FlashInferInitKvKernel( \ 174 | T * *kv_ptrs, int32_t * kv_indptr, int32_t * last_page_offset, T * key, \ 175 | T * value, int32_t * seqlen_indptr, int num_layers, int layer_idx, \ 176 | int num_kv_heads, int page_size, int batch_size); 177 | FOR_FlashInferBatchDecode_D(INST_FlashInferInitKvKernel, nv_half); 178 | FOR_FlashInferBatchDecode_D(INST_FlashInferInitKvKernel, nv_bfloat16); 179 | 180 | #define INST_FlashInferAppendKvKernel(head_dim, T) \ 181 | template void FlashInferAppendKvKernel( \ 182 | T * *kv_ptrs, int32_t * kv_indptr, int32_t * last_page_offset, T * key, \ 183 | T * value, int num_layers, int layer_idx, int num_kv_heads, \ 184 | int page_size, int batch_size); 185 | FOR_FlashInferBatchDecode_D(INST_FlashInferAppendKvKernel, nv_half); 186 | FOR_FlashInferBatchDecode_D(INST_FlashInferAppendKvKernel, nv_bfloat16); 187 | -------------------------------------------------------------------------------- /src/punica/models/llama.py: -------------------------------------------------------------------------------- 1 | # Adapted from HuggingFace Transformers Library 2 | # https://github.com/huggingface/transformers/blob/17a55534f5e5df10ac4804d4270bf6b8cc24998d/src/transformers/models/llama/modeling_llama.py 3 | 4 | import math 5 | 6 | import torch 7 | from torch import nn 8 | from transformers.models.llama.modeling_llama import ( 9 | LlamaConfig, 10 | LlamaMLP, 11 | PreTrainedModel, 12 | ) 13 | 14 | from punica.ops import append_kv, batch_decode, batch_prefill, init_kv, rms_norm 15 | from punica.utils import BatchedKvCache, BatchLenInfo 16 | 17 | 18 | class LlamaAttention(nn.Module): 19 | def __init__(self, config: LlamaConfig, layer_idx: int): 20 | super().__init__() 21 | self.config = config 22 | self.hidden_size = config.hidden_size 23 | self.num_qo_heads = config.num_attention_heads 24 | self.num_kv_heads = config.num_key_value_heads 25 | self.num_kv_groups = self.num_qo_heads // self.num_kv_heads 26 | self.head_dim = self.hidden_size // self.num_qo_heads 27 | self._scale = 1 / math.sqrt(self.head_dim) 28 | self.layer_idx = layer_idx 29 | 30 | assert self.head_dim * self.num_qo_heads == self.hidden_size 31 | assert self.num_kv_heads * self.num_kv_groups == self.num_qo_heads 32 | self.q_proj = nn.Linear( 33 | self.hidden_size, self.num_qo_heads * self.head_dim, bias=False 34 | ) 35 | self.k_proj = nn.Linear( 36 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False 37 | ) 38 | self.v_proj = nn.Linear( 39 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False 40 | ) 41 | self.o_proj = nn.Linear( 42 | self.num_qo_heads * self.head_dim, self.hidden_size, bias=False 43 | ) 44 | 45 | def forward( 46 | self, 47 | hidden_states: torch.Tensor, 48 | blen: BatchLenInfo, 49 | prefill_kv: BatchedKvCache | None, 50 | decode_kv: BatchedKvCache | None, 51 | ) -> torch.Tensor: 52 | torch.cuda.nvtx.range_push("qkv_proj") 53 | q_proj = self.q_proj(hidden_states) 54 | k_proj = self.k_proj(hidden_states) 55 | v_proj = self.v_proj(hidden_states) 56 | torch.cuda.nvtx.range_pop() 57 | stack_attn_output = [] 58 | 59 | if len(blen.prefills) > 0: 60 | assert prefill_kv is not None 61 | assert blen.indptr is not None 62 | q = q_proj[: blen.doff].view(blen.doff, self.num_qo_heads, self.head_dim) 63 | k = k_proj[: blen.doff].view(blen.doff, self.num_kv_heads, self.head_dim) 64 | v = v_proj[: blen.doff].view(blen.doff, self.num_kv_heads, self.head_dim) 65 | 66 | torch.cuda.nvtx.range_push("init_kv") 67 | init_kv(prefill_kv, k, v, blen.indptr, self.layer_idx) 68 | torch.cuda.nvtx.range_pop() 69 | 70 | torch.cuda.nvtx.range_push("batch_prefill") 71 | attn_output = batch_prefill(q, blen.indptr, prefill_kv, self.layer_idx) 72 | attn_output = attn_output.view(blen.doff, self.hidden_size) 73 | stack_attn_output.append(attn_output) 74 | torch.cuda.nvtx.range_pop() 75 | 76 | if blen.decode > 0: 77 | q = q_proj[blen.doff :].view(blen.decode, self.num_qo_heads, self.head_dim) 78 | k = k_proj[blen.doff :].view(blen.decode, self.num_kv_heads, self.head_dim) 79 | v = v_proj[blen.doff :].view(blen.decode, self.num_kv_heads, self.head_dim) 80 | 81 | torch.cuda.nvtx.range_push("append_kv") 82 | assert decode_kv is not None 83 | append_kv(decode_kv, k, v, self.layer_idx) 84 | torch.cuda.nvtx.range_pop() 85 | 86 | torch.cuda.nvtx.range_push("batch_decode") 87 | attn_outputs = batch_decode(q, decode_kv, self.layer_idx) 88 | attn_outputs = attn_outputs.view(blen.decode, self.hidden_size) 89 | stack_attn_output.append(attn_outputs) 90 | torch.cuda.nvtx.range_pop() 91 | 92 | if len(stack_attn_output) == 1: 93 | attn_outputs = stack_attn_output[0] 94 | else: 95 | attn_outputs = torch.cat(stack_attn_output, dim=0) 96 | 97 | # output projection 98 | torch.cuda.nvtx.range_push("o_proj") 99 | attn_output = self.o_proj(attn_outputs) 100 | torch.cuda.nvtx.range_pop() 101 | 102 | return attn_output 103 | 104 | 105 | class LlamaRMSNorm(nn.Module): 106 | def __init__(self, hidden_size, eps=1e-6): 107 | super().__init__() 108 | self.weight = nn.Parameter(torch.ones(hidden_size)) 109 | self.variance_epsilon = eps 110 | 111 | def forward(self, hidden_states): 112 | return rms_norm(hidden_states, self.weight, self.variance_epsilon) 113 | 114 | 115 | class LlamaDecoderLayer(nn.Module): 116 | def __init__(self, config: LlamaConfig, layer_idx: int): 117 | super().__init__() 118 | self.hidden_size = config.hidden_size 119 | self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) 120 | self.mlp = LlamaMLP(config) 121 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 122 | self.post_attention_layernorm = LlamaRMSNorm( 123 | config.hidden_size, eps=config.rms_norm_eps 124 | ) 125 | 126 | def forward( 127 | self, 128 | hidden_states: torch.Tensor, 129 | blen: BatchLenInfo, 130 | prefill_kv: BatchedKvCache | None, 131 | decode_kv: BatchedKvCache | None, 132 | ) -> torch.Tensor: 133 | residual = hidden_states 134 | 135 | torch.cuda.nvtx.range_push("input_norm") 136 | hidden_states = self.input_layernorm(hidden_states) 137 | torch.cuda.nvtx.range_pop() 138 | 139 | # Self Attention 140 | torch.cuda.nvtx.range_push("LlamaAttention") 141 | hidden_states = self.self_attn(hidden_states, blen, prefill_kv, decode_kv) 142 | torch.cuda.nvtx.range_pop() 143 | torch.cuda.nvtx.range_push("r") 144 | hidden_states = residual + hidden_states 145 | torch.cuda.nvtx.range_pop() 146 | 147 | # Fully Connected 148 | residual = hidden_states 149 | torch.cuda.nvtx.range_push("norm") 150 | hidden_states = self.post_attention_layernorm(hidden_states) 151 | torch.cuda.nvtx.range_pop() 152 | torch.cuda.nvtx.range_push("mlp") 153 | hidden_states = self.mlp(hidden_states) 154 | torch.cuda.nvtx.range_pop() 155 | torch.cuda.nvtx.range_push("r") 156 | hidden_states = residual + hidden_states 157 | torch.cuda.nvtx.range_pop() 158 | 159 | return hidden_states 160 | 161 | 162 | class LlamaPreTrainedModel(PreTrainedModel): 163 | config_class = LlamaConfig 164 | base_model_prefix = "model" 165 | supports_gradient_checkpointing = False 166 | _no_split_modules = ["LlamaDecoderLayer"] 167 | _keys_to_ignore_on_load_unexpected = [ 168 | r"decoder\.version", 169 | r"self_attn\.rotary_emb\.inv_freq", 170 | ] 171 | 172 | 173 | class LlamaModel(LlamaPreTrainedModel): 174 | def __init__(self, config: LlamaConfig): 175 | super().__init__(config) 176 | self.padding_idx = config.pad_token_id 177 | self.vocab_size = config.vocab_size 178 | self.embed_tokens = nn.Embedding( 179 | config.vocab_size, config.hidden_size, self.padding_idx 180 | ) 181 | self.layers = nn.ModuleList( 182 | [LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)] 183 | ) 184 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 185 | self.post_init() 186 | 187 | def forward( 188 | self, 189 | input_ids: torch.Tensor, 190 | blen: BatchLenInfo, 191 | prefill_kv: BatchedKvCache | None, 192 | decode_kv: BatchedKvCache | None, 193 | ) -> torch.Tensor: 194 | torch.cuda.nvtx.range_push("embed") 195 | hidden_states = self.embed_tokens(input_ids) 196 | torch.cuda.nvtx.range_pop() 197 | 198 | for layer_idx, decoder_layer in enumerate(self.layers): 199 | torch.cuda.nvtx.range_push(f"layer={layer_idx}") 200 | hidden_states = decoder_layer(hidden_states, blen, prefill_kv, decode_kv) 201 | torch.cuda.nvtx.range_pop() 202 | 203 | torch.cuda.nvtx.range_push("lastnorm") 204 | hidden_states = self.norm(hidden_states) 205 | torch.cuda.nvtx.range_pop() 206 | 207 | return hidden_states 208 | 209 | 210 | class LlamaForCausalLM(LlamaPreTrainedModel): 211 | def __init__(self, config): 212 | super().__init__(config) 213 | self.model = LlamaModel(config) 214 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 215 | self.post_init() 216 | 217 | def forward( 218 | self, 219 | input_ids: torch.Tensor, 220 | blen: BatchLenInfo, 221 | prefill_kv: BatchedKvCache | None, 222 | decode_kv: BatchedKvCache | None, 223 | ) -> tuple[torch.Tensor, torch.Tensor]: 224 | torch.cuda.nvtx.range_push("LlamaForCausalLM") 225 | hidden_states = self.model(input_ids, blen, prefill_kv, decode_kv) 226 | torch.cuda.nvtx.range_push("lm_head") 227 | logits = self.lm_head(hidden_states) 228 | torch.cuda.nvtx.range_pop() 229 | torch.cuda.nvtx.range_pop() 230 | return logits, hidden_states 231 | -------------------------------------------------------------------------------- /tests/test_flashinfer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from punica import BatchedKvCache, BatchLenInfo, KvCache, KvPool 6 | from punica.ops import append_kv, batch_decode, batch_prefill, init_kv 7 | 8 | num_layers = 3 9 | num_qo_heads = 32 10 | head_dim = 128 11 | batch_size = 7 12 | page_len = 16 13 | maxlen = 500 14 | device = torch.device("cuda:0") 15 | 16 | 17 | def assert_close(a, b): 18 | rtol, atol = { 19 | torch.float16: (1e-3, 5e-4), 20 | torch.bfloat16: (8e-3, 8e-3), 21 | }[a.dtype] 22 | torch.testing.assert_close(a, b, rtol=rtol, atol=atol) 23 | 24 | 25 | def assert_eq(a, b): 26 | torch.testing.assert_close(a, b, rtol=0, atol=0) 27 | 28 | 29 | def rotate_half(x): 30 | x1 = x[..., : x.shape[-1] // 2] 31 | x2 = x[..., x.shape[-1] // 2 :] 32 | return torch.cat((-x2, x1), dim=-1) 33 | 34 | 35 | def rotary_embed(q, beg): 36 | device = q.device 37 | dtype = q.dtype 38 | dim = q.size(-1) 39 | l = q.size(-2) if q.dim() == 3 else 1 40 | 41 | base = 1e4 42 | inv_freq = 1.0 / ( 43 | base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) 44 | ) 45 | t = torch.arange(beg, beg + l, device=device, dtype=dtype) 46 | freqs = torch.einsum("i,j->ij", t, inv_freq) 47 | emb = torch.cat((freqs, freqs), dim=-1) 48 | cos = emb.cos() 49 | sin = emb.sin() 50 | q_embed = (q * cos) + (rotate_half(q) * sin) 51 | return q_embed 52 | 53 | 54 | def repeat_kv(t: torch.Tensor, repeat: int) -> torch.Tensor: 55 | if repeat == 1: 56 | return t 57 | num_kv_heads, seqlen, head_dim = t.shape 58 | t = t[:, None, :, :].expand(num_kv_heads, repeat, seqlen, head_dim) 59 | t = t.reshape(num_kv_heads * repeat, seqlen, head_dim) 60 | return t 61 | 62 | 63 | def ref_batch_prefill( 64 | q: torch.Tensor, 65 | qo_indptr: torch.Tensor, 66 | cs: list[KvCache], 67 | layer_idx: int, 68 | ) -> torch.Tensor: 69 | _, n_qo, _ = q.shape 70 | _l, _2, n_kv, _p, d = cs[0].pool.page_meta.shape 71 | b = len(cs) 72 | assert (b + 1,) == qo_indptr.shape 73 | assert (qo_indptr[-1].item(), n_qo, d) == q.shape 74 | 75 | sm_scale = 1.0 / np.sqrt(d) 76 | out = [] 77 | for i in range(b): 78 | assert qo_indptr[i + 1] - qo_indptr[i] == cs[i].seqlen 79 | s = cs[i].seqlen 80 | 81 | mask = torch.zeros(s, s, dtype=torch.float32, device=q.device) 82 | mask.masked_fill_( 83 | torch.ones(s, s, device=q.device, dtype=torch.bool).tril().logical_not(), 84 | float("-inf"), 85 | ) 86 | 87 | kv_pages = torch.cat(list(cs[i].pages), dim=3)[layer_idx] 88 | ki = kv_pages[0, :, :s, :].contiguous().to(torch.float32) 89 | vi = kv_pages[1, :, :s, :].contiguous().to(torch.float32) 90 | qi = q[qo_indptr[i] : qo_indptr[i + 1]].to(torch.float32) 91 | 92 | qi = rotary_embed(qi.transpose(0, 1), 0).transpose(0, 1) 93 | ki = rotary_embed(ki, 0) 94 | 95 | ki = repeat_kv(ki, n_qo // n_kv) 96 | vi = repeat_kv(vi, n_qo // n_kv) 97 | 98 | pi = torch.einsum("qnd,nkd->nqk", qi, ki) * sm_scale 99 | pi += mask 100 | pi = torch.softmax(pi, dim=-1) 101 | oi = torch.einsum("nqs,nsd->qnd", pi, vi).to(q.dtype) 102 | out.append(oi) 103 | o = torch.cat(out, dim=0) 104 | return o 105 | 106 | 107 | def ref_batch_decode( 108 | q: torch.Tensor, 109 | cs: list[KvCache], 110 | layer_idx: int, 111 | ) -> torch.Tensor: 112 | b, n_qo, _ = q.shape 113 | _l, _2, n_kv, _p, d = cs[0].pool.page_meta.shape 114 | assert (b, n_qo, d) == q.shape 115 | 116 | sm_scale = 1.0 / np.sqrt(d) 117 | out = [] 118 | for i in range(b): 119 | seqlen = cs[i].seqlen 120 | kv_pages = torch.cat(list(cs[i].pages), dim=3)[layer_idx] 121 | ki = kv_pages[0, :, :seqlen, :].contiguous().to(torch.float32) 122 | vi = kv_pages[1, :, :seqlen, :].contiguous().to(torch.float32) 123 | qi = q[i].to(torch.float32) 124 | 125 | qi = rotary_embed(qi, seqlen - 1) 126 | ki = rotary_embed(ki, 0) 127 | 128 | ki = repeat_kv(ki, n_qo // n_kv) 129 | vi = repeat_kv(vi, n_qo // n_kv) 130 | 131 | pi = torch.einsum("nd,nsd->ns", qi, ki) * sm_scale 132 | pi = torch.softmax(pi, dim=-1) 133 | oi = torch.einsum("ns,nsd->nd", pi, vi).to(q.dtype) 134 | out.append(oi) 135 | o = torch.stack(out) 136 | return o 137 | 138 | 139 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 140 | @pytest.mark.parametrize("group_size", [1, 8]) 141 | @torch.inference_mode() 142 | def test_batch_prefill_correctness(dtype_str: str, group_size: int): 143 | torch.manual_seed(0xABCDABCD987) 144 | num_kv_heads = num_qo_heads // group_size 145 | assert num_kv_heads * group_size == num_qo_heads 146 | dtype = getattr(torch, dtype_str) 147 | 148 | pool = KvPool(num_layers, num_kv_heads, head_dim, page_len, dtype, device) 149 | seqlens = torch.randint(1, maxlen, (batch_size,), device="cpu").tolist() 150 | blen = BatchLenInfo(seqlens, 0, device) 151 | q = torch.randn(sum(seqlens), num_qo_heads, head_dim, dtype=dtype, device=device) 152 | cs = [KvCache(pool, l) for l in seqlens] 153 | kv = BatchedKvCache(cs) 154 | for page in pool.allocated_pages(): 155 | page.copy_(torch.rand_like(page)) 156 | 157 | assert blen.indptr is not None 158 | for layer_idx in range(num_layers): 159 | o_ref = ref_batch_prefill(q, blen.indptr, cs, layer_idx) 160 | o_our = batch_prefill(q, blen.indptr, kv, layer_idx) 161 | assert_close(o_ref, o_our) 162 | 163 | 164 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 165 | @pytest.mark.parametrize("group_size", [1, 8]) 166 | @torch.inference_mode() 167 | def test_batch_decode_correctness(dtype_str: str, group_size: int): 168 | torch.manual_seed(0xABCDABCD987) 169 | num_kv_heads = num_qo_heads // group_size 170 | assert num_kv_heads * group_size == num_qo_heads 171 | dtype = getattr(torch, dtype_str) 172 | 173 | pool = KvPool(num_layers, num_kv_heads, head_dim, page_len, dtype, device) 174 | seqlens = torch.randint(1, maxlen, (batch_size,), dtype=torch.int32, device="cpu") 175 | q = torch.randn(batch_size, num_qo_heads, head_dim, dtype=dtype, device=device) 176 | cs = [KvCache(pool, int(l.item())) for l in seqlens] 177 | kv = BatchedKvCache(cs) 178 | for page in pool.allocated_pages(): 179 | page.copy_(torch.randn_like(page)) 180 | 181 | for layer_idx in range(num_layers): 182 | o_ref = ref_batch_decode(q, cs, layer_idx) 183 | o_our = batch_decode(q, kv, layer_idx) 184 | assert_close(o_ref, o_our) 185 | 186 | 187 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 188 | @pytest.mark.parametrize("group_size", [1, 8]) 189 | @torch.inference_mode() 190 | def test_init_kv(dtype_str: str, group_size: int): 191 | torch.manual_seed(0xABCDABCD987) 192 | num_kv_heads = num_qo_heads // group_size 193 | assert num_kv_heads * group_size == num_qo_heads 194 | dtype = getattr(torch, dtype_str) 195 | 196 | pool = KvPool(num_layers, num_kv_heads, head_dim, page_len, dtype, device) 197 | seqlens = torch.randint(1, maxlen, (batch_size,), dtype=torch.int32, device="cpu") 198 | seqlens = seqlens.tolist() + [15, 16, 17, 31, 32, 33] 199 | total_len = sum(seqlens) 200 | cs = [KvCache(pool, l) for l in seqlens] 201 | kv = BatchedKvCache(cs) 202 | blen = BatchLenInfo(seqlens, 0, device) 203 | for layer_idx in range(num_layers): 204 | k = torch.randn(total_len, num_kv_heads, head_dim, dtype=dtype, device=device) 205 | v = torch.randn(total_len, num_kv_heads, head_dim, dtype=dtype, device=device) 206 | assert blen.indptr is not None 207 | init_kv(kv, k, v, blen.indptr, layer_idx) 208 | for i, kvcache in enumerate(cs): 209 | ki = k[blen.indptr[i] : blen.indptr[i + 1]] 210 | vi = v[blen.indptr[i] : blen.indptr[i + 1]] 211 | for j, page in enumerate(kvcache.pages): 212 | s1 = slice(j * page_len, (j + 1) * page_len) 213 | s2 = ( 214 | kvcache.pool.page_len 215 | if j + 1 < kvcache.num_pages 216 | else kv.last_page_offset[i].item() 217 | ) 218 | assert_eq(ki[s1], page[layer_idx, 0, :, :s2, :].transpose(0, 1)) 219 | assert_eq(vi[s1], page[layer_idx, 1, :, :s2, :].transpose(0, 1)) 220 | 221 | 222 | @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) 223 | @pytest.mark.parametrize("group_size", [1, 8]) 224 | @torch.inference_mode() 225 | def test_append_kv(dtype_str: str, group_size: int): 226 | torch.manual_seed(0xABCDABCD987) 227 | num_kv_heads = num_qo_heads // group_size 228 | assert num_kv_heads * group_size == num_qo_heads 229 | dtype = getattr(torch, dtype_str) 230 | 231 | pool = KvPool(num_layers, num_kv_heads, head_dim, page_len, dtype, device) 232 | seqlens = torch.randint(1, maxlen, (batch_size,), dtype=torch.int32, device="cpu") 233 | seqlens = seqlens.tolist() + [15, 16, 17, 31, 32, 33] 234 | bs = len(seqlens) 235 | cs = [KvCache(pool, l) for l in seqlens] 236 | kv = BatchedKvCache(cs) 237 | for layer_idx in range(num_layers): 238 | k = torch.randn(bs, num_kv_heads, head_dim, dtype=dtype, device=device) 239 | v = torch.randn(bs, num_kv_heads, head_dim, dtype=dtype, device=device) 240 | append_kv(kv, k, v, layer_idx) 241 | for i, kvcache in enumerate(cs): 242 | offset = kv.last_page_offset[i].item() - 1 243 | assert_eq(k[i], kvcache.pages[-1][layer_idx, 0, :, offset, :]) 244 | assert_eq(v[i], kvcache.pages[-1][layer_idx, 1, :, offset, :]) 245 | -------------------------------------------------------------------------------- /benchmarks/nvbench/sgmv_flashinfer.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "nvbench/nvbench.cuh" 10 | #include "sgmv_flashinfer/sgmv_flashinfer.cuh" 11 | 12 | template 13 | void sgmv_cpu_reference(T* y, T* x, T** w, int* s, int num_problems, int d_in, int d_out, 14 | int layer_idx) { 15 | for (int p = 0; p < num_problems; p++) { 16 | for (int i = s[p]; i < s[p + 1]; i++) { 17 | for (int j = 0; j < d_out; j++) { 18 | float accum = y[i * d_out + j]; 19 | for (int k = 0; k < d_in; k++) { 20 | accum += float(x[i * d_in + k]) * float(w[p][layer_idx * d_in * d_out + k * d_out + j]); 21 | } 22 | y[i * d_out + j] = accum; 23 | } 24 | } 25 | } 26 | } 27 | 28 | template 29 | std::vector transpose(const std::vector& x, uint32_t num_layers, uint32_t M, uint32_t N) { 30 | std::vector y(x.size()); 31 | assert(x.size() == num_layers * M * N); 32 | for (uint32_t l = 0; l < num_layers; l++) { 33 | for (uint32_t i = 0; i < M; i++) { 34 | for (uint32_t j = 0; j < N; j++) { 35 | y[l * M * N + j * M + i] = x[l * M * N + i * N + j]; 36 | } 37 | } 38 | } 39 | return std::move(y); 40 | } 41 | 42 | template 43 | bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) { 44 | float a_f32 = static_cast(a); 45 | float b_f32 = static_cast(b); 46 | return fabs(a_f32 - b_f32) <= (atol + rtol * fabs(b_f32)); 47 | } 48 | 49 | uint32_t pad_to_multiple_of_16(uint32_t x) { return (x + 15) & ~15; } 50 | 51 | void bench_sgmv(nvbench::state& state) { 52 | auto problem_size_s = state.get_string("problem_size"); 53 | int num_problems = state.get_int64("num_problems"); 54 | int d_in = state.get_int64("d_in"); 55 | int d_out = state.get_int64("d_out"); 56 | int num_loras = 100; 57 | int num_layers = 3; 58 | std::mt19937 gen(0xabcdabcd987); 59 | std::normal_distribution dis; 60 | 61 | int problem_size; 62 | if (problem_size_s == "num_problems") { 63 | problem_size = num_problems; 64 | } else { 65 | try { 66 | problem_size = std::stoi(problem_size_s); 67 | } catch (...) { 68 | state.skip("problem_size is not valid"); 69 | return; 70 | } 71 | } 72 | 73 | std::vector s(num_problems + 1); 74 | std::vector> w_all(num_loras); 75 | s[0] = 0; 76 | for (size_t b = 1; b < num_problems + 1; b++) { 77 | s[b] = s[b - 1] + problem_size; 78 | } 79 | int batch_size = s.back(); 80 | state.add_summary("batch_size").set_int64("value", batch_size); 81 | 82 | std::vector x(batch_size * d_in); 83 | std::vector y_init(batch_size * d_out); 84 | 85 | // random init x, w, y with normal distribution 86 | for (size_t i = 0; i < batch_size * d_in; i++) { 87 | x[i] = dis(gen); 88 | } 89 | for (size_t i = 0; i < num_loras; i++) { 90 | w_all[i].resize(num_layers * d_in * d_out); 91 | for (size_t j = 0; j < w_all[i].size(); j++) { 92 | w_all[i][j] = dis(gen); 93 | } 94 | } 95 | for (size_t i = 0; i < batch_size * d_out; i++) { 96 | y_init[i] = dis(gen); 97 | } 98 | 99 | // copy std vector x, w, y to thrust device vector 100 | thrust::device_vector x_d(x.begin(), x.end()); 101 | std::vector> w_all_d; 102 | for (size_t i = 0; i < num_loras; i++) { 103 | std::vector w_all_i_trans = std::move(transpose(w_all[i], num_layers, d_in, d_out)); 104 | w_all_d.emplace_back(w_all_i_trans.begin(), w_all_i_trans.end()); 105 | } 106 | thrust::device_vector s_d(s.begin(), s.end()); 107 | 108 | // build w ptr 109 | std::vector w; 110 | for (int i = 0; i < num_problems; ++i) { 111 | w.push_back(w_all[i].data()); 112 | } 113 | std::vector w_gpu_ptr; 114 | for (int i = 0; i < num_problems; ++i) { 115 | w_gpu_ptr.push_back(thrust::raw_pointer_cast(w_all_d[i].data())); 116 | } 117 | thrust::device_vector w_d(w_gpu_ptr.begin(), w_gpu_ptr.end()); 118 | 119 | // tmp_ptr 120 | thrust::device_vector tmp_d(2 * 1024 * 1024); 121 | 122 | constexpr uint32_t num_warps = 4; 123 | constexpr uint32_t D_OUT = 16; 124 | dim3 nthrs(32, num_warps); 125 | constexpr uint32_t num_stages = 2; 126 | constexpr uint32_t num_k_frags_per_stage = 8; 127 | const uint32_t num_blocks_n = d_out / 16; 128 | uint32_t smem = 129 | num_stages * sizeof(half) * num_k_frags_per_stage * 16 * 16 * (num_warps + num_blocks_n); 130 | cudaStream_t stream = nullptr; 131 | auto cooperative_kernel = flashinfer::sgmv::sgmv_shrink; 132 | auto kernel = flashinfer::sgmv::sgmv_shrink; 133 | 134 | uint32_t dev_id = 0; 135 | int num_blocks_per_sm = 0; 136 | int num_sm = 0; 137 | bool use_cooperative = true; 138 | cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id); 139 | cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, cooperative_kernel, 140 | num_warps * 32, smem); 141 | 142 | const uint32_t max_grid_size = num_sm * num_blocks_per_sm; 143 | 144 | uint32_t chunk_size = 256; 145 | uint32_t num_chunks = (d_in + chunk_size - 1) / chunk_size; 146 | if (num_chunks * num_problems > max_grid_size) { 147 | use_cooperative = false; 148 | chunk_size = d_in; 149 | num_chunks = 1; 150 | } 151 | 152 | dim3 nblks(num_chunks, num_problems); 153 | 154 | for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { 155 | // call cpu_reference function 156 | std::vector y_cpu_ref(y_init.begin(), y_init.end()); 157 | sgmv_cpu_reference(y_cpu_ref.data(), x.data(), w.data(), s.data(), num_problems, d_in, d_out, 158 | layer_idx); 159 | 160 | // call sgmv function 161 | thrust::device_vector y_d(y_init.begin(), y_init.end()); 162 | 163 | half* y_ptr = thrust::raw_pointer_cast(y_d.data()); 164 | half* x_ptr = thrust::raw_pointer_cast(x_d.data()); 165 | half** w_ptr = thrust::raw_pointer_cast(w_d.data()); 166 | int* s_ptr = thrust::raw_pointer_cast(s_d.data()); 167 | float* tmp_ptr = thrust::raw_pointer_cast(tmp_d.data()); 168 | 169 | void* args[] = {(void*)&y_ptr, (void*)&x_ptr, (void*)&w_ptr, 170 | (void*)&s_ptr, (void*)&tmp_ptr, (void*)&num_problems, 171 | (void*)&d_in, (void*)&layer_idx, (void*)&chunk_size}; 172 | 173 | cudaError_t status = cudaSuccess; 174 | if (use_cooperative) { 175 | status = 176 | cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks, nthrs, args, smem, stream); 177 | } else { 178 | status = cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem, stream); 179 | } 180 | if (status != cudaSuccess) { 181 | state.skip("sgmv_shrink kernel failed"); 182 | return; 183 | } 184 | 185 | // copy thrust device_vector y_d to std vector y_h 186 | thrust::host_vector y_h = y_d; 187 | 188 | // compare y_h and y_cpu_ref 189 | for (size_t i = 0; i < batch_size * d_out; i++) { 190 | if (!isclose(float(y_h[i]), float(y_cpu_ref[i]), 1e-3, 1e-3)) { 191 | state.skip("y_h and y_cpu_ref are not close"); 192 | printf("layer_idx=%i, i=%zu, ref=%f, our=%f, diff=%f\n", layer_idx, i, float(y_cpu_ref[i]), 193 | float(y_h[i]), float(y_h[i]) - float(y_cpu_ref[i])); 194 | return; 195 | } 196 | } 197 | } 198 | state.add_global_memory_reads(batch_size * d_in * sizeof(half) // x 199 | + num_problems * d_in * d_out * sizeof(half) // w 200 | + (num_problems + 1) * sizeof(int32_t) // s 201 | ); 202 | state.add_global_memory_writes(batch_size * d_out * sizeof(half)); 203 | 204 | thrust::device_vector y_d(y_init.begin(), y_init.end()); 205 | state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { 206 | half* y_ptr = thrust::raw_pointer_cast(y_d.data()); 207 | half* x_ptr = thrust::raw_pointer_cast(x_d.data()); 208 | half** w_ptr = thrust::raw_pointer_cast(w_d.data()); 209 | int* s_ptr = thrust::raw_pointer_cast(s_d.data()); 210 | float* tmp_ptr = thrust::raw_pointer_cast(tmp_d.data()); 211 | int layer_idx = 0; 212 | void* args[] = {(void*)&y_ptr, (void*)&x_ptr, (void*)&w_ptr, 213 | (void*)&s_ptr, (void*)&tmp_ptr, (void*)&num_problems, 214 | (void*)&d_in, (void*)&layer_idx, (void*)&chunk_size}; 215 | cudaError_t status = cudaSuccess; 216 | if (use_cooperative) { 217 | status = 218 | cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks, nthrs, args, smem, stream); 219 | } else { 220 | status = cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem, stream); 221 | } 222 | if (status != cudaSuccess) { 223 | state.skip("sgmv_shrink kernel failed"); 224 | return; 225 | } 226 | }); 227 | } 228 | 229 | int wide_dim = 4096; 230 | int narrow_dim = 16; 231 | 232 | std::vector num_problems = {1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 233 | 14, 16, 20, 24, 28, 32, 40, 48, 56, 64}; 234 | std::vector problem_size = {"1", "2", "3", "4", "5", "6", "7", 235 | "8", "10", "12", "14", "16", "20", "24", 236 | "28", "32", "40", "48", "56", "64"}; 237 | 238 | NVBENCH_BENCH(bench_sgmv) 239 | .set_name("sgmv_shrink_NxN") 240 | .add_int64_axis("d_in", {wide_dim}) 241 | .add_int64_axis("d_out", {narrow_dim}) 242 | .add_string_axis("problem_size", {"num_problems"}) 243 | .add_int64_axis("num_problems", num_problems); 244 | 245 | NVBENCH_BENCH(bench_sgmv) 246 | .set_name("sgmv_shrink_bgmv") 247 | .add_int64_axis("d_in", {wide_dim}) 248 | .add_int64_axis("d_out", {narrow_dim}) 249 | .add_string_axis("problem_size", {"1"}) 250 | .add_int64_axis("num_problems", num_problems); 251 | 252 | NVBENCH_BENCH(bench_sgmv) 253 | .set_name("sgmv_shrink_bmm") 254 | .add_int64_axis("d_in", {wide_dim}) 255 | .add_int64_axis("d_out", {narrow_dim}) 256 | .add_string_axis("problem_size", problem_size) 257 | .add_int64_axis("num_problems", {1}); 258 | 259 | NVBENCH_BENCH(bench_sgmv) 260 | .set_name("sgmv_shrink_fixed-num-problems") 261 | .add_int64_axis("d_in", {wide_dim}) 262 | .add_int64_axis("d_out", {narrow_dim}) 263 | .add_string_axis("problem_size", problem_size) 264 | .add_int64_axis("num_problems", {8}); 265 | --------------------------------------------------------------------------------