├── .devcontainer └── devcontainer.json ├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ ├── feature-request.yml │ └── new-model-addition.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── build.yaml │ ├── client-tests.yaml │ ├── docs.yaml │ ├── integration-tests │ └── action.yaml │ ├── load_test.yaml │ ├── release_charts.yaml │ ├── router_tests.yaml │ ├── run-tests.yaml │ └── server_tests.yaml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .vscode └── settings.json ├── Cargo.lock ├── Cargo.toml ├── Dockerfile ├── Dockerfile.dev ├── LICENSE ├── Makefile ├── README.md ├── assets ├── architecture.jpg └── benchmark.png ├── build.sh ├── charts └── lorax │ ├── .gitignore │ ├── .helmignore │ ├── Chart.yaml │ ├── templates │ ├── _helpers.tpl │ ├── deployment.yaml │ └── service.yaml │ └── values.yaml ├── clients └── python │ ├── .gitignore │ ├── Makefile │ ├── README.md │ ├── lorax │ ├── __init__.py │ ├── client.py │ ├── errors.py │ └── types.py │ ├── poetry.lock │ ├── pyproject.toml │ └── tests │ ├── conftest.py │ ├── test_errors.py │ └── test_types.py ├── container-entrypoint.sh ├── docs ├── CNAME ├── LoRAX_Main_Logo-Orange.png ├── LoRAX_Main_Logo-White.png ├── favicon-16x16.png ├── favicon-32x32.png ├── favicon.ico ├── getting_started │ ├── docker.md │ ├── kubernetes.md │ ├── local.md │ └── skypilot.md ├── guides │ ├── contributing │ │ ├── development_env.md │ │ └── index.md │ ├── cuda_graphs.md │ ├── merging_adapters.md │ ├── quantization.md │ ├── speculative_decoding.md │ └── structured_output.md ├── http_status_codes │ └── http_status.md ├── index.md ├── models │ ├── adapters │ │ ├── index.md │ │ ├── lora.md │ │ └── medusa.md │ └── base_models.md ├── reference │ ├── launcher.md │ ├── metrics.md │ ├── openai_api.md │ ├── openapi.json │ ├── python_client │ │ ├── client.md │ │ └── index.md │ └── rest_api.md └── requirements.txt ├── integration-tests ├── __init__.py ├── pytest.ini ├── requirements.txt ├── test_base_llms.py ├── test_classifications.py ├── test_embeddings.py └── utils │ ├── __init__.py │ └── docker_runner.py ├── launcher ├── Cargo.toml ├── build.rs └── src │ ├── env_runtime.rs │ └── main.rs ├── load_tests └── starcoder_load.js ├── mkdocs.yml ├── proto └── generate.proto ├── router ├── Cargo.toml ├── README.md ├── build.rs ├── client │ ├── Cargo.toml │ ├── build.rs │ └── src │ │ ├── client.rs │ │ ├── lib.rs │ │ ├── pb │ │ └── .gitignore │ │ └── sharded_client.rs ├── grpc-metadata │ ├── Cargo.toml │ └── src │ │ └── lib.rs └── src │ ├── adapter.rs │ ├── batch.rs │ ├── block_allocator.rs │ ├── config.rs │ ├── health.rs │ ├── infer.rs │ ├── lib.rs │ ├── loader.rs │ ├── main.rs │ ├── queue.rs │ ├── radix.rs │ ├── scheduler.rs │ ├── server.rs │ ├── tool_grammar.rs │ └── validation.rs ├── rust-toolchain.toml ├── sagemaker-entrypoint.sh ├── server ├── .gitignore ├── Makefile ├── Makefile-awq ├── Makefile-eetq ├── Makefile-flash-att ├── Makefile-flash-att-v2 ├── Makefile-megablocks ├── Makefile-vllm ├── README.md ├── custom_kernels │ ├── custom_kernels │ │ ├── fused_attention_cuda.cu │ │ └── fused_bloom_attention_cuda.cu │ └── setup.py ├── exllama_kernels │ ├── exllama_kernels │ │ ├── cu_compat.cuh │ │ ├── cuda_buffers.cu │ │ ├── cuda_buffers.cuh │ │ ├── cuda_func │ │ │ ├── column_remap.cu │ │ │ ├── column_remap.cuh │ │ │ ├── q4_matmul.cu │ │ │ ├── q4_matmul.cuh │ │ │ ├── q4_matrix.cu │ │ │ └── q4_matrix.cuh │ │ ├── exllama_ext.cpp │ │ ├── hip_compat.cuh │ │ ├── matrix.cuh │ │ ├── tuning.h │ │ └── util.cuh │ └── setup.py ├── exllamav2_kernels │ ├── exllamav2_kernels │ │ ├── config.h │ │ ├── cpp │ │ │ └── util.h │ │ ├── cuda │ │ │ ├── compat.cuh │ │ │ ├── matrix_view.cuh │ │ │ ├── q_gemm.cu │ │ │ ├── q_gemm.cuh │ │ │ ├── q_gemm_kernel.cuh │ │ │ ├── q_gemm_kernel_gptq.cuh │ │ │ ├── q_matrix.cu │ │ │ ├── q_matrix.cuh │ │ │ ├── quant │ │ │ │ ├── qdq_2.cuh │ │ │ │ ├── qdq_3.cuh │ │ │ │ ├── qdq_4.cuh │ │ │ │ ├── qdq_5.cuh │ │ │ │ ├── qdq_6.cuh │ │ │ │ ├── qdq_8.cuh │ │ │ │ └── qdq_util.cuh │ │ │ └── util.cuh │ │ └── ext.cpp │ └── setup.py ├── lorax_server │ ├── __init__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── config.py │ │ ├── lora.py │ │ ├── medusa.py │ │ ├── medusa_lora.py │ │ ├── types.py │ │ ├── utils.py │ │ └── weights.py │ ├── cache.py │ ├── cli.py │ ├── interceptor.py │ ├── layers │ │ ├── __init__.py │ │ ├── awq │ │ │ ├── conversion_utils.py │ │ │ └── quantize │ │ │ │ └── qmodule.py │ │ ├── bnb.py │ │ ├── conv.py │ │ ├── eetq.py │ │ ├── fp8.py │ │ ├── gptq │ │ │ ├── __init__.py │ │ │ ├── custom_autotune.py │ │ │ ├── exllama.py │ │ │ ├── exllamav2.py │ │ │ └── quant_linear.py │ │ ├── hqq.py │ │ ├── layernorm.py │ │ ├── linear.py │ │ ├── rotary.py │ │ └── tensor_parallel.py │ ├── models │ │ ├── __init__.py │ │ ├── bloom.py │ │ ├── causal_lm.py │ │ ├── custom_modeling │ │ │ ├── __init__.py │ │ │ ├── bloom_modeling.py │ │ │ ├── clip.py │ │ │ ├── flash_bert_modeling.py │ │ │ ├── flash_cohere_modeling.py │ │ │ ├── flash_dbrx_modeling.py │ │ │ ├── flash_gemma2_modeling.py │ │ │ ├── flash_gemma_modeling.py │ │ │ ├── flash_gpt2_modeling.py │ │ │ ├── flash_granite_modeling.py │ │ │ ├── flash_llama_modeling.py │ │ │ ├── flash_mistral_modeling.py │ │ │ ├── flash_mixtral_modeling.py │ │ │ ├── flash_neox_modeling.py │ │ │ ├── flash_phi3_modeling.py │ │ │ ├── flash_phi_modeling.py │ │ │ ├── flash_qwen2_modeling.py │ │ │ ├── flash_qwen_modeling.py │ │ │ ├── flash_roberta_modeling.py │ │ │ ├── flash_rw_modeling.py │ │ │ ├── flash_santacoder_modeling.py │ │ │ ├── flash_solar_modeling.py │ │ │ ├── llava_next.py │ │ │ ├── mllama.py │ │ │ ├── mpt_modeling.py │ │ │ ├── neox_modeling.py │ │ │ ├── opt_modeling.py │ │ │ ├── siglip.py │ │ │ ├── t5_modeling.py │ │ │ ├── utils.py │ │ │ └── vlm.py │ │ ├── flash_bert.py │ │ ├── flash_causal_lm.py │ │ ├── flash_cohere.py │ │ ├── flash_dbrx.py │ │ ├── flash_distilbert.py │ │ ├── flash_gemma.py │ │ ├── flash_gemma2.py │ │ ├── flash_gpt2.py │ │ ├── flash_granite.py │ │ ├── flash_llama.py │ │ ├── flash_mistral.py │ │ ├── flash_mixtral.py │ │ ├── flash_neox.py │ │ ├── flash_phi.py │ │ ├── flash_phi3.py │ │ ├── flash_qwen.py │ │ ├── flash_qwen2.py │ │ ├── flash_roberta.py │ │ ├── flash_rw.py │ │ ├── flash_santacoder.py │ │ ├── flash_solar.py │ │ ├── galactica.py │ │ ├── gpt_neox.py │ │ ├── metadata_kernels.py │ │ ├── mllama.py │ │ ├── model.py │ │ ├── mpt.py │ │ ├── opt.py │ │ ├── rw.py │ │ ├── santacoder.py │ │ ├── seq2seq_lm.py │ │ ├── t5.py │ │ ├── types.py │ │ └── vlm_causal_lm.py │ ├── pb │ │ └── .gitignore │ ├── server.py │ ├── tracing.py │ └── utils │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── attention │ │ ├── __init__.py │ │ └── common.py │ │ ├── awq │ │ └── awq.py │ │ ├── convert.py │ │ ├── dist.py │ │ ├── errors.py │ │ ├── flash_attn.py │ │ ├── flash_attn_triton.py │ │ ├── flashinfer_attention.py │ │ ├── gptq │ │ ├── custom_autotune.py │ │ ├── exllamav2.py │ │ ├── quant_linear.py │ │ └── quantize.py │ │ ├── graph.py │ │ ├── import_utils.py │ │ ├── layers.py │ │ ├── logits_process.py │ │ ├── lora.py │ │ ├── merges │ │ ├── __init__.py │ │ ├── strategies.py │ │ └── utils.py │ │ ├── ops │ │ ├── __init__.py │ │ ├── bgmv_expand.py │ │ ├── bgmv_expand_slice.py │ │ ├── bgmv_shrink.py │ │ ├── libentry.py │ │ ├── sgmv_expand.py │ │ ├── sgmv_expand_slice.py │ │ ├── sgmv_shrink.py │ │ └── utils.py │ │ ├── paged_attention.py │ │ ├── punica.py │ │ ├── segments.py │ │ ├── sources │ │ ├── __init__.py │ │ ├── hub.py │ │ ├── local.py │ │ ├── s3.py │ │ └── source.py │ │ ├── state.py │ │ ├── tokenizer.py │ │ ├── tokens.py │ │ ├── torch_utils.py │ │ ├── watermark.py │ │ └── weights.py ├── poetry.lock ├── punica_kernels │ ├── README.md │ ├── punica_kernels │ │ ├── bgmv │ │ │ ├── bgmv_all.cu │ │ │ ├── bgmv_config.h │ │ │ └── bgmv_impl.cuh │ │ ├── flashinfer_adapter │ │ │ ├── flashinfer_all.cu │ │ │ ├── flashinfer_config.h │ │ │ ├── flashinfer_decl.h │ │ │ └── generated │ │ │ │ ├── batch_decode_p16_g1_h128_bf16.cu │ │ │ │ ├── batch_decode_p16_g1_h128_fp16.cu │ │ │ │ ├── batch_decode_p16_g2_h128_bf16.cu │ │ │ │ ├── batch_decode_p16_g2_h128_fp16.cu │ │ │ │ ├── batch_decode_p16_g4_h128_bf16.cu │ │ │ │ ├── batch_decode_p16_g4_h128_fp16.cu │ │ │ │ ├── batch_decode_p16_g8_h128_bf16.cu │ │ │ │ ├── batch_decode_p16_g8_h128_fp16.cu │ │ │ │ ├── batch_prefill_p16_g1_h128_bf16.cu │ │ │ │ ├── batch_prefill_p16_g1_h128_fp16.cu │ │ │ │ ├── batch_prefill_p16_g2_h128_bf16.cu │ │ │ │ ├── batch_prefill_p16_g2_h128_fp16.cu │ │ │ │ ├── batch_prefill_p16_g4_h128_bf16.cu │ │ │ │ ├── batch_prefill_p16_g4_h128_fp16.cu │ │ │ │ ├── batch_prefill_p16_g8_h128_bf16.cu │ │ │ │ ├── batch_prefill_p16_g8_h128_fp16.cu │ │ │ │ └── dispatch.inc │ │ ├── punica_ops.cc │ │ ├── rms_norm │ │ │ ├── rms_norm.h │ │ │ └── rms_norm_cutlass.cu │ │ ├── sgmv │ │ │ ├── sgmv.h │ │ │ ├── sgmv_cutlass.cu │ │ │ └── sgmv_cutlass.cuh │ │ └── sgmv_flashinfer │ │ │ ├── sgmv_all.cu │ │ │ ├── sgmv_config.h │ │ │ └── sgmv_flashinfer.cuh │ └── setup.py ├── pyproject.toml ├── requirements.txt └── tests │ ├── adapters │ ├── test_medusa.py │ └── test_utils.py │ ├── conftest.py │ ├── models │ ├── test_bloom.py │ ├── test_causal_lm.py │ ├── test_model.py │ ├── test_santacoder.py │ └── test_seq2seq_lm.py │ └── utils │ ├── test_convert.py │ ├── test_hub.py │ ├── test_logits_process.py │ ├── test_lora.py │ ├── test_s3.py │ ├── test_segments.py │ ├── test_sgmv.py │ ├── test_tokens.py │ ├── test_watermark.py │ └── test_weights.py ├── sync.sh └── tests ├── create-pod.sh └── test.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "build": { 3 | "dockerfile": "../Dockerfile.dev" 4 | }, 5 | "runArgs": [ 6 | "--gpus", 7 | "all" 8 | ] 9 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | aml 2 | target 3 | server/transformers 4 | server/flash-attention 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve LoRAX 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info 8 | description: | 9 | Please share your system info with us (`lorax-launcher --env` if installed locally). 10 | The full command line used that causes issues: 11 | OS version: 12 | Rust version (if self-compiling, `cargo version`): 13 | Model being used (`curl 127.0.0.1:8080/info | jq`): 14 | If local model please explicit the kind of model and/or equivalents. 15 | Hardware used (GPUs, how many, on which cloud) (`nvidia-smi`): 16 | Deployment specificities (Kubernetes, EKS, AKS, any particular deployments): 17 | The current version being used: 18 | 19 | placeholder: lorax version, platform, python version, ... 20 | validations: 21 | required: true 22 | 23 | - type: checkboxes 24 | id: information-scripts-examples 25 | attributes: 26 | label: Information 27 | description: 'The problem arises when using:' 28 | options: 29 | - label: "Docker" 30 | - label: "The CLI directly" 31 | 32 | - type: checkboxes 33 | id: information-tasks 34 | attributes: 35 | label: Tasks 36 | description: "The thing I am working on is:" 37 | options: 38 | - label: "An officially supported command" 39 | - label: "My own modifications" 40 | 41 | - type: textarea 42 | id: reproduction 43 | validations: 44 | required: true 45 | attributes: 46 | label: Reproduction 47 | description: | 48 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. 49 | If you have code snippets, error messages, stack traces please provide them here as well. 50 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 51 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. 52 | 53 | placeholder: | 54 | Steps to reproduce the behavior: 55 | 56 | 1. 57 | 2. 58 | 3. 59 | 60 | 61 | - type: textarea 62 | id: expected-behavior 63 | validations: 64 | required: true 65 | attributes: 66 | label: Expected behavior 67 | description: "A clear and concise description of what you would expect to happen." 68 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | version: 2.1 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a proposal/request for a new LoRAX feature 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request 11 | description: | 12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. 13 | 14 | - type: textarea 15 | id: motivation 16 | validations: 17 | required: true 18 | attributes: 19 | label: Motivation 20 | description: | 21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. 22 | 23 | 24 | - type: textarea 25 | id: contribution 26 | validations: 27 | required: true 28 | attributes: 29 | label: Your contribution 30 | description: | 31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/predibase/lorax/blob/main/CONTRIBUTING.md) 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/new-model-addition.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F31F New model addition" 2 | description: Submit a proposal/request to implement a new model 3 | labels: [ "New model" ] 4 | 5 | body: 6 | - type: textarea 7 | id: description-request 8 | validations: 9 | required: true 10 | attributes: 11 | label: Model description 12 | description: | 13 | Put any and all important information relative to the model 14 | 15 | - type: checkboxes 16 | id: information-tasks 17 | attributes: 18 | label: Open source status 19 | description: | 20 | Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `transformers`. 21 | options: 22 | - label: "The model implementation is available" 23 | - label: "The model weights are available" 24 | 25 | - type: textarea 26 | id: additional-info 27 | attributes: 28 | label: Provide useful links for the implementation 29 | description: | 30 | Please provide information regarding the implementation, the weights, and the authors. 31 | Please mention the authors by @gh-username if you're aware of their usernames. 32 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 12 | 13 | 14 | 15 | Fixes # (issue) 16 | 17 | 18 | ## Before submitting 19 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). 20 | - [ ] Was this discussed/approved via a Github issue or the discord / slack channel? Please add a link 21 | to it if that's the case. 22 | - [ ] Did you write any new necessary tests? 23 | 24 | 25 | ## Who can review? 26 | 27 | Anyone in the community is free to review the PR once the tests have passed. Feel free to tag 28 | members/contributors who may be interested in your PR. 29 | 30 | 36 | -------------------------------------------------------------------------------- /.github/workflows/client-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Python Client Tests 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - ".github/workflows/client-tests.yaml" 7 | - "clients/python/**" 8 | 9 | jobs: 10 | run_tests: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: 3.9 19 | - name: Install 20 | run: | 21 | cd clients/python && pip install . 22 | - name: Run tests 23 | run: | 24 | pip install pytest pytest-asyncio 25 | make python-client-tests 26 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Configure Git Credentials 14 | run: | 15 | git config user.name github-actions[bot] 16 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 17 | - uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.x 20 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 21 | - uses: actions/cache@v3 22 | with: 23 | key: mkdocs-material-${{ env.cache_id }} 24 | path: .cache 25 | restore-keys: | 26 | mkdocs-material- 27 | - run: pip install -r docs/requirements.txt 28 | - run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /.github/workflows/integration-tests/action.yaml: -------------------------------------------------------------------------------- 1 | name: LoRAX Integration Tests 2 | description: Run integration tests for LoRAX 3 | inputs: 4 | test_image_tag: 5 | description: 'LoRAX Docker image tag to test' 6 | required: true 7 | type: string 8 | use_local_image: 9 | description: 'Use local image instead of GitHub Container Registry' 10 | required: true 11 | type: boolean 12 | github_token: 13 | description: 'GitHub token for authentication' 14 | required: true 15 | type: string 16 | huggingface_token: 17 | description: 'HuggingFace Hub token for authentication' 18 | required: true 19 | type: string 20 | 21 | runs: 22 | using: 'composite' 23 | steps: 24 | 25 | - name: Print Test Image Tag 26 | shell: bash 27 | run: | 28 | echo "Test image tag: ${{ inputs.test_image_tag }}" 29 | 30 | - name: Docker debug info 31 | shell: bash 32 | run: | 33 | docker info 34 | docker version 35 | ls -la /var/run/docker.sock 36 | 37 | - name: Set up Python 38 | uses: actions/setup-python@v4 39 | with: 40 | python-version: '3.8' 41 | 42 | - name: Install dependencies 43 | shell: bash 44 | run: | 45 | python -m pip install --upgrade pip 46 | pip install -r ./integration-tests/requirements.txt 47 | 48 | - name: Login to GitHub Container Registry 49 | if: ${{ !inputs.use_local_image }} 50 | uses: docker/login-action@v2 51 | with: 52 | registry: ghcr.io 53 | username: ${{ github.repository_owner }} 54 | password: ${{ inputs.github_token }} 55 | 56 | - name: Set up environment variables 57 | shell: bash 58 | run: | 59 | echo "HUGGING_FACE_HUB_TOKEN=${{ inputs.huggingface_token }}" >> $GITHUB_ENV 60 | first_tag=$(echo "${{ inputs.test_image_tag }}" | head -n 1) 61 | echo "TEST_IMAGE_TAG=$first_tag" >> $GITHUB_ENV 62 | 63 | - name: Run Embedding tests 64 | shell: bash 65 | run: | 66 | cd integration-tests 67 | pytest test_embeddings.py -vv --capture=tee-sys --log-cli-level=INFO 68 | 69 | - name: Run Classification tests 70 | shell: bash 71 | run: | 72 | cd integration-tests 73 | pytest test_classifications.py -vv --capture=tee-sys --log-cli-level=INFO 74 | 75 | - name: Run LLM tests 76 | shell: bash 77 | run: | 78 | cd integration-tests 79 | pytest test_base_llms.py -vv --capture=tee-sys --log-cli-level=INFO 80 | -------------------------------------------------------------------------------- /.github/workflows/release_charts.yaml: -------------------------------------------------------------------------------- 1 | name: Release Charts 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | release: 10 | permissions: 11 | contents: write # to push chart release and create a release (helm/chart-releaser-action) 12 | 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout Code 16 | uses: actions/checkout@v3 17 | with: 18 | fetch-depth: 0 19 | 20 | - name: Configure Git 21 | run: | 22 | git config user.name "$GITHUB_ACTOR" 23 | git config user.email "$GITHUB_ACTOR@users.noreply.github.com" 24 | - name: Set up Helm 25 | uses: azure/setup-helm@v3.5 26 | with: 27 | version: v3.9.2 28 | 29 | - name: Run chart-releaser 30 | uses: helm/chart-releaser-action@v1.6.0 31 | with: 32 | charts_dir: charts/ 33 | skip_existing: true 34 | env: 35 | CR_TOKEN: "${{ secrets.HELM_RELEASE_TOKEN }}" 36 | -------------------------------------------------------------------------------- /.github/workflows/router_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Router Tests 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - ".github/workflows/router_tests.yaml" 7 | - "proto/**" 8 | - "router/**" 9 | - "launcher/**" 10 | - "Cargo.lock" 11 | - "rust-toolchain.toml" 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | run_tests: 19 | runs-on: ubuntu-latest 20 | 21 | env: 22 | SCCACHE_GHA_ENABLED: "on" 23 | RUSTC_WRAPPER: /usr/local/bin/sccache 24 | RUST_BACKTRACE: 1 25 | SCCACHE: 0.3.3 26 | 27 | steps: 28 | - uses: actions/checkout@v2 29 | - name: Install Rust 30 | uses: actions-rs/toolchain@v1 31 | with: 32 | toolchain: 1.83.0 33 | override: true 34 | components: rustfmt, clippy 35 | - name: Install Protoc 36 | uses: arduino/setup-protoc@v1 37 | - name: Install sccache 38 | run: | 39 | curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache 40 | chmod +x /usr/local/bin/sccache 41 | - name: configure sccache 42 | uses: actions/github-script@v6 43 | with: 44 | script: | 45 | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); 46 | core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); 47 | core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}'); 48 | core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-'); 49 | - name: cargo registry cache 50 | uses: actions/cache@v3 51 | with: 52 | key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }} 53 | restore-keys: | 54 | cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}- 55 | cargo-${{ runner.os }}- 56 | path: | 57 | ~/.cargo/registry 58 | ~/.cargo/git 59 | - name: Install 60 | run: | 61 | make install-router install-launcher 62 | - name: Run Rust fmt 63 | run: | 64 | cargo fmt --check 65 | # - name: Run Rust clippy 66 | # run: | 67 | # cargo clippy 68 | - name: Run Rust tests 69 | run: | 70 | cargo test 71 | - name: sccache stats 72 | run: | 73 | /usr/local/bin/sccache --show-stats 74 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run Integration Tests 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | test_image_tag: 7 | description: 'LoRAX Docker image tag to test' 8 | required: true 9 | type: string 10 | use_local_image: 11 | description: 'Use local image instead of GitHub Container Registry' 12 | required: false 13 | type: boolean 14 | default: false 15 | huggingface_token: 16 | description: 'HuggingFace Hub token for authentication' 17 | required: true 18 | type: string 19 | 20 | jobs: 21 | integration-tests: 22 | runs-on: ubuntu-latest 23 | steps: 24 | - name: Checkout repository 25 | uses: actions/checkout@v4 26 | 27 | - name: Run Integration Tests 28 | uses: ./.github/workflows/integration-tests 29 | with: 30 | test_image_tag: ${{ inputs.test_image_tag }} 31 | use_local_image: false 32 | github_token: ${{ secrets.GHCR_PAT }} 33 | huggingface_token: ${{ inputs.huggingface_token }} 34 | -------------------------------------------------------------------------------- /.github/workflows/server_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Server Tests 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | paths: 7 | - ".github/workflows/server_tests.yaml" 8 | - "server/**" 9 | - "proto/**" 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | run_tests: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - name: Adding actions/checkout@v2 21 | uses: actions/checkout@v2 22 | with: 23 | fetch-depth: 0 24 | 25 | - name: Set up Python 26 | uses: actions/setup-python@v1 27 | with: 28 | python-version: "3.10" 29 | 30 | - name: Run ruff 31 | run: | 32 | pip install ruff 33 | python -m ruff check server/lorax_server 34 | 35 | - name: Install Protoc 36 | uses: arduino/setup-protoc@v1 37 | - name: Filter test dependencies 38 | run: | 39 | # remove stanford-stk from test requirements as it cannot install correctly without GPUs 40 | sed -i '/stanford-stk/d' server/requirements.txt 41 | sed -i '/stanford-stk/d' server/pyproject.toml 42 | - name: Install 43 | run: | 44 | make install-server install-custom-kernels 45 | - name: Run server tests 46 | run: | 47 | pip install pytest 48 | pip install outlines 49 | export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} 50 | pytest -s -vv server/tests 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | target 3 | router/tokenizer.json 4 | *__pycache__* 5 | run.sh 6 | data/ 7 | .vscode/* 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "server/punica_kernels/third_party/cutlass"] 2 | path = server/punica_kernels/third_party/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | [submodule "server/punica_kernels/third_party/flashinfer"] 5 | path = server/punica_kernels/third_party/flashinfer 6 | url = https://github.com/tgaddair/flashinfer.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.4.0 # Use the latest revision 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - repo: https://github.com/astral-sh/ruff-pre-commit 9 | # Ruff version. 10 | rev: v0.8.3 11 | hooks: 12 | # Run the linter. 13 | - id: ruff 14 | args: [--fix] 15 | # Run the formatter. 16 | - id: ruff-format 17 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.rulers": [ 3 | 120 4 | ], 5 | "editor.formatOnSave": true, 6 | "[python]": { 7 | "editor.defaultFormatter": "ms-python.black-formatter", 8 | "editor.formatOnSave": true 9 | }, 10 | "black-formatter.args": [ 11 | "--line-length", 12 | "120" 13 | ], 14 | "flake8.args": [ 15 | "--max-line-length=120", 16 | "--ignore=E203,W503" 17 | ], 18 | "python.testing.unittestEnabled": false, 19 | "python.testing.pytestEnabled": false, 20 | "python.envFile": "${workspaceFolder}/.env", 21 | "files.associations": { 22 | "__hash_table": "cpp", 23 | "__split_buffer": "cpp", 24 | "__tree": "cpp", 25 | "array": "cpp", 26 | "bitset": "cpp", 27 | "deque": "cpp", 28 | "initializer_list": "cpp", 29 | "iterator": "cpp", 30 | "list": "cpp", 31 | "map": "cpp", 32 | "queue": "cpp", 33 | "random": "cpp", 34 | "stack": "cpp", 35 | "string": "cpp", 36 | "string_view": "cpp", 37 | "unordered_map": "cpp", 38 | "utility": "cpp", 39 | "vector": "cpp", 40 | "filesystem": "cpp", 41 | "fstream": "cpp", 42 | "istream": "cpp", 43 | "locale": "cpp", 44 | "streambuf": "cpp", 45 | "*.tcc": "cpp", 46 | "compare": "cpp", 47 | "cstdlib": "cpp", 48 | "numeric": "cpp", 49 | "tuple": "cpp", 50 | "type_traits": "cpp", 51 | "atomic": "cpp", 52 | "__locale": "cpp", 53 | "ios": "cpp" 54 | } 55 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["router", "router/client", "router/grpc-metadata", "launcher"] 3 | 4 | [workspace.package] 5 | version = "0.1.0" 6 | edition = "2021" 7 | authors = ["Predibase", "Olivier Dehaene"] 8 | homepage = "https://github.com/predibase/lorax" 9 | 10 | 11 | [profile.release] 12 | debug = 1 13 | incremental = true 14 | lto = "off" 15 | panic = "abort" 16 | -------------------------------------------------------------------------------- /Dockerfile.dev: -------------------------------------------------------------------------------- 1 | # LoRAX base image 2 | FROM ghcr.io/predibase/lorax:latest as base 3 | 4 | # Install server 5 | COPY proto proto 6 | COPY server server 7 | COPY server/Makefile server/Makefile 8 | 9 | # Final image 10 | FROM base 11 | 12 | COPY container-entrypoint.sh entrypoint.sh 13 | RUN chmod +x entrypoint.sh 14 | COPY sync.sh sync.sh 15 | RUN chmod +x sync.sh 16 | 17 | # ENTRYPOINT ["./entrypoint.sh"] 18 | ENTRYPOINT ["lorax-launcher"] 19 | CMD ["--json-output"] 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install-server: 2 | cd server && make install 3 | 4 | install-custom-kernels: 5 | if [ "$$BUILD_EXTENSIONS" = "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need to set the BUILD_EXTENSIONS environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi 6 | 7 | install-integration-tests: 8 | cd integration-tests && pip install -r requirements.txt 9 | cd clients/python && pip install . 10 | 11 | install-router: 12 | cd router && RUSTFLAGS="-D warnings" cargo install --path . 13 | 14 | install-launcher: 15 | cd launcher && cargo install --path . 16 | 17 | install-benchmark: 18 | cd benchmark && cargo install --path . 19 | 20 | install: install-server install-router install-launcher install-custom-kernels 21 | 22 | server-dev: 23 | cd server && make run-dev 24 | 25 | router-dev: 26 | cd router && cargo run -- --port 8080 27 | 28 | rust-tests: install-router install-launcher 29 | cargo test 30 | 31 | integration-tests: install-integration-tests 32 | pytest -s -vv -m "not private" integration-tests 33 | 34 | update-integration-tests: install-integration-tests 35 | pytest -s -vv --snapshot-update integration-tests 36 | 37 | python-server-tests: 38 | HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests 39 | 40 | python-client-tests: 41 | pytest clients/python/tests 42 | 43 | python-tests: python-server-tests python-client-tests 44 | 45 | run-mistral-7b-instruct: 46 | lorax-launcher --model-id mistralai/Mistral-7B-Instruct-v0.1 --port 8080 47 | 48 | clean: 49 | rm -rf target aml 50 | -------------------------------------------------------------------------------- /assets/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/assets/architecture.jpg -------------------------------------------------------------------------------- /assets/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/assets/benchmark.png -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit if any command fails 4 | set -ex 5 | 6 | # Check if there are any uncommitted changes 7 | if [[ -n $(git status -s) ]]; then 8 | DIRTY="-dirty" 9 | else 10 | DIRTY="" 11 | fi 12 | 13 | # Get the latest commit SHA 14 | COMMIT_SHA=$(git rev-parse --short HEAD) 15 | 16 | # Combine the SHA and dirty status to form the complete tag 17 | TAG="${COMMIT_SHA}${DIRTY}" 18 | 19 | # Name of the Docker image 20 | IMAGE_NAME="lorax" 21 | 22 | # ECR Repository URL (replace with your actual ECR repository URL) 23 | ECR_REPO="474375891613.dkr.ecr.us-west-2.amazonaws.com" 24 | 25 | echo "Building ${IMAGE_NAME}:${TAG}" 26 | 27 | # Build the Docker image 28 | docker build -t ${IMAGE_NAME}:${TAG} . 29 | docker tag ${IMAGE_NAME}:${TAG} ${IMAGE_NAME}:latest 30 | 31 | -------------------------------------------------------------------------------- /charts/lorax/.gitignore: -------------------------------------------------------------------------------- 1 | values/** 2 | -------------------------------------------------------------------------------- /charts/lorax/.helmignore: -------------------------------------------------------------------------------- 1 | docs/ 2 | integration-tests/ 3 | load_tests/ 4 | -------------------------------------------------------------------------------- /charts/lorax/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: lorax 3 | description: LoRAX is the open-source framework for serving 4 | hundreds of fine-tuned LLMs in production for the price of one. 5 | version: 0.4.0 6 | appVersion: 0.3.0 7 | 8 | home: https://github.com/predibase/lorax 9 | 10 | annotations: 11 | artifacthub.io/category: ai-machine-learning 12 | 13 | keywords: 14 | - lorax 15 | - llama 16 | - llm 17 | - predibase 18 | 19 | maintainers: 20 | - email: maintainers@predibase.com 21 | name: Predibase 22 | 23 | sources: 24 | - https://github.com/predibase/lorax 25 | -------------------------------------------------------------------------------- /charts/lorax/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{- define "app.name" -}} 2 | {{- printf "%s-%s" .Chart.Name .Release.Name | lower -}} 3 | {{- end -}} -------------------------------------------------------------------------------- /charts/lorax/templates/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | labels: 5 | app: {{ template "app.name" . }} 6 | {{- if .Values.deployment.additionalLabels }} 7 | {{- toYaml .Values.deployment.additionalLabels | nindent 4 }} 8 | {{- end }} 9 | name: {{ template "app.name" . }} 10 | namespace: {{ .Release.Namespace }} 11 | spec: 12 | replicas: {{ .Values.deployment.replicas }} 13 | selector: 14 | matchLabels: 15 | app: {{ template "app.name" . }} 16 | {{- if .Values.deployment.updateStrategy }} 17 | strategy: {{- toYaml .Values.deployment.updateStrategy | nindent 4 }} 18 | {{- end }} 19 | template: 20 | metadata: 21 | labels: 22 | app: {{ template "app.name" . }} 23 | {{- if .Values.deployment.additionalPodLabels }} 24 | {{- toYaml .Values.deployment.additionalPodLabels | nindent 8 }} 25 | {{- end }} 26 | {{- if .Values.deployment.additionalPodAnnotations }} 27 | annotations: {{- toYaml .Values.deployment.additionalPodAnnotations | nindent 8 }} 28 | {{- end }} 29 | spec: 30 | {{- if .Values.deployment.affinity }} 31 | affinity: 32 | {{- toYaml .Values.deployment.affinity | nindent 8 }} 33 | {{- end }} 34 | containers: 35 | - args: 36 | {{- range .Values.deployment.args }} 37 | - {{ .name }} 38 | {{- if .value }} 39 | - {{ .value | quote }} 40 | {{- end }} 41 | {{- end }} 42 | env: 43 | - name: PORT 44 | value: "8000" 45 | {{- toYaml .Values.deployment.env | nindent 8 }} 46 | image: {{ .Values.deployment.image.repository }}:{{ .Values.deployment.image.tag }} 47 | imagePullPolicy: IfNotPresent 48 | {{- if .Values.deployment.livenessProbe }} 49 | livenessProbe: {{ toYaml .Values.deployment.livenessProbe | nindent 10 }} 50 | {{- end }} 51 | name: lorax 52 | ports: 53 | - containerPort: 8000 54 | name: http 55 | protocol: TCP 56 | {{- if .Values.deployment.readinessProbe }} 57 | readinessProbe: {{ toYaml .Values.deployment.readinessProbe | nindent 10 }} 58 | {{- end }} 59 | resources: {{ toYaml .Values.deployment.resources | nindent 10 }} 60 | volumeMounts: 61 | - mountPath: /data 62 | name: data 63 | - mountPath: /dev/shm 64 | name: shm 65 | {{- if .Values.deployment.tolerations }} 66 | tolerations: 67 | {{- toYaml .Values.deployment.tolerations | nindent 6 }} 68 | {{- end }} 69 | nodeSelector: {{ toYaml .Values.deployment.nodeSelector | nindent 8 }} 70 | restartPolicy: Always 71 | schedulerName: default-scheduler 72 | terminationGracePeriodSeconds: 30 73 | {{- if .Values.priorityClassName }} 74 | priorityClassName: {{ .Values.deployment.priorityClassName | quote }} 75 | {{- end }} 76 | volumes: 77 | - emptyDir: 78 | medium: Memory 79 | name: shm 80 | - emptyDir: 81 | name: data 82 | -------------------------------------------------------------------------------- /charts/lorax/templates/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | labels: 5 | app: {{ template "app.name" . }} 6 | app.kubernetes.io/name: {{ template "app.name" . }} 7 | {{ if .Values.service.additionalLabels }} 8 | {{- toYaml .Values.service.additionalLabels | nindent 4 }} 9 | {{ end }} 10 | name: {{ .Values.service.name }} 11 | namespace: {{ .Release.Namespace }} 12 | spec: 13 | ports: 14 | - name: http 15 | port: {{ .Values.service.port }} 16 | protocol: TCP 17 | targetPort: http 18 | selector: 19 | app: {{ template "app.name" . }} 20 | # sessionAffinity: None 21 | type: {{ .Values.service.serviceType }} 22 | -------------------------------------------------------------------------------- /charts/lorax/values.yaml: -------------------------------------------------------------------------------- 1 | deployment: 2 | replicas: 1 3 | updateStrategy: {} 4 | 5 | image: 6 | repository: "ghcr.io/predibase/lorax" 7 | tag: "latest" 8 | 9 | args: 10 | - name: "--model-id" 11 | value: "mistralai/Mistral-7B-Instruct-v0.1" 12 | - name: "--max-input-length" 13 | value: "512" 14 | - name: "--max-total-tokens" 15 | value: "1024" 16 | - name: "--max-batch-total-tokens" 17 | value: "4096" 18 | - name: "--max-batch-prefill-tokens" 19 | value: "2048" 20 | - name: "--eager-prefill" 21 | value: "false" 22 | - name: "--compile" 23 | value: "" # --complie does not take a second argument 24 | 25 | env: 26 | # Your huggingface hub token. Required for some models such as the llama-2 family. 27 | - name: "HUGGING_FACE_HUB_TOKEN" 28 | value: "" 29 | 30 | resources: 31 | limits: 32 | nvidia.com/gpu: "1" 33 | requests: 34 | nvidia.com/gpu: "1" 35 | 36 | livenessProbe: 37 | {} 38 | # failureThreshold: 240 39 | # httpGet: 40 | # path: /health 41 | # port: http 42 | # scheme: HTTP 43 | # initialDelaySeconds: 5 44 | # periodSeconds: 5 45 | # successThreshold: 1 46 | # timeoutSeconds: 1 47 | 48 | readinessProbe: 49 | {} 50 | # failureThreshold: 600 51 | # httpGet: 52 | # path: /health 53 | # port: http 54 | # scheme: HTTP 55 | # initialDelaySeconds: 5 56 | # periodSeconds: 5 57 | # successThreshold: 1 58 | # timeoutSeconds: 1 59 | 60 | nodeSelector: {} 61 | tolerations: [] 62 | additionalLabels: {} 63 | additionalPodLabels: {} 64 | 65 | additionalAnnotations: {} 66 | additionalPodAnnotations: {} 67 | affinity: {} 68 | 69 | priorityClassName: "" 70 | 71 | service: 72 | name: "lorax" 73 | serviceType: ClusterIP 74 | port: 80 75 | additionalLabels: {} 76 | -------------------------------------------------------------------------------- /clients/python/Makefile: -------------------------------------------------------------------------------- 1 | unit-tests: 2 | python -m pytest --cov=lorax tests 3 | 4 | install: 5 | pip install pip --upgrade 6 | pip install -e . 7 | 8 | release: 9 | pip install poetry 10 | poetry build 11 | poetry publish -------------------------------------------------------------------------------- /clients/python/lorax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = "0.6.1" 16 | 17 | from lorax.client import Client, AsyncClient, MergedAdapters # noqa 18 | -------------------------------------------------------------------------------- /clients/python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "lorax-client" 3 | packages = [ 4 | {include = "lorax"} 5 | ] 6 | version = "0.7.0" 7 | description = "LoRAX Python Client" 8 | license = "Apache-2.0" 9 | authors = ["Travis Addair ", "Olivier Dehaene "] 10 | maintainers = ["Travis Addair "] 11 | readme = "README.md" 12 | homepage = "https://github.com/predibase/lorax" 13 | repository = "https://github.com/predibase/lorax" 14 | 15 | 16 | [tool.poetry.dependencies] 17 | python = "^3.8" 18 | pydantic = "> 2, < 3" 19 | aiohttp = "3.10.11" 20 | huggingface-hub = ">= 0.12, < 1.0" 21 | certifi = "2024.7.4" 22 | urllib3 = "1.26.19" 23 | requests = "2.32.0" 24 | idna = "3.7" 25 | tqdm = "4.66.3" 26 | 27 | [tool.poetry.dev-dependencies] 28 | pytest = "^7.3.0" 29 | pytest-asyncio = "^0.17.2" 30 | pytest-cov = "^3.0.0" 31 | 32 | [tool.pytest.ini_options] 33 | asyncio_mode = "auto" 34 | 35 | [build-system] 36 | requires = ["poetry-core>=1.0.0"] 37 | build-backend = "poetry.core.masonry.api" 38 | -------------------------------------------------------------------------------- /clients/python/tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/clients/python/tests/conftest.py -------------------------------------------------------------------------------- /clients/python/tests/test_errors.py: -------------------------------------------------------------------------------- 1 | from lorax.errors import ( 2 | parse_error, 3 | GenerationError, 4 | IncompleteGenerationError, 5 | OverloadedError, 6 | ValidationError, 7 | BadRequestError, 8 | ShardNotReadyError, 9 | ShardTimeoutError, 10 | NotFoundError, 11 | RateLimitExceededError, 12 | UnknownError, UnprocessableEntityError, 13 | ) 14 | 15 | 16 | def test_generation_error(): 17 | payload = {"error_type": "generation", "error": "test"} 18 | assert isinstance(parse_error(400, payload), GenerationError) 19 | 20 | 21 | def test_incomplete_generation_error(): 22 | payload = {"error_type": "incomplete_generation", "error": "test"} 23 | assert isinstance(parse_error(400, payload), IncompleteGenerationError) 24 | 25 | 26 | def test_overloaded_error(): 27 | payload = {"error_type": "overloaded", "error": "test"} 28 | assert isinstance(parse_error(400, payload), OverloadedError) 29 | 30 | 31 | def test_validation_error(): 32 | payload = {"error_type": "validation", "error": "test"} 33 | assert isinstance(parse_error(400, payload), ValidationError) 34 | 35 | 36 | def test_bad_request_error(): 37 | payload = {"error": "test"} 38 | assert isinstance(parse_error(400, payload), BadRequestError) 39 | 40 | 41 | def test_shard_not_ready_error(): 42 | payload = {"error": "test"} 43 | assert isinstance(parse_error(403, payload), ShardNotReadyError) 44 | assert isinstance(parse_error(424, payload), ShardNotReadyError) 45 | 46 | 47 | def test_shard_timeout_error(): 48 | payload = {"error": "test"} 49 | assert isinstance(parse_error(504, payload), ShardTimeoutError) 50 | 51 | 52 | def test_not_found_error(): 53 | payload = {"error": "test"} 54 | assert isinstance(parse_error(404, payload), NotFoundError) 55 | 56 | 57 | def test_rate_limit_exceeded_error(): 58 | payload = {"error": "test"} 59 | assert isinstance(parse_error(429, payload), RateLimitExceededError) 60 | 61 | 62 | def test_unprocessable_entity_error(): 63 | payload = {"error": "test"} 64 | assert isinstance(parse_error(422, payload), UnprocessableEntityError) 65 | 66 | 67 | def test_unknown_error(): 68 | payload = {"error": "test"} 69 | assert isinstance(parse_error(500, payload), UnknownError) 70 | -------------------------------------------------------------------------------- /clients/python/tests/test_types.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from lorax.types import Parameters, Request, MergedAdapters 4 | from lorax.errors import ValidationError 5 | 6 | 7 | def test_parameters_validation(): 8 | # Test best_of 9 | Parameters(best_of=1) 10 | with pytest.raises(ValidationError): 11 | Parameters(best_of=0) 12 | with pytest.raises(ValidationError): 13 | Parameters(best_of=-1) 14 | Parameters(best_of=2, do_sample=True) 15 | with pytest.raises(ValidationError): 16 | Parameters(best_of=2) 17 | with pytest.raises(ValidationError): 18 | Parameters(best_of=2, seed=1) 19 | 20 | # Test repetition_penalty 21 | Parameters(repetition_penalty=1) 22 | with pytest.raises(ValidationError): 23 | Parameters(repetition_penalty=0) 24 | with pytest.raises(ValidationError): 25 | Parameters(repetition_penalty=-1) 26 | 27 | # Test seed 28 | Parameters(seed=1) 29 | with pytest.raises(ValidationError): 30 | Parameters(seed=-1) 31 | 32 | # Test temperature 33 | Parameters(temperature=1) 34 | Parameters(temperature=0) 35 | with pytest.raises(ValidationError): 36 | Parameters(temperature=-1) 37 | 38 | # Test top_k 39 | Parameters(top_k=1) 40 | with pytest.raises(ValidationError): 41 | Parameters(top_k=0) 42 | with pytest.raises(ValidationError): 43 | Parameters(top_k=-1) 44 | 45 | # Test top_p 46 | Parameters(top_p=0.5) 47 | with pytest.raises(ValidationError): 48 | Parameters(top_p=0) 49 | with pytest.raises(ValidationError): 50 | Parameters(top_p=-1) 51 | with pytest.raises(ValidationError): 52 | Parameters(top_p=1) 53 | 54 | # Test truncate 55 | Parameters(truncate=1) 56 | with pytest.raises(ValidationError): 57 | Parameters(truncate=0) 58 | with pytest.raises(ValidationError): 59 | Parameters(truncate=-1) 60 | 61 | # Test typical_p 62 | Parameters(typical_p=0.5) 63 | with pytest.raises(ValidationError): 64 | Parameters(typical_p=0) 65 | with pytest.raises(ValidationError): 66 | Parameters(typical_p=-1) 67 | with pytest.raises(ValidationError): 68 | Parameters(typical_p=1) 69 | 70 | # Test adapter_id and merged_adapters 71 | merged_adapters = MergedAdapters(ids=["test/adapter-id-1", "test/adapter-id-2"], weights=[0.5, 0.5], density=0.5) 72 | Parameters(adapter_id="test/adapter-id") 73 | Parameters(merged_adapters=merged_adapters) 74 | with pytest.raises(ValidationError): 75 | Parameters(adapter_id="test/adapter-id", merged_adapters=merged_adapters) 76 | 77 | 78 | def test_request_validation(): 79 | Request(inputs="test") 80 | 81 | with pytest.raises(ValidationError): 82 | Request(inputs="") 83 | 84 | Request(inputs="test", stream=True) 85 | Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True)) 86 | 87 | with pytest.raises(ValidationError): 88 | Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True) 89 | -------------------------------------------------------------------------------- /container-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # upload weights to cache 4 | upload() { 5 | echo "Received a SIGTERM or SIGKILL signal. Uploading to cache." 6 | object_id="${MODEL_ID//\//--}" 7 | S3_DIRECTORY="models--$object_id" 8 | aws configure set default.s3.preferred_transfer_client crt 9 | aws configure set default.s3.payload_signing_enabled false 10 | aws configure set default.s3.target_bandwidth 50Gb/s 11 | echo "running aws s3 sync /data/${S3_DIRECTORY}/ s3://${HF_CACHE_BUCKET}/${S3_DIRECTORY}/ --exclude blobs/* --exclude *.bin" 12 | aws s3 sync "/data/${S3_DIRECTORY}/" "s3://${HF_CACHE_BUCKET}/${S3_DIRECTORY}/" --exclude "blobs/*" --exclude "*.bin" 13 | exit 0 14 | } 15 | 16 | # Trap SIGTERM signals and call the cleanup function 17 | trap upload SIGTERM SIGKILL 18 | 19 | # print AWS CLI version 20 | aws --version 21 | 22 | # download files 23 | time ./sync.sh 24 | 25 | # Function to check if lorax-launcher is running 26 | is_launcher_running() { 27 | local launcher_pid="$1" 28 | # this checks whether the process is alive or not. Redirects the output of kill -0 to devnull. 29 | kill -0 "$launcher_pid" >/dev/null 2>&1 30 | } 31 | 32 | # launch TG launcher in the background 33 | lorax-launcher "$@" & 34 | 35 | # Capture the PID of the process we just launched 36 | launcher_pid="$!" 37 | 38 | # Loop to continuously check if lorax-launcher is running 39 | while is_launcher_running "$launcher_pid"; do 40 | sleep 1 41 | done 42 | 43 | # Once lorax-launcher has stopped, the loop exits, and upload is called 44 | upload 45 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | loraexchange.ai -------------------------------------------------------------------------------- /docs/LoRAX_Main_Logo-Orange.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/docs/LoRAX_Main_Logo-Orange.png -------------------------------------------------------------------------------- /docs/LoRAX_Main_Logo-White.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/docs/LoRAX_Main_Logo-White.png -------------------------------------------------------------------------------- /docs/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/docs/favicon-16x16.png -------------------------------------------------------------------------------- /docs/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/docs/favicon-32x32.png -------------------------------------------------------------------------------- /docs/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/docs/favicon.ico -------------------------------------------------------------------------------- /docs/getting_started/kubernetes.md: -------------------------------------------------------------------------------- 1 | # Kubernetes (Helm) 2 | 3 | LoRAX includes Helm charts that make it easy to start using LoRAX in production with high availability and load balancing on Kubernetes. 4 | 5 | To spin up a LoRAX deployment with Helm, you only need to be connected to a Kubernetes cluster through `kubectl``. We provide a default values.yaml file that can be used to deploy a Mistral 7B base model to your Kubernetes cluster: 6 | 7 | ```shell 8 | helm install mistral-7b-release charts/lorax 9 | ``` 10 | 11 | The default [values.yaml](https://github.com/predibase/lorax/blob/main/charts/lorax/values.yaml) configuration deploys a single replica of the Mistral 7B model. You can tailor configuration parameters to deploy any Llama or Mistral model by creating a new values file from the template and updating variables. Once a new values file is created, you can run the following command to deploy your LLM with LoRAX: 12 | 13 | ```shell 14 | helm install -f your-values-file.yaml your-model-release charts/lorax 15 | ``` 16 | 17 | To delete the resources: 18 | 19 | ```shell 20 | helm uninstall your-model-release 21 | ``` 22 | -------------------------------------------------------------------------------- /docs/getting_started/local.md: -------------------------------------------------------------------------------- 1 | # Local 2 | 3 | Advanced users or contributors may opt to install LoRAX locally. 4 | 5 | First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least 6 | Python 3.9, e.g. using `conda`: 7 | 8 | ```shell 9 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 10 | 11 | conda create -n lorax python=3.9 12 | conda activate lorax 13 | ``` 14 | 15 | You may also need to install Protoc. 16 | 17 | On Linux: 18 | 19 | ```shell 20 | PROTOC_ZIP=protoc-21.12-linux-x86_64.zip 21 | curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP 22 | sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc 23 | sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*' 24 | rm -f $PROTOC_ZIP 25 | ``` 26 | 27 | On MacOS, using Homebrew: 28 | 29 | ```shell 30 | brew install protobuf 31 | ``` 32 | 33 | Then run: 34 | 35 | ```shell 36 | BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels 37 | make run-mistral-7b-instruct 38 | ``` 39 | 40 | **Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run: 41 | 42 | ```shell 43 | sudo apt-get install libssl-dev gcc -y 44 | ``` 45 | 46 | ### CUDA Kernels 47 | 48 | The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove 49 | the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable. 50 | 51 | Be aware that the official Docker image has them enabled by default. 52 | 53 | ### Run Mistral 54 | 55 | ```shell 56 | make run-mistral-7b-instruct 57 | ``` -------------------------------------------------------------------------------- /docs/guides/contributing/index.md: -------------------------------------------------------------------------------- 1 | # Contributing to LoRAX 2 | 3 | ## Setting up your development environment 4 | 5 | See [Development Environment](./development_env.md). 6 | 7 | ## Updating Python server dependencies 8 | 9 | LoRAX uses [Poetry](https://python-poetry.org/) to manage dependencies. 10 | 11 | When modifying the dependencies of the LoRAX Python server, first modify the server [pyproject.toml](https://github.com/predibase/lorax/blob/main/server/pyproject.toml) file directly making the desired changes. 12 | 13 | Next, from within the `server` directory, generate an updated `poetry.lock` file: 14 | 15 | ```shell 16 | poetry lock --no-update 17 | ``` 18 | 19 | Then (still within the `server` directory) generate a new `requirements.txt` file: 20 | 21 | ```shell 22 | make export-requirements 23 | ``` 24 | 25 | Never modify `requirements.txt` directly, as it may introduce dependency conflicts. 26 | 27 | ## Profiling 28 | 29 | LoRAX supports the [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) to measure performance of LoRAX. 30 | 31 | You can enable profiling when launching LoRAX by setting the `LORAX_PROFILER_DIR` environment variable to the directory 32 | you wish to output the Tensorboard traces to. 33 | 34 | Once initialized, LoRAX will begin recording traces for every request to the server. Because traces can get very large, 35 | we record only the first 10 prefill requests (plus any decode requests between them), then stop recording and write 36 | out the results. A summary will be printed to stdout when this occurs. 37 | 38 | Once you have your traces written to the profiler directory, you can visualize them in Tensorboard using the 39 | [PyTorch Profiler Tensorboard Plugin](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). 40 | 41 | ```bash 42 | pip install torch_tb_profiler 43 | tensorboard --logdir=$LORAX_PROFILER_DIR 44 | ``` 45 | -------------------------------------------------------------------------------- /docs/guides/cuda_graphs.md: -------------------------------------------------------------------------------- 1 | LoRAX supports compiling the model into a static CUDA Graph to speedup inference by upwards of 2x. See [Accelerating PyTorch with CUDA Graphs](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for more details on CUDA graphs and how they can reduce latency. 2 | 3 | ## Usage 4 | 5 | To enable this (experimental) feature: 6 | 7 | ``` 8 | lorax-launcher ... --compile 9 | ``` 10 | 11 | ## When should I use this? 12 | 13 | CUDA graph compilation is a simple way to decrease latency for smaller LLMs (O(1b params)) that are compute bound rather than memory bound. 14 | 15 | There is a tradeoff to be aware of when using CUDA graphs, namely that it increases memory overhead by 3-10GB depending on model size. However, the observed decrease in latency can be as much as 50%, so if you don't need to run with very large batch sizes and are more latency constrained than throughput, this is a very compelling feature to enable. 16 | 17 | In practice, CUDA graphs are most useful in cases where there are excess GPU flops available, such as during decoding. As such, we do not use the compiled version of the model during prefill, only during the decoding steps. Which means in practice that the benefits of enabling compilation will be most pronounced when generating longer sequences (for which more time is spent during decoding). 18 | 19 | ## Limitations 20 | 21 | Current limitations: 22 | 23 | - Batch size < 256 24 | - Context length (input + output) < 8192 25 | - LoRA rank >= 8 and <= 64 26 | - Only one LoRA rank in the batch 27 | - 1 GPU (no sharding) 28 | 29 | If any of these conditions are not met, then LoRAX will fallback to using eager execution for the batch. 30 | 31 | ## Benchmarks 32 | 33 | gpt2-medium, 1x A100, time to generate 100 tokens: 34 | 35 | no adapter: 36 | 37 | - baseline: 1.044 s 38 | - cuda graph: 0.422 s 39 | 40 | 1 adapter (rank 16): 41 | 42 | - baseline: 1.503 s 43 | - cuda graph: 0.583 s -------------------------------------------------------------------------------- /docs/models/base_models.md: -------------------------------------------------------------------------------- 1 | # Base Models 2 | 3 | ## Supported Architectures 4 | 5 | - 🦙 [Llama](https://huggingface.co/meta-llama) 6 | - [CodeLlama](https://huggingface.co/codellama) 7 | - 🌬️[Mistral](https://huggingface.co/mistralai) 8 | - [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) 9 | - 🔄 [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) 10 | - 💎 [Gemma](https://blog.google/technology/developers/gemma-open-models/) 11 | - [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) 12 | - 🏛️ [Phi-3](https://azure.microsoft.com/en-us/blog/introducing-phi-3-redefining-whats-possible-with-slms/) / [Phi-2](https://huggingface.co/microsoft/phi-2) 13 | - 🔮 [Qwen2 / Qwen](https://huggingface.co/Qwen) 14 | - 🗣️ [Command-R](https://docs.cohere.com/docs/command-r) 15 | - 🧱 [DBRX](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) 16 | - 🤖 [GPT2](https://huggingface.co/gpt2) 17 | - 🔆 [Solar](https://huggingface.co/upstage/SOLAR-10.7B-v1.0) 18 | - 🌸 [Bloom](https://huggingface.co/bigscience/bloom) 19 | 20 | Other architectures are supported on a best effort basis, but do not support dynamic adapter loading. 21 | 22 | ## Selecting a Base Model 23 | 24 | Check the [HuggingFace Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) to find supported base models. 25 | 26 | Usage: 27 | 28 | ```shell 29 | lorax-launcher --model-id mistralai/Mistral-7B-v0.1 ... 30 | ``` 31 | 32 | ## Private Models 33 | 34 | You can access private base models from HuggingFace by setting the `HUGGING_FACE_HUB_TOKEN` environment variable: 35 | 36 | ```bash 37 | export HUGGING_FACE_HUB_TOKEN= 38 | ``` 39 | 40 | Using Docker: 41 | 42 | ```bash 43 | docker run --gpus all \ 44 | --shm-size 1g \ 45 | -p 8080:80 \ 46 | -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \ 47 | ghcr.io/predibase/lorax:main \ 48 | --model-id $MODEL_ID 49 | ``` 50 | 51 | ## Quantization 52 | 53 | LoRAX supports loading the base model with quantization to reduce memory overhead, while loading adapters in 54 | full (fp32) or half precision (fp16, bf16), similar to the approach described in [QLoRA](https://arxiv.org/abs/2305.14314). 55 | 56 | See [Quantization](../guides/quantization.md) for details on the various quantization strategies provided by LoRAX. 57 | -------------------------------------------------------------------------------- /docs/reference/metrics.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | Prometheus-compatible metrics are made available on the default port, on the `/metrics` endpoint. 4 | 5 | Below is a list of the metrics that are exposed: 6 | | Metric Name | Type | 7 | | -------------------------------------------- | --------- | 8 | | `lorax_request_count` | Counter | 9 | | `lorax_request_success` | Counter | 10 | | `lorax_request_failure` | Counter | 11 | | `lorax_request_duration` | Histogram | 12 | | `lorax_request_queue_duration` | Histogram | 13 | | `lorax_request_validation_duration` | Histogram | 14 | | `lorax_request_inference_duration` | Histogram | 15 | | `lorax_request_mean_time_per_token_duration` | Histogram | 16 | | `lorax_request_generated_tokens` | Histogram | 17 | | `lorax_request_input_length` | Histogram | 18 | 19 | For all histograms, there are metrics that are autogenerated which are the metric name + `_sum` and `_count`, which are the sum of all values for that histogram, and the count of all instances of that histogram respectively. 20 | -------------------------------------------------------------------------------- /docs/reference/rest_api.md: -------------------------------------------------------------------------------- 1 | !!swagger openapi.json!! -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs-material 2 | mkdocs-render-swagger-plugin 3 | -------------------------------------------------------------------------------- /integration-tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/integration-tests/__init__.py -------------------------------------------------------------------------------- /integration-tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_cli=true 3 | log_level=DEBUG 4 | -------------------------------------------------------------------------------- /integration-tests/requirements.txt: -------------------------------------------------------------------------------- 1 | requests==2.32.3 2 | docker==7.1.0 3 | pytest==7.4.0 -------------------------------------------------------------------------------- /integration-tests/test_base_llms.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from utils.docker_runner import run_lorax_container 3 | 4 | 5 | def test_base_mistral(): 6 | config = { 7 | "name": "mistral-7b", 8 | "model_id": "mistralai/Mistral-7B-Instruct-v0.1", 9 | } 10 | test_prompt = "[INST] What is the capital of France? [/INST]" 11 | with run_lorax_container(config): 12 | response = requests.post( 13 | "http://localhost:8080/generate", 14 | json={"inputs": test_prompt, "parameters": {"max_new_tokens": 10}}, 15 | ) 16 | response.raise_for_status() 17 | print("RESPONSE FROM LLM: ", response.json()) 18 | assert len(response.json()["generated_text"]) > 0 19 | 20 | 21 | def test_base_llama_3_1_8b(): 22 | config = { 23 | "name": "llama-3-1-8b", 24 | "model_id": "meta-llama/Meta-Llama-3.1-8B-Instruct", 25 | } 26 | test_prompt = "[INST] What is the capital of France? [/INST]" 27 | with run_lorax_container(config): 28 | response = requests.post( 29 | "http://localhost:8080/generate", 30 | json={"inputs": test_prompt, "parameters": {"max_new_tokens": 10}}, 31 | ) 32 | response.raise_for_status() 33 | print("RESPONSE FROM LLM: ", response.json()) 34 | assert len(response.json()["generated_text"]) > 0 35 | 36 | 37 | def test_base_qwen_2_1_5b(): 38 | config = {"name": "qwen-2-1-5b", "model_id": "predibase/Qwen2-1.5B-Instruct-dequantized"} 39 | test_prompt = "[INST] What is the capital of France? [/INST]" 40 | with run_lorax_container(config): 41 | response = requests.post( 42 | "http://localhost:8080/generate", 43 | json={"inputs": test_prompt, "parameters": {"max_new_tokens": 10}}, 44 | ) 45 | response.raise_for_status() 46 | print("RESPONSE FROM LLM: ", response.json()) 47 | assert len(response.json()["generated_text"]) > 0 48 | -------------------------------------------------------------------------------- /integration-tests/test_classifications.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from utils.docker_runner import run_lorax_container 3 | 4 | 5 | def test_distilbert_ner(): 6 | config = { 7 | "name": "distilbert-ner", 8 | "model_id": "dslim/distilbert-NER", 9 | "docker_args": { 10 | "max_input_length": 512, 11 | "max_batch_prefill_tokens": 512, 12 | "max_batch_total_tokens": 512, 13 | "max_total_tokens": 512, 14 | }, 15 | } 16 | with run_lorax_container(config): 17 | response = requests.post( 18 | "http://localhost:8080/classify", 19 | json={ 20 | "inputs": "Johnny supports the Golden State Warriors. He lives in London." 21 | }, 22 | ) 23 | response.raise_for_status() 24 | print("RESPONSE FROM CLASSIFICATION:", response.json()) 25 | assert len(response.json()) > 0 26 | 27 | 28 | def test_bert_ner(): 29 | config = { 30 | "name": "bert-ner", 31 | "model_id": "magdyks/bert-base-ner", 32 | "docker_args": { 33 | "max_input_length": 512, 34 | "max_batch_prefill_tokens": 512, 35 | "max_batch_total_tokens": 512, 36 | "max_total_tokens": 512, 37 | "backend": "flashinfer", 38 | }, 39 | } 40 | with run_lorax_container(config): 41 | response = requests.post( 42 | "http://localhost:8080/classify", 43 | json={ 44 | "inputs": "Johnny supports the Golden State Warriors. He lives in London." 45 | }, 46 | ) 47 | response.raise_for_status() 48 | print("RESPONSE FROM CLASSIFICATION:", response.json()) 49 | assert len(response.json()) > 0 50 | -------------------------------------------------------------------------------- /integration-tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from utils.docker_runner import run_lorax_container 3 | 4 | 5 | def test_stella_1_5b(): 6 | config = { 7 | "name": "stella-1.5b", 8 | "model_id": "dunzhang/stella_en_1.5B_v5", 9 | "docker_args": {"embedding_dim": 256}, 10 | } 11 | with run_lorax_container(config): 12 | response = requests.post("http://localhost:8080/embed", json={"inputs": "Hello, world!"}) 13 | response.raise_for_status() 14 | print("RESPONSE FROM EMBEDDING: ", response.json()) 15 | assert len(response.json()["embeddings"]) > 0 16 | 17 | 18 | def test_uae_large_v1_1_5b(): 19 | config = { 20 | "name": "UAE-Large-V1-1.5b", 21 | "model_id": "WhereIsAI/UAE-Large-V1", 22 | "docker_args": { 23 | "max_input_length": 512, 24 | "max_batch_prefill_tokens": 512, 25 | "max_batch_total_tokens": 512, 26 | "max_total_tokens": 512, 27 | }, 28 | } 29 | with run_lorax_container(config): 30 | response = requests.post("http://localhost:8080/embed", json={"inputs": "Hello, world!"}) 31 | response.raise_for_status() 32 | print("RESPONSE FROM EMBEDDING: ", response.json()) 33 | assert len(response.json()["embeddings"]) > 0 34 | -------------------------------------------------------------------------------- /integration-tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/integration-tests/utils/__init__.py -------------------------------------------------------------------------------- /launcher/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lorax-launcher" 3 | description = "LoRAX Launcher" 4 | version.workspace = true 5 | edition.workspace = true 6 | authors.workspace = true 7 | homepage.workspace = true 8 | 9 | [dependencies] 10 | clap = { version = "4.1.4", features = ["derive", "env"] } 11 | ctrlc = { version = "3.2.5", features = ["termination"] } 12 | nix = "0.26.2" 13 | openssl = "0.10.70" 14 | hf-hub = { version = "0.3.0", features = ["tokio"] } 15 | h2 = "0.3.26" 16 | rustix = "0.37.25" 17 | serde = { version = "1.0.152", features = ["derive"] } 18 | serde_json = { version = "1.0.93", features = ["preserve_order"] } 19 | tracing = "0.1.37" 20 | tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } 21 | secrecy = "0.10.3" 22 | 23 | [dev-dependencies] 24 | float_eq = "1.0.1" 25 | reqwest = { version = "0.11.14", features = ["blocking", "json"] } 26 | 27 | [build-dependencies] 28 | vergen = { version = "8.2.5", features = ["build", "cargo", "git", "gitcl", "rustc", "si"] } 29 | -------------------------------------------------------------------------------- /launcher/build.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use vergen::EmitBuilder; 3 | 4 | fn main() -> Result<(), Box> { 5 | // Emit cargo and rustc compile time values 6 | EmitBuilder::builder().all_cargo().all_rustc().emit()?; 7 | 8 | // Try to get the git sha from the local git repository 9 | if EmitBuilder::builder() 10 | .fail_on_error() 11 | .git_sha(false) 12 | .emit() 13 | .is_err() 14 | { 15 | // Unable to get the git sha 16 | if let Ok(sha) = std::env::var("GIT_SHA") { 17 | // Set it from an env var 18 | println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}"); 19 | } 20 | } 21 | 22 | // Set docker label if present 23 | if let Ok(label) = std::env::var("DOCKER_LABEL") { 24 | // Set it from an env var 25 | println!("cargo:rustc-env=DOCKER_LABEL={label}"); 26 | } 27 | 28 | Ok(()) 29 | } 30 | -------------------------------------------------------------------------------- /launcher/src/env_runtime.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::process::Command; 3 | 4 | pub(crate) struct Env { 5 | cargo_target: &'static str, 6 | cargo_version: &'static str, 7 | git_sha: &'static str, 8 | docker_label: &'static str, 9 | nvidia_env: String, 10 | } 11 | 12 | impl Env { 13 | pub fn new() -> Self { 14 | let nvidia_env = nvidia_smi(); 15 | 16 | Self { 17 | nvidia_env: nvidia_env.unwrap_or("N/A".to_string()), 18 | cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"), 19 | cargo_version: env!("VERGEN_RUSTC_SEMVER"), 20 | git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), 21 | docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), 22 | } 23 | } 24 | } 25 | 26 | impl fmt::Display for Env { 27 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 28 | writeln!(f, "Runtime environment:")?; 29 | 30 | writeln!(f, "Target: {}", self.cargo_target)?; 31 | writeln!(f, "Cargo version: {}", self.cargo_version)?; 32 | writeln!(f, "Commit sha: {}", self.git_sha)?; 33 | writeln!(f, "Docker label: {}", self.docker_label)?; 34 | write!(f, "nvidia-smi:\n{}", self.nvidia_env)?; 35 | 36 | Ok(()) 37 | } 38 | } 39 | 40 | fn nvidia_smi() -> Option { 41 | let output = Command::new("nvidia-smi").output().ok()?; 42 | let nvidia_smi = String::from_utf8(output.stdout).ok()?; 43 | let output = nvidia_smi.replace('\n', "\n "); 44 | Some(output.trim().to_string()) 45 | } 46 | -------------------------------------------------------------------------------- /load_tests/starcoder_load.js: -------------------------------------------------------------------------------- 1 | import {check} from 'k6'; 2 | import http from 'k6/http'; 3 | import {Trend} from 'k6/metrics'; 4 | 5 | const host = __ENV.HOST || '127.0.0.1:3000'; 6 | 7 | const totalTime = new Trend('total_time', true); 8 | const totalTokens = new Trend('total_tokens', true); 9 | const validationTime = new Trend('validation_time', true); 10 | const queueTime = new Trend('queue_time', true); 11 | const inferenceTime = new Trend('inference_time', true); 12 | const timePerToken = new Trend('time_per_token', true); 13 | 14 | const example = { 15 | payload: JSON.stringify({ 16 | inputs: '# This is a fibonacci function written in the Python programming language.' + 17 | 'def fibonacci', 18 | parameters: { 19 | details: true, 20 | max_new_tokens: 60, 21 | temperature: 0.2, 22 | top_p: 0.95, 23 | seed: 0, 24 | }, 25 | }), 26 | generated_tokens: 60 27 | }; 28 | 29 | export const options = { 30 | thresholds: { 31 | http_req_failed: ['rate==0'], 32 | time_per_token: ['p(95)<90'], 33 | queue_time: ['p(95)<1500'], 34 | }, 35 | scenarios: { 36 | load_test: { 37 | executor: 'constant-arrival-rate', 38 | duration: '60s', 39 | preAllocatedVUs: 100, 40 | rate: 10, 41 | timeUnit: '1s', 42 | }, 43 | }, 44 | }; 45 | 46 | export default function () { 47 | const headers = {'Content-Type': 'application/json'}; 48 | const res = http.post(`http://${host}/generate`, example.payload, { 49 | headers, 50 | }); 51 | 52 | check(res, { 53 | 'Post status is 200': (r) => res.status === 200, 54 | 'Post response generated tokens': (r) => res.status === 200 && res.json().details.generated_tokens === example.generated_tokens, 55 | }); 56 | 57 | if (res.status === 200) { 58 | totalTime.add(res.headers["X-Total-Time"]); 59 | totalTokens.add(res.headers["X-Total-Tokens"]); 60 | validationTime.add(res.headers["X-Validation-Time"]); 61 | queueTime.add(res.headers["X-Queue-Time"]); 62 | inferenceTime.add(res.headers["X-Inference-Time"]); 63 | timePerToken.add(res.headers["X-Time-Per-Token"]); 64 | } 65 | } -------------------------------------------------------------------------------- /router/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lorax-router" 3 | description = "LoRAX Webserver" 4 | build = "build.rs" 5 | version.workspace = true 6 | edition.workspace = true 7 | authors.workspace = true 8 | homepage.workspace = true 9 | 10 | [lib] 11 | path = "src/lib.rs" 12 | 13 | [[bin]] 14 | name = "lorax-router" 15 | path = "src/main.rs" 16 | 17 | [dependencies] 18 | async-stream = "0.3.3" 19 | axum = { version = "0.7", features = ["json", "macros"] } 20 | axum-tracing-opentelemetry = "0.16" 21 | clap = { version = "4.1.4", features = ["derive", "env"] } 22 | futures = "0.3.26" 23 | home = "=0.5.9" 24 | hf-hub = { version = "0.3.0", features = ["tokio"] } 25 | h2 = "0.3.26" 26 | lorax-client = { path = "client" } 27 | metrics = "0.21.0" 28 | metrics-exporter-prometheus = { version = "0.12.1", features = [] } 29 | nohash-hasher = "0.2.0" 30 | opentelemetry = { version = "0.19.0", features = ["rt-tokio"] } 31 | openssl = "0.10.70" 32 | opentelemetry-otlp = "0.12.0" 33 | rand = "0.8.5" 34 | reqwest = { version = "0.11.14", features = ["blocking"] } 35 | reqwest-middleware = "0.2.4" 36 | reqwest-retry = "0.4.0" 37 | regex = "1.5.4" 38 | rustix = "0.37.25" 39 | serde = "1.0.152" 40 | serde_json = { version = "1.0.93", features = ["preserve_order"] } 41 | slotmap = "1.0.7" 42 | thiserror = "1.0.38" 43 | tokenizers = { version = "0.20.0", features = ["http"] } 44 | tokio = { version = "1.32.0", features = [ 45 | "rt", 46 | "rt-multi-thread", 47 | "parking_lot", 48 | "signal", 49 | "sync", 50 | ] } 51 | tokio-stream = "0.1.14" 52 | tower-http = { version = "0.6.1", features = ["cors"] } 53 | tracing = "0.1.37" 54 | tracing-opentelemetry = "0.19.0" 55 | tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } 56 | utoipa = { version = "4.2.0", features = ["axum_extras"] } 57 | utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } 58 | ngrok = { version = "0.12.3", features = ["axum"], optional = true } 59 | init-tracing-opentelemetry = { version = "0.14.1", features = [ 60 | "opentelemetry-otlp", 61 | ] } 62 | once_cell = "1.19.0" 63 | itertools = "0.12.1" 64 | async-trait = "0.1.80" 65 | minijinja = { version = "2.2.0", features = ["json"] } 66 | minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } 67 | image = "=0.25.5" 68 | rustls = "0.23.18" 69 | webpki = "0.22.2" 70 | base64 = "0.22.0" 71 | wasm-bindgen = "=0.2.95" 72 | wasm-bindgen-macro = "=0.2.95" 73 | secrecy = "0.10.3" 74 | 75 | [build-dependencies] 76 | vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } 77 | 78 | [dev-dependencies] 79 | tracing-test = "0.1" 80 | 81 | [features] 82 | default = ["ngrok"] 83 | ngrok = ["dep:ngrok"] 84 | -------------------------------------------------------------------------------- /router/build.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use vergen::EmitBuilder; 3 | 4 | fn main() -> Result<(), Box> { 5 | // Try to get the git sha from the local git repository 6 | if EmitBuilder::builder() 7 | .fail_on_error() 8 | .git_sha(false) 9 | .emit() 10 | .is_err() 11 | { 12 | // Unable to get the git sha 13 | if let Ok(sha) = std::env::var("GIT_SHA") { 14 | // Set it from an env var 15 | println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}"); 16 | } 17 | } 18 | 19 | // Set docker label if present 20 | if let Ok(label) = std::env::var("DOCKER_LABEL") { 21 | // Set it from an env var 22 | println!("cargo:rustc-env=DOCKER_LABEL={label}"); 23 | } 24 | 25 | Ok(()) 26 | } 27 | -------------------------------------------------------------------------------- /router/client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lorax-client" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | homepage.workspace = true 7 | 8 | [dependencies] 9 | futures = "^0.3" 10 | grpc-metadata = { path = "../grpc-metadata" } 11 | prost = "^0.11" 12 | thiserror = "^1.0" 13 | tokio = { version = "^1.25", features = ["sync"] } 14 | tokenizers = { version = "0.20.0", features = ["http"] } 15 | tonic = "^0.9" 16 | tower = "^0.4" 17 | tracing = "^0.1" 18 | regex = "1.5.4" 19 | base64 = "0.22.0" 20 | rustls = "0.23.18" 21 | webpki = "0.22.2" 22 | 23 | [build-dependencies] 24 | tonic-build = "0.9.2" 25 | prost-build = "0.11.6" 26 | -------------------------------------------------------------------------------- /router/client/build.rs: -------------------------------------------------------------------------------- 1 | use std::fs; 2 | 3 | fn main() -> Result<(), Box> { 4 | println!("cargo:rerun-if-changed=../../proto/generate.proto"); 5 | fs::create_dir("src/pb").unwrap_or(()); 6 | 7 | let mut config = prost_build::Config::new(); 8 | config.protoc_arg("--experimental_allow_proto3_optional"); 9 | 10 | tonic_build::configure() 11 | .build_client(true) 12 | .build_server(false) 13 | .out_dir("src/pb") 14 | .include_file("mod.rs") 15 | .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) 16 | .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); 17 | 18 | Ok(()) 19 | } 20 | -------------------------------------------------------------------------------- /router/client/src/pb/.gitignore: -------------------------------------------------------------------------------- 1 | *.rs -------------------------------------------------------------------------------- /router/grpc-metadata/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "grpc-metadata" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | opentelemetry = "^0.19" 8 | tonic = "^0.9" 9 | tracing = "^0.1" 10 | tracing-opentelemetry = "^0.19" 11 | -------------------------------------------------------------------------------- /router/grpc-metadata/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A crate to extract and inject a OpenTelemetry context from and to a gRPC request. 2 | //! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples 3 | 4 | use opentelemetry::global; 5 | use opentelemetry::propagation::{Extractor, Injector}; 6 | use tracing_opentelemetry::OpenTelemetrySpanExt; 7 | 8 | /// Extract context metadata from a gRPC request's metadata 9 | #[allow(dead_code)] 10 | struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap); 11 | 12 | impl<'a> Extractor for MetadataExtractor<'a> { 13 | /// Get a value for a key from the MetadataMap. If the value can't be converted to &str, returns None 14 | fn get(&self, key: &str) -> Option<&str> { 15 | self.0.get(key).and_then(|metadata| metadata.to_str().ok()) 16 | } 17 | 18 | /// Collect all the keys from the MetadataMap. 19 | fn keys(&self) -> Vec<&str> { 20 | self.0 21 | .keys() 22 | .map(|key| match key { 23 | tonic::metadata::KeyRef::Ascii(v) => v.as_str(), 24 | tonic::metadata::KeyRef::Binary(v) => v.as_str(), 25 | }) 26 | .collect::>() 27 | } 28 | } 29 | 30 | /// Inject context in the metadata of a gRPC request. 31 | struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap); 32 | 33 | impl<'a> Injector for MetadataInjector<'a> { 34 | /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs 35 | fn set(&mut self, key: &str, value: String) { 36 | if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { 37 | if let Ok(val) = value.parse() { 38 | self.0.insert(key, val); 39 | } 40 | } 41 | } 42 | } 43 | 44 | /// Get a context from the global context and inject the span into a gRPC request's metadata. 45 | fn inject(metadata: &mut tonic::metadata::MetadataMap) { 46 | global::get_text_map_propagator(|propagator| { 47 | propagator.inject_context( 48 | &tracing::Span::current().context(), 49 | &mut MetadataInjector(metadata), 50 | ) 51 | }) 52 | } 53 | 54 | pub trait InjectTelemetryContext { 55 | fn inject_context(self) -> Self; 56 | } 57 | 58 | impl InjectTelemetryContext for tonic::Request { 59 | fn inject_context(mut self) -> Self { 60 | inject(self.metadata_mut()); 61 | self 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /router/src/adapter.rs: -------------------------------------------------------------------------------- 1 | use std::hash; 2 | 3 | use crate::AdapterParameters; 4 | 5 | use crate::server::DEFAULT_ADAPTER_SOURCE; 6 | 7 | /// "adapter ID" for the base model. The base model does not have an adapter ID, 8 | /// but we reason about it in the same way. This must match the base model ID 9 | /// used in the Python server. 10 | pub const BASE_MODEL_ADAPTER_ID: &str = "__base_model__"; 11 | 12 | #[derive(Debug, Clone)] 13 | pub(crate) struct Adapter { 14 | /// adapter parameters 15 | params: AdapterParameters, 16 | /// source (enforced at proto level) 17 | source: String, 18 | /// index of the adapter 19 | index: u32, 20 | /// Optional - External api token 21 | api_token: Option, 22 | } 23 | 24 | impl Adapter { 25 | pub(crate) fn new( 26 | params: AdapterParameters, 27 | source: String, 28 | index: u32, 29 | api_token: Option, 30 | ) -> Self { 31 | Self { 32 | params, 33 | source, 34 | index, 35 | api_token, 36 | } 37 | } 38 | 39 | pub(crate) fn params(&self) -> &AdapterParameters { 40 | &self.params 41 | } 42 | 43 | pub(crate) fn source(&self) -> &str { 44 | &self.source 45 | } 46 | 47 | pub(crate) fn api_token(&self) -> &std::option::Option { 48 | &self.api_token 49 | } 50 | 51 | pub(crate) fn index(&self) -> u32 { 52 | self.index 53 | } 54 | 55 | pub(crate) fn as_string(&self) -> String { 56 | // format ":" 57 | format!("{}:{}", self.source, self.params.adapter_ids.join(",")) 58 | } 59 | } 60 | 61 | impl hash::Hash for Adapter { 62 | fn hash(&self, state: &mut H) { 63 | self.index.hash(state); 64 | } 65 | } 66 | 67 | impl Eq for Adapter {} 68 | 69 | impl PartialEq for Adapter { 70 | fn eq(&self, other: &Self) -> bool { 71 | self.index == other.index 72 | } 73 | } 74 | 75 | pub(crate) fn extract_adapter_params( 76 | adapter_id: Option, 77 | adapter_source: Option, 78 | adapter_parameters: Option, 79 | ) -> (Option, AdapterParameters) { 80 | let mut adapter_id = adapter_id.clone(); 81 | if adapter_id.is_none() || adapter_id.as_ref().unwrap().is_empty() { 82 | adapter_id = Some(BASE_MODEL_ADAPTER_ID.to_string()); 83 | } 84 | let mut adapter_source = adapter_source.clone(); 85 | if adapter_source.is_none() { 86 | adapter_source = Some(DEFAULT_ADAPTER_SOURCE.get().unwrap().to_string()); 87 | } 88 | 89 | let adapter_parameters = adapter_parameters.clone().unwrap_or(AdapterParameters { 90 | adapter_ids: vec![adapter_id.clone().unwrap()], 91 | ..Default::default() 92 | }); 93 | return (adapter_source, adapter_parameters); 94 | } 95 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "1.83.0" 3 | components = ["rustfmt", "clippy"] 4 | -------------------------------------------------------------------------------- /sagemaker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z "${HF_MODEL_ID}" ]]; then 4 | echo "HF_MODEL_ID must be set" 5 | exit 1 6 | fi 7 | export MODEL_ID="${HF_MODEL_ID}" 8 | 9 | if [[ -n "${HF_MODEL_REVISION}" ]]; then 10 | export REVISION="${HF_MODEL_REVISION}" 11 | fi 12 | 13 | if [[ -n "${SM_NUM_GPUS}" ]]; then 14 | export NUM_SHARD="${SM_NUM_GPUS}" 15 | fi 16 | 17 | if [[ -n "${HF_MODEL_QUANTIZE}" ]]; then 18 | export QUANTIZE="${HF_MODEL_QUANTIZE}" 19 | fi 20 | 21 | if [[ -n "${HF_MODEL_TRUST_REMOTE_CODE}" ]]; then 22 | export TRUST_REMOTE_CODE="${HF_MODEL_TRUST_REMOTE_CODE}" 23 | fi 24 | 25 | lorax-launcher --port 8080 26 | -------------------------------------------------------------------------------- /server/Makefile: -------------------------------------------------------------------------------- 1 | include Makefile-flash-att 2 | include Makefile-flash-att-v2 3 | include Makefile-vllm 4 | include Makefile-megablocks 5 | include Makefile-eetq 6 | include Makefile-awq 7 | 8 | unit-tests: 9 | pytest -s -vv -m "not private" tests 10 | 11 | gen-server: 12 | # Compile protos 13 | pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir 14 | mkdir lorax_server/pb || true 15 | python -m grpc_tools.protoc -I../proto --python_out=lorax_server/pb \ 16 | --grpc_python_out=lorax_server/pb --mypy_out=lorax_server/pb ../proto/generate.proto 17 | find lorax_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; 18 | touch lorax_server/pb/__init__.py 19 | 20 | install: gen-server 21 | pip install pip --upgrade 22 | pip install torch==2.6.0 23 | pip install -r requirements.txt 24 | pip install -e ".[bnb, accelerate, quantize, peft, outlines]" 25 | 26 | run-dev: 27 | # SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve gpt2 28 | # SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve meta-llama/Llama-2-7b-hf --sharded 29 | SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve mistralai/Mistral-7B-Instruct-v0.1 --sharded 30 | # SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 lorax_server/cli.py serve mistralai/Mistral-7B-v0.1 --sharded 31 | # SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve flozi00/Mistral-7B-german-assistant-v5-4bit-autogptq --quantize gptq 32 | 33 | export-requirements: 34 | poetry export -o requirements.txt --without-hashes 35 | 36 | format: 37 | pip install ruff 38 | python -m ruff format lorax_server 39 | python -m ruff check lorax_server --fix 40 | -------------------------------------------------------------------------------- /server/Makefile-awq: -------------------------------------------------------------------------------- 1 | # Fork that adds only the correct stream to this kernel in order 2 | # to make cuda graphs work. 3 | awq_commit := f084f40bd996f3cf3a0633c1ad7d9d476c318aaa 4 | 5 | awq: 6 | rm -rf llm-awq 7 | git clone https://github.com/mit-han-lab/llm-awq 8 | 9 | build-awq: awq 10 | cd llm-awq/ && git fetch && git checkout $(awq_commit) 11 | cd llm-awq/awq/kernels && python setup.py build 12 | 13 | install-awq: build-awq 14 | pip uninstall awq_inference_engine -y || true 15 | cd llm-awq/awq/kernels && python setup.py install 16 | -------------------------------------------------------------------------------- /server/Makefile-eetq: -------------------------------------------------------------------------------- 1 | eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0 2 | 3 | eetq: 4 | # Clone eetq 5 | pip install packaging 6 | git clone https://github.com/NetEase-FuXi/EETQ.git eetq 7 | 8 | build-eetq: eetq 9 | cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive 10 | cd eetq && python setup.py build 11 | 12 | install-eetq: build-eetq 13 | cd eetq && python setup.py install 14 | -------------------------------------------------------------------------------- /server/Makefile-flash-att: -------------------------------------------------------------------------------- 1 | flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec 2 | 3 | flash-attention: 4 | # Clone flash attention 5 | pip install -U packaging ninja --no-cache-dir 6 | git clone https://github.com/HazyResearch/flash-attention.git 7 | 8 | build-flash-attention: flash-attention 9 | cd flash-attention && git fetch && git checkout $(flash_att_commit) 10 | cd flash-attention && python setup.py build 11 | cd flash-attention/csrc/rotary && python setup.py build 12 | cd flash-attention/csrc/layer_norm && python setup.py build 13 | 14 | install-flash-attention: build-flash-attention 15 | pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true 16 | cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install 17 | -------------------------------------------------------------------------------- /server/Makefile-flash-att-v2: -------------------------------------------------------------------------------- 1 | flash_att_v2_commit_cuda := v2.5.8 2 | flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 3 | 4 | 5 | flash-attention-v2-cuda: 6 | # Clone flash attention 7 | pip install -U packaging ninja --no-cache-dir 8 | git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2 9 | 10 | build-flash-attention-v2-cuda: flash-attention-v2-cuda 11 | cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) 12 | cd flash-attention-v2 && git submodule update --init --recursive 13 | cd flash-attention-v2 && python setup.py build 14 | 15 | install-flash-attention-v2-cuda: build-flash-attention-v2-cuda 16 | cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install 17 | 18 | flash-attention-v2-rocm: 19 | # Clone flash attention 20 | pip install -U packaging ninja --no-cache-dir 21 | git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 22 | 23 | build-flash-attention-v2-rocm: flash-attention-v2-rocm 24 | cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) 25 | cd flash-attention-v2 && git submodule update --init --recursive 26 | cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build 27 | 28 | install-flash-attention-v2-rocm: build-flash-attention-v2-rocm 29 | cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install 30 | -------------------------------------------------------------------------------- /server/Makefile-megablocks: -------------------------------------------------------------------------------- 1 | megablocks_commit := 5897cd6f254b7b3edf7a708a3a3314ecb54b6f78 2 | 3 | megablocks: 4 | git clone https://github.com/stanford-futuredata/megablocks.git 5 | 6 | build-megablocks: megablocks 7 | cd megablocks && git fetch && git checkout $(megablocks_commit) 8 | cd megablocks && python setup.py build 9 | -------------------------------------------------------------------------------- /server/Makefile-vllm: -------------------------------------------------------------------------------- 1 | vllm-cuda: 2 | # Clone vllm 3 | pip install -U ninja packaging --no-cache-dir 4 | git clone https://github.com/vllm-project/vllm.git vllm 5 | 6 | build-vllm-cuda: vllm-cuda 7 | cd vllm && git fetch && git checkout 766435e660a786933392eb8ef0a873bc38cf0c8b 8 | cd vllm && python setup.py build 9 | 10 | install-vllm-cuda: build-vllm-cuda 11 | pip uninstall vllm -y || true 12 | cd vllm && python setup.py install 13 | 14 | vllm-rocm: 15 | # Clone vllm 16 | pip install -U ninja packaging --no-cache-dir 17 | git clone https://github.com/fxmarty/rocm-vllm.git vllm 18 | 19 | build-vllm-rocm: vllm-rocm 20 | cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 21 | cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install 22 | 23 | install-vllm-rocm: build-vllm-rocm 24 | pip uninstall vllm -y || true 25 | cd vllm && python setup.py install 26 | -------------------------------------------------------------------------------- /server/README.md: -------------------------------------------------------------------------------- 1 | # LoRAX Python gRPC Server 2 | 3 | A Python gRPC server for LoRAX 4 | 5 | ## Install 6 | 7 | ```shell 8 | make install 9 | ``` 10 | 11 | ## Run 12 | 13 | ```shell 14 | make run-dev 15 | ``` -------------------------------------------------------------------------------- /server/custom_kernels/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name="custom_kernels", 6 | ext_modules=[ 7 | CUDAExtension( 8 | name="custom_kernels.fused_bloom_attention_cuda", 9 | sources=["custom_kernels/fused_bloom_attention_cuda.cu"], 10 | extra_compile_args=["-arch=compute_80", "-std=c++17"], 11 | ), 12 | CUDAExtension( 13 | name="custom_kernels.fused_attention_cuda", 14 | sources=["custom_kernels/fused_attention_cuda.cu"], 15 | extra_compile_args=["-arch=compute_80", "-std=c++17"], 16 | ), 17 | ], 18 | cmdclass={"build_ext": BuildExtension}, 19 | ) 20 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/cu_compat.cuh: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #ifndef _cuda_compat_cuh 4 | #define _cuda_compat_cuh 5 | 6 | // atomicAdd for half types, to support CC < 7.x 7 | 8 | __device__ __forceinline__ void atomicAdd_half(half* address, half val) 9 | { 10 | unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); 11 | unsigned int old = *address_as_ui; 12 | unsigned int assumed; 13 | 14 | do 15 | { 16 | assumed = old; 17 | __half_raw hsum; 18 | hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); 19 | half tmpres = __hadd(hsum, val); 20 | hsum = __half_raw(tmpres); 21 | old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; 22 | old = atomicCAS(address_as_ui, assumed, old); 23 | } 24 | while (assumed != old); 25 | } 26 | 27 | // atomicAdd for half2 types 28 | 29 | __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) 30 | { 31 | unsigned int* address_as_ui = (unsigned int*)address; 32 | unsigned int old = *address_as_ui; 33 | unsigned int assumed; 34 | do 35 | { 36 | assumed = old; 37 | half2 old_val = *((half2*)&old); 38 | half2 new_val = __hadd2(old_val, val); 39 | old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); 40 | } 41 | while (assumed != old); 42 | } 43 | 44 | // 45 | 46 | #if defined(__CUDA_ARCH__) || defined(USE_ROCM) 47 | #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) 48 | 49 | __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } 50 | 51 | #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) 52 | __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } 53 | #endif 54 | 55 | #endif 56 | #endif 57 | 58 | #endif 59 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/cuda_buffers.cu: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #define _cuda_buffers_cu 4 | #include "cuda_buffers.cuh" 5 | 6 | CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; 7 | // __constant__ half2 q4_table[16][256]; 8 | // half2 q4_table_host[16][256]; 9 | // bool q4_table_init = false; 10 | 11 | CudaBuffers::CudaBuffers 12 | ( 13 | int _device, 14 | half* _temp_state, 15 | half* _temp_dq 16 | ) : 17 | device(_device), 18 | temp_state(_temp_state), 19 | temp_dq(_temp_dq) 20 | { 21 | cudaSetDevice(_device); 22 | 23 | cudaStreamCreate(&alt_stream_1); 24 | cudaStreamCreate(&alt_stream_2); 25 | cudaStreamCreate(&alt_stream_3); 26 | cudaEventCreate(&alt_stream_1_done); 27 | cudaEventCreate(&alt_stream_2_done); 28 | cudaEventCreate(&alt_stream_3_done); 29 | } 30 | 31 | CudaBuffers::~CudaBuffers() 32 | { 33 | cudaStreamDestroy(alt_stream_1); 34 | cudaStreamDestroy(alt_stream_2); 35 | cudaStreamDestroy(alt_stream_3); 36 | cudaEventDestroy(alt_stream_1_done); 37 | cudaEventDestroy(alt_stream_2_done); 38 | cudaEventDestroy(alt_stream_3_done); 39 | } 40 | 41 | CudaBuffers* get_buffers(const int device_index) 42 | { 43 | return g_buffers[device_index]; 44 | } 45 | 46 | void prepare_buffers_cuda 47 | ( 48 | int _device, 49 | half* _temp_state, 50 | half* _temp_dq 51 | ) 52 | { 53 | CudaBuffers* buffers = new CudaBuffers 54 | ( 55 | _device, 56 | _temp_state, 57 | _temp_dq 58 | ); 59 | 60 | g_buffers[_device] = buffers; 61 | } 62 | 63 | void cleanup_buffers_cuda() 64 | { 65 | for (int i = 0; i < CUDA_MAX_DEVICES; i++) 66 | { 67 | if (!g_buffers[i]) continue; 68 | delete g_buffers[i]; 69 | g_buffers[i] = NULL; 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/cuda_buffers.cuh: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #ifndef _cuda_buffers_cuh 4 | #define _cuda_buffers_cuh 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | const int CUDA_MAX_DEVICES = 16; 12 | 13 | // #ifndef _cuda_buffers_cu 14 | // extern __constant__ half2 q4_table[16][256]; 15 | // #endif 16 | 17 | class CudaBuffers 18 | { 19 | public: 20 | int device; 21 | 22 | half* temp_state; // [max_hidden_rows * intermediate_size] 23 | half* temp_dq; // size of largest quant tensor * 8 24 | 25 | cudaStream_t alt_stream_1; 26 | cudaStream_t alt_stream_2; 27 | cudaStream_t alt_stream_3; 28 | cudaEvent_t alt_stream_1_done; 29 | cudaEvent_t alt_stream_2_done; 30 | cudaEvent_t alt_stream_3_done; 31 | 32 | CudaBuffers 33 | ( 34 | int _device, 35 | half* _temp_state, 36 | half* _temp_dq 37 | ); 38 | ~CudaBuffers(); 39 | }; 40 | 41 | CudaBuffers* get_buffers(const int device_index); 42 | 43 | void prepare_buffers_cuda 44 | ( 45 | int _device, 46 | half* _temp_state, 47 | half* _temp_dq 48 | ); 49 | 50 | void cleanup_buffers_cuda(); 51 | 52 | #endif 53 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #include "column_remap.cuh" 4 | #include "../util.cuh" 5 | 6 | const int SHUF_BLOCKSIZE_X = 256; 7 | const int SHUF_BLOCKSIZE_Y = 16; 8 | 9 | __global__ void column_remap_kernel 10 | ( 11 | const half* __restrict__ x, 12 | half* __restrict__ x_new, 13 | const int x_width, 14 | const int x_height, 15 | const uint32_t* x_map 16 | ) 17 | { 18 | int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; 19 | int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; 20 | 21 | int x_stride = x_width; 22 | int x_idx = x_row * x_stride + x_column; 23 | 24 | int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); 25 | int x_idx_end = x_row_end * x_stride + x_column; 26 | 27 | int s_column = x_map[x_column]; 28 | int s_idx = x_row * x_stride + s_column; 29 | 30 | while (x_idx < x_idx_end) 31 | { 32 | x_new[x_idx] = x[s_idx]; 33 | x_idx += x_stride; 34 | s_idx += x_stride; 35 | } 36 | } 37 | 38 | // Remap columns in x to correspond to sequential group index before matmul 39 | // 40 | // perform x -> seq_x such that seq_x @ seq_w == x @ w 41 | 42 | void column_remap_cuda 43 | ( 44 | const half* x, 45 | half* x_new, 46 | const int x_height, 47 | const int x_width, 48 | const uint32_t* x_map 49 | ) 50 | { 51 | dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); 52 | 53 | dim3 blocks 54 | ( 55 | (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, 56 | (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, 57 | 1 58 | ); 59 | 60 | column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); 61 | } 62 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #ifndef _column_remap_cuh 4 | #define _column_remap_cuh 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | void column_remap_cuda 11 | ( 12 | const half* x, 13 | half* x_new, 14 | const int x_height, 15 | const int x_width, 16 | const uint32_t* x_map 17 | ); 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #ifndef _q4_matmul_cuh 4 | #define _q4_matmul_cuh 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "q4_matrix.cuh" 13 | #include "../tuning.h" 14 | 15 | void q4_matmul_cuda 16 | ( 17 | ExLlamaTuning* tuningParams, 18 | const half* x, 19 | const int x_height, 20 | const Q4Matrix* w, 21 | half* out, 22 | bool no_zero, 23 | cudaStream_t alt_stream 24 | ); 25 | 26 | void q4_matmul_recons_cuda 27 | ( 28 | ExLlamaTuning* tuningParams, 29 | const half* x, 30 | const int x_height, 31 | Q4Matrix* w, 32 | half* out, 33 | bool no_zero, 34 | const cublasHandle_t handle 35 | ); 36 | 37 | #endif 38 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #ifndef _q4_matrix_cuh 4 | #define _q4_matrix_cuh 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | class Q4Matrix 11 | { 12 | public: 13 | 14 | int device; 15 | 16 | int height; 17 | int width; 18 | int groups; 19 | int groupsize; 20 | 21 | uint32_t* cuda_qweight = NULL; 22 | uint32_t* cuda_qzeros = NULL; 23 | half* cuda_scales = NULL; 24 | uint32_t* cuda_x_map = NULL; 25 | 26 | Q4Matrix 27 | ( 28 | const int _height, 29 | const int _width, 30 | const int _groups, 31 | 32 | uint32_t* _qweight, 33 | uint32_t* _qzeros, 34 | half* _scales, 35 | uint32_t* _g_idx, 36 | 37 | const int _device 38 | ); 39 | 40 | ~Q4Matrix(); 41 | 42 | void reconstruct(half* out); 43 | 44 | private: 45 | 46 | void make_sequential(const uint32_t* cpu_g_idx); 47 | 48 | }; 49 | 50 | void g_q4_keep_matrix(Q4Matrix* m); 51 | void g_q4_free_matrices(); 52 | 53 | #endif 54 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/hip_compat.cuh: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #ifndef _hip_compat_cuh 4 | #define _hip_compat_cuh 5 | 6 | // Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6. 7 | __device__ __forceinline__ __half __compat_hrcp(__half x) { 8 | return __half_raw{ 9 | static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; 10 | } 11 | 12 | __device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { 13 | return _Float16_2{ 14 | _Float16_2{static_cast<_Float16>(1.0f), 15 | static_cast<_Float16>(1.0f)} / x.data}; 16 | } 17 | 18 | #define hrcp __compat_hrcp 19 | #define h2rcp __compat_h2rcp 20 | 21 | // Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf. 22 | __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, 23 | hipblasOperation_t transA, 24 | hipblasOperation_t transB, 25 | int m, 26 | int n, 27 | int k, 28 | const half* alpha, 29 | const half* AP, 30 | int lda, 31 | const half* BP, 32 | int ldb, 33 | const half* beta, 34 | half* CP, 35 | int ldc) { 36 | return hipblasHgemm(handle, transA, transB, m, n, k, 37 | reinterpret_cast(alpha), 38 | reinterpret_cast(AP), lda, 39 | reinterpret_cast(BP), ldb, 40 | reinterpret_cast(beta), 41 | reinterpret_cast(CP), ldc); 42 | } 43 | #define hipblasHgemm __compat_hipblasHgemm 44 | 45 | // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. 46 | #define rocblas_handle hipblasHandle_t 47 | #define rocblas_operation_none HIPBLAS_OP_N 48 | #define rocblas_get_stream hipblasGetStream 49 | #define rocblas_set_stream hipblasSetStream 50 | #define rocblas_hgemm __compat_hipblasHgemm 51 | 52 | #endif 53 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/tuning.h: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #ifndef _tuning_h 4 | #define _tuning_h 5 | 6 | struct ExLlamaTuning 7 | { 8 | int matmul_recons_thd; 9 | bool matmul_fused_remap; 10 | bool matmul_no_half2; 11 | }; 12 | 13 | #endif 14 | -------------------------------------------------------------------------------- /server/exllama_kernels/exllama_kernels/util.cuh: -------------------------------------------------------------------------------- 1 | // Adapted from turboderp exllama: https://github.com/turboderp/exllama 2 | 3 | #ifndef _util_cuh 4 | #define _util_cuh 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #if defined(USE_ROCM) 12 | #define cudaUnspecified hipErrorUnknown 13 | #else 14 | #define cudaUnspecified cudaErrorApiFailureBase 15 | #endif 16 | 17 | // React to failure on return code != cudaSuccess 18 | 19 | #define _cuda_check(fn) \ 20 | do { \ 21 | {_cuda_err = fn;} \ 22 | if (_cuda_err != cudaSuccess) goto _cuda_fail; \ 23 | } while(false) 24 | 25 | // React to failure on return code == 0 26 | 27 | #define _alloc_check(fn) \ 28 | do { \ 29 | if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ 30 | else _cuda_err = cudaSuccess; \ 31 | } while(false) 32 | 33 | #endif 34 | -------------------------------------------------------------------------------- /server/exllama_kernels/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name="exllama_kernels", 6 | ext_modules=[ 7 | CUDAExtension( 8 | name="exllama_kernels", 9 | sources=[ 10 | "exllama_kernels/exllama_ext.cpp", 11 | "exllama_kernels/cuda_buffers.cu", 12 | "exllama_kernels/cuda_func/column_remap.cu", 13 | "exllama_kernels/cuda_func/q4_matmul.cu", 14 | "exllama_kernels/cuda_func/q4_matrix.cu", 15 | ], 16 | ) 17 | ], 18 | cmdclass={"build_ext": BuildExtension}, 19 | ) 20 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/config.h: -------------------------------------------------------------------------------- 1 | #ifndef _config_h 2 | #define _config_h 3 | 4 | #define MAX_Q_GEMM_ROWS 50 5 | #define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS 6 | 7 | #define QMODE_2BIT 1 8 | #define QMODE_3BIT 1 9 | #define QMODE_4BIT 1 10 | #define QMODE_5BIT 1 11 | #define QMODE_6BIT 0 12 | #define QMODE_8BIT 0 13 | 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/cpp/util.h: -------------------------------------------------------------------------------- 1 | #ifndef _util_h 2 | #define _util_h 3 | 4 | #define DBGS(__x) printf("%s\n", __x) 5 | #define DBGI(__x) printf("%s: %i\n", #__x, __x) 6 | #define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) 7 | #define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) 8 | #define DBGF(__x) printf("%s: %f\n", #__x, __x) 9 | #define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) 10 | #define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _compat_cuh 2 | #define _compat_cuh 3 | 4 | // atomicAdd for half types, to support CC < 7.x 5 | 6 | __device__ __forceinline__ void atomicAdd_half(half* address, half val) 7 | { 8 | unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); 9 | unsigned int old = *address_as_ui; 10 | unsigned int assumed; 11 | 12 | do 13 | { 14 | assumed = old; 15 | __half_raw hsum; 16 | hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); 17 | half tmpres = __hadd(hsum, val); 18 | hsum = __half_raw(tmpres); 19 | old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; 20 | old = atomicCAS(address_as_ui, assumed, old); 21 | } 22 | while (assumed != old); 23 | } 24 | 25 | // atomicAdd for half2 types 26 | 27 | __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) 28 | { 29 | unsigned int* address_as_ui = (unsigned int*)address; 30 | unsigned int old = *address_as_ui; 31 | unsigned int assumed; 32 | do 33 | { 34 | assumed = old; 35 | half2 old_val = *((half2*)&old); 36 | half2 new_val = __hadd2(old_val, val); 37 | old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); 38 | } 39 | while (assumed != old); 40 | } 41 | 42 | // 43 | 44 | #if defined(__CUDA_ARCH__) || defined(USE_ROCM) 45 | #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) 46 | 47 | __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } 48 | 49 | #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) 50 | __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } 51 | #endif 52 | 53 | #endif 54 | #endif 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _q_gemm_cuh 2 | #define _q_gemm_cuh 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "q_matrix.cuh" 11 | 12 | void gemm_half_q_half_cuda 13 | ( 14 | cublasHandle_t cublas_handle, 15 | const half* a, 16 | QMatrix* b, 17 | half* c, 18 | int size_m, 19 | int size_n, 20 | int size_k, 21 | bool clear = false, 22 | half* reconstruct = NULL, 23 | bool force_cuda = false, 24 | const half* r_weights = NULL, 25 | const int r_weights_stride = 0, 26 | bool mul_r_weights = false 27 | ); 28 | 29 | void clear_tensor_cuda 30 | ( 31 | half* c, 32 | int size_m, 33 | int size_n 34 | ); 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _q_matrix_cuh 2 | #define _q_matrix_cuh 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define MAX_SUPERGROUPS 16 10 | 11 | class QMatrix 12 | { 13 | public: 14 | 15 | int device; 16 | bool is_gptq; 17 | 18 | int height; 19 | int width; 20 | int groups; 21 | int gptq_groupsize; 22 | 23 | int rows_8; 24 | int rows_6; 25 | int rows_5; 26 | int rows_4; 27 | int rows_3; 28 | int rows_2; 29 | 30 | uint32_t* cuda_q_weight = NULL; 31 | uint16_t* cuda_q_perm = NULL; 32 | uint16_t* cuda_q_invperm = NULL; 33 | uint32_t* cuda_q_scale = NULL; 34 | half* cuda_q_scale_max = NULL; 35 | uint16_t* cuda_q_groups = NULL; 36 | uint16_t* cuda_q_group_map = NULL; 37 | uint32_t* cuda_gptq_qzeros = NULL; 38 | half* cuda_gptq_scales = NULL; 39 | 40 | half* temp_dq; 41 | 42 | bool failed; 43 | 44 | QMatrix 45 | ( 46 | const int _device, 47 | const int _height, 48 | const int _width, 49 | const int _groups, 50 | 51 | uint32_t* _q_weight, 52 | uint16_t* _q_perm, 53 | uint16_t* _q_invperm, 54 | uint32_t* _q_scale, 55 | half* _q_scale_max, 56 | uint16_t* _q_groups, 57 | uint16_t* _q_group_map, 58 | 59 | uint32_t* _gptq_qzeros, 60 | half* _gptq_scales, 61 | uint32_t* _gptq_g_idx, 62 | 63 | half* _temp_dq 64 | ); 65 | 66 | ~QMatrix(); 67 | 68 | void reconstruct(half* out); 69 | bool make_sequential(const uint32_t* cpu_g_idx); 70 | 71 | private: 72 | 73 | }; 74 | 75 | #endif 76 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _qdq_6_cuh 2 | #define _qdq_6_cuh 3 | 4 | #include "qdq_util.cuh" 5 | #include "../../config.h" 6 | 7 | #if QMODE_6BIT == 1 8 | 9 | // Not implemented 10 | 11 | #else 12 | 13 | __forceinline__ __device__ void shuffle_6bit_16 14 | ( 15 | uint32_t* q, 16 | int stride 17 | ) 18 | { 19 | } 20 | 21 | __forceinline__ __device__ void dequant_6bit_16 22 | ( 23 | const uint32_t q_0, 24 | const uint32_t q_1, 25 | const uint32_t q_2, 26 | half2 (&dq)[8], 27 | int stride 28 | ) 29 | { 30 | half dqh[16]; 31 | for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32); 32 | dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32); 33 | for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32); 34 | dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32); 35 | for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32); 36 | 37 | for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); 38 | } 39 | 40 | #endif 41 | 42 | #endif 43 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _qdq_8_cuh 2 | #define _qdq_8_cuh 3 | 4 | #include "qdq_util.cuh" 5 | #include "../../config.h" 6 | 7 | #if QMODE_8BIT == 1 8 | 9 | // Not implemented 10 | 11 | #else 12 | 13 | __forceinline__ __device__ void shuffle_8bit_4 14 | ( 15 | uint32_t* q, 16 | int stride 17 | ) 18 | { 19 | } 20 | 21 | __forceinline__ __device__ void dequant_8bit_8 22 | ( 23 | const uint32_t q_0, 24 | const uint32_t q_1, 25 | half2 (&dq)[4], 26 | int stride 27 | ) 28 | { 29 | half dqh[8]; 30 | for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128); 31 | for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128); 32 | 33 | for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); 34 | } 35 | 36 | #endif 37 | 38 | #endif 39 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _qdq_util_cuh 2 | #define _qdq_util_cuh 3 | 4 | union half2_uint32 5 | { 6 | uint32_t as_uint32; 7 | half2 as_half2; 8 | __device__ half2_uint32(uint32_t val) : as_uint32(val) {} 9 | __device__ half2_uint32(half2 val) : as_half2(val) {} 10 | __device__ half2_uint32() : as_uint32(0) {} 11 | }; 12 | 13 | union half_uint16 14 | { 15 | uint16_t as_uint16; 16 | half as_half; 17 | __device__ half_uint16(uint16_t val) : as_uint16(val) {} 18 | __device__ half_uint16(half val) : as_half(val) {} 19 | __device__ half_uint16() : as_uint16(0) {} 20 | }; 21 | 22 | // Max_scale premultiplied by 1/256 23 | 24 | __forceinline__ __device__ half dq_scale(const int qs, const half max_scale) 25 | { 26 | int qs_i = qs + 1; 27 | half qs_h = __int2half_rn(qs_i * qs_i); 28 | qs_h = __hmul(qs_h, max_scale); 29 | return qs_h; 30 | } 31 | 32 | __forceinline__ __device__ half dq(const int q, const int qzero, const half scale) 33 | { 34 | return __hmul(__int2half_rn(q - qzero), scale); 35 | } 36 | 37 | __forceinline__ __device__ half dq_ns(const int q, const int qzero) 38 | { 39 | //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); 40 | return __int2half_rn(q - qzero); 41 | } 42 | 43 | __forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) 44 | { 45 | return (int)((q >> shift) & mask); 46 | } 47 | 48 | __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) 49 | { 50 | return (int)(__funnelshift_rc(q0, q1, shift) & mask); 51 | } 52 | 53 | #endif 54 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _util_cuh 2 | #define _util_cuh 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) 11 | 12 | #define DBGS(__x) printf("%s\n", __x) 13 | #define DBGI(__x) printf("%s: %i\n", #__x, __x) 14 | #define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) 15 | #define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) 16 | #define DBGX(__x) printf("%s: %x\n", #__x, __x) 17 | #define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y) 18 | #define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z) 19 | #define DBGF(__x) printf("%s: %f\n", #__x, __x) 20 | #define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) 21 | #define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) 22 | #define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x)) 23 | #define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y)) 24 | #define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z)) 25 | 26 | #define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y)) 27 | #define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z)) 28 | 29 | __forceinline__ __device__ half dq_scale_(const int qs, const half max_scale) 30 | { 31 | half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f)); 32 | qs_h = __hmul(qs_h, qs_h); 33 | qs_h = __hmul(qs_h, max_scale); 34 | return qs_h; 35 | } 36 | 37 | __forceinline__ __device__ float clamp(float x, float a, float b) 38 | { 39 | return fmaxf(a, fminf(b, x)); 40 | } 41 | 42 | #define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); } 43 | inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true) 44 | { 45 | if (code != cudaSuccess) 46 | { 47 | fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line); 48 | if (abort) exit(code); 49 | } 50 | } 51 | 52 | void print_global_mem(const half* ptr, int rows, int columns, int stride); 53 | 54 | #endif 55 | -------------------------------------------------------------------------------- /server/exllamav2_kernels/setup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | extra_cuda_cflags = ["-lineinfo", "-O3"] 6 | 7 | if torch.version.hip: 8 | extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"] 9 | 10 | extra_compile_args = { 11 | "nvcc": extra_cuda_cflags, 12 | } 13 | 14 | setup( 15 | name="exllamav2_kernels", 16 | ext_modules=[ 17 | CUDAExtension( 18 | name="exllamav2_kernels", 19 | sources=[ 20 | "exllamav2_kernels/ext.cpp", 21 | "exllamav2_kernels/cuda/q_matrix.cu", 22 | "exllamav2_kernels/cuda/q_gemm.cu", 23 | ], 24 | extra_compile_args=extra_compile_args, 25 | ) 26 | ], 27 | cmdclass={"build_ext": BuildExtension}, 28 | ) 29 | -------------------------------------------------------------------------------- /server/lorax_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/server/lorax_server/__init__.py -------------------------------------------------------------------------------- /server/lorax_server/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Dict, Optional 4 | 5 | from lorax_server.adapters.config import AdapterConfig 6 | from lorax_server.adapters.lora import LoraConfig 7 | from lorax_server.adapters.medusa import MedusaConfig 8 | from lorax_server.adapters.medusa_lora import MedusaLoraConfig 9 | from lorax_server.adapters.weights import AdapterBatchData, AdapterBatchMetadata 10 | 11 | 12 | def load_medusa_config(config_path: Optional[Path]) -> Optional[Dict]: 13 | if config_path is not None and config_path.exists(): 14 | config = json.load(config_path.open()) 15 | if "medusa_num_heads" in config: 16 | return config 17 | return None 18 | 19 | 20 | def load_adapter_config( 21 | config_path: Optional[Path], 22 | adapter_config_path: Optional[Path], 23 | api_token: str, 24 | ) -> AdapterConfig: 25 | medusa_config = load_medusa_config(config_path) 26 | if adapter_config_path is not None and adapter_config_path.exists(): 27 | if medusa_config is not None: 28 | return MedusaLoraConfig.load(str(adapter_config_path.parent), medusa_config, api_token) 29 | else: 30 | return LoraConfig.load(str(adapter_config_path.parent), api_token) 31 | 32 | if medusa_config is not None: 33 | return MedusaConfig.load(medusa_config) 34 | 35 | raise ValueError(f"No valid adapter config file found: " f"tried {adapter_config_path} and {config_path}") 36 | 37 | 38 | __all__ = [ 39 | "AdapterBatchData", 40 | "AdapterBatchMetadata", 41 | "load_adapter_config", 42 | ] 43 | -------------------------------------------------------------------------------- /server/lorax_server/adapters/config.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple 4 | 5 | import torch 6 | 7 | from lorax_server.adapters.weights import AdapterWeights 8 | 9 | if TYPE_CHECKING: 10 | from lorax_server.models.model import Model 11 | 12 | 13 | ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] 14 | 15 | 16 | @dataclass 17 | class AdapterConfig(ABC): 18 | base_model_name_or_path: str 19 | 20 | @abstractmethod 21 | def map_weights_for_model( 22 | self, 23 | adapter_weights: Dict, 24 | weight_names: Tuple[str], 25 | ) -> Tuple[ModuleMap, Set[str]]: 26 | pass 27 | 28 | @abstractmethod 29 | def load_batched_adapter_weights( 30 | self, 31 | model: "Model", 32 | module_map: Dict[str, Dict], 33 | layer_type: str, 34 | unused_weight_names: Set[str], 35 | dynamic: bool, 36 | ) -> Optional[AdapterWeights]: 37 | pass 38 | -------------------------------------------------------------------------------- /server/lorax_server/adapters/types.py: -------------------------------------------------------------------------------- 1 | LORA = "lora" 2 | MEDUSA = "medusa" 3 | -------------------------------------------------------------------------------- /server/lorax_server/adapters/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from lorax_server.utils.sources import HUB, PBASE, S3, get_model_source, map_pbase_model_id_to_s3 5 | from lorax_server.utils.sources.hub import get_hub_api 6 | from lorax_server.utils.weights import download_weights 7 | 8 | 9 | def download_adapter_weights( 10 | adapter_id: str, 11 | adapter_source: str, 12 | api_token: Optional[str] = None, 13 | ) -> int: 14 | if adapter_source == PBASE: 15 | api_token = api_token or os.environ.get("PREDIBASE_API_TOKEN") 16 | adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) 17 | adapter_source = S3 18 | 19 | if adapter_source == HUB: 20 | # Quick auth check on the repo against the token 21 | get_hub_api(token=api_token).model_info(adapter_id, revision=None) 22 | 23 | # fail fast if ID is not an adapter (i.e. it is a full model) 24 | source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) 25 | source.load_config() 26 | 27 | download_weights(adapter_id, source=adapter_source, api_token=api_token) 28 | 29 | # Calculate size of adapter to be loaded 30 | return source.get_weight_bytes() 31 | -------------------------------------------------------------------------------- /server/lorax_server/cache.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, TypeVar 2 | 3 | import torch 4 | 5 | from lorax_server.models.types import Batch 6 | 7 | B = TypeVar("B", bound=Batch) 8 | 9 | 10 | class Cache: 11 | """ 12 | A class representing a cache. 13 | 14 | Attributes: 15 | cache (Dict[int, B]): A dictionary representing the cache, where the keys are batch IDs and the values are entries. 16 | 17 | Methods: 18 | pop(batch_id: int) -> Optional[B]: Removes and returns the entry with the specified batch ID from the cache. 19 | set(entry: B): Adds the specified entry to the cache. 20 | delete(batch_id: int): Deletes the entry with the specified batch ID from the cache. 21 | clear(): Clears the cache. 22 | __len__(): Returns the number of entries in the cache. 23 | """ 24 | 25 | def __init__(self): 26 | self.cache: Dict[int, B] = {} 27 | 28 | def pop(self, batch_id: int) -> Optional[B]: 29 | return self.cache.pop(batch_id, None) 30 | 31 | def set(self, entry: B): 32 | if entry is not None: 33 | self.cache[entry.batch_id] = entry 34 | 35 | def delete(self, batch_id: int): 36 | batch = self.pop(batch_id) 37 | if batch is not None: 38 | del batch 39 | if torch.cuda.is_available(): 40 | torch.cuda.empty_cache() 41 | 42 | def clear(self): 43 | keys = list(self.cache.keys()) 44 | for k in keys: 45 | self.delete(k) 46 | 47 | def __len__(self): 48 | return len(self.cache.keys()) 49 | -------------------------------------------------------------------------------- /server/lorax_server/interceptor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import grpc 4 | import torch 5 | from google.rpc import code_pb2, status_pb2 6 | from grpc_interceptor.server import AsyncServerInterceptor 7 | from grpc_status import rpc_status 8 | from loguru import logger 9 | 10 | 11 | class ExceptionInterceptor(AsyncServerInterceptor): 12 | """Intercepts and handles exceptions that occur during gRPC method execution.""" 13 | 14 | async def intercept( 15 | self, 16 | method: Callable, 17 | request_or_iterator: Any, 18 | context: grpc.ServicerContext, 19 | method_name: str, 20 | ) -> Any: 21 | """ 22 | Intercepts the gRPC method execution and handles any exceptions that occur. 23 | 24 | Args: 25 | method (Callable): The gRPC method to be executed. 26 | request_or_iterator (Any): The request object or iterator. 27 | context (grpc.ServicerContext): The gRPC servicer context. 28 | method_name (str): The name of the gRPC method. 29 | 30 | Returns: 31 | Any: The response of the gRPC method. 32 | 33 | Raises: 34 | Exception: If an error occurs during the execution of the gRPC method. 35 | """ 36 | try: 37 | response = method(request_or_iterator, context) 38 | return await response 39 | except Exception as err: 40 | method_name = method_name.split("/")[-1] 41 | logger.exception(f"Method {method_name} encountered an error.") 42 | 43 | if torch.cuda.is_available(): 44 | torch.cuda.empty_cache() 45 | 46 | await context.abort_with_status( 47 | rpc_status.to_status(status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))) 48 | ) 49 | -------------------------------------------------------------------------------- /server/lorax_server/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from lorax_server.layers.conv import load_conv2d # noqa 2 | 3 | # Just to add the `load` methods. 4 | from lorax_server.layers.layernorm import load_layer_norm # noqa 5 | from lorax_server.layers.linear import ( 6 | FastLinear, # noqa 7 | get_linear, # noqa 8 | ) 9 | from lorax_server.layers.tensor_parallel import ( 10 | TensorParallelColumnLinear, # noqa 11 | TensorParallelEmbedding, # noqa 12 | TensorParallelRowLinear, # noqa 13 | ) 14 | -------------------------------------------------------------------------------- /server/lorax_server/layers/awq/quantize/qmodule.py: -------------------------------------------------------------------------------- 1 | # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py 2 | 3 | import awq_inference_engine # with CUDA kernels 4 | import torch 5 | import torch.nn as nn 6 | 7 | # class ScaledActivation(nn.Module): 8 | # def __init__(self, module, scales): 9 | # super().__init__() 10 | # self.act = module 11 | # self.scales = nn.Parameter(scales.data) 12 | # 13 | # def forward(self, x): 14 | # return self.act(x) / self.scales.view(1, 1, -1).to(x.device) 15 | 16 | 17 | class WQLinear(nn.Module): 18 | def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): 19 | super().__init__() 20 | 21 | if w_bit not in [4]: 22 | raise NotImplementedError("Only 4-bit are supported for now.") 23 | 24 | self.in_features = qweight.shape[0] 25 | self.out_features = qweight.shape[1] * 32 // w_bit 26 | 27 | self.w_bit = w_bit 28 | self.group_size = group_size if group_size != -1 else self.in_features 29 | # quick sanity check (make sure aligment) 30 | assert self.in_features % self.group_size == 0 31 | assert self.out_features % (32 // self.w_bit) == 0 32 | 33 | self.qweight = qweight 34 | self.qzeros = qzeros 35 | self.scales = scales 36 | if bias: 37 | self.bias = bias 38 | else: 39 | self.bias = None 40 | 41 | @torch.no_grad() 42 | def forward(self, x): 43 | out_shape = x.shape[:-1] + (self.out_features,) 44 | out = awq_inference_engine.gemm_forward_cuda( 45 | x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 46 | ) 47 | out = out + self.bias if self.bias is not None else out 48 | return out.reshape(out_shape) 49 | 50 | @property 51 | def weight(self) -> torch.Tensor: 52 | return self.qweight 53 | -------------------------------------------------------------------------------- /server/lorax_server/layers/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from accelerate import init_empty_weights 3 | 4 | 5 | @classmethod 6 | def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): 7 | weight = weights.get_tensor(f"{prefix}.weight") 8 | bias = weights.get_tensor(f"{prefix}.bias") 9 | with init_empty_weights(): 10 | conv2d = cls( 11 | in_channels=in_channels, 12 | out_channels=out_channels, 13 | kernel_size=kernel_size, 14 | stride=stride, 15 | ) 16 | 17 | conv2d.weight = torch.nn.Parameter(weight) 18 | conv2d.bias = torch.nn.Parameter(bias) 19 | return conv2d 20 | 21 | 22 | @classmethod 23 | def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): 24 | weight = weights.get_tensor(f"{prefix}.weight") 25 | with init_empty_weights(): 26 | conv2d = cls( 27 | in_channels=in_channels, 28 | out_channels=out_channels, 29 | kernel_size=kernel_size, 30 | stride=stride, 31 | ) 32 | 33 | conv2d.weight = torch.nn.Parameter(weight) 34 | conv2d.bias = None 35 | return conv2d 36 | 37 | 38 | torch.nn.Conv2d.load = load_conv2d 39 | torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias 40 | -------------------------------------------------------------------------------- /server/lorax_server/layers/eetq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from EETQ import quant_weights, w8_a16_gemm 3 | 4 | 5 | class EETQLinear(torch.nn.Module): 6 | def __init__( 7 | self, 8 | weight, 9 | bias, 10 | ) -> None: 11 | super().__init__() 12 | device = weight.device 13 | if weight.dtype != torch.float16: 14 | weight = weight.to(dtype=torch.float16) 15 | weight = torch.t(weight).contiguous().cpu() 16 | weight, scale = quant_weights(weight, torch.int8, False) 17 | 18 | self.weight = weight.cuda(device) 19 | self.scale = scale.cuda(device) 20 | self.bias = bias.cuda(device) if bias is not None else None 21 | 22 | def forward(self, input: torch.Tensor) -> torch.Tensor: 23 | output = w8_a16_gemm(input, self.weight, self.scale) 24 | output = output + self.bias if self.bias is not None else output 25 | return output 26 | -------------------------------------------------------------------------------- /server/lorax_server/layers/fp8.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from vllm import _custom_ops as ops 5 | 6 | ####### from vLLM code ####### 7 | 8 | 9 | def apply_fp8_linear( 10 | input: torch.Tensor, 11 | qweight: torch.Tensor, 12 | weight_scale: torch.Tensor, 13 | input_scale: Optional[torch.Tensor] = None, 14 | input_scale_ub: Optional[torch.Tensor] = None, 15 | qbias: Optional[torch.Tensor] = None, 16 | ) -> torch.Tensor: 17 | qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=True) 18 | 19 | output = ops.cutlass_scaled_mm( 20 | qinput, qweight, out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale, bias=qbias 21 | ) 22 | 23 | return output 24 | 25 | 26 | class Fp8Linear(torch.nn.Module): 27 | def __init__( 28 | self, 29 | weight, 30 | bias, 31 | weight_scale, 32 | input_scale, 33 | ) -> None: 34 | super().__init__() 35 | self.dtype = weight.dtype 36 | self.qweight = weight.t() 37 | self.weight_scale = weight_scale.view(1, -1).contiguous().float() 38 | self.qbias = bias if bias is not None else None 39 | self.input_scale = input_scale.float() if input_scale is not None else None 40 | 41 | def forward(self, input: torch.Tensor) -> torch.Tensor: 42 | return apply_fp8_linear( 43 | input=input, 44 | qweight=self.qweight, 45 | weight_scale=self.weight_scale, 46 | input_scale=self.input_scale, 47 | qbias=self.qbias, 48 | ) 49 | 50 | @property 51 | def weight(self) -> torch.Tensor: 52 | return self.qweight 53 | -------------------------------------------------------------------------------- /server/lorax_server/layers/gptq/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from lorax_server.utils.import_utils import ( 6 | SYSTEM, 7 | ) 8 | 9 | try: 10 | major, _minor = torch.cuda.get_device_capability() 11 | except Exception: 12 | major = 1 13 | 14 | HAS_EXLLAMA = False 15 | CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" 16 | V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" 17 | if os.getenv("DISABLE_EXLLAMA") == "True": 18 | HAS_EXLLAMA = False 19 | elif CAN_EXLLAMA: 20 | try: 21 | if V2: 22 | from lorax_server.layers.gptq.exllamav2 import ( 23 | QuantLinear as ExllamaQuantLinear, 24 | ) 25 | from lorax_server.layers.gptq.exllamav2 import ( 26 | create_exllama_buffers, 27 | set_device, 28 | ) 29 | 30 | HAS_EXLLAMA = "2" 31 | else: 32 | from lorax_server.layers.gptq.exllama import ( 33 | Ex4bitLinear as ExllamaQuantLinear, # noqa 34 | ) 35 | from lorax_server.layers.gptq.exllama import ( 36 | create_exllama_buffers, # noqa 37 | set_device, # noqa 38 | ) 39 | 40 | HAS_EXLLAMA = "1" 41 | 42 | except ImportError: 43 | pass 44 | 45 | from lorax_server.layers.gptq.quant_linear import QuantLinear # noqa 46 | -------------------------------------------------------------------------------- /server/lorax_server/layers/hqq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | HAS_HQQ = True 5 | try: 6 | from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear 7 | 8 | HQQLinear.set_backend(HQQBackend.ATEN) 9 | 10 | class HQQLinearLayer(HQQLinear): 11 | @property 12 | def weight(self) -> torch.Tensor: 13 | return self.W_q 14 | 15 | except ImportError: 16 | HAS_HQQ = False 17 | 18 | 19 | def get_hqq_linear(quantize, weight, bias=None) -> HQQLinearLayer: 20 | if quantize == "hqq-4bit": 21 | quant_config = BaseQuantizeConfig( 22 | nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16 23 | ) 24 | elif quantize == "hqq-3bit": 25 | quant_config = BaseQuantizeConfig( 26 | nbits=3, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16 27 | ) 28 | elif quantize == "hqq-2bit": 29 | quant_config = BaseQuantizeConfig( 30 | nbits=2, group_size=16, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16 31 | ) 32 | 33 | # init nn.linear from weight and bias 34 | layer = nn.Linear(weight.shape[1], weight.shape[0], bias=bias is not None) 35 | with torch.no_grad(): 36 | layer.weight.data = weight 37 | if bias is not None: 38 | layer.bias.data = bias 39 | 40 | linear = HQQLinearLayer(layer, quant_config, del_orig=True) 41 | 42 | return linear 43 | -------------------------------------------------------------------------------- /server/lorax_server/models/custom_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/server/lorax_server/models/custom_modeling/__init__.py -------------------------------------------------------------------------------- /server/lorax_server/models/custom_modeling/utils.py: -------------------------------------------------------------------------------- 1 | def prepend(prefix: str, path: str) -> str: 2 | return f"{prefix}.{path}" if prefix else path 3 | -------------------------------------------------------------------------------- /server/lorax_server/models/custom_modeling/vlm.py: -------------------------------------------------------------------------------- 1 | def load_text_model(prefix, config, weights, name=None): 2 | if config.model_type == "llama": 3 | from lorax_server.models.custom_modeling.flash_llama_modeling import ( 4 | FlashLlamaForCausalLM, 5 | ) 6 | 7 | return FlashLlamaForCausalLM(prefix, config, weights) 8 | elif config.model_type == "mistral": 9 | from lorax_server.models.custom_modeling.flash_mistral_modeling import ( 10 | FlashMistralForCausalLM, 11 | ) 12 | 13 | return FlashMistralForCausalLM(prefix, config, weights, name=name) 14 | elif config.model_type == "gemma": 15 | from lorax_server.models.custom_modeling.flash_gemma_modeling import ( 16 | FlashGemmaForCausalLM, 17 | ) 18 | 19 | return FlashGemmaForCausalLM(prefix, config, weights, causal=False) 20 | elif config.model_type == "paligemma": 21 | from lorax_server.models.custom_modeling.flash_gemma_modeling import ( 22 | FlashGemmaForCausalLM, 23 | ) 24 | 25 | return FlashGemmaForCausalLM(prefix, config, weights) 26 | else: 27 | raise RuntimeError(f"Unsupported model type {config.model_type}") 28 | 29 | 30 | def load_vision_model(prefix, config, weights): 31 | if config.model_type == "clip_vision_model": 32 | from lorax_server.models.custom_modeling.clip import ( 33 | CLIPVisionTransformer, 34 | ) 35 | 36 | return CLIPVisionTransformer(prefix=f"{prefix}.vision_model", config=config, weights=weights) 37 | if config.model_type == "siglip_vision_model": 38 | from lorax_server.models.custom_modeling.siglip import ( 39 | SiglipVisionTransformer, 40 | ) 41 | 42 | return SiglipVisionTransformer(prefix="vision_tower.vision_model", config=config, weights=weights) 43 | else: 44 | raise RuntimeError(f"Unsupported model type {config.model_type}") 45 | -------------------------------------------------------------------------------- /server/lorax_server/models/flash_dbrx.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed 5 | from opentelemetry import trace 6 | from transformers import AutoTokenizer 7 | 8 | from lorax_server.models import FlashCausalLM 9 | from lorax_server.models.custom_modeling.flash_dbrx_modeling import ( 10 | ATTN_O_PROJ, 11 | ATTN_WQKV, 12 | DbrxConfig, 13 | FlashDbrxForCausalLM, 14 | ) 15 | from lorax_server.utils.lora import LM_HEAD 16 | 17 | tracer = trace.get_tracer(__name__) 18 | 19 | ADAPTER_LAYERS = [ATTN_WQKV, ATTN_O_PROJ, LM_HEAD] 20 | ROW_PARALLEL = {ATTN_O_PROJ, LM_HEAD} 21 | 22 | 23 | class FlashDbrx(FlashCausalLM): 24 | def __init__( 25 | self, 26 | model_id: str, 27 | adapter_id: str, 28 | adapter_source: str, 29 | revision: Optional[str] = None, 30 | dtype: Optional[torch.dtype] = None, 31 | **kwargs, 32 | ): 33 | super().__init__( 34 | model_id=model_id, 35 | model_cls=FlashDbrxForCausalLM, 36 | dtype=dtype, 37 | revision=revision, 38 | adapter_id=adapter_id, 39 | adapter_source=adapter_source, 40 | tokenizer_cls=AutoTokenizer, 41 | config_cls=DbrxConfig, 42 | **kwargs, 43 | ) 44 | 45 | @property 46 | def supports_adapter_loading(self) -> bool: 47 | return True 48 | 49 | def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: 50 | layer_weights = {} 51 | 52 | prefix = "transformer.blocks" 53 | for i, layer in enumerate(self.model.model.layers): 54 | layer_weights[(i, ATTN_WQKV)] = ( 55 | f"{prefix}.{i}.norm_attn_norm.attn.q_proj", 56 | layer.attn.self_attn.query_key_value, 57 | ) 58 | layer_weights[(i, ATTN_O_PROJ)] = ( 59 | f"{prefix}.{i}.norm_attn_norm.attn.out_proj", 60 | layer.attn.self_attn.o_proj, 61 | ) 62 | 63 | layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) 64 | return layer_weights 65 | 66 | @property 67 | def adapter_layers(self) -> List[str]: 68 | return ADAPTER_LAYERS 69 | 70 | @property 71 | def default_traced_adapter_layers(self) -> List[str]: 72 | return [ATTN_WQKV] 73 | 74 | def get_num_layers_for_type(self, layer_type: str) -> int: 75 | return len(self.model.model.layers) 76 | 77 | def is_row_parallel(self, layer_type: str) -> bool: 78 | return layer_type in ROW_PARALLEL 79 | -------------------------------------------------------------------------------- /server/lorax_server/models/flash_gemma2.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed 5 | from opentelemetry import trace 6 | 7 | from lorax_server.models import FlashCausalLM 8 | from lorax_server.models.custom_modeling.flash_gemma2_modeling import FlashGemma2ForCausalLM 9 | from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ 10 | 11 | tracer = trace.get_tracer(__name__) 12 | 13 | # TODO(tim): re-enable LM_HEAD after resolving issues with outputs 14 | ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ] 15 | ROW_PARALLEL = {O_PROJ, DOWN_PROJ} 16 | 17 | 18 | class FlashGemma2(FlashCausalLM): 19 | def __init__( 20 | self, 21 | model_id: str, 22 | adapter_id: str, 23 | adapter_source: str, 24 | revision: Optional[str] = None, 25 | dtype: Optional[torch.dtype] = None, 26 | **kwargs, 27 | ): 28 | super().__init__( 29 | model_id=model_id, 30 | model_cls=FlashGemma2ForCausalLM, 31 | dtype=dtype, 32 | revision=revision, 33 | adapter_id=adapter_id, 34 | adapter_source=adapter_source, 35 | **kwargs, 36 | ) 37 | 38 | @property 39 | def supports_adapter_loading(self) -> bool: 40 | return True 41 | 42 | def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: 43 | layer_weights = {} 44 | 45 | prefix = "model.layers" 46 | for i, layer in enumerate(self.model.model.layers): 47 | layer_weights[(i, Q_PROJ)] = ( 48 | f"{prefix}.{i}.self_attn.q_proj", 49 | layer.self_attn.query_key_value, 50 | ) 51 | layer_weights[(i, K_PROJ)] = ( 52 | f"{prefix}.{i}.self_attn.k_proj", 53 | layer.self_attn.query_key_value, 54 | ) 55 | layer_weights[(i, V_PROJ)] = ( 56 | f"{prefix}.{i}.self_attn.v_proj", 57 | layer.self_attn.query_key_value, 58 | ) 59 | layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) 60 | 61 | layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) 62 | layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) 63 | layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) 64 | 65 | return layer_weights 66 | 67 | @property 68 | def adapter_layers(self) -> List[str]: 69 | return ADAPTER_LAYERS 70 | 71 | @property 72 | def default_traced_adapter_layers(self) -> List[str]: 73 | return [Q_PROJ, V_PROJ] 74 | 75 | def get_num_layers_for_type(self, layer_type: str) -> int: 76 | return len(self.model.model.layers) 77 | 78 | def is_row_parallel(self, layer_type: str) -> bool: 79 | return layer_type in ROW_PARALLEL 80 | -------------------------------------------------------------------------------- /server/lorax_server/models/flash_gpt2.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed 5 | from opentelemetry import trace 6 | 7 | from lorax_server.models import FlashCausalLM 8 | from lorax_server.models.custom_modeling.flash_gpt2_modeling import ( 9 | ATTN_C_ATTN, 10 | ATTN_C_PROJ, 11 | LM_HEAD, 12 | MLP_C_FC, 13 | MLP_C_PROJ, 14 | FlashGPT2ForCausalLM, 15 | GPT2Config, 16 | ) 17 | 18 | tracer = trace.get_tracer(__name__) 19 | 20 | ADAPTER_LAYERS = [ATTN_C_ATTN, ATTN_C_PROJ, MLP_C_FC, MLP_C_PROJ] 21 | ROW_PARALLEL = {ATTN_C_PROJ, MLP_C_PROJ} 22 | 23 | 24 | class FlashGPT2(FlashCausalLM): 25 | def __init__( 26 | self, 27 | model_id: str, 28 | adapter_id: str, 29 | adapter_source: str, 30 | revision: Optional[str] = None, 31 | dtype: Optional[torch.dtype] = None, 32 | **kwargs, 33 | ): 34 | super().__init__( 35 | model_id=model_id, 36 | model_cls=FlashGPT2ForCausalLM, 37 | dtype=dtype, 38 | revision=revision, 39 | adapter_id=adapter_id, 40 | adapter_source=adapter_source, 41 | config_cls=GPT2Config, 42 | **kwargs, 43 | ) 44 | 45 | @property 46 | def supports_adapter_loading(self) -> bool: 47 | return True 48 | 49 | def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: 50 | layer_weights = {} 51 | 52 | prefix = "transformer.h" 53 | for i, layer in enumerate(self.model.transformer.h): 54 | layer_weights[(i, ATTN_C_ATTN)] = (f"{prefix}.{i}.{ATTN_C_ATTN}", layer.attn.c_attn) 55 | layer_weights[(i, ATTN_C_PROJ)] = (f"{prefix}.{i}.{ATTN_C_PROJ}", layer.attn.c_proj) 56 | 57 | layer_weights[(i, MLP_C_FC)] = (f"{prefix}.{i}.{MLP_C_FC}", layer.mlp.c_fc) 58 | layer_weights[(i, MLP_C_PROJ)] = (f"{prefix}.{i}.{MLP_C_PROJ}", layer.mlp.c_proj) 59 | 60 | # TODO: make Embedding layers adapter-compatible 61 | # layer_weights[(0, LM_HEAD)] = ("transformer.wte", self.model.transformer.wte) 62 | return layer_weights 63 | 64 | @property 65 | def adapter_layers(self) -> List[str]: 66 | return ADAPTER_LAYERS 67 | 68 | @property 69 | def default_traced_adapter_layers(self) -> List[str]: 70 | return [ATTN_C_ATTN] 71 | 72 | def get_num_layers_for_type(self, layer_type: str) -> int: 73 | return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) 74 | 75 | def is_row_parallel(self, layer_type: str) -> bool: 76 | return layer_type in ROW_PARALLEL 77 | -------------------------------------------------------------------------------- /server/lorax_server/models/flash_neox.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.distributed 5 | from opentelemetry import trace 6 | 7 | from lorax_server.models import FlashCausalLM 8 | from lorax_server.models.custom_modeling.flash_neox_modeling import ( 9 | FlashGPTNeoXForCausalLM, 10 | ) 11 | 12 | tracer = trace.get_tracer(__name__) 13 | 14 | 15 | class FlashNeoXSharded(FlashCausalLM): 16 | def __init__( 17 | self, 18 | model_id: str, 19 | revision: Optional[str] = None, 20 | dtype: Optional[torch.dtype] = None, 21 | **kwargs, 22 | ): 23 | super().__init__( 24 | model_id=model_id, 25 | model_cls=FlashGPTNeoXForCausalLM, 26 | dtype=dtype, 27 | revision=revision, 28 | **kwargs, 29 | ) 30 | -------------------------------------------------------------------------------- /server/lorax_server/models/flash_phi3.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed 5 | from opentelemetry import trace 6 | 7 | from lorax_server.models import FlashCausalLM 8 | from lorax_server.models.custom_modeling.flash_phi3_modeling import ( 9 | GATE_UP_PROJ, 10 | QKV_PROJ, 11 | FlashPhi3ForCausalLM, 12 | Phi3Config, 13 | ) 14 | from lorax_server.utils.lora import ( 15 | DOWN_PROJ, 16 | LM_HEAD, 17 | O_PROJ, 18 | ) 19 | 20 | tracer = trace.get_tracer(__name__) 21 | 22 | 23 | # TODO(travis): re-enable LM_HEAD after resolving issues with outputs 24 | ADAPTER_LAYERS = [QKV_PROJ, O_PROJ, GATE_UP_PROJ, DOWN_PROJ] # LM_HEAD 25 | ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} 26 | 27 | 28 | class FlashPhi3(FlashCausalLM): 29 | def __init__( 30 | self, 31 | model_id: str, 32 | adapter_id: str, 33 | adapter_source: str, 34 | revision: Optional[str] = None, 35 | dtype: Optional[torch.dtype] = None, 36 | **kwargs, 37 | ): 38 | super().__init__( 39 | model_id=model_id, 40 | model_cls=FlashPhi3ForCausalLM, 41 | dtype=dtype, 42 | revision=revision, 43 | adapter_id=adapter_id, 44 | adapter_source=adapter_source, 45 | config_cls=Phi3Config, 46 | **kwargs, 47 | ) 48 | 49 | @property 50 | def supports_adapter_loading(self) -> bool: 51 | return True 52 | 53 | def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: 54 | layer_weights = {} 55 | 56 | prefix = "model.layers" 57 | for i, layer in enumerate(self.model.model.layers): 58 | layer_weights[(i, QKV_PROJ)] = ( 59 | f"{prefix}.{i}.self_attn.qkv_proj", 60 | layer.self_attn.query_key_value, 61 | ) 62 | layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) 63 | 64 | layer_weights[(i, GATE_UP_PROJ)] = (f"{prefix}.{i}.mlp.gate_up_proj", layer.mlp.gate_up_proj) 65 | layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) 66 | 67 | layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) 68 | return layer_weights 69 | 70 | @property 71 | def adapter_layers(self) -> List[str]: 72 | return ADAPTER_LAYERS 73 | 74 | @property 75 | def default_traced_adapter_layers(self) -> List[str]: 76 | return [QKV_PROJ] 77 | 78 | def get_num_layers_for_type(self, layer_type: str) -> int: 79 | return 1 if layer_type == LM_HEAD else len(self.model.model.layers) 80 | 81 | def is_row_parallel(self, layer_type: str) -> bool: 82 | return layer_type in ROW_PARALLEL 83 | -------------------------------------------------------------------------------- /server/lorax_server/models/flash_rw.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.distributed 5 | from opentelemetry import trace 6 | 7 | from lorax_server.models import FlashCausalLM 8 | from lorax_server.models.custom_modeling.flash_rw_modeling import ( 9 | FlashRWForCausalLM, 10 | RWConfig, 11 | ) 12 | 13 | tracer = trace.get_tracer(__name__) 14 | 15 | 16 | class FlashRWSharded(FlashCausalLM): 17 | def __init__( 18 | self, 19 | model_id: str, 20 | revision: Optional[str] = None, 21 | dtype: Optional[torch.dtype] = None, 22 | **kwargs, 23 | ): 24 | super().__init__( 25 | model_id=model_id, 26 | model_cls=FlashRWForCausalLM, 27 | dtype=dtype, 28 | revision=revision, 29 | config_cls=RWConfig, 30 | **kwargs, 31 | ) 32 | -------------------------------------------------------------------------------- /server/lorax_server/models/flash_santacoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.distributed 5 | from opentelemetry import trace 6 | from transformers import AutoConfig 7 | 8 | from lorax_server.models import FlashCausalLM 9 | from lorax_server.models.custom_modeling.flash_santacoder_modeling import ( 10 | FlashSantacoderForCausalLM, 11 | ) 12 | 13 | tracer = trace.get_tracer(__name__) 14 | 15 | 16 | class FlashSantacoderSharded(FlashCausalLM): 17 | def __init__( 18 | self, 19 | model_id: str, 20 | revision: Optional[str] = None, 21 | dtype: Optional[torch.dtype] = None, 22 | **kwargs, 23 | ): 24 | super().__init__( 25 | model_id=model_id, 26 | model_cls=FlashSantacoderForCausalLM, 27 | dtype=dtype, 28 | revision=revision, 29 | config_cls=AutoConfig, 30 | **kwargs, 31 | ) 32 | 33 | def decode(self, generated_ids: List[int]) -> str: 34 | # Do not skip special tokens as they are used for custom parsing rules of the generated text 35 | return self.tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) 36 | -------------------------------------------------------------------------------- /server/lorax_server/models/gpt_neox.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.distributed 5 | from loguru import logger 6 | from transformers import ( 7 | AutoConfig, 8 | AutoTokenizer, 9 | ) 10 | 11 | from lorax_server.models.causal_lm import CausalLM 12 | from lorax_server.models.custom_modeling.neox_modeling import ( 13 | GPTNeoxForCausalLM, 14 | ) 15 | from lorax_server.utils import ( 16 | Weights, 17 | initialize_torch_distributed, 18 | weight_files, 19 | ) 20 | 21 | 22 | class GPTNeoxSharded(CausalLM): 23 | def __init__( 24 | self, 25 | model_id: str, 26 | revision: Optional[str] = None, 27 | quantize: Optional[str] = None, 28 | compile: bool = False, 29 | dtype: Optional[torch.dtype] = None, 30 | trust_remote_code: bool = False, 31 | ): 32 | if compile: 33 | logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.") 34 | 35 | self.process_group, rank, world_size = initialize_torch_distributed() 36 | if torch.cuda.is_available(): 37 | device = torch.device(f"cuda:{rank}") 38 | dtype = torch.float16 if dtype is None else dtype 39 | else: 40 | device = torch.device("cpu") 41 | dtype = torch.float32 42 | 43 | tokenizer = AutoTokenizer.from_pretrained( 44 | model_id, 45 | revision=revision, 46 | padding_side="left", 47 | truncation_side="left", 48 | trust_remote_code=trust_remote_code, 49 | ) 50 | tokenizer.pad_token = tokenizer.eos_token 51 | 52 | config = AutoConfig.from_pretrained( 53 | model_id, 54 | revision=revision, 55 | trust_remote_code=trust_remote_code, 56 | ) 57 | config.quantize = quantize 58 | 59 | torch.distributed.barrier(group=self.process_group) 60 | filenames = weight_files(model_id, revision=revision, extension=".safetensors") 61 | weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group) 62 | weights._set_config(model_id, config) 63 | 64 | model = GPTNeoxForCausalLM(config, weights) 65 | 66 | torch.distributed.barrier(group=self.process_group) 67 | super(CausalLM, self).__init__( 68 | model_id=model_id, 69 | model=model, 70 | tokenizer=tokenizer, 71 | requires_padding=True, 72 | dtype=dtype, 73 | device=device, 74 | rank=rank, 75 | world_size=world_size, 76 | trust_remote_code=trust_remote_code, 77 | ) 78 | 79 | def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): 80 | outputs = self.model.forward( 81 | input_ids=input_ids, 82 | attention_mask=attention_mask, 83 | position_ids=position_ids, 84 | past_key_values=past_key_values, 85 | use_cache=True, 86 | ) 87 | 88 | logits = outputs.logits 89 | return logits, outputs.past_key_values 90 | -------------------------------------------------------------------------------- /server/lorax_server/models/opt.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.distributed 5 | from loguru import logger 6 | from transformers import ( 7 | AutoConfig, 8 | AutoTokenizer, 9 | ) 10 | 11 | from lorax_server.models.causal_lm import CausalLM 12 | from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM 13 | from lorax_server.utils import ( 14 | Weights, 15 | initialize_torch_distributed, 16 | weight_files, 17 | ) 18 | 19 | 20 | class OPTSharded(CausalLM): 21 | def __init__( 22 | self, 23 | model_id: str, 24 | revision: Optional[str] = None, 25 | quantize: Optional[str] = None, 26 | compile: bool = False, 27 | dtype: Optional[torch.dtype] = None, 28 | trust_remote_code: bool = False, 29 | ): 30 | if compile: 31 | logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.") 32 | 33 | self.process_group, rank, world_size = initialize_torch_distributed() 34 | if torch.cuda.is_available(): 35 | device = torch.device(f"cuda:{rank}") 36 | dtype = torch.float16 if dtype is None else dtype 37 | else: 38 | device = torch.device("cpu") 39 | dtype = torch.float32 40 | 41 | tokenizer = AutoTokenizer.from_pretrained( 42 | model_id, 43 | revision=revision, 44 | padding_side="left", 45 | truncation_side="left", 46 | trust_remote_code=trust_remote_code, 47 | ) 48 | 49 | config = AutoConfig.from_pretrained( 50 | model_id, 51 | revision=revision, 52 | trust_remote_code=trust_remote_code, 53 | ) 54 | config.quantize = quantize 55 | tokenizer.pad_token_id = config.pad_token_id 56 | 57 | torch.distributed.barrier(group=self.process_group) 58 | filenames = weight_files(model_id, revision=revision, extension=".safetensors") 59 | weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group) 60 | weights._set_config(model_id, config) 61 | 62 | model = OPTForCausalLM(config, weights) 63 | 64 | torch.distributed.barrier(group=self.process_group) 65 | super(CausalLM, self).__init__( 66 | model_id=model_id, 67 | model=model, 68 | tokenizer=tokenizer, 69 | requires_padding=True, 70 | dtype=dtype, 71 | device=device, 72 | rank=rank, 73 | world_size=world_size, 74 | trust_remote_code=trust_remote_code, 75 | ) 76 | 77 | def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): 78 | outputs = self.model.forward( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask, 81 | past_key_values=past_key_values, 82 | use_cache=True, 83 | ) 84 | 85 | return outputs.logits, outputs.past_key_values 86 | -------------------------------------------------------------------------------- /server/lorax_server/models/santacoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.distributed 5 | from loguru import logger 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from lorax_server.models.causal_lm import CausalLM 9 | 10 | FIM_PREFIX = "" 11 | FIM_MIDDLE = "" 12 | FIM_SUFFIX = "" 13 | FIM_PAD = "" 14 | EOD = "<|endoftext|>" 15 | 16 | 17 | class SantaCoder(CausalLM): 18 | def __init__( 19 | self, 20 | model_id: str, 21 | revision: Optional[str] = None, 22 | quantize: Optional[str] = None, 23 | compile: bool = False, 24 | dtype: Optional[torch.dtype] = None, 25 | trust_remote_code: bool = False, 26 | ): 27 | if compile: 28 | logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.") 29 | 30 | if torch.cuda.is_available(): 31 | device = torch.device("cuda") 32 | dtype = torch.float16 if dtype is None else dtype 33 | else: 34 | if quantize: 35 | raise ValueError("quantization is not available on CPU") 36 | 37 | device = torch.device("cpu") 38 | dtype = torch.float32 39 | 40 | tokenizer = AutoTokenizer.from_pretrained( 41 | model_id, 42 | revision=revision, 43 | padding_side="left", 44 | truncation_side="left", 45 | trust_remote_code=trust_remote_code, 46 | ) 47 | tokenizer.add_special_tokens( 48 | { 49 | "additional_special_tokens": [ 50 | EOD, 51 | FIM_PREFIX, 52 | FIM_MIDDLE, 53 | FIM_SUFFIX, 54 | FIM_PAD, 55 | ], 56 | "pad_token": EOD, 57 | } 58 | ) 59 | with device: 60 | model = AutoModelForCausalLM.from_pretrained( 61 | model_id, 62 | revision=revision, 63 | torch_dtype=dtype, 64 | load_in_8bit=quantize == "bitsandbytes", 65 | trust_remote_code=trust_remote_code, 66 | ) 67 | 68 | super(CausalLM, self).__init__( 69 | model_id=model_id, 70 | model=model, 71 | tokenizer=tokenizer, 72 | requires_padding=True, 73 | dtype=dtype, 74 | device=device, 75 | trust_remote_code=trust_remote_code, 76 | ) 77 | 78 | def decode(self, generated_ids: List[int]) -> str: 79 | # Do not skip special tokens as they are used for custom parsing rules of the generated text 80 | return self.tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) 81 | -------------------------------------------------------------------------------- /server/lorax_server/pb/.gitignore: -------------------------------------------------------------------------------- 1 | *.py 2 | *.pyi 3 | *.py-e -------------------------------------------------------------------------------- /server/lorax_server/tracing.py: -------------------------------------------------------------------------------- 1 | import grpc 2 | from opentelemetry import trace 3 | from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter 4 | from opentelemetry.instrumentation.grpc._aio_server import ( 5 | OpenTelemetryAioServerInterceptor, 6 | ) 7 | from opentelemetry.sdk.resources import Resource 8 | from opentelemetry.sdk.trace import TracerProvider 9 | from opentelemetry.sdk.trace.export import ( 10 | BatchSpanProcessor, 11 | ) 12 | from opentelemetry.semconv.trace import SpanAttributes 13 | 14 | 15 | class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor): 16 | def __init__(self): 17 | super().__init__(trace.get_tracer(__name__)) 18 | 19 | def _start_span(self, handler_call_details, context, set_status_on_exception=False): 20 | """ 21 | Rewrite _start_span method to support Unix Domain Socket gRPC contexts 22 | """ 23 | 24 | # standard attributes 25 | attributes = { 26 | SpanAttributes.RPC_SYSTEM: "grpc", 27 | SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0], 28 | } 29 | 30 | # if we have details about the call, split into service and method 31 | if handler_call_details.method: 32 | service, method = handler_call_details.method.lstrip("/").split("/", 1) 33 | attributes.update( 34 | { 35 | SpanAttributes.RPC_METHOD: method, 36 | SpanAttributes.RPC_SERVICE: service, 37 | } 38 | ) 39 | 40 | # add some attributes from the metadata 41 | metadata = dict(context.invocation_metadata()) 42 | if "user-agent" in metadata: 43 | attributes["rpc.user_agent"] = metadata["user-agent"] 44 | 45 | # We use gRPC over a UNIX socket 46 | attributes.update({SpanAttributes.NET_TRANSPORT: "unix"}) 47 | 48 | return self._tracer.start_as_current_span( 49 | name=handler_call_details.method, 50 | kind=trace.SpanKind.SERVER, 51 | attributes=attributes, 52 | set_status_on_exception=set_status_on_exception, 53 | ) 54 | 55 | 56 | def setup_tracing(shard: int, otlp_endpoint: str): 57 | resource = Resource.create(attributes={"service.name": f"lorax.server-{shard}"}) 58 | span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) 59 | span_processor = BatchSpanProcessor(span_exporter) 60 | 61 | trace.set_tracer_provider(TracerProvider(resource=resource)) 62 | trace.get_tracer_provider().add_span_processor(span_processor) 63 | -------------------------------------------------------------------------------- /server/lorax_server/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from lorax_server.utils.adapter import ( 2 | load_module_map, 3 | ) 4 | from lorax_server.utils.convert import convert_file, convert_files 5 | from lorax_server.utils.dist import initialize_torch_distributed 6 | from lorax_server.utils.sources import ( 7 | HUB, 8 | LOCAL, 9 | PBASE, 10 | S3, 11 | download_weights, 12 | get_config_path, 13 | get_local_dir, 14 | get_model_source, 15 | map_pbase_model_id_to_s3, 16 | weight_files, 17 | weight_hub_files, 18 | ) 19 | from lorax_server.utils.tokens import ( 20 | FinishReason, 21 | Greedy, 22 | HeterogeneousNextTokenChooser, 23 | NextTokenChooser, 24 | Sampling, 25 | StoppingCriteria, 26 | StopSequenceCriteria, 27 | ) 28 | from lorax_server.utils.weights import Weights, get_start_stop_idxs_for_rank 29 | 30 | __all__ = [ 31 | "load_module_map", 32 | "convert_file", 33 | "convert_files", 34 | "get_model_source", 35 | "get_config_path", 36 | "get_local_dir", 37 | "get_start_stop_idxs_for_rank", 38 | "initialize_torch_distributed", 39 | "map_pbase_model_id_to_s3", 40 | "download_weights", 41 | "weight_files", 42 | "weight_hub_files", 43 | "HeterogeneousNextTokenChooser", 44 | "HUB", 45 | "LOCAL", 46 | "PBASE", 47 | "S3", 48 | "Greedy", 49 | "NextTokenChooser", 50 | "Sampling", 51 | "StoppingCriteria", 52 | "StopSequenceCriteria", 53 | "FinishReason", 54 | "Weights", 55 | ] 56 | -------------------------------------------------------------------------------- /server/lorax_server/utils/attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/server/lorax_server/utils/attention/__init__.py -------------------------------------------------------------------------------- /server/lorax_server/utils/attention/common.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class Seqlen: 9 | input_lengths: torch.Tensor 10 | cache_lengths: torch.Tensor 11 | cu_seqlen_q: Optional[torch.Tensor] 12 | cu_seqlen_k: Optional[torch.Tensor] 13 | max_q: int 14 | max_k: int 15 | 16 | def __init__( 17 | self, 18 | input_lengths, 19 | cache_lengths, 20 | cu_seqlen_q=None, 21 | max_q=None, 22 | max_k=None, 23 | ): 24 | self.input_lengths = input_lengths 25 | self.cache_lengths = cache_lengths 26 | device = self.input_lengths.device 27 | shape = self.input_lengths.shape 28 | if cu_seqlen_q is None: 29 | cu_seqlen_q = torch.arange( 30 | shape[0] + 1, 31 | device=device, 32 | dtype=torch.int32, 33 | ) 34 | max_q = 1 35 | else: 36 | assert max_q is not None 37 | assert max_k is not None 38 | cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) 39 | 40 | # cuda graphs don't like this and this is necessary to clamp within mistral 41 | # Although FA2 might not want the clamping 42 | # cu_seqlen_k[0] = 0 43 | total = self.input_lengths + self.cache_lengths 44 | torch.cumsum(total, -1, out=cu_seqlen_k[1:]) 45 | 46 | self.cu_seqlen_q = cu_seqlen_q 47 | self.cu_seqlen_k = cu_seqlen_k 48 | self.max_q = max_q 49 | self.max_k = max_k 50 | 51 | def clamp(self, max): 52 | self.input_lengths.data.clamp_(max=max) 53 | return self 54 | -------------------------------------------------------------------------------- /server/lorax_server/utils/awq/awq.py: -------------------------------------------------------------------------------- 1 | # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py 2 | 3 | import awq_inference_engine # with CUDA kernels 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class AWQLinear(nn.Module): 9 | def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): 10 | super().__init__() 11 | 12 | if w_bit != 4: 13 | raise NotImplementedError("Only 4-bit are supported for now.") 14 | 15 | self.in_features = qweight.shape[0] 16 | self.out_features = qweight.shape[1] * 32 // w_bit 17 | 18 | self.split_k_iters = 8 19 | 20 | self.w_bit = w_bit 21 | self.group_size = group_size if group_size != -1 else self.in_features 22 | 23 | assert self.in_features % self.group_size == 0, "in_features must be divisible by group_size" 24 | assert self.out_features % (32 // self.w_bit) == 0, "out_features must be divisible by 32 // w_bit" 25 | 26 | self.qweight = qweight 27 | self.qzeros = qzeros 28 | self.scales = scales 29 | self.bias = bias 30 | 31 | @torch.no_grad() 32 | def forward(self, x): 33 | out_shape = x.shape[:-1] + (self.out_features,) 34 | 35 | input_dtype = x.dtype 36 | if input_dtype != torch.float16: 37 | x = x.half() 38 | 39 | out = awq_inference_engine.gemm_forward_cuda( 40 | x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 41 | ) 42 | 43 | if input_dtype != torch.float16: 44 | out = out.to(dtype=input_dtype) 45 | 46 | out = out + self.bias if self.bias is not None else out 47 | return out.reshape(out_shape) 48 | 49 | @property 50 | def weight(self) -> torch.Tensor: 51 | return self.qweight 52 | -------------------------------------------------------------------------------- /server/lorax_server/utils/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import timedelta 3 | 4 | import torch 5 | from loguru import logger 6 | 7 | # Tensor Parallelism settings 8 | RANK = int(os.getenv("RANK", "0")) 9 | WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) 10 | 11 | # CUDA memory fraction 12 | MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) 13 | 14 | MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.9")) 15 | 16 | 17 | class FakeBarrier: 18 | def wait(self): 19 | pass 20 | 21 | 22 | class FakeGroup: 23 | def __init__(self, rank, size): 24 | self._rank = rank 25 | self._size = size 26 | 27 | def allreduce(self, *args, **kwargs): 28 | return FakeBarrier() 29 | 30 | def allgather(self, inputs, local_tensor, **kwargs): 31 | assert ( 32 | len(inputs[0]) == len(local_tensor) == 1 33 | ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" 34 | for input_ in inputs: 35 | input_[0].data = local_tensor[0].data 36 | return FakeBarrier() 37 | 38 | def barrier(self, *args, **kwargs): 39 | return FakeBarrier() 40 | 41 | def size(self): 42 | return self._size 43 | 44 | def rank(self): 45 | return self._rank 46 | 47 | 48 | def initialize_torch_distributed(): 49 | if torch.cuda.is_available(): 50 | from torch.distributed import ProcessGroupNCCL 51 | 52 | # Set the device id. 53 | assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" 54 | device = RANK % torch.cuda.device_count() 55 | torch.cuda.set_device(device) 56 | torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) 57 | backend = "nccl" 58 | options = ProcessGroupNCCL.Options() 59 | options.is_high_priority_stream = True 60 | options._timeout = timedelta(seconds=60) 61 | else: 62 | backend = "gloo" 63 | options = None 64 | 65 | if WORLD_SIZE == 1: 66 | return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE 67 | else: 68 | if os.getenv("DEBUG", None) == "1": 69 | return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE 70 | 71 | if not torch.distributed.is_initialized(): 72 | # Call the init process. 73 | torch.distributed.init_process_group( 74 | backend=backend, 75 | world_size=WORLD_SIZE, 76 | rank=RANK, 77 | timeout=timedelta(seconds=60), 78 | pg_options=options, 79 | ) 80 | else: 81 | logger.warning("torch.distributed is already initialized.") 82 | 83 | return torch.distributed.group.WORLD, RANK, WORLD_SIZE 84 | -------------------------------------------------------------------------------- /server/lorax_server/utils/errors.py: -------------------------------------------------------------------------------- 1 | class NanWeightsError(RuntimeError): 2 | pass 3 | 4 | 5 | class InfWeightsError(RuntimeError): 6 | pass 7 | -------------------------------------------------------------------------------- /server/lorax_server/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def is_xpu_available(): 5 | try: 6 | import intel_extension_for_pytorch # noqa 7 | except ImportError: 8 | return False 9 | 10 | return hasattr(torch, "xpu") and torch.xpu.is_available() 11 | 12 | 13 | def get_cuda_free_memory(device, memory_fraction): 14 | total_free_memory, _ = torch.cuda.mem_get_info(device) 15 | total_gpu_memory = torch.cuda.get_device_properties(device).total_memory 16 | free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) 17 | return free_memory 18 | 19 | 20 | def get_xpu_free_memory(device): 21 | total_gpu_memory = torch.xpu.get_device_properties(device).total_memory 22 | free_memory = int(total_gpu_memory * 0.5) 23 | return free_memory 24 | 25 | 26 | SYSTEM = None 27 | if torch.version.hip is not None: 28 | SYSTEM = "rocm" 29 | empty_cache = torch.cuda.empty_cache 30 | synchronize = torch.cuda.synchronize 31 | get_free_memory = get_cuda_free_memory 32 | elif torch.version.cuda is not None and torch.cuda.is_available(): 33 | SYSTEM = "cuda" 34 | empty_cache = torch.cuda.empty_cache 35 | synchronize = torch.cuda.synchronize 36 | get_free_memory = get_cuda_free_memory 37 | elif is_xpu_available(): 38 | SYSTEM = "xpu" 39 | empty_cache = torch.xpu.empty_cache 40 | synchronize = torch.xpu.synchronize 41 | get_free_memory = get_xpu_free_memory 42 | else: 43 | SYSTEM = "cpu" 44 | 45 | def noop(*args, **kwargs): 46 | pass 47 | 48 | empty_cache = noop 49 | synchronize = noop 50 | get_free_memory = noop 51 | -------------------------------------------------------------------------------- /server/lorax_server/utils/lora.py: -------------------------------------------------------------------------------- 1 | # Constants 2 | Q_PROJ = "q_proj" 3 | K_PROJ = "k_proj" 4 | V_PROJ = "v_proj" 5 | O_PROJ = "o_proj" 6 | 7 | GATE_PROJ = "gate_proj" 8 | UP_PROJ = "up_proj" 9 | DOWN_PROJ = "down_proj" 10 | 11 | FC1 = "fc1" 12 | FC2 = "fc2" 13 | 14 | LM_HEAD = "lm_head" 15 | -------------------------------------------------------------------------------- /server/lorax_server/utils/merges/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/predibase/lorax/b5a9e38dc9479ca664bbaac9ae949e27e3c30832/server/lorax_server/utils/merges/__init__.py -------------------------------------------------------------------------------- /server/lorax_server/utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/vllm-project/vllm/tree/main/vllm/lora/ops 2 | -------------------------------------------------------------------------------- /server/lorax_server/utils/ops/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Dict 3 | 4 | 5 | @functools.lru_cache 6 | def _get_op_configs(op_type: str, batch: int, hidden_size: int): 7 | # TODO: add optimal configurations 8 | return None 9 | 10 | 11 | def _check_divisibility(hidden_size: int): 12 | # The bgmv_expand kernel requires that the hidden_size be divisible by 13 | # the number below. 14 | divisibility = [2, 4, 8, 16, 32, 64] 15 | divisibility.sort(reverse=True) 16 | for div in divisibility: 17 | if hidden_size % div == 0: 18 | return div 19 | # hidden_size is an odd number 20 | return 1 21 | 22 | 23 | def _get_default_config(op_type: str, batch: int, hidden_size: int): 24 | if op_type == "expand": 25 | return {"BLOCK_N": 256, "SPLIT_N": _check_divisibility(hidden_size), "num_warps": 8} 26 | else: 27 | return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} 28 | 29 | 30 | def get_lora_op_configs(op_type: str, batch: int, hidden_size: int) -> Dict[str, int]: 31 | """Inspired by `fused_moe_kernel` 32 | The return value will be a dictionary mapping an irregular grid of batch 33 | sizes and hidden_size to configurations of the bgmv-related kernel. 34 | NOTE: It currently only supports the default configuration. We plan to 35 | generate optimal configurations for different hardware in the future using 36 | scripts similar to `benchmark_moe.py`. 37 | """ 38 | config = _get_op_configs(op_type, batch, hidden_size) 39 | if not config: 40 | config = _get_default_config(op_type, batch, hidden_size) 41 | return config 42 | -------------------------------------------------------------------------------- /server/lorax_server/utils/segments.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import torch 4 | 5 | 6 | def find_segments(adapter_indices: Union[torch.Tensor, List[int]]) -> Tuple[List[int], List[int]]: 7 | segments = [0] 8 | segment_indices = [] 9 | 10 | if isinstance(adapter_indices, torch.Tensor): 11 | # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first 12 | adapter_indices = adapter_indices.cpu().tolist() 13 | 14 | start_index = 0 15 | for i in range(1, len(adapter_indices)): 16 | if adapter_indices[i] != adapter_indices[i - 1]: 17 | segments.append(i) 18 | segment_indices.append(adapter_indices[i - 1]) 19 | start_index = i 20 | 21 | # Handle the last segment 22 | if start_index < len(adapter_indices): 23 | segments.append(len(adapter_indices)) 24 | segment_indices.append(adapter_indices[-1]) 25 | 26 | return segments, segment_indices 27 | 28 | 29 | class SegmentConcatBuilder: 30 | def __init__(self): 31 | self.adapter_segment_indices = [] 32 | self.adapter_segment_tensors = [] 33 | 34 | def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): 35 | # Update adapter segments 36 | if self.adapter_segment_tensors: 37 | # Because we have already processed at least one batch, remove the 0 start index 38 | # from this batch denoting the beginning of the segment, then offset all segment 39 | # positions by the value of the last segment in the previous batch to account for 40 | # the concatenation. 41 | adapter_segments = adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] 42 | 43 | if self.adapter_segment_indices and self.adapter_segment_indices[-1] == segment_indices[0]: 44 | # If the last segment in the previous batch is the same as the first segment in this batch, 45 | # then we merge them together into a single segment. In effect, this means removing it from 46 | # the segment indices of this batch, and extending the segment span by removing the segment 47 | # end index from the previous batch. 48 | segment_indices = segment_indices[1:] 49 | self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1] 50 | 51 | self.adapter_segment_indices.extend(segment_indices) 52 | self.adapter_segment_tensors.append(adapter_segments) 53 | 54 | def build(self) -> Tuple[torch.Tensor, List[int]]: 55 | return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices 56 | -------------------------------------------------------------------------------- /server/lorax_server/utils/sources/local.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import List, Optional 4 | 5 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE 6 | 7 | from .source import BaseModelSource 8 | 9 | 10 | def get_model_local_dir(model_id: str) -> Path: 11 | if os.path.isabs(model_id): 12 | return Path(model_id) 13 | 14 | repo_cache = Path(HUGGINGFACE_HUB_CACHE) / model_id 15 | return repo_cache 16 | 17 | 18 | class LocalModelSource(BaseModelSource): 19 | def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = ".safetensors"): 20 | if len(model_id) < 5: 21 | raise ValueError(f"model_id '{model_id}' is too short for prefix filtering") 22 | 23 | # TODO: add support for revisions of the same model 24 | self.model_id = model_id 25 | self.revision = revision 26 | self.extension = extension 27 | 28 | @property 29 | def api_token(self) -> Optional[str]: 30 | return None 31 | 32 | def remote_weight_files(self, extension: str = None): 33 | return [] 34 | 35 | def weight_files(self, extension: str = None): 36 | model_id = self.model_id 37 | extension = extension or self.extension 38 | 39 | local_path = get_model_local_dir(model_id) 40 | if local_path.exists() and local_path.is_dir(): 41 | local_files = list(local_path.glob(f"*{extension}")) 42 | if not local_files: 43 | raise FileNotFoundError(f"No local weights found in {model_id} with extension {extension}") 44 | return local_files 45 | 46 | raise FileNotFoundError(f"No local weights found in {model_id} with extension {extension}") 47 | 48 | def download_weights(self, filenames: List[str]): 49 | return [] 50 | 51 | def download_model_assets(self): 52 | return [] 53 | 54 | def get_local_path(self, model_id: str) -> Path: 55 | return get_model_local_dir(model_id) 56 | 57 | def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: 58 | path = get_model_local_dir(self.model_id) / filename 59 | if not path.exists(): 60 | if ignore_errors: 61 | return None 62 | raise FileNotFoundError(f"File {filename} of model {self.model_id} not found in {path}") 63 | return path 64 | -------------------------------------------------------------------------------- /server/lorax_server/utils/state.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | from typing import Optional 4 | 5 | from loguru import logger 6 | 7 | WARMUP = False 8 | SPECULATIVE_TOKENS = 0 9 | NGRAM = False 10 | 11 | 12 | LORAX_PROFILER_DIR = os.environ.get("LORAX_PROFILER_DIR", None) 13 | PREFIX_CACHING = bool(int(os.environ.get("PREFIX_CACHING", "0"))) 14 | CHUNKED_PREFILL = bool(int(os.environ.get("CHUNKED_PREFILL", "0"))) 15 | LORAX_SPECULATION_MAX_BATCH_SIZE = int(os.environ.get("LORAX_SPECULATION_MAX_BATCH_SIZE", 32)) 16 | 17 | # Always use flashinfer when prefix caching is enabled 18 | FLASH_INFER = bool(int(os.environ.get("FLASH_INFER", "0"))) or PREFIX_CACHING 19 | if FLASH_INFER: 20 | logger.info("Backend = flashinfer") 21 | else: 22 | logger.info("Backend = fa2") 23 | 24 | logger.info(f"Prefix caching = {PREFIX_CACHING}") 25 | logger.info(f"Chunked prefill = {CHUNKED_PREFILL}") 26 | 27 | if LORAX_PROFILER_DIR: 28 | logger.info(f"Torch profiling enabled, output dir = {LORAX_PROFILER_DIR}") 29 | 30 | SUPPORTS_CHUNKING: Optional[bool] = None 31 | MAX_PREFILL_TOKENS: Optional[int] = None 32 | 33 | 34 | BLOCK_SIZE: int 35 | if FLASH_INFER: 36 | BLOCK_SIZE = 1 37 | else: 38 | BLOCK_SIZE = 16 39 | 40 | 41 | def set_warmup(value: bool): 42 | global WARMUP 43 | WARMUP = value 44 | 45 | 46 | def is_warmup() -> bool: 47 | return WARMUP 48 | 49 | 50 | @contextmanager 51 | def warmup_mode(): 52 | try: 53 | set_warmup(True) 54 | yield 55 | finally: 56 | set_warmup(False) 57 | 58 | 59 | def set_speculative_tokens(value: int, use_ngram: bool): 60 | global SPECULATIVE_TOKENS 61 | global NGRAM 62 | SPECULATIVE_TOKENS = value 63 | NGRAM = use_ngram 64 | 65 | 66 | def get_speculative_tokens() -> int: 67 | return SPECULATIVE_TOKENS 68 | 69 | 70 | def use_ngram() -> bool: 71 | return NGRAM 72 | 73 | 74 | def set_supports_chunking(supports_chunking: bool): 75 | global SUPPORTS_CHUNKING 76 | SUPPORTS_CHUNKING = supports_chunking 77 | 78 | 79 | def get_supports_chunking() -> bool: 80 | global SUPPORTS_CHUNKING 81 | return SUPPORTS_CHUNKING 82 | 83 | 84 | def set_max_prefill_tokens(max_prefill_tokens: int): 85 | global MAX_PREFILL_TOKENS 86 | MAX_PREFILL_TOKENS = max_prefill_tokens 87 | 88 | 89 | def get_max_prefill_tokens() -> int: 90 | global MAX_PREFILL_TOKENS 91 | return MAX_PREFILL_TOKENS 92 | -------------------------------------------------------------------------------- /server/lorax_server/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from transformers import PreTrainedTokenizerBase 4 | 5 | from lorax_server.pb import generate_pb2 6 | 7 | 8 | class TokenizerManager: 9 | def __init__(self): 10 | self.tokenizers = {} 11 | 12 | def add_tokenizer(self, adapter_idx: int, tokenizer: PreTrainedTokenizerBase): 13 | self.tokenizers[adapter_idx] = tokenizer 14 | 15 | def get_tokenizer(self, adapter_idx: int, default: PreTrainedTokenizerBase) -> Optional[PreTrainedTokenizerBase]: 16 | return self.tokenizers.get(adapter_idx, default) 17 | 18 | def get_inputs( 19 | self, 20 | r: generate_pb2.Request, 21 | base_tokenizer: PreTrainedTokenizerBase, 22 | ) -> str: 23 | return r.inputs 24 | -------------------------------------------------------------------------------- /server/lorax_server/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def is_bf16_supported() -> bool: 5 | """Check if the current GPU supports bfloat16. 6 | 7 | Returns: 8 | True if supported, False otherwise. 9 | """ 10 | return torch.cuda.is_available() and torch.cuda.is_bf16_supported() 11 | 12 | 13 | def is_quantized(quantize): 14 | return quantize and quantize in ["gptq", "awq", "fp8", "fp8-kv"] 15 | 16 | 17 | def is_fp8_supported(): 18 | return ( 19 | torch.cuda.is_available() 20 | and (torch.cuda.get_device_capability()[0] >= 9) 21 | or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) 22 | ) 23 | 24 | 25 | def is_fp8_kv(quantize): 26 | return quantize and quantize == "fp8-kv" 27 | 28 | 29 | def is_fp8(quantize): 30 | return quantize and quantize.startswith("fp8") 31 | -------------------------------------------------------------------------------- /server/punica_kernels/README.md: -------------------------------------------------------------------------------- 1 | These kernels are forked from the [Punica](https://github.com/punica-ai/punica) project. 2 | 3 | Forked from commit: https://github.com/punica-ai/punica/commit/07a40b9d30e98d88963e8a7e140120a25ac0d518 4 | 5 | Modifications to BGMV kernel from vLLM: https://github.com/vllm-project/vllm -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/bgmv/bgmv_all.cu: -------------------------------------------------------------------------------- 1 | #include "bgmv_config.h" 2 | #include "bgmv_impl.cuh" 3 | 4 | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half) 5 | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/bgmv/bgmv_config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | template 4 | void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X, 5 | T **__restrict__ W, 6 | const int64_t *__restrict__ indicies, int64_t y_offset, 7 | int64_t full_y_size, int64_t batch_size, 8 | int64_t layer_idx, float scale); 9 | 10 | // clang-format off 11 | 12 | #define FOR_BGMV_WIDE(f, T, narrow) \ 13 | f(T, narrow, 256) \ 14 | f(T, narrow, 512) \ 15 | f(T, narrow, 640) \ 16 | f(T, narrow, 768) \ 17 | f(T, narrow, 1024) \ 18 | f(T, narrow, 1152) \ 19 | f(T, narrow, 1280) \ 20 | f(T, narrow, 1536) \ 21 | f(T, narrow, 1728) \ 22 | f(T, narrow, 1792) \ 23 | f(T, narrow, 2048) \ 24 | f(T, narrow, 2304) \ 25 | f(T, narrow, 2560) \ 26 | f(T, narrow, 2752) \ 27 | f(T, narrow, 2816) \ 28 | f(T, narrow, 3072) \ 29 | f(T, narrow, 3456) \ 30 | f(T, narrow, 3584) \ 31 | f(T, narrow, 4096) \ 32 | f(T, narrow, 4480) \ 33 | f(T, narrow, 4608) \ 34 | f(T, narrow, 5120) \ 35 | f(T, narrow, 5504) \ 36 | f(T, narrow, 5632) \ 37 | f(T, narrow, 6144) \ 38 | f(T, narrow, 6848) \ 39 | f(T, narrow, 6912) \ 40 | f(T, narrow, 7168) \ 41 | f(T, narrow, 7680) \ 42 | f(T, narrow, 8192) \ 43 | f(T, narrow, 8960) \ 44 | f(T, narrow, 9216) \ 45 | f(T, narrow, 9472) \ 46 | f(T, narrow, 10240) \ 47 | f(T, narrow, 11008) \ 48 | f(T, narrow, 12288) \ 49 | f(T, narrow, 13696) \ 50 | f(T, narrow, 13824) \ 51 | f(T, narrow, 14336) \ 52 | f(T, narrow, 15360) \ 53 | f(T, narrow, 16384) \ 54 | f(T, narrow, 17920) \ 55 | f(T, narrow, 18944) \ 56 | f(T, narrow, 20480) \ 57 | f(T, narrow, 22016) \ 58 | f(T, narrow, 24576) \ 59 | f(T, narrow, 27392) \ 60 | f(T, narrow, 27648) \ 61 | f(T, narrow, 28672) \ 62 | f(T, narrow, 32000) \ 63 | f(T, narrow, 32256) \ 64 | f(T, narrow, 32512) \ 65 | f(T, narrow, 32768) \ 66 | f(T, narrow, 33024) \ 67 | f(T, narrow, 35840) \ 68 | f(T, narrow, 36864) \ 69 | f(T, narrow, 43264) \ 70 | f(T, narrow, 49152) \ 71 | f(T, narrow, 64000) \ 72 | f(T, narrow, 64256) \ 73 | f(T, narrow, 64512) \ 74 | f(T, narrow, 102400) \ 75 | f(T, narrow, 102656) \ 76 | f(T, narrow, 102912) \ 77 | f(T, narrow, 128000) \ 78 | f(T, narrow, 128256) \ 79 | f(T, narrow, 128512) \ 80 | 81 | #define FOR_BGMV_WIDE_NARROW(f, T) \ 82 | FOR_BGMV_WIDE(f, T, 8) \ 83 | FOR_BGMV_WIDE(f, T, 16) \ 84 | FOR_BGMV_WIDE(f, T, 32) \ 85 | FOR_BGMV_WIDE(f, T, 64) \ 86 | FOR_BGMV_WIDE(f, T, 128) 87 | 88 | // clang-format on 89 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | template 5 | bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs, 6 | int32_t* kv_indptr, int32_t* last_page_offset, 7 | void* tmpbuf, int head_dim, int num_layers, 8 | int layer_idx, int group_size, 9 | int num_kv_heads, int page_size, 10 | int batch_size); 11 | 12 | template 13 | bool FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr, 14 | int32_t* last_page_offset, void* tmpbuf, 15 | int head_dim, int num_layers, int layer_idx, 16 | int group_size, int num_kv_heads, 17 | int page_size, int batch_size); 18 | 19 | template 20 | void FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr, 21 | int32_t* last_page_offset, T* key, T* value, 22 | int32_t* seqlen_indptr, int num_layers, 23 | int layer_idx, int num_kv_heads, int page_size, 24 | int batch_size); 25 | 26 | template 27 | void FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr, 28 | int32_t* last_page_offset, T* key, T* value, 29 | int num_layers, int layer_idx, int num_kv_heads, 30 | int page_size, int batch_size); 31 | 32 | // clang-format off 33 | 34 | #define FOR_FlashInferBatchDecode_D(f, ...) \ 35 | f(64, __VA_ARGS__) \ 36 | f(128, __VA_ARGS__) 37 | 38 | // clang-format on 39 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_decl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "flashinfer/page.cuh" 3 | #include "flashinfer/rope.cuh" 4 | 5 | namespace flashinfer { 6 | template 10 | cudaError_t BatchPrefillWithPagedKVCacheDispatched( 11 | DTypeIn* q, paged_kv_t paged_kv, 12 | IdType* qo_indptr, DTypeOut* o, float* tmp, uint32_t num_qo_heads, 13 | float rope_scale, float rope_theta, cudaStream_t stream); 14 | } 15 | 16 | #define INST_BatchPrefill(T, PAGE_SIZE, GROUP_SIZE, HEAD_DIM) \ 17 | namespace flashinfer { \ 18 | template cudaError_t BatchPrefillWithPagedKVCacheDispatched< \ 19 | PAGE_SIZE, GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, \ 20 | RotaryMode::kLlama, /* ALLOW_FP16_QK_REDUCTION= */ false, \ 21 | /* CAUSAL= */ true, T, T, int32_t>( \ 22 | T * q, paged_kv_t paged_kv, \ 23 | int32_t* qo_indptr, T* o, float* tmp, uint32_t num_qo_heads, \ 24 | float rope_scale, float rope_theta, cudaStream_t stream); \ 25 | } 26 | 27 | namespace flashinfer { 28 | template 31 | cudaError_t BatchDecodeWithPagedKVCacheDispatched( 32 | DTypeIn* q, paged_kv_t paged_kv, DTypeOut* o, 33 | float* tmp, float rope_scale, float rope_theta, cudaStream_t stream); 34 | } 35 | #define INST_BatchDecode(T, PAGE_SIZE, GROUP_SIZE, HEAD_DIM) \ 36 | namespace flashinfer { \ 37 | template cudaError_t BatchDecodeWithPagedKVCacheDispatched< \ 38 | PAGE_SIZE, GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, \ 39 | RotaryMode::kLlama, T, T, int32_t>( \ 40 | T * q, paged_kv_t paged_kv, T* o, \ 41 | float* tmp, float rope_scale, float rope_theta, cudaStream_t stream); \ 42 | } 43 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_decode_p16_g1_h128_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/decode.cuh" 4 | 5 | INST_BatchDecode(nv_bfloat16, 16, 1, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_decode_p16_g1_h128_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/decode.cuh" 4 | 5 | INST_BatchDecode(nv_half, 16, 1, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_decode_p16_g2_h128_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/decode.cuh" 4 | 5 | INST_BatchDecode(nv_bfloat16, 16, 2, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_decode_p16_g2_h128_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/decode.cuh" 4 | 5 | INST_BatchDecode(nv_half, 16, 2, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_decode_p16_g4_h128_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/decode.cuh" 4 | 5 | INST_BatchDecode(nv_bfloat16, 16, 4, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_decode_p16_g4_h128_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/decode.cuh" 4 | 5 | INST_BatchDecode(nv_half, 16, 4, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_decode_p16_g8_h128_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/decode.cuh" 4 | 5 | INST_BatchDecode(nv_bfloat16, 16, 8, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_decode_p16_g8_h128_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/decode.cuh" 4 | 5 | INST_BatchDecode(nv_half, 16, 8, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_prefill_p16_g1_h128_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/prefill.cuh" 4 | 5 | INST_BatchPrefill(nv_bfloat16, 16, 1, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_prefill_p16_g1_h128_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/prefill.cuh" 4 | 5 | INST_BatchPrefill(nv_half, 16, 1, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_prefill_p16_g2_h128_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/prefill.cuh" 4 | 5 | INST_BatchPrefill(nv_bfloat16, 16, 2, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_prefill_p16_g2_h128_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/prefill.cuh" 4 | 5 | INST_BatchPrefill(nv_half, 16, 2, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_prefill_p16_g4_h128_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/prefill.cuh" 4 | 5 | INST_BatchPrefill(nv_bfloat16, 16, 4, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_prefill_p16_g4_h128_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/prefill.cuh" 4 | 5 | INST_BatchPrefill(nv_half, 16, 4, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_prefill_p16_g8_h128_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/prefill.cuh" 4 | 5 | INST_BatchPrefill(nv_bfloat16, 16, 8, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/batch_prefill_p16_g8_h128_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "../flashinfer_decl.h" 2 | 3 | #include "flashinfer/prefill.cuh" 4 | 5 | INST_BatchPrefill(nv_half, 16, 8, 128) 6 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/flashinfer_adapter/generated/dispatch.inc: -------------------------------------------------------------------------------- 1 | #define _DISPATCH_CASES_page_size(...) \ 2 | _DISPATCH_CASE(16, PAGE_SIZE, __VA_ARGS__) \ 3 | // EOL 4 | #define _DISPATCH_CASES_group_size(...) \ 5 | _DISPATCH_CASE(1, GROUP_SIZE, __VA_ARGS__) \ 6 | _DISPATCH_CASE(2, GROUP_SIZE, __VA_ARGS__) \ 7 | _DISPATCH_CASE(4, GROUP_SIZE, __VA_ARGS__) \ 8 | _DISPATCH_CASE(8, GROUP_SIZE, __VA_ARGS__) \ 9 | // EOL 10 | #define _DISPATCH_CASES_head_dim(...) \ 11 | _DISPATCH_CASE(128, HEAD_DIM, __VA_ARGS__) \ 12 | // EOL 13 | 14 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/rms_norm/rms_norm.h: -------------------------------------------------------------------------------- 1 | template 2 | bool rms_norm(T *__restrict__ output, const T *__restrict__ input, 3 | const T *__restrict__ weight, int rows, int columns, 4 | float epsilon); 5 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/sgmv/sgmv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | 6 | template 7 | bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end, 8 | void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx, cudaStream_t stream); 9 | 10 | size_t sgmv_tmp_size(int num_problems); 11 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "sgmv_cutlass.cuh" 5 | 6 | template bool sgmv(nv_half *y, nv_half *x, nv_half **w, 7 | int32_t *s_start, int32_t *s_end, 8 | void *tmp_d, int num_problems, int d_in, int d_out, 9 | int layer_idx, cudaStream_t stream); 10 | 11 | template bool sgmv(nv_bfloat16 *y, nv_bfloat16 *x, nv_bfloat16 **w, 12 | int32_t *s_start, int32_t *s_end, 13 | void *tmp_d, int num_problems, int d_in, int d_out, 14 | int layer_idx, cudaStream_t stream); 15 | -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | #include "sgmv_config.h" 8 | #include "sgmv_flashinfer.cuh" 9 | 10 | template 11 | bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp, 12 | uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, cudaStream_t stream) { 13 | static_assert(d_out % 16 == 0); 14 | 15 | constexpr uint32_t num_warps = 4; 16 | constexpr uint32_t num_stages = 2; 17 | constexpr uint32_t num_k_frags_per_stage = 8; 18 | constexpr uint32_t num_blocks_n = d_out / 16; 19 | uint32_t smem = num_stages * sizeof(T) * num_k_frags_per_stage * 16 * 16 * 20 | (num_warps + num_blocks_n); 21 | auto cooperative_kernel = 22 | flashinfer::sgmv::sgmv_shrink; 23 | auto kernel = flashinfer::sgmv::sgmv_shrink; 24 | 25 | int dev_id = 0; 26 | int num_blocks_per_sm = 0; 27 | int num_sm = 0; 28 | bool use_cooperative = true; 29 | cudaGetDevice(&dev_id); 30 | cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id); 31 | cudaOccupancyMaxActiveBlocksPerMultiprocessor( 32 | &num_blocks_per_sm, cooperative_kernel, num_warps * 32, smem); 33 | 34 | const uint32_t max_grid_size = num_sm * num_blocks_per_sm; 35 | 36 | uint32_t chunk_size = 256; 37 | uint32_t num_chunks = (d_in + chunk_size - 1) / chunk_size; 38 | if (num_chunks * num_problems > max_grid_size) { 39 | use_cooperative = false; 40 | chunk_size = d_in; 41 | num_chunks = 1; 42 | } 43 | 44 | dim3 nthrs(32, num_warps); 45 | dim3 nblks(num_chunks, num_problems); 46 | 47 | void* args[] = {(void*)&y, (void*)&x, (void*)&w, 48 | (void*)&s_start, (void*)&s_end, (void*)&tmp, (void*)&num_problems, 49 | (void*)&d_in, (void*)&layer_idx, (void*)&chunk_size}; 50 | 51 | cudaError_t status; 52 | if (use_cooperative) { 53 | if (smem > 46 * 1024) { 54 | cudaFuncSetAttribute(cooperative_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); 55 | } 56 | status = cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks, 57 | nthrs, args, smem, stream); 58 | } else { 59 | if (smem > 46 * 1024) { 60 | cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); 61 | } 62 | status = cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem, stream); 63 | } 64 | return status == cudaSuccess; 65 | } 66 | 67 | #define INST(T, d_out) \ 68 | template bool sgmv_shrink(T * y, T * x, T * *w, int32_t * s_start, int32_t * s_end, \ 69 | void* tmp, uint32_t num_problems, \ 70 | uint32_t d_in, uint32_t layer_idx, cudaStream_t stream); 71 | 72 | FOR_SGMV_NARROW(INST, nv_half); 73 | FOR_SGMV_NARROW(INST, nv_bfloat16); -------------------------------------------------------------------------------- /server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | template 5 | bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp, 6 | uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, cudaStream_t stream); 7 | 8 | // clang-format off 9 | 10 | #define FOR_SGMV_NARROW(f, T) \ 11 | f(T, 16) \ 12 | f(T, 32) \ 13 | f(T, 64) \ 14 | f(T, 96) \ 15 | f(T, 128) 16 | 17 | // clang-format on 18 | -------------------------------------------------------------------------------- /server/tests/adapters/test_medusa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lorax_server.adapters.medusa import BatchMedusaWeights, MedusaConfig 4 | from lorax_server.adapters.utils import download_adapter_weights 5 | from lorax_server.adapters.weights import AdapterBatchMetadata 6 | from lorax_server.models.causal_lm import CausalLM 7 | from lorax_server.utils.adapter import load_module_map 8 | from lorax_server.utils.lora import LM_HEAD 9 | from lorax_server.utils.sources import HUB 10 | 11 | model_id = "mistralai/Mistral-7B-Instruct-v0.2" 12 | adapter_id = "predibase/Mistral-7B-Instruct-v0.2-medusa" 13 | 14 | 15 | def test_batched_medusa_weights(default_causal_lm: CausalLM): 16 | download_adapter_weights(adapter_id, HUB) 17 | 18 | module_map, medusa_config, _, _ = load_module_map( 19 | model_id, adapter_id, HUB, tuple(), None 20 | ) 21 | assert isinstance(medusa_config, MedusaConfig) 22 | 23 | medusa_weights = medusa_config.load_batched_adapter_weights( 24 | default_causal_lm, 25 | module_map, 26 | LM_HEAD, 27 | set(), 28 | False, 29 | ) 30 | 31 | meta = AdapterBatchMetadata( 32 | adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), 33 | adapter_list=[0, 1, 0, 1], 34 | adapter_set={0, 1}, 35 | adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), 36 | segment_indices=[0, 1, 0, 1], 37 | ) 38 | 39 | batch_medusa_weights = BatchMedusaWeights.load( 40 | { 41 | 0: medusa_weights, 42 | 1: medusa_weights, 43 | }, 44 | meta, 45 | layer_name=LM_HEAD, 46 | prefill=False, 47 | prefill_head_indices=None, 48 | ) 49 | 50 | assert batch_medusa_weights is not None 51 | assert batch_medusa_weights.default_medusa == medusa_weights 52 | -------------------------------------------------------------------------------- /server/tests/adapters/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from huggingface_hub.utils import RepositoryNotFoundError 5 | 6 | from lorax_server.adapters.utils import download_adapter_weights 7 | from lorax_server.utils.sources import HUB 8 | 9 | 10 | def test_download_private_adapter_hf(): 11 | # store and unset HUGGING_FACE_HUB_TOKEN from the environment 12 | token = os.environ.pop("HUGGING_FACE_HUB_TOKEN", None) 13 | assert token is not None, "HUGGING_FACE_HUB_TOKEN must be set in the environment to run this test" 14 | 15 | # verify download fails without the token set 16 | with pytest.raises(RepositoryNotFoundError): 17 | download_adapter_weights("predibase/test-private-lora", HUB, api_token=None) 18 | 19 | # pass in the token and verify download succeeds 20 | download_adapter_weights("predibase/test-private-lora", HUB, api_token=token) 21 | 22 | # set the token back in the environment 23 | os.environ["HUGGING_FACE_HUB_TOKEN"] = token 24 | -------------------------------------------------------------------------------- /server/tests/models/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from transformers import AutoTokenizer 4 | 5 | from lorax_server.models.model import Model 6 | 7 | 8 | def get_test_model(): 9 | class TestModel(Model): 10 | def batch_type(self): 11 | raise NotImplementedError 12 | 13 | def generate_token(self, batch): 14 | raise NotImplementedError 15 | 16 | model_id = "meta-llama/Llama-2-7b-hf" 17 | tokenizer = AutoTokenizer.from_pretrained(model_id) 18 | 19 | model = TestModel(model_id, torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")) 20 | return model 21 | 22 | 23 | @pytest.mark.private 24 | def test_decode_streaming_english_spaces(): 25 | model = get_test_model() 26 | truth = "Hello here, this is a simple test" 27 | all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243] 28 | assert all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"] 29 | 30 | decoded_text = "" 31 | offset = 0 32 | token_offset = 0 33 | for i in range(len(all_input_ids)): 34 | text, offset, token_offset = model.decode_token(all_input_ids[: i + 1], offset, token_offset) 35 | decoded_text += text 36 | 37 | assert decoded_text == truth 38 | 39 | 40 | @pytest.mark.private 41 | def test_decode_streaming_chinese_utf8(): 42 | model = get_test_model() 43 | truth = "我很感谢你的热情" 44 | all_input_ids = [ 45 | 30672, 46 | 232, 47 | 193, 48 | 139, 49 | 233, 50 | 135, 51 | 162, 52 | 235, 53 | 179, 54 | 165, 55 | 30919, 56 | 30210, 57 | 234, 58 | 134, 59 | 176, 60 | 30993, 61 | ] 62 | 63 | decoded_text = "" 64 | offset = 0 65 | token_offset = 0 66 | for i in range(len(all_input_ids)): 67 | text, offset, token_offset = model.decode_token(all_input_ids[: i + 1], offset, token_offset) 68 | decoded_text += text 69 | 70 | assert decoded_text == truth 71 | -------------------------------------------------------------------------------- /server/tests/utils/test_convert.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from lorax_server.utils.convert import convert_files 7 | from lorax_server.utils.errors import NanWeightsError 8 | from lorax_server.utils.sources.hub import ( 9 | download_weights, 10 | weight_files, 11 | weight_hub_files, 12 | ) 13 | 14 | 15 | def test_convert_files(): 16 | model_id = "bigscience/bloom-560m" 17 | pt_filenames = weight_hub_files(model_id, extension=".bin") 18 | local_pt_files = download_weights(pt_filenames, model_id) 19 | local_st_files = [ 20 | p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files 21 | ] 22 | convert_files(local_pt_files, local_st_files, discard_names=[]) 23 | 24 | found_st_files = weight_files(model_id) 25 | 26 | assert all([p in found_st_files for p in local_st_files]) 27 | 28 | 29 | def test_convert_files_nan_error(tmpdir): 30 | model_id = "bigscience/bloom-560m" 31 | pt_filenames = weight_hub_files(model_id, extension=".bin") 32 | local_pt_files = download_weights(pt_filenames, model_id) 33 | local_st_files = [ 34 | p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files 35 | ] 36 | 37 | # Introduce NaN to the first tensor in the first file 38 | pt_file = local_pt_files[0] 39 | with open(pt_file, "rb") as f: 40 | state_dict = torch.load(f, map_location="cpu") 41 | state_dict[list(state_dict.keys())[0]].fill_(float("nan")) 42 | 43 | # Write the corrupted state to a new temporary file 44 | pt_file = Path(tmpdir) / pt_file.name 45 | with open(pt_file, "wb") as f: 46 | torch.save(state_dict, f) 47 | 48 | # Replace the first file with the corrupted file 49 | local_pt_files[0] = pt_file 50 | 51 | with pytest.raises(NanWeightsError): 52 | convert_files(local_pt_files, local_st_files, discard_names=[]) 53 | -------------------------------------------------------------------------------- /server/tests/utils/test_hub.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from huggingface_hub.utils import ( 3 | EntryNotFoundError, 4 | LocalEntryNotFoundError, 5 | RevisionNotFoundError, 6 | ) 7 | 8 | from lorax_server.utils.sources.hub import ( 9 | download_weights, 10 | weight_files, 11 | weight_hub_files, 12 | ) 13 | 14 | 15 | def test_weight_hub_files(): 16 | filenames = weight_hub_files("bigscience/bloom-560m") 17 | assert filenames == ["model.safetensors"] 18 | 19 | 20 | def test_weight_hub_files_llm(): 21 | filenames = weight_hub_files("bigscience/bloom") 22 | assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)] 23 | 24 | 25 | def test_weight_hub_files_empty(): 26 | with pytest.raises(EntryNotFoundError): 27 | weight_hub_files("bigscience/bloom", extension=".errors") 28 | 29 | 30 | def test_download_weights(): 31 | model_id = "bigscience/bloom-560m" 32 | filenames = weight_hub_files(model_id) 33 | files = download_weights(filenames, model_id) 34 | local_files = weight_files("bigscience/bloom-560m") 35 | assert files == local_files 36 | 37 | 38 | def test_weight_files_error(): 39 | with pytest.raises(RevisionNotFoundError): 40 | weight_files("bigscience/bloom-560m", revision="error") 41 | with pytest.raises(LocalEntryNotFoundError): 42 | weight_files("bert-base-uncased") 43 | -------------------------------------------------------------------------------- /server/tests/utils/test_logits_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | from pydantic import BaseModel, constr 5 | from transformers import AutoTokenizer 6 | 7 | from lorax_server.utils.logits_process import OutlinesLogitsProcessor 8 | 9 | 10 | class Person(BaseModel): 11 | name: constr(max_length=10) 12 | age: int 13 | 14 | 15 | def test_outlines_process(): 16 | torch.manual_seed(42) 17 | 18 | schema = json.dumps(Person.schema()) 19 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 20 | 21 | logit_processor = OutlinesLogitsProcessor(schema, tokenizer) 22 | 23 | B = 1 24 | V = tokenizer.vocab_size 25 | 26 | generated_tokens = [] 27 | max_steps = 1000 28 | for step in range(max_steps): 29 | scores = torch.randn(B, V) 30 | biased_scores = logit_processor(scores) 31 | next_token_id = biased_scores.argmax(dim=-1).item() 32 | if next_token_id == tokenizer.eos_token_id: 33 | break 34 | 35 | generated_tokens.append(next_token_id) 36 | 37 | logit_processor.next_state(next_token_id) 38 | if logit_processor.fsm_state == -1: 39 | break 40 | 41 | if step == max_steps - 1: 42 | raise RuntimeError("Max steps reached") 43 | 44 | text = tokenizer.decode(generated_tokens) 45 | try: 46 | decoded_json = json.loads(text) 47 | except json.JSONDecodeError: 48 | raise RuntimeError(f"Failed to decode JSON: {text}") 49 | 50 | Person(**decoded_json) 51 | -------------------------------------------------------------------------------- /server/tests/utils/test_s3.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | from typing import Optional 4 | 5 | import pytest 6 | 7 | from lorax_server.utils.sources.s3 import _get_bucket_and_model_id 8 | 9 | 10 | @contextlib.contextmanager 11 | def with_env_var(key: str, value: Optional[str]): 12 | if value is None: 13 | yield 14 | return 15 | 16 | prev = os.environ.get(key) 17 | try: 18 | os.environ[key] = value 19 | yield 20 | finally: 21 | if prev is None: 22 | del os.environ[key] 23 | else: 24 | os.environ[key] = prev 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "s3_path, env_var, expected_bucket, expected_model_id", 29 | [ 30 | ("s3://loras/foobar", None, "loras", "foobar"), 31 | ("s3://loras/foo/bar", None, "loras", "foo/bar"), 32 | ("s3://loras/foo/bar", "bucket", "loras", "foo/bar"), 33 | ("loras/foobar", None, "loras", "foobar"), 34 | ("loras/foo/bar", None, "loras", "foo/bar"), 35 | ("loras/foo/bar", "bucket", "bucket", "loras/foo/bar"), 36 | ] 37 | ) 38 | def test_get_bucket_and_model_id( 39 | s3_path: str, 40 | env_var: Optional[str], 41 | expected_bucket: str, 42 | expected_model_id: str, 43 | ): 44 | with with_env_var("PREDIBASE_MODEL_BUCKET", env_var): 45 | bucket, model_id = _get_bucket_and_model_id(s3_path) 46 | assert bucket == expected_bucket 47 | assert model_id == expected_model_id 48 | -------------------------------------------------------------------------------- /server/tests/utils/test_segments.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lorax_server.utils.segments import SegmentConcatBuilder, find_segments 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "adapter_indices,expected_segments,expected_segment_indices", 9 | [ 10 | ( 11 | torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 1, 1]), 12 | [0, 3, 5, 10, 12], 13 | [0, 1, 2, 1], 14 | ), 15 | (torch.tensor([]), [0], []), 16 | (torch.tensor([0]), [0, 1], [0]), 17 | (torch.tensor([1]), [0, 1], [1]), 18 | ], 19 | ) 20 | def test_find_segments(adapter_indices, expected_segments, expected_segment_indices): 21 | segments, segment_indices = find_segments(adapter_indices) 22 | assert segments == expected_segments 23 | assert segment_indices == expected_segment_indices 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "batches,expected_segments,expected_segment_indices", 28 | [ 29 | ( 30 | [ 31 | (torch.tensor([0, 1, 4, 7, 8]), [2, 1, 2, 1]), 32 | (torch.tensor([0, 2, 5]), [1, 2]), 33 | ], 34 | [0, 1, 4, 7, 10, 13], 35 | [2, 1, 2, 1, 2], 36 | ), 37 | ( 38 | [ 39 | (torch.tensor([0, 1, 4, 7]), [2, 1, 2]), 40 | (torch.tensor([0, 2, 5]), [1, 2]), 41 | ], 42 | [0, 1, 4, 7, 9, 12], 43 | [2, 1, 2, 1, 2], 44 | ), 45 | ], 46 | ) 47 | def test_concat_segments(batches, expected_segments, expected_segment_indices): 48 | builder = SegmentConcatBuilder() 49 | for segment, indices in batches: 50 | builder.concat(segment, indices) 51 | 52 | segments, segment_indices = builder.build() 53 | assert segments.tolist() == expected_segments 54 | assert segment_indices == expected_segment_indices 55 | -------------------------------------------------------------------------------- /server/tests/utils/test_watermark.py: -------------------------------------------------------------------------------- 1 | # test_watermark_logits_processor.py 2 | 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from lorax_server.utils.watermark import WatermarkLogitsProcessor 9 | 10 | GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) 11 | DELTA = os.getenv("WATERMARK_DELTA", 2.0) 12 | 13 | 14 | def test_seed_rng(): 15 | input_ids = [101, 2036, 3731, 102, 2003, 103] 16 | processor = WatermarkLogitsProcessor() 17 | processor._seed_rng(input_ids) 18 | assert isinstance(processor.rng, torch.Generator) 19 | 20 | 21 | def test_get_greenlist_ids(): 22 | input_ids = [101, 2036, 3731, 102, 2003, 103] 23 | processor = WatermarkLogitsProcessor() 24 | result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu")) 25 | assert max(result) <= 10 26 | assert len(result) == int(10 * 0.5) 27 | 28 | 29 | def test_calc_greenlist_mask(): 30 | processor = WatermarkLogitsProcessor() 31 | scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) 32 | greenlist_token_ids = torch.tensor([2, 3]) 33 | result = processor._calc_greenlist_mask(scores, greenlist_token_ids) 34 | assert result.tolist() == [[False, False, False, False], [False, False, True, True]] 35 | assert result.shape == scores.shape 36 | 37 | 38 | def test_bias_greenlist_logits(): 39 | processor = WatermarkLogitsProcessor() 40 | scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) 41 | green_tokens_mask = torch.tensor( 42 | [[False, False, True, True], [False, False, False, True]] 43 | ) 44 | greenlist_bias = 2.0 45 | result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias) 46 | assert np.allclose(result.tolist(), [[0.5, 0.3, 2.2, 2.8], [0.1, 0.2, 0.7, 2.9]]) 47 | assert result.shape == scores.shape 48 | 49 | 50 | def test_call(): 51 | input_ids = [101, 2036, 3731, 102, 2003, 103] 52 | processor = WatermarkLogitsProcessor() 53 | scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) 54 | result = processor(input_ids, scores) 55 | assert result.shape == scores.shape 56 | -------------------------------------------------------------------------------- /server/tests/utils/test_weights.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from transformers.models.qwen2 import Qwen2Config 4 | 5 | from lorax_server.utils.dist import initialize_torch_distributed 6 | from lorax_server.utils.sources.hub import ( 7 | download_weights, 8 | weight_hub_files, 9 | ) 10 | from lorax_server.utils.weights import Weights 11 | 12 | 13 | @pytest.mark.parametrize( 14 | 'model_id', [ 15 | 'neuralmagic/Qwen2-0.5B-Instruct-FP8', 16 | 'Qwen/Qwen2-0.5B-Instruct' 17 | ] 18 | ) 19 | @pytest.mark.parametrize( 20 | 'prefixes', [ 21 | ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj'], 22 | ['mlp.gate_proj', 'mlp.up_proj'] 23 | ] 24 | ) 25 | def test_get_multi_weights_col(model_id, prefixes): 26 | process_group, _, _ = initialize_torch_distributed() 27 | filenames = weight_hub_files(model_id, 'main', '.safetensors') 28 | local_filenames = download_weights(filenames, model_id, 'main') 29 | config = Qwen2Config.from_pretrained(model_id, revision='main', trust_remote_code=False) 30 | quantize = None 31 | if hasattr(config, 'quantization_config'): 32 | quantize = config.quantization_config['quant_method'] 33 | 34 | weights = Weights(local_filenames, 'cpu', torch.bfloat16, process_group=process_group) 35 | prefix = 'model.layers.0' 36 | prefixes = [f'{prefix}.{k}' for k in prefixes] 37 | weight = weights.get_multi_weights_col( 38 | prefixes=prefixes, 39 | quantize=quantize, 40 | dim=0, 41 | ) 42 | if quantize is not None: 43 | assert type(weight) is tuple 44 | weight, input_scale, weight_scale = weight 45 | assert weight.dtype == torch.float8_e4m3fn 46 | assert input_scale.dtype == torch.float 47 | assert weight_scale.dtype == torch.float 48 | else: 49 | assert weight.dtype == torch.bfloat16 50 | 51 | @pytest.mark.parametrize( 52 | 'model_id', [ 53 | 'neuralmagic/Qwen2-0.5B-Instruct-FP8', 54 | 'Qwen/Qwen2-0.5B-Instruct' 55 | ] 56 | ) 57 | @pytest.mark.parametrize( 58 | 'prefix', ['self_attn.o_proj', 'mlp.down_proj'], 59 | ) 60 | def test_get_multi_weights_row(model_id, prefix): 61 | process_group, _, _ = initialize_torch_distributed() 62 | filenames = weight_hub_files(model_id, 'main', '.safetensors') 63 | local_filenames = download_weights(filenames, model_id, 'main') 64 | config = Qwen2Config.from_pretrained(model_id, revision='main', trust_remote_code=False) 65 | quantize = None 66 | if hasattr(config, 'quantization_config'): 67 | quantize = config.quantization_config['quant_method'] 68 | 69 | weights = Weights(local_filenames, 'cpu', torch.bfloat16, process_group=process_group) 70 | weight = weights.get_multi_weights_row(f'model.layers.0.{prefix}', quantize=quantize) 71 | if quantize is not None: 72 | assert type(weight) is tuple 73 | weight, input_scale, weight_scale = weight 74 | assert weight.dtype == torch.float8_e4m3fn 75 | assert input_scale.dtype == torch.float 76 | assert weight_scale.dtype == torch.float 77 | else: 78 | assert weight.dtype == torch.bfloat16 79 | -------------------------------------------------------------------------------- /tests/create-pod.sh: -------------------------------------------------------------------------------- 1 | IMAGE_NAME="$1" 2 | 3 | runpodctl create pods \ 4 | --name lorax-tests-new \ 5 | --gpuType "NVIDIA A40" \ 6 | --imageName "$IMAGE_NAME" \ 7 | --containerDiskSize 100 \ 8 | --volumeSize 100 \ 9 | --ports "8080/http" \ 10 | --args "--port 8080 --model-id predibase/Mistral-7B-v0.1-dequantized --adapter-source hub --default-adapter-source pbase --max-batch-prefill-tokens 32768 --max-total-tokens 8192 --max-input-length 8191 --max-concurrent-requests 1024" | awk '{print $2}' > pod_name.txt -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | from lorax import Client 2 | import sys 3 | 4 | pod_id = sys.argv[1] 5 | 6 | client = Client(f"https://{pod_id}-8080.proxy.runpod.net") 7 | 8 | response = client.generate("hello!", max_new_tokens=10) 9 | print(response) --------------------------------------------------------------------------------